TravelContentCreator/utils/content_generator.py

625 lines
27 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import time
import logging
import random
import traceback
import simplejson as json
from datetime import datetime
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from core.ai_agent import AI_Agent
class ContentGenerator:
"""
海报文本内容生成器
使用AI_Agent代替直接管理OpenAI客户端简化代码结构
"""
def __init__(self,
output_dir="/root/autodl-tmp/poster_generate_result",
model_name="qwenQWQ",
base_url="http://localhost:8000/v1",
api_key="EMPTY",
temperature=0.7,
top_p=0.8,
presence_penalty=1.2):
"""
初始化内容生成器
参数:
output_dir: 输出结果保存目录
temperature: 生成温度参数
top_p: top_p参数
presence_penalty: 惩罚参数
"""
self.output_dir = output_dir
self.temperature = temperature
self.top_p = top_p
self.presence_penalty = presence_penalty
self.add_description = ""
self.model_name = model_name
self.base_url = base_url
self.api_key = api_key
# 设置日志
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
self.logger = logging.getLogger(__name__)
def load_infomation(self, info_directory_path):
"""
加载额外描述文件
参数:
info_directory_path: 信息目录路径列表
"""
self.add_description = "" # 重置描述文本
for path in info_directory_path:
try:
with open(path, "r", encoding="utf-8") as f:
self.add_description += f.read()
self.logger.info(f"成功加载描述文件: {path}")
except Exception as e:
self.logger.warning(f"加载描述文件失败: {path}, 错误: {e}")
self.add_description = ""
def split_content(self, content):
"""
分割结果, 返回去除
```json
```的json内容
参数:
content: 需要分割的内容
返回:
分割后的json内容
"""
try:
# 记录原始内容的前200个字符用于调试
self.logger.debug(f"解析内容原始内容前200字符: {content[:200]}")
# 首先尝试直接解析整个内容,以防已经是干净的 JSON
try:
parsed_data = json.loads(content)
# 验证解析后的数据格式
if isinstance(parsed_data, list):
# 如果是列表,验证每个元素是否符合预期结构
for item in parsed_data:
if isinstance(item, dict) and ('main_title' in item or 'texts' in item):
# 至少有一个元素符合海报配置结构
self.logger.info("成功直接解析为JSON格式列表符合预期结构")
return parsed_data
# 如果到这里,说明列表内没有符合结构的元素
if len(parsed_data) > 0 and isinstance(parsed_data[0], str):
self.logger.warning(f"解析到JSON列表但内容是字符串列表: {parsed_data}")
# 将字符串列表返回供后续修复
return parsed_data
self.logger.warning("解析到JSON列表但结构不符合预期")
elif isinstance(parsed_data, dict) and ('main_title' in parsed_data or 'texts' in parsed_data):
# 单个字典结构符合预期
self.logger.info("成功直接解析为JSON字典符合预期结构")
return parsed_data
# 如果结构不符合预期,记录但仍返回解析结果,交给后续函数修复
self.logger.warning(f"解析到JSON但结构不完全符合预期: {parsed_data}")
return parsed_data
except json.JSONDecodeError:
# 不是完整有效的JSON继续尝试提取
self.logger.debug("直接JSON解析失败尝试提取结构化内容")
# 常规模式:查找 ```json 和 ``` 之间的内容
if "```json" in content:
json_str = content.split("```json")[1].split("```")[0].strip()
try:
parsed_json = json.loads(json_str)
self.logger.info("成功从```json```代码块提取JSON")
return parsed_json
except json.JSONDecodeError as e:
self.logger.warning(f"从```json```提取的内容解析失败: {e}, 尝试其他方法")
# 备用模式1查找连续的 [ 开头和 ] 结尾的部分
import re
json_pattern = r'(\[(?:\s*\{.*?\}\s*,?)+\s*\])' # 更严格的模式,要求[]内至少有一个{}对象
json_matches = re.findall(json_pattern, content, re.DOTALL)
for match in json_matches:
try:
result = json.loads(match)
if isinstance(result, list) and len(result) > 0:
# 验证结构
for item in result:
if isinstance(item, dict) and ('main_title' in item or 'texts' in item):
self.logger.info("成功从正则表达式提取JSON数组")
return result
self.logger.warning("从正则表达式提取的JSON数组不符合预期结构")
except Exception as e:
self.logger.warning(f"解析正则匹配的内容失败: {e}")
continue
# 备用模式2查找 [ 开头 和 ] 结尾,并尝试解析
content = content.strip()
square_bracket_start = content.find('[')
square_bracket_end = content.rfind(']')
if square_bracket_start != -1 and square_bracket_end != -1 and square_bracket_end > square_bracket_start:
potential_json = content[square_bracket_start:square_bracket_end + 1]
try:
result = json.loads(potential_json)
if isinstance(result, list):
# 检查列表内容
self.logger.info(f"成功从方括号内容提取列表: {result}")
return result
except Exception as e:
self.logger.warning(f"尝试提取方括号内容失败: {e}")
# 最后一种尝试:查找所有可能的 JSON 结构并尝试解析
json_structures = re.findall(r'({.*?})', content, re.DOTALL)
if json_structures:
items = []
for i, struct in enumerate(json_structures):
try:
item = json.loads(struct)
# 验证结构包含预期字段
if isinstance(item, dict) and ('main_title' in item or 'texts' in item):
items.append(item)
except Exception as e:
self.logger.warning(f"解析可能的JSON结构 {i+1} 失败: {e}")
continue
if items:
self.logger.info(f"成功从文本中提取 {len(items)} 个JSON对象")
return items
# 如果以上所有方法都失败,尝试简单字符串处理
if "|" in content or "必打卡" in content or "性价比" in content:
# 这可能是一个简单的标签字符串
self.logger.warning(f"无法提取标准JSON但发现可能的标签字符串: {content}")
return content.strip()
# 都失败了,打印错误并引发异常
self.logger.error(f"无法解析内容,返回原始文本: {content[:200]}...")
raise ValueError("无法从响应中提取有效的 JSON 格式")
except Exception as e:
self.logger.error(f"解析内容时出错: {e}")
self.logger.debug(f"原始内容: {content[:200]}...") # 仅显示前200个字符
return content.strip() # 返回原始内容,让后续验证函数处理
def generate_posters(self,
poster_num,
content_data_list,
system_prompt=None,
api_url=None,
model_name=None,
api_key=None,
timeout=60,
max_retries=3):
"""
生成海报内容
参数:
poster_num: 海报数量
content_data_list: 内容数据列表(字典或字符串)
system_prompt: 系统提示默认为None则使用预设提示
api_url: API基础URL
model_name: 使用的模型名称
api_key: API密钥
timeout: 请求超时时间
max_retries: 最大重试次数
返回:
生成的海报内容
"""
# 构建默认系统提示词
if not system_prompt:
system_prompt = """
你是一名资深海报设计师有丰富的爆款海报设计经验你现在要为旅游景点做宣传在小红书上发布大量宣传海报。你的主要工作目标有2个
1、你要根据我给你的图片描述和笔记推文内容设计图文匹配的海报。
2、为海报设计文案文案的<第一个小标题>和<第二个小标题>之间你需要检查是否逻辑关系合理,你将通过先去生成<第二个小标题>关于景区亮点的部分,再去综合判断<第一个小标题>应该如何搭配组合更符合两个小标题的逻辑再生成<第一个小标题>。
其中,生成三类标题文案的通用性要求如下:
1、生成的<大标题>字数必须小于8个字符
2、生成的<第一个小标题>字数和<第二个小标题>字数两者都必须小8个字符
3、标题和文案都应符合中国社会主义核心价值观
接下来先开始生成<大标题>部分,由于海报是用来宣传旅游景点,生成的海报<大标题>必须使用以下8种格式之一
①地名+景点名(例如福建厦门鼓浪屿/厦门鼓浪屿);
②地名+景点名+plog
③拿捏+地名+景点名;
④地名+景点名+攻略;
⑤速通+地名+景点名
⑥推荐!+地名+景点名
⑦勇闯!+地名+景点名
⑧收藏!+地名+景点名
你需要随机挑选一种格式生成对应景点的文案但是格式除了上面8种不可以有其他任何格式同时尽量保证每一种格式出现的频率均衡。
接下来先去生成<第二个小标题><第二个小标题>文案的创作必须遵循以下原则:
请根据笔记内容和图片识别用极简的文字概括这篇笔记和图片中景点的特色亮点其中你可以参考以下词汇进行创作这段文案字数控制6-8字符以内
特色亮点可能会出现的词汇不完全举例:非遗、古建、绝佳山水、祈福圣地、研学圣地、解压天堂、中国小瑞士、秘境竹筏游等等类型词汇
接下来再去生成<第一个小标题><第一个小标题>文案的创作必须遵循以下原则:
这部分文案创作公式有5种分别为
①<受众人群画像>+<痛点词>
②<受众人群画像>
③<痛点词>
④<受众人群画像>+ | +<痛点词>
⑤<痛点词>+ | +<受众人群画像>
请你根据实际笔记内容,结合这部分文案创作公式,需要结合<受众人群画像>和<痛点词>时,必须根据<第二个小标题>的景点特征和所对应的完整笔记推文内容主旨,特征挑选对应<受众人群画像>和<痛点词>。
我给你提供受众人群画像库和痛点词库如下:
1、受众人群画像库情侣党、亲子游、合家游、银发族、亲子研学、学生党、打工人、周边游、本地人、穷游党、性价比、户外人、美食党、出片
2、痛点词库3天2夜、必去、看了都哭了、不能错过、一定要来、问爆了、超全攻略、必打卡、强推、懒人攻略、必游榜、小众打卡、狂喜等等。
你需要为每个请求至少生成{poster_num}个海报设计。请使用JSON格式输出结果结构如下
[
{
"index": 1,
"main_title": "主标题内容",
"texts": ["第一个小标题", "第二个小标题"]
},
{
"index": 2,
"main_title": "主标题内容",
"texts": ["第一个小标题", "第二个小标题"]
},
// ... 更多海报
]
确保生成的数量与用户要求的数量一致。只生成上述JSON格式内容不要有其他任何额外内容。
"""
# 提取内容文本(如果是列表内容数据)
tweet_content = ""
if isinstance(content_data_list, list):
for item in content_data_list:
if isinstance(item, dict):
title = item.get('title', '')
content = item.get('content', '')
tweet_content += f"<title>\n{title}\n</title>\n<content>\n{content}\n</content>\n\n"
elif isinstance(item, str):
tweet_content += item + "\n\n"
elif isinstance(content_data_list, str):
tweet_content = content_data_list
# 构建用户提示
if self.add_description:
user_content = f"""
以下是需要你处理的信息:
关于景点的描述:
{self.add_description}
推文内容:
{tweet_content}
请根据这些信息,生成{poster_num}个海报文案配置以JSON数组格式返回。
"""
else:
user_content = f"""
以下是需要你处理的推文内容:
{tweet_content}
请根据这些信息,生成{poster_num}个海报文案配置以JSON数组格式返回。
"""
self.logger.info(f"正在生成{poster_num}个海报文案配置")
# 创建AI_Agent实例
ai_agent = AI_Agent(
self.base_url,
self.model_name,
self.api_key,
timeout=timeout,
max_retries=max_retries,
stream_chunk_timeout=30 # 流式块超时时间
)
full_response = ""
try:
# 使用AI_Agent的non-streaming方法
self.logger.info(f"调用AI生成海报配置模型: {self.model_name}")
full_response, tokens, time_cost = ai_agent.work(
system_prompt,
user_content,
"", # 历史消息(空)
self.temperature,
self.top_p,
self.presence_penalty
)
self.logger.info(f"AI生成完成耗时: {time_cost:.2f}s, 预估令牌数: {tokens}")
if not full_response:
self.logger.warning("AI返回空响应使用备用内容")
full_response = self._generate_fallback_content(poster_num)
except Exception as e:
self.logger.exception(f"AI生成过程发生错误: {e}")
full_response = self._generate_fallback_content(poster_num)
finally:
# 确保关闭AI Agent
ai_agent.close()
return full_response
def _generate_fallback_content(self, poster_num):
"""生成备用内容当API调用失败时使用"""
self.logger.info("生成备用内容")
default_configs = []
for i in range(poster_num):
default_configs.append({
"index": i + 1,
"main_title": "",
"texts": ["", ""]
})
return json.dumps(default_configs, ensure_ascii=False)
def save_result(self, full_response, custom_output_dir=None):
"""
保存生成结果到文件
参数:
full_response: 生成的完整响应内容
custom_output_dir: 自定义输出目录(可选)
返回:
结果文件路径
"""
# 生成时间戳
date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
output_dir = custom_output_dir or self.output_dir
try:
# 解析内容为JSON格式
parsed_data = self.split_content(full_response)
# 验证内容格式并修复
validated_data = self._validate_and_fix_data(parsed_data)
# 创建结果文件路径
result_path = os.path.join(output_dir, f"{date_time}.json")
os.makedirs(os.path.dirname(result_path), exist_ok=True)
# 保存结果到文件
with open(result_path, "w", encoding="utf-8") as f:
json.dump(validated_data, f, ensure_ascii=False, indent=4, ignore_nan=True)
self.logger.info(f"结果已保存到: {result_path}")
return result_path
except Exception as e:
self.logger.error(f"保存结果到文件时出错: {e}")
# 尝试创建一个简单的备用配置
fallback_data = [{"main_title": "", "texts": ["", ""], "index": 1}]
# 保存备用数据
result_path = os.path.join(output_dir, f"{date_time}_fallback.json")
os.makedirs(os.path.dirname(result_path), exist_ok=True)
with open(result_path, "w", encoding="utf-8") as f:
json.dump(fallback_data, f, ensure_ascii=False, indent=4, ignore_nan=True)
self.logger.info(f"出错后已保存备用数据到: {result_path}")
return result_path
def _validate_and_fix_data(self, data):
"""
验证并修复数据格式,确保符合预期结构
参数:
data: 需要验证的数据
返回:
修复后的数据
"""
fixed_data = []
# 记录原始数据格式信息
self.logger.info(f"验证和修复数据,原始数据类型: {type(data)}")
if isinstance(data, list):
self.logger.info(f"原始数据是列表,长度: {len(data)}")
if len(data) > 0:
self.logger.info(f"第一个元素类型: {type(data[0])}")
elif isinstance(data, str):
self.logger.info(f"原始数据是字符串: {data[:100]}")
else:
self.logger.info(f"原始数据是其他类型: {data}")
# 如果数据是列表
if isinstance(data, list):
for i, item in enumerate(data):
# 检查项目是否为字典
if isinstance(item, dict):
# 确保必需字段存在
fixed_item = {
"index": item.get("index", i + 1),
"main_title": item.get("main_title", ""),
"texts": item.get("texts", ["", ""])
}
# 确保texts是列表格式
if not isinstance(fixed_item["texts"], list):
if isinstance(fixed_item["texts"], str):
fixed_item["texts"] = [fixed_item["texts"], ""]
else:
fixed_item["texts"] = ["", ""]
# 限制texts最多包含两个元素
if len(fixed_item["texts"]) > 2:
fixed_item["texts"] = fixed_item["texts"][:2]
elif len(fixed_item["texts"]) < 2:
while len(fixed_item["texts"]) < 2:
fixed_item["texts"].append("")
fixed_data.append(fixed_item)
# 如果项目是字符串可能是错误格式的texts值
elif isinstance(item, str):
self.logger.warning(f"配置项 {i+1} 是字符串格式: '{item}',将转换为标准格式")
# 尝试解析字符串格式,例如"性价比|必打卡"
texts = []
if "|" in item:
texts = item.split("|")
else:
texts = [item, ""]
fixed_item = {
"index": i + 1,
"main_title": "",
"texts": texts
}
fixed_data.append(fixed_item)
else:
self.logger.warning(f"配置项 {i+1} 格式不支持: {type(item)},将使用默认值")
fixed_data.append({
"index": i + 1,
"main_title": "",
"texts": ["", ""]
})
# 如果数据是字典
elif isinstance(data, dict):
fixed_item = {
"index": data.get("index", 1),
"main_title": data.get("main_title", ""),
"texts": data.get("texts", ["", ""])
}
# 确保texts是列表格式
if not isinstance(fixed_item["texts"], list):
if isinstance(fixed_item["texts"], str):
fixed_item["texts"] = [fixed_item["texts"], ""]
else:
fixed_item["texts"] = ["", ""]
# 限制texts最多包含两个元素
if len(fixed_item["texts"]) > 2:
fixed_item["texts"] = fixed_item["texts"][:2]
elif len(fixed_item["texts"]) < 2:
while len(fixed_item["texts"]) < 2:
fixed_item["texts"].append("")
fixed_data.append(fixed_item)
# 如果数据是字符串
elif isinstance(data, str):
self.logger.warning(f"数据是字符串格式: '{data}',尝试转换为标准格式")
# 尝试解析字符串格式,例如"性价比|必打卡"
texts = []
if "|" in data:
texts = data.split("|")
else:
texts = [data, ""]
fixed_data.append({
"index": 1,
"main_title": "",
"texts": texts
})
# 如果数据是其他格式
else:
self.logger.warning(f"数据格式不支持: {type(data)},将使用默认值")
fixed_data.append({
"index": 1,
"main_title": "",
"texts": ["", ""]
})
# 确保至少有一个配置项
if not fixed_data:
fixed_data.append({
"index": 1,
"main_title": "",
"texts": ["", ""]
})
self.logger.info(f"修复后的数据: {fixed_data}")
return fixed_data
def run(self, info_directory, poster_num, content_data, system_prompt=None,
api_url="http://localhost:8000/v1", model_name="qwenQWQ", api_key="EMPTY", timeout=120):
"""
运行海报内容生成流程,并返回生成的配置数据。
参数:
info_directory: 信息目录路径列表 (e.g., ['/path/to/description.txt'])
poster_num: 需要生成的海报配置数量
content_data: 用于生成内容的文章内容(可以是字符串或字典列表)
system_prompt: 系统提示词默认为None使用内置提示词
api_url: API基础URL
model_name: 使用的模型名称
api_key: API密钥
返回:
list | dict | None: 生成的海报配置数据 (通常是列表),如果生成或解析失败则返回 None。
"""
try:
# 加载描述信息
self.load_infomation(info_directory)
# 生成海报内容
self.logger.info(f"开始生成海报内容,数量: {poster_num}")
full_response = self.generate_posters(
poster_num,
content_data,
system_prompt,
api_url,
model_name,
api_key,
timeout=timeout,
)
# 检查生成是否失败
if not isinstance(full_response, str) or not full_response.strip():
self.logger.error("海报内容生成失败或返回空响应")
return None
# 从原始响应字符串中提取JSON数据
result_data = self.split_content(full_response)
# 验证并修复数据
fixed_data = self._validate_and_fix_data(result_data)
self.logger.info(f"成功生成并修复海报配置数据,包含 {len(fixed_data) if isinstance(fixed_data, list) else 1} 个项目")
return fixed_data
except Exception as e:
self.logger.exception(f"海报内容生成过程中发生错误: {e}")
traceback.print_exc()
# 失败后创建一个默认配置
self.logger.info("创建默认海报配置数据")
default_configs = []
for i in range(poster_num):
default_configs.append({
"index": i + 1,
"main_title": "",
"texts": ["", ""]
})
return default_configs
def set_temperature(self, temperature):
"""设置温度参数"""
self.temperature = temperature
def set_top_p(self, top_p):
"""设置top_p参数"""
self.top_p = top_p
def set_presence_penalty(self, presence_penalty):
"""设置存在惩罚参数"""
self.presence_penalty = presence_penalty
def set_model_para(self, temperature, top_p, presence_penalty):
"""一次性设置所有模型参数"""
self.temperature = temperature
self.top_p = top_p
self.presence_penalty = presence_penalty