TravelContentCreator/utils/content_generator.py

570 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import time
import logging
import random
import traceback
import simplejson as json
from datetime import datetime
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from core.ai_agent import AI_Agent
class ContentGenerator:
"""
海报文本内容生成器
使用AI_Agent代替直接管理OpenAI客户端简化代码结构
"""
def __init__(self,
output_dir="/root/autodl-tmp/poster_generate_result",
model_name="qwenQWQ",
base_url="http://localhost:8000/v1",
api_key="EMPTY",
temperature=0.7,
top_p=0.8,
presence_penalty=1.2):
"""
初始化内容生成器
参数:
output_dir: 输出结果保存目录
temperature: 生成温度参数
top_p: top_p参数
presence_penalty: 惩罚参数
"""
self.output_dir = output_dir
self.temperature = temperature
self.top_p = top_p
self.presence_penalty = presence_penalty
self.add_description = ""
self.model_name = model_name
self.base_url = base_url
self.api_key = api_key
# 设置日志
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
self.logger = logging.getLogger(__name__)
def load_infomation(self, info_directory_path):
"""
加载额外描述文件
参数:
info_directory_path: 信息目录路径列表
"""
self.add_description = "" # 重置描述文本
for path in info_directory_path:
try:
with open(path, "r", encoding="utf-8") as f:
self.add_description += f.read()
self.logger.info(f"成功加载描述文件: {path}")
except Exception as e:
self.logger.warning(f"加载描述文件失败: {path}, 错误: {e}")
self.add_description = ""
def split_content(self, content):
"""
分割结果, 返回去除
```json
```的json内容
参数:
content: 需要分割的内容
返回:
分割后的json内容
"""
try:
# 记录原始内容的前200个字符用于调试
self.logger.debug(f"解析内容原始内容前200字符: {content[:200]}")
# 首先尝试直接解析整个内容,以防已经是干净的 JSON
try:
parsed_data = json.loads(content)
# 验证解析后的数据格式
if isinstance(parsed_data, list):
# 如果是列表,验证每个元素是否符合预期结构
for item in parsed_data:
if isinstance(item, dict) and ('main_title' in item or 'texts' in item):
# 至少有一个元素符合海报配置结构
self.logger.info("成功直接解析为JSON格式列表符合预期结构")
return parsed_data
# 如果到这里,说明列表内没有符合结构的元素
if len(parsed_data) > 0 and isinstance(parsed_data[0], str):
self.logger.warning(f"解析到JSON列表但内容是字符串列表: {parsed_data}")
# 将字符串列表返回供后续修复
return parsed_data
self.logger.warning("解析到JSON列表但结构不符合预期")
elif isinstance(parsed_data, dict) and ('main_title' in parsed_data or 'texts' in parsed_data):
# 单个字典结构符合预期
self.logger.info("成功直接解析为JSON字典符合预期结构")
return parsed_data
# 如果结构不符合预期,记录但仍返回解析结果,交给后续函数修复
self.logger.warning(f"解析到JSON但结构不完全符合预期: {parsed_data}")
return parsed_data
except json.JSONDecodeError:
# 不是完整有效的JSON继续尝试提取
self.logger.debug("直接JSON解析失败尝试提取结构化内容")
# 常规模式:查找 ```json 和 ``` 之间的内容
if "```json" in content:
json_str = content.split("```json")[1].split("```")[0].strip()
try:
parsed_json = json.loads(json_str)
self.logger.info("成功从```json```代码块提取JSON")
return parsed_json
except json.JSONDecodeError as e:
self.logger.warning(f"从```json```提取的内容解析失败: {e}, 尝试其他方法")
# 备用模式1查找连续的 [ 开头和 ] 结尾的部分
import re
json_pattern = r'(\[(?:\s*\{.*?\}\s*,?)+\s*\])' # 更严格的模式,要求[]内至少有一个{}对象
json_matches = re.findall(json_pattern, content, re.DOTALL)
for match in json_matches:
try:
result = json.loads(match)
if isinstance(result, list) and len(result) > 0:
# 验证结构
for item in result:
if isinstance(item, dict) and ('main_title' in item or 'texts' in item):
self.logger.info("成功从正则表达式提取JSON数组")
return result
self.logger.warning("从正则表达式提取的JSON数组不符合预期结构")
except Exception as e:
self.logger.warning(f"解析正则匹配的内容失败: {e}")
continue
# 备用模式2查找 [ 开头 和 ] 结尾,并尝试解析
content = content.strip()
square_bracket_start = content.find('[')
square_bracket_end = content.rfind(']')
if square_bracket_start != -1 and square_bracket_end != -1 and square_bracket_end > square_bracket_start:
potential_json = content[square_bracket_start:square_bracket_end + 1]
try:
result = json.loads(potential_json)
if isinstance(result, list):
# 检查列表内容
self.logger.info(f"成功从方括号内容提取列表: {result}")
return result
except Exception as e:
self.logger.warning(f"尝试提取方括号内容失败: {e}")
# 最后一种尝试:查找所有可能的 JSON 结构并尝试解析
json_structures = re.findall(r'({.*?})', content, re.DOTALL)
if json_structures:
items = []
for i, struct in enumerate(json_structures):
try:
item = json.loads(struct)
# 验证结构包含预期字段
if isinstance(item, dict) and ('main_title' in item or 'texts' in item):
items.append(item)
except Exception as e:
self.logger.warning(f"解析可能的JSON结构 {i+1} 失败: {e}")
continue
if items:
self.logger.info(f"成功从文本中提取 {len(items)} 个JSON对象")
return items
# 如果以上所有方法都失败,尝试简单字符串处理
if "|" in content or "必打卡" in content or "性价比" in content:
# 这可能是一个简单的标签字符串
self.logger.warning(f"无法提取标准JSON但发现可能的标签字符串: {content}")
return content.strip()
# 都失败了,打印错误并引发异常
self.logger.error(f"无法解析内容,返回原始文本: {content[:200]}...")
raise ValueError("无法从响应中提取有效的 JSON 格式")
except Exception as e:
self.logger.error(f"解析内容时出错: {e}")
self.logger.debug(f"原始内容: {content[:200]}...") # 仅显示前200个字符
return content.strip() # 返回原始内容,让后续验证函数处理
def _preprocess_for_json(self, text):
"""预处理文本,将换行符转换为\\n形式保证JSON安全"""
if not isinstance(text, str):
return text
# 将所有实际换行符替换为\\n字符串
return text.replace('\n', '\\n').replace('\r', '\\r')
def generate_posters(self, poster_num, content_data_list, system_prompt=None,
api_url=None, model_name=None, api_key=None, timeout=120, max_retries=3):
"""
生成海报配置
Args:
poster_num: 生成的海报数量
content_data_list: 内容数据列表
system_prompt: 系统提示词(可选)
api_url: API基础URL可选
model_name: 模型名称(可选)
api_key: API密钥可选
Returns:
str: 生成的配置JSON字符串
"""
# 更新API设置
if api_url:
self.base_url = api_url
if model_name:
self.model_name = model_name
if api_key:
self.api_key = api_key
# 使用系统提示或默认提示
if system_prompt:
self.system_prompt = system_prompt
elif not self.system_prompt:
self.system_prompt = """你是一名专业的旅游景点海报文案创作专家。你的任务是根据提供的旅游景点信息和推文内容生成海报文案配置。你的回复必须是一个JSON数组每一项表示一个海报配置包含'index''main_title''texts'三个字段,其中'texts'是一个字符串数组。海报文案要简洁有力,突出景点特色和吸引力。"""
# 提取内容文本(如果是列表内容数据)
tweet_content = ""
if isinstance(content_data_list, list):
for item in content_data_list:
if isinstance(item, dict):
# 对标题和内容进行预处理,替换换行符
title = self._preprocess_for_json(item.get('title', ''))
content = self._preprocess_for_json(item.get('content', ''))
tweet_content += f"<title>\n{title}\n</title>\n<content>\n{content}\n</content>\n\n"
elif isinstance(item, str):
tweet_content += self._preprocess_for_json(item) + "\n\n"
elif isinstance(content_data_list, str):
tweet_content = self._preprocess_for_json(content_data_list)
# 构建用户提示
if self.add_description:
# 预处理景点描述
processed_description = self._preprocess_for_json(self.add_description)
user_content = f"""
以下是需要你处理的信息:
关于景点的描述:
{processed_description}
推文内容:
{tweet_content}
请根据这些信息,生成{poster_num}个海报文案配置以JSON数组格式返回。
"""
else:
user_content = f"""
以下是需要你处理的推文内容:
{tweet_content}
请根据这些信息,生成{poster_num}个海报文案配置以JSON数组格式返回。
"""
self.logger.info(f"正在生成{poster_num}个海报文案配置")
# 创建AI_Agent实例
ai_agent = AI_Agent(
self.base_url,
self.model_name,
self.api_key,
timeout=timeout,
max_retries=max_retries,
stream_chunk_timeout=30 # 流式块超时时间
)
full_response = ""
try:
# 使用AI_Agent的non-streaming方法
self.logger.info(f"调用AI生成海报配置模型: {self.model_name}")
full_response, tokens, time_cost = ai_agent.work(
self.system_prompt,
user_content,
"", # 历史消息(空)
self.temperature,
self.top_p,
self.presence_penalty
)
self.logger.info(f"AI生成完成耗时: {time_cost:.2f}s, 预估令牌数: {tokens}")
if not full_response:
self.logger.warning("AI返回空响应使用备用内容")
full_response = self._generate_fallback_content(poster_num)
except Exception as e:
self.logger.exception(f"AI生成过程发生错误: {e}")
full_response = self._generate_fallback_content(poster_num)
finally:
# 确保关闭AI Agent
ai_agent.close()
return full_response
def _generate_fallback_content(self, poster_num):
"""生成备用内容当API调用失败时使用"""
self.logger.info("生成备用内容")
default_configs = []
for i in range(poster_num):
default_configs.append({
"index": i + 1,
"main_title": "",
"texts": ["", ""]
})
return json.dumps(default_configs, ensure_ascii=False)
def save_result(self, full_response, custom_output_dir=None):
"""
保存生成结果到文件
参数:
full_response: 生成的完整响应内容
custom_output_dir: 自定义输出目录(可选)
返回:
结果文件路径
"""
# 生成时间戳
date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
output_dir = custom_output_dir or self.output_dir
try:
# 解析内容为JSON格式
parsed_data = self.split_content(full_response)
# 验证内容格式并修复
validated_data = self._validate_and_fix_data(parsed_data)
# 创建结果文件路径
result_path = os.path.join(output_dir, f"{date_time}.json")
os.makedirs(os.path.dirname(result_path), exist_ok=True)
# 保存结果到文件
with open(result_path, "w", encoding="utf-8") as f:
json.dump(validated_data, f, ensure_ascii=False, indent=4, ignore_nan=True)
self.logger.info(f"结果已保存到: {result_path}")
return result_path
except Exception as e:
self.logger.error(f"保存结果到文件时出错: {e}")
# 尝试创建一个简单的备用配置
fallback_data = [{"main_title": "", "texts": ["", ""], "index": 1}]
# 保存备用数据
result_path = os.path.join(output_dir, f"{date_time}_fallback.json")
os.makedirs(os.path.dirname(result_path), exist_ok=True)
with open(result_path, "w", encoding="utf-8") as f:
json.dump(fallback_data, f, ensure_ascii=False, indent=4, ignore_nan=True)
self.logger.info(f"出错后已保存备用数据到: {result_path}")
return result_path
def _validate_and_fix_data(self, data):
"""
验证并修复从AI返回的数据确保其符合期望的结构
Args:
data: 需要验证的数据
Returns:
list: 修复后的数据列表
"""
fixed_data = []
self.logger.info(f"验证并修复数据: {type(data)}")
# 尝试处理字符串类型 (通常是JSON字符串)
if isinstance(data, str):
try:
# 尝试将字符串解析为JSON对象
parsed_data = json.loads(data)
# 递归调用本函数处理解析后的数据
return self._validate_and_fix_data(parsed_data)
except json.JSONDecodeError as e:
self.logger.warning(f"JSON解析失败: {e}")
# 可以选择尝试清理和再次解析
try:
# 寻找字符串中第一个 [ 和最后一个 ] 之间的内容
start_idx = data.find('[')
end_idx = data.rfind(']')
if start_idx >= 0 and end_idx > start_idx:
json_part = data[start_idx:end_idx+1]
self.logger.info(f"尝试从字符串中提取JSON部分: {json_part[:100]}...")
parsed_data = json.loads(json_part)
return self._validate_and_fix_data(parsed_data)
except:
self.logger.warning("无法从字符串中提取有效的JSON部分")
fixed_data.append({
"index": 1,
"main_title": self._preprocess_for_json("默认标题"), # 应用预处理
"texts": [self._preprocess_for_json("默认副标题1"), self._preprocess_for_json("默认副标题2")] # 应用预处理
})
# 处理列表类型
elif isinstance(data, list):
for idx, item in enumerate(data):
# 如果是字典,检查必须字段
if isinstance(item, dict):
fixed_item = {}
# 设置索引
fixed_item["index"] = item.get("index", idx + 1)
# 处理主标题
if "main_title" in item and item["main_title"]:
# 应用预处理,确保所有换行符被正确转义
fixed_item["main_title"] = self._preprocess_for_json(item["main_title"])
else:
fixed_item["main_title"] = "默认标题"
# 处理文本列表
if "texts" in item and isinstance(item["texts"], list) and len(item["texts"]) > 0:
# 对文本列表中的每个元素应用预处理
fixed_item["texts"] = [self._preprocess_for_json(text) if text else "" for text in item["texts"]]
# 确保至少有两个元素
while len(fixed_item["texts"]) < 2:
fixed_item["texts"].append("")
else:
fixed_item["texts"] = ["默认副标题1", "默认副标题2"]
fixed_data.append(fixed_item)
# 如果是字符串,转换为默认格式
elif isinstance(item, str):
fixed_data.append({
"index": idx + 1,
"main_title": self._preprocess_for_json(item), # 应用预处理
"texts": ["", ""]
})
# 其他类型,使用默认值
else:
fixed_data.append({
"index": idx + 1,
"main_title": "默认标题",
"texts": ["", ""]
})
# 处理字典类型 (单个配置项)
elif isinstance(data, dict):
# 处理主标题
main_title = self._preprocess_for_json(data.get("main_title", "默认标题")) # 应用预处理
# 处理文本列表
texts = []
if "texts" in data and isinstance(data["texts"], list):
texts = [self._preprocess_for_json(text) if text else "" for text in data["texts"]] # 应用预处理
# 确保文本列表至少有两个元素
while len(texts) < 2:
texts.append("")
fixed_data.append({
"index": data.get("index", 1),
"main_title": main_title,
"texts": texts
})
# 如果数据是其他格式
else:
self.logger.warning(f"数据格式不支持: {type(data)},将使用默认值")
fixed_data.append({
"index": 1,
"main_title": "",
"texts": ["", ""]
})
# 确保至少有一个配置项
if not fixed_data:
fixed_data.append({
"index": 1,
"main_title": "",
"texts": ["", ""]
})
self.logger.info(f"修复后的数据: {fixed_data}")
return fixed_data
def run(self, info_directory, poster_num, content_data, system_prompt=None,
api_url="http://localhost:8000/v1", model_name="qwenQWQ", api_key="EMPTY", timeout=120):
"""
运行海报内容生成流程,并返回生成的配置数据。
参数:
info_directory: 信息目录路径列表 (e.g., ['/path/to/description.txt'])
poster_num: 需要生成的海报配置数量
content_data: 用于生成内容的文章内容(可以是字符串或字典列表)
system_prompt: 系统提示词默认为None使用内置提示词
api_url: API基础URL
model_name: 使用的模型名称
api_key: API密钥
返回:
list | dict | None: 生成的海报配置数据 (通常是列表),如果生成或解析失败则返回 None。
"""
try:
# 加载描述信息
self.load_infomation(info_directory)
# 生成海报内容
self.logger.info(f"开始生成海报内容,数量: {poster_num}")
full_response = self.generate_posters(
poster_num,
content_data,
system_prompt,
api_url,
model_name,
api_key,
timeout=timeout,
)
# 检查生成是否失败
if not isinstance(full_response, str) or not full_response.strip():
self.logger.error("海报内容生成失败或返回空响应")
return None
# 从原始响应字符串中提取JSON数据
result_data = self.split_content(full_response)
# 验证并修复数据
fixed_data = self._validate_and_fix_data(result_data)
self.logger.info(f"成功生成并修复海报配置数据,包含 {len(fixed_data) if isinstance(fixed_data, list) else 1} 个项目")
return fixed_data
except Exception as e:
self.logger.exception(f"海报内容生成过程中发生错误: {e}")
traceback.print_exc()
# 失败后创建一个默认配置
self.logger.info("创建默认海报配置数据")
default_configs = []
for i in range(poster_num):
default_configs.append({
"index": i + 1,
"main_title": "",
"texts": ["", ""]
})
return default_configs
def set_temperature(self, temperature):
"""设置温度参数"""
self.temperature = temperature
def set_top_p(self, top_p):
"""设置top_p参数"""
self.top_p = top_p
def set_presence_penalty(self, presence_penalty):
"""设置存在惩罚参数"""
self.presence_penalty = presence_penalty
def set_model_para(self, temperature, top_p, presence_penalty):
"""一次性设置所有模型参数"""
self.temperature = temperature
self.top_p = top_p
self.presence_penalty = presence_penalty