TravelContentCreator/utils/content_generator.py

570 lines
24 KiB
Python
Raw Normal View History

2025-04-24 20:35:25 +08:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import time
2025-04-24 20:35:25 +08:00
import logging
import random
2025-04-24 20:35:25 +08:00
import traceback
import simplejson as json
2025-04-24 20:35:25 +08:00
from datetime import datetime
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from core.ai_agent import AI_Agent
class ContentGenerator:
"""
海报文本内容生成器
使用AI_Agent代替直接管理OpenAI客户端简化代码结构
"""
def __init__(self,
output_dir="/root/autodl-tmp/poster_generate_result",
model_name="qwenQWQ",
base_url="http://localhost:8000/v1",
api_key="EMPTY",
2025-04-24 20:35:25 +08:00
temperature=0.7,
top_p=0.8,
presence_penalty=1.2):
"""
初始化内容生成器
参数:
output_dir: 输出结果保存目录
temperature: 生成温度参数
top_p: top_p参数
presence_penalty: 惩罚参数
"""
self.output_dir = output_dir
self.temperature = temperature
self.top_p = top_p
self.presence_penalty = presence_penalty
self.add_description = ""
self.model_name = model_name
self.base_url = base_url
self.api_key = api_key
2025-04-24 20:35:25 +08:00
# 设置日志
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
self.logger = logging.getLogger(__name__)
def load_infomation(self, info_directory_path):
"""
加载额外描述文件
参数:
info_directory_path: 信息目录路径列表
"""
self.add_description = "" # 重置描述文本
for path in info_directory_path:
try:
with open(path, "r", encoding="utf-8") as f:
self.add_description += f.read()
self.logger.info(f"成功加载描述文件: {path}")
except Exception as e:
self.logger.warning(f"加载描述文件失败: {path}, 错误: {e}")
self.add_description = ""
def split_content(self, content):
"""
分割结果, 返回去除
```json
```的json内容
参数:
content: 需要分割的内容
返回:
分割后的json内容
"""
try:
# 记录原始内容的前200个字符用于调试
self.logger.debug(f"解析内容原始内容前200字符: {content[:200]}")
2025-04-24 20:35:25 +08:00
# 首先尝试直接解析整个内容,以防已经是干净的 JSON
try:
parsed_data = json.loads(content)
# 验证解析后的数据格式
if isinstance(parsed_data, list):
# 如果是列表,验证每个元素是否符合预期结构
for item in parsed_data:
if isinstance(item, dict) and ('main_title' in item or 'texts' in item):
# 至少有一个元素符合海报配置结构
self.logger.info("成功直接解析为JSON格式列表符合预期结构")
return parsed_data
# 如果到这里,说明列表内没有符合结构的元素
if len(parsed_data) > 0 and isinstance(parsed_data[0], str):
self.logger.warning(f"解析到JSON列表但内容是字符串列表: {parsed_data}")
# 将字符串列表返回供后续修复
return parsed_data
self.logger.warning("解析到JSON列表但结构不符合预期")
elif isinstance(parsed_data, dict) and ('main_title' in parsed_data or 'texts' in parsed_data):
# 单个字典结构符合预期
self.logger.info("成功直接解析为JSON字典符合预期结构")
return parsed_data
# 如果结构不符合预期,记录但仍返回解析结果,交给后续函数修复
self.logger.warning(f"解析到JSON但结构不完全符合预期: {parsed_data}")
return parsed_data
2025-04-24 20:35:25 +08:00
except json.JSONDecodeError:
# 不是完整有效的JSON继续尝试提取
self.logger.debug("直接JSON解析失败尝试提取结构化内容")
2025-04-24 20:35:25 +08:00
# 常规模式:查找 ```json 和 ``` 之间的内容
if "```json" in content:
json_str = content.split("```json")[1].split("```")[0].strip()
try:
parsed_json = json.loads(json_str)
self.logger.info("成功从```json```代码块提取JSON")
return parsed_json
2025-04-24 20:35:25 +08:00
except json.JSONDecodeError as e:
self.logger.warning(f"从```json```提取的内容解析失败: {e}, 尝试其他方法")
2025-04-24 20:35:25 +08:00
# 备用模式1查找连续的 [ 开头和 ] 结尾的部分
import re
json_pattern = r'(\[(?:\s*\{.*?\}\s*,?)+\s*\])' # 更严格的模式,要求[]内至少有一个{}对象
2025-04-24 20:35:25 +08:00
json_matches = re.findall(json_pattern, content, re.DOTALL)
for match in json_matches:
try:
result = json.loads(match)
if isinstance(result, list) and len(result) > 0:
# 验证结构
for item in result:
if isinstance(item, dict) and ('main_title' in item or 'texts' in item):
self.logger.info("成功从正则表达式提取JSON数组")
return result
self.logger.warning("从正则表达式提取的JSON数组不符合预期结构")
except Exception as e:
self.logger.warning(f"解析正则匹配的内容失败: {e}")
continue
2025-04-24 20:35:25 +08:00
# 备用模式2查找 [ 开头 和 ] 结尾,并尝试解析
content = content.strip()
square_bracket_start = content.find('[')
square_bracket_end = content.rfind(']')
if square_bracket_start != -1 and square_bracket_end != -1 and square_bracket_end > square_bracket_start:
2025-04-24 20:35:25 +08:00
potential_json = content[square_bracket_start:square_bracket_end + 1]
try:
result = json.loads(potential_json)
if isinstance(result, list):
# 检查列表内容
self.logger.info(f"成功从方括号内容提取列表: {result}")
return result
except Exception as e:
self.logger.warning(f"尝试提取方括号内容失败: {e}")
2025-04-24 20:35:25 +08:00
# 最后一种尝试:查找所有可能的 JSON 结构并尝试解析
json_structures = re.findall(r'({.*?})', content, re.DOTALL)
if json_structures:
items = []
for i, struct in enumerate(json_structures):
try:
item = json.loads(struct)
# 验证结构包含预期字段
if isinstance(item, dict) and ('main_title' in item or 'texts' in item):
2025-04-24 20:35:25 +08:00
items.append(item)
except Exception as e:
self.logger.warning(f"解析可能的JSON结构 {i+1} 失败: {e}")
2025-04-24 20:35:25 +08:00
continue
if items:
self.logger.info(f"成功从文本中提取 {len(items)} 个JSON对象")
2025-04-24 20:35:25 +08:00
return items
# 如果以上所有方法都失败,尝试简单字符串处理
if "|" in content or "必打卡" in content or "性价比" in content:
# 这可能是一个简单的标签字符串
self.logger.warning(f"无法提取标准JSON但发现可能的标签字符串: {content}")
return content.strip()
2025-04-24 20:35:25 +08:00
# 都失败了,打印错误并引发异常
self.logger.error(f"无法解析内容,返回原始文本: {content[:200]}...")
raise ValueError("无法从响应中提取有效的 JSON 格式")
except Exception as e:
self.logger.error(f"解析内容时出错: {e}")
self.logger.debug(f"原始内容: {content[:200]}...") # 仅显示前200个字符
return content.strip() # 返回原始内容,让后续验证函数处理
2025-04-24 20:35:25 +08:00
def _preprocess_for_json(self, text):
"""预处理文本,将换行符转换为\\n形式保证JSON安全"""
if not isinstance(text, str):
return text
# 将所有实际换行符替换为\\n字符串
return text.replace('\n', '\\n').replace('\r', '\\r')
def generate_posters(self, poster_num, content_data_list, system_prompt=None,
api_url=None, model_name=None, api_key=None, timeout=120, max_retries=3):
2025-04-24 20:35:25 +08:00
"""
生成海报配置
2025-04-24 20:35:25 +08:00
Args:
poster_num: 生成的海报数量
content_data_list: 内容数据列表
system_prompt: 系统提示词可选
api_url: API基础URL可选
model_name: 模型名称可选
api_key: API密钥可选
Returns:
str: 生成的配置JSON字符串
2025-04-24 20:35:25 +08:00
"""
# 更新API设置
if api_url:
self.base_url = api_url
if model_name:
self.model_name = model_name
if api_key:
self.api_key = api_key
# 使用系统提示或默认提示
if system_prompt:
self.system_prompt = system_prompt
elif not self.system_prompt:
self.system_prompt = """你是一名专业的旅游景点海报文案创作专家。你的任务是根据提供的旅游景点信息和推文内容生成海报文案配置。你的回复必须是一个JSON数组每一项表示一个海报配置包含'index''main_title''texts'三个字段,其中'texts'是一个字符串数组。海报文案要简洁有力,突出景点特色和吸引力。"""
2025-04-24 20:35:25 +08:00
# 提取内容文本(如果是列表内容数据)
tweet_content = ""
if isinstance(content_data_list, list):
for item in content_data_list:
if isinstance(item, dict):
# 对标题和内容进行预处理,替换换行符
title = self._preprocess_for_json(item.get('title', ''))
content = self._preprocess_for_json(item.get('content', ''))
2025-04-24 20:35:25 +08:00
tweet_content += f"<title>\n{title}\n</title>\n<content>\n{content}\n</content>\n\n"
elif isinstance(item, str):
tweet_content += self._preprocess_for_json(item) + "\n\n"
2025-04-24 20:35:25 +08:00
elif isinstance(content_data_list, str):
tweet_content = self._preprocess_for_json(content_data_list)
2025-04-24 20:35:25 +08:00
# 构建用户提示
if self.add_description:
# 预处理景点描述
processed_description = self._preprocess_for_json(self.add_description)
2025-04-24 20:35:25 +08:00
user_content = f"""
以下是需要你处理的信息
关于景点的描述:
{processed_description}
2025-04-24 20:35:25 +08:00
推文内容:
{tweet_content}
请根据这些信息生成{poster_num}个海报文案配置以JSON数组格式返回
"""
else:
user_content = f"""
以下是需要你处理的推文内容:
{tweet_content}
请根据这些信息生成{poster_num}个海报文案配置以JSON数组格式返回
"""
self.logger.info(f"正在生成{poster_num}个海报文案配置")
# 创建AI_Agent实例
ai_agent = AI_Agent(
self.base_url,
self.model_name,
self.api_key,
2025-04-24 20:35:25 +08:00
timeout=timeout,
max_retries=max_retries,
stream_chunk_timeout=30 # 流式块超时时间
)
full_response = ""
try:
# 使用AI_Agent的non-streaming方法
self.logger.info(f"调用AI生成海报配置模型: {self.model_name}")
2025-04-24 20:35:25 +08:00
full_response, tokens, time_cost = ai_agent.work(
self.system_prompt,
2025-04-24 20:35:25 +08:00
user_content,
"", # 历史消息(空)
self.temperature,
self.top_p,
self.presence_penalty
)
self.logger.info(f"AI生成完成耗时: {time_cost:.2f}s, 预估令牌数: {tokens}")
if not full_response:
self.logger.warning("AI返回空响应使用备用内容")
full_response = self._generate_fallback_content(poster_num)
except Exception as e:
self.logger.exception(f"AI生成过程发生错误: {e}")
full_response = self._generate_fallback_content(poster_num)
finally:
# 确保关闭AI Agent
ai_agent.close()
return full_response
def _generate_fallback_content(self, poster_num):
"""生成备用内容当API调用失败时使用"""
self.logger.info("生成备用内容")
default_configs = []
for i in range(poster_num):
default_configs.append({
"index": i + 1,
"main_title": "",
"texts": ["", ""]
2025-04-24 20:35:25 +08:00
})
return json.dumps(default_configs, ensure_ascii=False)
def save_result(self, full_response, custom_output_dir=None):
"""
保存生成结果到文件
参数:
full_response: 生成的完整响应内容
custom_output_dir: 自定义输出目录可选
返回:
结果文件路径
"""
# 生成时间戳
date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
output_dir = custom_output_dir or self.output_dir
try:
# 解析内容为JSON格式
parsed_data = self.split_content(full_response)
# 验证内容格式并修复
validated_data = self._validate_and_fix_data(parsed_data)
2025-04-24 20:35:25 +08:00
# 创建结果文件路径
result_path = os.path.join(output_dir, f"{date_time}.json")
os.makedirs(os.path.dirname(result_path), exist_ok=True)
# 保存结果到文件
with open(result_path, "w", encoding="utf-8") as f:
json.dump(validated_data, f, ensure_ascii=False, indent=4, ignore_nan=True)
2025-04-24 20:35:25 +08:00
self.logger.info(f"结果已保存到: {result_path}")
return result_path
except Exception as e:
self.logger.error(f"保存结果到文件时出错: {e}")
# 尝试创建一个简单的备用配置
fallback_data = [{"main_title": "", "texts": ["", ""], "index": 1}]
2025-04-24 20:35:25 +08:00
# 保存备用数据
result_path = os.path.join(output_dir, f"{date_time}_fallback.json")
os.makedirs(os.path.dirname(result_path), exist_ok=True)
with open(result_path, "w", encoding="utf-8") as f:
json.dump(fallback_data, f, ensure_ascii=False, indent=4, ignore_nan=True)
2025-04-24 20:35:25 +08:00
self.logger.info(f"出错后已保存备用数据到: {result_path}")
return result_path
def _validate_and_fix_data(self, data):
"""
验证并修复从AI返回的数据确保其符合期望的结构
Args:
2025-04-24 20:35:25 +08:00
data: 需要验证的数据
Returns:
list: 修复后的数据列表
2025-04-24 20:35:25 +08:00
"""
fixed_data = []
self.logger.info(f"验证并修复数据: {type(data)}")
2025-04-24 20:35:25 +08:00
# 尝试处理字符串类型 (通常是JSON字符串)
if isinstance(data, str):
try:
# 尝试将字符串解析为JSON对象
parsed_data = json.loads(data)
# 递归调用本函数处理解析后的数据
return self._validate_and_fix_data(parsed_data)
except json.JSONDecodeError as e:
self.logger.warning(f"JSON解析失败: {e}")
# 可以选择尝试清理和再次解析
try:
# 寻找字符串中第一个 [ 和最后一个 ] 之间的内容
start_idx = data.find('[')
end_idx = data.rfind(']')
if start_idx >= 0 and end_idx > start_idx:
json_part = data[start_idx:end_idx+1]
self.logger.info(f"尝试从字符串中提取JSON部分: {json_part[:100]}...")
parsed_data = json.loads(json_part)
return self._validate_and_fix_data(parsed_data)
except:
self.logger.warning("无法从字符串中提取有效的JSON部分")
fixed_data.append({
"index": 1,
"main_title": self._preprocess_for_json("默认标题"), # 应用预处理
"texts": [self._preprocess_for_json("默认副标题1"), self._preprocess_for_json("默认副标题2")] # 应用预处理
})
# 处理列表类型
elif isinstance(data, list):
for idx, item in enumerate(data):
# 如果是字典,检查必须字段
2025-04-24 20:35:25 +08:00
if isinstance(item, dict):
fixed_item = {}
# 设置索引
fixed_item["index"] = item.get("index", idx + 1)
# 处理主标题
if "main_title" in item and item["main_title"]:
# 应用预处理,确保所有换行符被正确转义
fixed_item["main_title"] = self._preprocess_for_json(item["main_title"])
else:
fixed_item["main_title"] = "默认标题"
2025-04-24 20:35:25 +08:00
# 处理文本列表
if "texts" in item and isinstance(item["texts"], list) and len(item["texts"]) > 0:
# 对文本列表中的每个元素应用预处理
fixed_item["texts"] = [self._preprocess_for_json(text) if text else "" for text in item["texts"]]
# 确保至少有两个元素
2025-04-24 20:35:25 +08:00
while len(fixed_item["texts"]) < 2:
fixed_item["texts"].append("")
else:
fixed_item["texts"] = ["默认副标题1", "默认副标题2"]
2025-04-24 20:35:25 +08:00
fixed_data.append(fixed_item)
# 如果是字符串,转换为默认格式
elif isinstance(item, str):
fixed_data.append({
"index": idx + 1,
"main_title": self._preprocess_for_json(item), # 应用预处理
"texts": ["", ""]
})
# 其他类型,使用默认值
2025-04-24 20:35:25 +08:00
else:
fixed_data.append({
"index": idx + 1,
"main_title": "默认标题",
"texts": ["", ""]
2025-04-24 20:35:25 +08:00
})
# 处理字典类型 (单个配置项)
2025-04-24 20:35:25 +08:00
elif isinstance(data, dict):
# 处理主标题
main_title = self._preprocess_for_json(data.get("main_title", "默认标题")) # 应用预处理
2025-04-24 20:35:25 +08:00
# 处理文本列表
texts = []
if "texts" in data and isinstance(data["texts"], list):
texts = [self._preprocess_for_json(text) if text else "" for text in data["texts"]] # 应用预处理
# 确保文本列表至少有两个元素
while len(texts) < 2:
texts.append("")
fixed_data.append({
"index": data.get("index", 1),
"main_title": main_title,
"texts": texts
})
# 如果数据是其他格式
2025-04-24 20:35:25 +08:00
else:
self.logger.warning(f"数据格式不支持: {type(data)},将使用默认值")
fixed_data.append({
"index": 1,
"main_title": "",
"texts": ["", ""]
2025-04-24 20:35:25 +08:00
})
# 确保至少有一个配置项
if not fixed_data:
fixed_data.append({
"index": 1,
"main_title": "",
"texts": ["", ""]
2025-04-24 20:35:25 +08:00
})
self.logger.info(f"修复后的数据: {fixed_data}")
2025-04-24 20:35:25 +08:00
return fixed_data
def run(self, info_directory, poster_num, content_data, system_prompt=None,
api_url="http://localhost:8000/v1", model_name="qwenQWQ", api_key="EMPTY", timeout=120):
2025-04-24 20:35:25 +08:00
"""
运行海报内容生成流程并返回生成的配置数据
参数:
info_directory: 信息目录路径列表 (e.g., ['/path/to/description.txt'])
poster_num: 需要生成的海报配置数量
content_data: 用于生成内容的文章内容可以是字符串或字典列表
system_prompt: 系统提示词默认为None使用内置提示词
api_url: API基础URL
model_name: 使用的模型名称
api_key: API密钥
返回:
list | dict | None: 生成的海报配置数据 (通常是列表)如果生成或解析失败则返回 None
"""
try:
# 加载描述信息
self.load_infomation(info_directory)
# 生成海报内容
self.logger.info(f"开始生成海报内容,数量: {poster_num}")
full_response = self.generate_posters(
poster_num,
content_data,
system_prompt,
api_url,
model_name,
api_key,
timeout=timeout,
)
2025-04-24 20:35:25 +08:00
# 检查生成是否失败
if not isinstance(full_response, str) or not full_response.strip():
self.logger.error("海报内容生成失败或返回空响应")
return None
# 从原始响应字符串中提取JSON数据
result_data = self.split_content(full_response)
# 验证并修复数据
fixed_data = self._validate_and_fix_data(result_data)
2025-04-24 20:35:25 +08:00
self.logger.info(f"成功生成并修复海报配置数据,包含 {len(fixed_data) if isinstance(fixed_data, list) else 1} 个项目")
2025-04-24 20:35:25 +08:00
return fixed_data
except Exception as e:
self.logger.exception(f"海报内容生成过程中发生错误: {e}")
traceback.print_exc()
# 失败后创建一个默认配置
self.logger.info("创建默认海报配置数据")
default_configs = []
for i in range(poster_num):
default_configs.append({
"index": i + 1,
"main_title": "",
"texts": ["", ""]
2025-04-24 20:35:25 +08:00
})
return default_configs
def set_temperature(self, temperature):
"""设置温度参数"""
self.temperature = temperature
def set_top_p(self, top_p):
"""设置top_p参数"""
self.top_p = top_p
def set_presence_penalty(self, presence_penalty):
"""设置存在惩罚参数"""
self.presence_penalty = presence_penalty
def set_model_para(self, temperature, top_p, presence_penalty):
"""一次性设置所有模型参数"""
self.temperature = temperature
self.top_p = top_p
self.presence_penalty = presence_penalty