重构了content_gen, 用了ai_agent

This commit is contained in:
jinye_huang 2025-04-24 20:35:25 +08:00
parent cca622d34f
commit e8f8d52f5c
7 changed files with 738 additions and 662 deletions

View File

@ -1,704 +1,107 @@
import os
from openai import OpenAI
import pandas as pd
from datetime import datetime
import cv2
import time
import random
import json
import logging
from utils.content_generator import ContentGenerator as NewContentGenerator
# 为了向后兼容我们保留ContentGenerator类但内部使用新的实现
class ContentGenerator:
def __init__(self,
model_name="qwenQWQ",
api_base_url="http://localhost:8000/v1",
api_key="EMPTY",
output_dir="/root/autodl-tmp/poster_generate_result",
):
output_dir="/root/autodl-tmp/poster_generate_result"):
"""
初始化海报生成器
初始化海报生成器内部使用新的实现
参数:
csv_path: CSV文件路径
img_base_dir: 图片基础目录
output_dir: 输出结果保存目录
model_name: 使用的模型名称
api_base_url: API基础URL
api_key: API密钥
"""
self.output_dir = output_dir
# 创建新的实现实例
self._impl = NewContentGenerator(
output_dir=output_dir,
temperature=0.7,
top_p=0.8,
presence_penalty=1.2
)
# 存储API参数
self.model_name = model_name
self.api_base_url = api_base_url
self.api_key = api_key
# 不在初始化时创建OpenAI客户端而是在需要时临时创建
self.client = None
# 初始化数据
self.df = None
self.all_images_info = []
self.structured_prompt = ""
self.current_img_info = None
self.add_description = ""
self.temperature = 0.7
self.top_p = 0.8
self.presence_penalty = 1.2
# 设置日志
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
self.logger = logging.getLogger(__name__)
def load_infomation(self, info_directory_path):
"""
加载额外描述文件
参数:
info_directory_path: 信息目录路径
"""
## 读取路径下的所有文件
for path in info_directory_path:
# file_extend_path = os.path.join(self.img_base_dir, info_directory, "description.txt")
try:
with open(path, "r") as f:
self.add_description += f.read()
except:
# 为向后兼容保留的字段
self.df = None
self.all_images_info = []
self.structured_prompt = ""
self.current_img_info = None
self.add_description = ""
def _create_temp_client(self):
"""
创建临时OpenAI客户端
返回:
OpenAI客户端实例
"""
try:
import gc
# 强制垃圾回收
gc.collect()
# 创建新的客户端实例
print(f"创建临时OpenAI客户端API URL: {self.api_base_url}")
client = OpenAI(
base_url=self.api_base_url,
api_key=self.api_key
)
return client
except Exception as e:
print(f"创建OpenAI客户端失败: {str(e)}")
return None
def _close_client(self, client):
"""
关闭并清理OpenAI客户端
参数:
client: 需要关闭的客户端实例
"""
try:
# OpenAI客户端可能没有显式的close方法
# 将引用设为None让Python垃圾回收处理
client = None
import gc
gc.collect()
print("OpenAI客户端资源已释放")
except Exception as e:
print(f"关闭客户端失败: {str(e)}")
# 将方法委托给新实现
def load_infomation(self, info_directory_path):
"""加载额外描述文件"""
self._impl.load_infomation(info_directory_path)
self.add_description = self._impl.add_description
def split_content(self, content):
"""
分割结果, 返回去除
```json
```的json内容
参数:
content: 需要分割的内容
返回:
分割后的json内容
"""
try:
# 首先尝试直接解析整个内容,以防已经是干净的 JSON
try:
return json.loads(content)
except json.JSONDecodeError:
pass # 不是干净的 JSON继续处理
# 常规模式:查找 ```json 和 ``` 之间的内容
if "```json" in content:
json_str = content.split("```json")[1].split("```")[0].strip()
try:
return json.loads(json_str)
except json.JSONDecodeError as e:
print(f"常规格式解析失败: {e}, 尝试其他方法")
# 备用模式1查找连续的 { 开头和 } 结尾的部分
import re
json_pattern = r'(\[.*?\])'
json_matches = re.findall(json_pattern, content, re.DOTALL)
if json_matches:
for match in json_matches:
try:
result = json.loads(match)
if isinstance(result, list) and len(result) > 0:
return result
except:
continue
# 备用模式2查找 { 开头 和 } 结尾,并尝试解析
content = content.strip()
square_bracket_start = content.find('[')
square_bracket_end = content.rfind(']')
if square_bracket_start != -1 and square_bracket_end != -1:
potential_json = content[square_bracket_start:square_bracket_end + 1]
try:
return json.loads(potential_json)
except:
print("尝试提取方括号内容失败")
# 最后一种尝试:查找所有可能的 JSON 结构并尝试解析
json_structures = re.findall(r'({.*?})', content, re.DOTALL)
if json_structures:
items = []
for i, struct in enumerate(json_structures):
try:
item = json.loads(struct)
# 验证结构包含预期字段
if 'main_title' in item and ('texts' in item or 'index' in item):
items.append(item)
except:
continue
if items:
return items
# 都失败了,打印错误并引发异常
print(f"无法解析内容,返回原始文本: {content[:200]}...")
raise ValueError("无法从响应中提取有效的 JSON 格式")
except Exception as e:
print(f"解析内容时出错: {e}")
print(f"原始内容: {content[:200]}...") # 仅显示前200个字符
raise e
"""分割JSON内容"""
return self._impl.split_content(content)
def generate_posters(self, poster_num, tweet_content, system_prompt=None, max_retries=3):
"""
生成海报内容
参数:
poster_num: 海报数量
tweet_content: 推文内容
system_prompt: 系统提示默认为None则使用预设提示
max_retries: 最大重试次数
返回:
生成的海报内容
"""
full_response = ""
timeout = 60 # 请求超时时间(秒)
if not system_prompt:
# 使用默认系统提示词
system_prompt = """
你是一名资深海报设计师有丰富的爆款海报设计经验你现在要为旅游景点做宣传在小红书上发布大量宣传海报你的主要工作目标有2个
1你要根据我给你的图片描述和笔记推文内容设计图文匹配的海报
2为海报设计文案文案的<第一个小标题><第二个小标题>之间你需要检查是否逻辑关系合理你将通过先去生成<第二个小标题>关于景区亮点的部分再去综合判断<第一个小标题>应该如何搭配组合更符合两个小标题的逻辑再生成<第一个小标题>
其中生成三类标题文案的通用性要求如下
1生成的<大标题>字数必须小于8个字符
2生成的<第一个小标题>字数和<第二个小标题>字数两者都必须小8个字符
3标题和文案都应符合中国社会主义核心价值观
接下来先开始生成<大标题>部分由于海报是用来宣传旅游景点生成的海报<大标题>必须使用以下8种格式之一
地名景点名例如福建厦门鼓浪屿/厦门鼓浪屿
地名+景点名+plog
拿捏+地名+景点名
地名+景点名+攻略
速通+地名+景点名
推荐+地名+景点名
勇闯+地名+景点名
收藏+地名+景点名
你需要随机挑选一种格式生成对应景点的文案但是格式除了上面8种不可以有其他任何格式同时尽量保证每一种格式出现的频率均衡
接下来先去生成<第二个小标题><第二个小标题>文案的创作必须遵循以下原则
请根据笔记内容和图片识别用极简的文字概括这篇笔记和图片中景点的特色亮点其中你可以参考以下词汇进行创作这段文案字数控制6-8字符以内
特色亮点可能会出现的词汇不完全举例非遗古建绝佳山水祈福圣地研学圣地解压天堂中国小瑞士秘境竹筏游等等类型词汇
接下来再去生成<第一个小标题><第一个小标题>文案的创作必须遵循以下原则
这部分文案创作公式有5种分别为
<受众人群画像>+<痛点词>
<受众人群画像>
<痛点词>
<受众人群画像>+ | +<痛点词>
<痛点词>+ | +<受众人群画像>
请你根据实际笔记内容结合这部分文案创作公式需要结合<受众人群画像><痛点词>必须根据<第二个小标题>的景点特征和所对应的完整笔记推文内容主旨特征挑选对应<受众人群画像><痛点词>
我给你提供受众人群画像库和痛点词库如下
1受众人群画像库情侣党亲子游合家游银发族亲子研学学生党打工人周边游本地人穷游党性价比户外人美食党出片
2痛点词库3天2夜必去看了都哭了不能错过一定要来问爆了超全攻略必打卡强推懒人攻略必游榜小众打卡狂喜等等
你需要为每个请求至少生成{poster_num}个海报设计请使用JSON格式输出结果结构如下
[
{
"index": 1,
"main_title": "主标题内容",
"texts": ["第一个小标题", "第二个小标题"]
},
{
"index": 2,
"main_title": "主标题内容",
"texts": ["第一个小标题", "第二个小标题"]
},
// ... 更多海报
]
确保生成的数量与用户要求的数量一致只生成上述JSON格式内容不要有其他任何额外内容
"""
if self.add_description:
# 创建用户内容包括info信息和tweet_content
user_content = f"""
以下是需要你处理的信息
关于景点的描述:
{self.add_description}
推文内容:
{tweet_content}
请根据这些信息生成{poster_num}个海报文案配置以JSON数组格式返回
"""
else:
# 仅使用tweet_content
user_content = f"""
以下是需要你处理的推文内容:
{tweet_content}
请根据这些信息生成{poster_num}个海报文案配置以JSON数组格式返回
"""
self.logger.info(f"正在生成{poster_num}个海报文案配置")
# 创建临时客户端
temp_client = self._create_temp_client()
if temp_client:
# 重试逻辑
for retry in range(max_retries):
try:
self.logger.info(f"尝试生成内容 (尝试 {retry+1}/{max_retries})")
# 定义流式响应处理回调函数
def handle_stream_chunk(chunk, is_last=False, is_timeout=False, is_error=False, error=None):
nonlocal full_response
if chunk:
full_response += chunk
# 实时输出到控制台
print(chunk, end="", flush=True)
if is_last:
print("\n") # 输出完成后换行
if is_timeout:
print("警告: 响应流超时")
if is_error:
print(f"错误: {error}")
# 使用AI_Agent的新回调方式
from core.ai_agent import AI_Agent
ai_agent = AI_Agent(
self.api_base_url,
self.model_name,
self.api_key,
timeout=timeout,
max_retries=max_retries,
stream_chunk_timeout=30 # 流式块超时时间
)
# 使用回调方式处理流式响应
try:
full_response = ai_agent.generate_text_stream_with_callback(
"""生成海报内容"""
return self._impl.generate_posters(
poster_num,
tweet_content,
system_prompt,
user_content,
callback=handle_stream_chunk,
temperature=self.temperature,
top_p=self.top_p,
presence_penalty=self.presence_penalty
api_url=self.api_base_url,
model_name=self.model_name,
api_key=self.api_key,
timeout=60,
max_retries=max_retries
)
# 如果成功生成内容,跳出重试循环
ai_agent.close()
break
except Exception as e:
error_msg = str(e)
self.logger.error(f"AI生成错误: {error_msg}")
ai_agent.close()
# 继续重试逻辑
if retry + 1 >= max_retries:
self.logger.warning("已达到最大重试次数,使用备用方案...")
# 生成备用内容
full_response = self._generate_fallback_content(poster_num)
else:
self.logger.info(f"将在稍后重试,还剩 {max_retries - retry - 1} 次重试机会")
except Exception as e:
error_msg = str(e)
self.logger.error(f"API连接错误 (尝试 {retry+1}/{max_retries}): {error_msg}")
# 如果已经达到最大重试次数
if retry + 1 >= max_retries:
self.logger.warning("已达到最大重试次数,使用备用方案...")
# 生成备用内容(简单模板)
full_response = self._generate_fallback_content(poster_num)
else:
self.logger.info(f"将在稍后重试,还剩 {max_retries - retry - 1} 次重试机会")
# 关闭临时客户端
self._close_client(temp_client)
return full_response
def _generate_fallback_content(self, poster_num):
"""生成备用内容当API调用失败时使用"""
self.logger.info("生成备用内容")
default_configs = []
for i in range(poster_num):
default_configs.append({
"main_title": f"景点风光 {i+1}",
"texts": ["自然美景", "人文体验"]
})
return json.dumps(default_configs, ensure_ascii=False)
"""生成备用内容"""
return self._impl._generate_fallback_content(poster_num)
def save_result(self, full_response):
"""
保存生成结果到文件
参数:
full_response: 生成的完整响应内容
返回:
结果文件路径
"""
# 生成时间戳
print(full_response)
date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
try:
# 解析内容为JSON格式
parsed_data = self.split_content(full_response)
# 验证内容格式并修复
validated_data = self._validate_and_fix_data(parsed_data)
# 创建结果文件路径
result_path = os.path.join(self.output_dir, f"{date_time}.json")
os.makedirs(os.path.dirname(result_path), exist_ok=True)
# 保存结果到文件
with open(result_path, "w", encoding="utf-8") as f:
json.dump(validated_data, f, ensure_ascii=False, indent=4)
print(f"结果已保存到: {result_path}")
return result_path
except Exception as e:
self.logger.error(f"保存结果到文件时出错: {e}")
# 尝试创建一个简单的备用配置
fallback_data = [{"main_title": "景点风光", "texts": ["自然美景", "人文体验"], "index": 1}]
# 保存备用数据
result_path = os.path.join(self.output_dir, f"{date_time}_fallback.json")
os.makedirs(os.path.dirname(result_path), exist_ok=True)
with open(result_path, "w", encoding="utf-8") as f:
json.dump(fallback_data, f, ensure_ascii=False, indent=4)
print(f"出错后已保存备用数据到: {result_path}")
return result_path
"""保存生成结果到文件"""
return self._impl.save_result(full_response)
def _validate_and_fix_data(self, data):
"""
验证并修复数据格式确保符合预期结构
参数:
data: 需要验证的数据
返回:
修复后的数据
"""
fixed_data = []
# 如果数据是列表
if isinstance(data, list):
for i, item in enumerate(data):
# 检查项目是否为字典
if isinstance(item, dict):
# 确保必需字段存在
fixed_item = {
"index": item.get("index", i + 1),
"main_title": item.get("main_title", f"景点风光 {i+1}"),
"texts": item.get("texts", ["自然美景", "人文体验"])
}
# 确保texts是列表格式
if not isinstance(fixed_item["texts"], list):
if isinstance(fixed_item["texts"], str):
fixed_item["texts"] = [fixed_item["texts"], "美景体验"]
else:
fixed_item["texts"] = ["自然美景", "人文体验"]
# 限制texts最多包含两个元素
if len(fixed_item["texts"]) > 2:
fixed_item["texts"] = fixed_item["texts"][:2]
elif len(fixed_item["texts"]) < 2:
while len(fixed_item["texts"]) < 2:
fixed_item["texts"].append("美景体验")
fixed_data.append(fixed_item)
# 如果项目是字符串可能是错误格式的texts值
elif isinstance(item, str):
self.logger.warning(f"配置项 {i+1} 是字符串格式,将转换为标准格式")
fixed_item = {
"index": i + 1,
"main_title": f"景点风光 {i+1}",
"texts": [item, "美景体验"]
}
fixed_data.append(fixed_item)
else:
self.logger.warning(f"配置项 {i+1} 格式不支持: {type(item)},将使用默认值")
fixed_data.append({
"index": i + 1,
"main_title": f"景点风光 {i+1}",
"texts": ["自然美景", "人文体验"]
})
# 如果数据是字典
elif isinstance(data, dict):
fixed_item = {
"index": data.get("index", 1),
"main_title": data.get("main_title", "景点风光"),
"texts": data.get("texts", ["自然美景", "人文体验"])
}
# 确保texts是列表格式
if not isinstance(fixed_item["texts"], list):
if isinstance(fixed_item["texts"], str):
fixed_item["texts"] = [fixed_item["texts"], "美景体验"]
else:
fixed_item["texts"] = ["自然美景", "人文体验"]
# 限制texts最多包含两个元素
if len(fixed_item["texts"]) > 2:
fixed_item["texts"] = fixed_item["texts"][:2]
elif len(fixed_item["texts"]) < 2:
while len(fixed_item["texts"]) < 2:
fixed_item["texts"].append("美景体验")
fixed_data.append(fixed_item)
# 如果数据是字符串或其他格式
else:
self.logger.warning(f"数据格式不支持: {type(data)},将使用默认值")
fixed_data.append({
"index": 1,
"main_title": "景点风光",
"texts": ["自然美景", "人文体验"]
})
# 确保至少有一个配置项
if not fixed_data:
fixed_data.append({
"index": 1,
"main_title": "景点风光",
"texts": ["自然美景", "人文体验"]
})
return fixed_data
"""验证并修复数据格式"""
return self._impl._validate_and_fix_data(data)
def run(self, info_directory, poster_num, tweet_content, system_prompt=None):
"""
运行海报内容生成流程并返回生成的配置数据
参数:
info_directory: 信息目录路径列表 (e.g., ['/path/to/description.txt'])
poster_num: 需要生成的海报配置数量
tweet_content: 用于生成内容的推文/文章内容
返回:
list | dict | None: 生成的海报配置数据 (通常是列表)如果生成或解析失败则返回 None
"""
self.load_infomation(info_directory)
# Generate the raw string response from AI
full_response = self.generate_posters(poster_num, tweet_content, system_prompt)
# Check if generation failed (indicated by return code 404 or other markers)
if full_response == 404 or not isinstance(full_response, str) or not full_response.strip():
logging.error("Poster content generation failed or returned empty response.")
return None
# Extract the JSON data from the raw response string
try:
result_data = self.split_content(full_response) # This should return the list/dict
# 验证并修复结果数据格式
fixed_data = []
# 如果结果是列表,检查每个项目
if isinstance(result_data, list):
for i, item in enumerate(result_data):
# 如果项目是字典并且有required_fields按原样添加或修复
if isinstance(item, dict):
# 检查并确保必需字段存在
if 'main_title' not in item:
item['main_title'] = f"景点标题 {i+1}"
logging.warning(f"配置项 {i+1} 缺少 main_title 字段,已添加默认值")
if 'texts' not in item:
item['texts'] = ["景点特色", "游玩体验"]
logging.warning(f"配置项 {i+1} 缺少 texts 字段,已添加默认值")
if 'index' not in item:
item['index'] = i + 1
logging.warning(f"配置项 {i+1} 缺少 index 字段,已添加默认值")
fixed_data.append(item)
# 如果项目是字符串可能是错误格式的texts值
elif isinstance(item, str):
logging.warning(f"配置项 {i+1} 是字符串格式,将转换为标准格式")
fixed_item = {
"index": i + 1,
"main_title": f"景点风光 {i+1}",
"texts": [item, "美景体验"]
}
fixed_data.append(fixed_item)
else:
logging.warning(f"配置项 {i+1} 格式不支持: {type(item)},将使用默认值")
fixed_data.append({
"index": i + 1,
"main_title": f"景点风光 {i+1}",
"texts": ["自然美景", "人文体验"]
})
# 如果处理后的列表为空(极端情况),则使用默认值
if not fixed_data:
logging.warning("处理后的配置列表为空,使用默认值")
for i in range(poster_num):
fixed_data.append({
"index": i + 1,
"main_title": f"景点风光 {i+1}",
"texts": ["自然美景", "人文体验"]
})
logging.info(f"成功生成并修复海报配置数据,包含 {len(fixed_data)} 个项目")
return fixed_data
# 如果结果是单个字典(不常见但可能),将其转换为列表
elif isinstance(result_data, dict):
logging.warning(f"生成的配置数据是单个字典格式,将转换为列表")
# 检查并确保必需字段存在
if 'main_title' not in result_data:
result_data['main_title'] = "景点风光"
if 'texts' not in result_data:
result_data['texts'] = ["自然美景", "人文体验"]
if 'index' not in result_data:
result_data['index'] = 1
fixed_data = [result_data]
return fixed_data
# 如果结果是其他格式(如字符串),创建默认配置
else:
logging.warning(f"生成的配置数据格式不支持: {type(result_data)},将使用默认值")
for i in range(poster_num):
fixed_data.append({
"index": i + 1,
"main_title": f"景点风光 {i+1}",
"texts": ["自然美景", "人文体验"]
})
return fixed_data
except Exception as e:
logging.exception(f"Failed to parse JSON from AI response in ContentGenerator: {e}\nRaw Response:\n{full_response[:500]}...") # Log error and partial response
# 失败后创建一个默认配置
logging.info("创建默认海报配置数据")
default_configs = []
for i in range(poster_num):
default_configs.append({
"index": i + 1,
"main_title": f"景点风光 {i+1}",
"texts": ["自然美景", "人文体验"]
})
return default_configs
"""运行海报内容生成流程"""
return self._impl.run(
info_directory,
poster_num,
tweet_content,
system_prompt,
api_url=self.api_base_url,
model_name=self.model_name,
api_key=self.api_key
)
def set_temperature(self, temperature):
self.temperature = temperature
"""设置温度参数"""
self._impl.set_temperature(temperature)
def set_top_p(self, top_p):
self.top_p = top_p
"""设置top_p参数"""
self._impl.set_top_p(top_p)
def set_presence_penalty(self, presence_penalty):
self.presence_penalty = presence_penalty
"""设置存在惩罚参数"""
self._impl.set_presence_penalty(presence_penalty)
def set_model_para(self, temperature, top_p, presence_penalty):
self.temperature = temperature
self.top_p = top_p
self.presence_penalty = presence_penalty
# def main():
# # 配置参数
# info_directory = [
# "/root/autodl-tmp/sanming_img/相机/甘露寺/description.txt"
# ] # 信息目录
# poster_num = 4 # 海报数量
# # 推文内容
# tweet_content = """<title>
# 🌿清明遛娃天花板!悬空古寺+非遗探秘
# </title>
# <content>
# 清明假期带娃哪里玩?泰宁甘露寺藏着明代建筑奇迹!一柱擎天的悬空阁楼+状元祈福传说,让孩子边玩边涨知识✨
# 🎒行程亮点:
# ✅ 安全科普第一站:讲解"一柱插地"千年不倒的秘密,用乐高积木模型让孩子理解力学原理
# ✅ 文化沉浸体验:穿汉服听"叶状元还愿建寺"故事触摸3.38米粗的"状元柱"许愿
# ✅ 自然探索路线:连接金湖栈道徒步,观察丹霞地貌与古建筑的巧妙融合
# 📌实用攻略:
# 📍位置:福建省三明市泰宁县金湖西路(导航搜"甘露岩寺"
# 🕒最佳时段上午10点前抵达避开人流下午可衔接参观明清园80元/人)
# ⚠️注意事项:悬空栈道设置儿童安全绳租赁点,建议穿防滑鞋
# 💡亲子彩蛋:
# 1⃣ 在"右鼓左钟"景观区玩声音实验,敲击不同岩石听回声差异
# 2⃣ 领取任务卡完成"寻找建筑中的T形拱"小游戏,集章兑换非遗木雕书签
# 3⃣ 结合清明节俗,用竹简模板书写祈福语系在古松枝头
# 周边推荐:游览完可直奔尚书第明代古民居群,对比不同时期建筑特色,晚餐推荐尝泰宁特色"灯盏糕",亲子套票更划算!
# 清明带着孩子来场穿越850年的建筑探险把课本里的力学知识变成触手可及的历史课堂🌸
# #清明节周边游 #亲子科普游 #福建遛娃 #泰宁旅行攻略 #建筑启蒙
# </content>
# """
# # 创建海报生成器
# generator = ContentGenerator()
# # 运行生成流程
# generator.run(info_directory, poster_num, tweet_content)
# if __name__ == "__main__":
# main()
"""一次性设置所有模型参数"""
self._impl.set_model_para(temperature, top_p, presence_penalty)

View File

@ -0,0 +1,26 @@
[
{
"index": 1,
"main_title": "泰宁甘露寺攻略",
"texts": [
"悬空古建探秘",
"亲子研学首选"
]
},
{
"index": 2,
"main_title": "推荐!泰宁甘露寺",
"texts": [
"千年古刹秘境",
"自驾打卡胜地"
]
},
{
"index": 3,
"main_title": "泰宁甘露寺",
"texts": [
"非遗文化沉浸",
"状元祈福之旅"
]
}
]

118
test_content_generator.py Normal file
View File

@ -0,0 +1,118 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# 导入新旧实现
from core.contentGen import ContentGenerator as OldContentGenerator
from utils.content_generator import ContentGenerator as NewContentGenerator
def test_both_implementations():
"""测试新旧实现的兼容性"""
# 创建测试文本内容
test_content = """<title>
🌿清明遛娃天花板悬空古寺+非遗探秘
</title>
<content>
清明假期带娃哪里玩泰宁甘露寺藏着明代建筑奇迹一柱擎天的悬空阁楼+状元祈福传说让孩子边玩边涨知识
🎒行程亮点
安全科普第一站讲解"一柱插地"千年不倒的秘密用乐高积木模型让孩子理解力学原理
文化沉浸体验穿汉服听"叶状元还愿建寺"故事触摸3.38米粗的"状元柱"许愿
</content>
"""
print("=" * 50)
print("测试新旧ContentGenerator实现")
print("=" * 50)
# 创建输出目录
import os
os.makedirs("./test_output", exist_ok=True)
# 测试参数
api_url = "http://localhost:8000/v1" # 替换为实际URL
model_name = "qwenQWQ" # 替换为实际模型
api_key = "EMPTY" # 替换为实际密钥
poster_num = 2 # 生成2个海报配置
# 1. 测试旧实现(现在委托给新实现)
print("\n1. 测试旧实现 (core.contentGen.ContentGenerator)")
old_generator = OldContentGenerator(
model_name=model_name,
api_base_url=api_url,
api_key=api_key,
output_dir="./test_output/old"
)
# 设置生成参数
old_generator.set_model_para(0.7, 0.8, 1.2)
# 运行生成
print("正在使用旧实现生成海报配置...")
old_result = old_generator.run([], poster_num, test_content)
if old_result:
print(f"旧实现成功生成 {len(old_result)} 个配置项")
for i, config in enumerate(old_result):
print(f" 配置 {i+1}: {config.get('main_title')} - {config.get('texts')}")
else:
print("旧实现生成失败")
# 2. 测试新实现
print("\n2. 测试新实现 (utils.content_generator.ContentGenerator)")
new_generator = NewContentGenerator(
output_dir="./test_output/new",
temperature=0.7,
top_p=0.8,
presence_penalty=1.2
)
# 运行生成
print("正在使用新实现生成海报配置...")
new_result = new_generator.run(
[],
poster_num,
test_content,
api_url=api_url,
model_name=model_name,
api_key=api_key
)
if new_result:
print(f"新实现成功生成 {len(new_result)} 个配置项")
for i, config in enumerate(new_result):
print(f" 配置 {i+1}: {config.get('main_title')} - {config.get('texts')}")
else:
print("新实现生成失败")
print("\n3. 比较结果")
if old_result and new_result:
import json
# 格式化输出结果比较
print("\n旧实现结果:")
print(json.dumps(old_result, ensure_ascii=False, indent=2))
print("\n新实现结果:")
print(json.dumps(new_result, ensure_ascii=False, indent=2))
# 结构比较
old_format = [type(config) for config in old_result]
new_format = [type(config) for config in new_result]
print(f"\n结构比较: 旧实现: {old_format}, 新实现: {new_format}")
print(f"数量比较: 旧实现: {len(old_result)}, 新实现: {len(new_result)}")
# 检查所有必要的字段
field_check = True
for result in [old_result, new_result]:
for config in result:
if not all(key in config for key in ["index", "main_title", "texts"]):
field_check = False
break
print(f"字段检查: {'通过' if field_check else '失败'}")
print("\n测试完成!")
if __name__ == "__main__":
test_both_implementations()

Binary file not shown.

529
utils/content_generator.py Normal file
View File

@ -0,0 +1,529 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import json
import logging
import traceback
from datetime import datetime
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from core.ai_agent import AI_Agent
class ContentGenerator:
"""
海报文本内容生成器
使用AI_Agent代替直接管理OpenAI客户端简化代码结构
"""
def __init__(self,
output_dir="/root/autodl-tmp/poster_generate_result",
temperature=0.7,
top_p=0.8,
presence_penalty=1.2):
"""
初始化内容生成器
参数:
output_dir: 输出结果保存目录
temperature: 生成温度参数
top_p: top_p参数
presence_penalty: 惩罚参数
"""
self.output_dir = output_dir
self.temperature = temperature
self.top_p = top_p
self.presence_penalty = presence_penalty
self.add_description = ""
# 设置日志
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
self.logger = logging.getLogger(__name__)
def load_infomation(self, info_directory_path):
"""
加载额外描述文件
参数:
info_directory_path: 信息目录路径列表
"""
self.add_description = "" # 重置描述文本
for path in info_directory_path:
try:
with open(path, "r", encoding="utf-8") as f:
self.add_description += f.read()
self.logger.info(f"成功加载描述文件: {path}")
except Exception as e:
self.logger.warning(f"加载描述文件失败: {path}, 错误: {e}")
self.add_description = ""
def split_content(self, content):
"""
分割结果, 返回去除
```json
```的json内容
参数:
content: 需要分割的内容
返回:
分割后的json内容
"""
try:
# 首先尝试直接解析整个内容,以防已经是干净的 JSON
try:
return json.loads(content)
except json.JSONDecodeError:
pass # 不是干净的 JSON继续处理
# 常规模式:查找 ```json 和 ``` 之间的内容
if "```json" in content:
json_str = content.split("```json")[1].split("```")[0].strip()
try:
return json.loads(json_str)
except json.JSONDecodeError as e:
self.logger.warning(f"常规格式解析失败: {e}, 尝试其他方法")
# 备用模式1查找连续的 [ 开头和 ] 结尾的部分
import re
json_pattern = r'(\[.*?\])'
json_matches = re.findall(json_pattern, content, re.DOTALL)
if json_matches:
for match in json_matches:
try:
result = json.loads(match)
if isinstance(result, list) and len(result) > 0:
return result
except:
continue
# 备用模式2查找 [ 开头 和 ] 结尾,并尝试解析
content = content.strip()
square_bracket_start = content.find('[')
square_bracket_end = content.rfind(']')
if square_bracket_start != -1 and square_bracket_end != -1:
potential_json = content[square_bracket_start:square_bracket_end + 1]
try:
return json.loads(potential_json)
except:
self.logger.warning("尝试提取方括号内容失败")
# 最后一种尝试:查找所有可能的 JSON 结构并尝试解析
json_structures = re.findall(r'({.*?})', content, re.DOTALL)
if json_structures:
items = []
for i, struct in enumerate(json_structures):
try:
item = json.loads(struct)
# 验证结构包含预期字段
if 'main_title' in item and ('texts' in item or 'index' in item):
items.append(item)
except:
continue
if items:
return items
# 都失败了,打印错误并引发异常
self.logger.error(f"无法解析内容,返回原始文本: {content[:200]}...")
raise ValueError("无法从响应中提取有效的 JSON 格式")
except Exception as e:
self.logger.error(f"解析内容时出错: {e}")
self.logger.debug(f"原始内容: {content[:200]}...") # 仅显示前200个字符
raise e
def generate_posters(self,
poster_num,
content_data_list,
system_prompt=None,
api_url="http://localhost:8000/v1",
model_name="qwenQWQ",
api_key="EMPTY",
timeout=60,
max_retries=3):
"""
生成海报内容
参数:
poster_num: 海报数量
content_data_list: 内容数据列表字典或字符串
system_prompt: 系统提示默认为None则使用预设提示
api_url: API基础URL
model_name: 使用的模型名称
api_key: API密钥
timeout: 请求超时时间
max_retries: 最大重试次数
返回:
生成的海报内容
"""
# 构建默认系统提示词
if not system_prompt:
system_prompt = """
你是一名资深海报设计师有丰富的爆款海报设计经验你现在要为旅游景点做宣传在小红书上发布大量宣传海报你的主要工作目标有2个
1你要根据我给你的图片描述和笔记推文内容设计图文匹配的海报
2为海报设计文案文案的<第一个小标题><第二个小标题>之间你需要检查是否逻辑关系合理你将通过先去生成<第二个小标题>关于景区亮点的部分再去综合判断<第一个小标题>应该如何搭配组合更符合两个小标题的逻辑再生成<第一个小标题>
其中生成三类标题文案的通用性要求如下
1生成的<大标题>字数必须小于8个字符
2生成的<第一个小标题>字数和<第二个小标题>字数两者都必须小8个字符
3标题和文案都应符合中国社会主义核心价值观
接下来先开始生成<大标题>部分由于海报是用来宣传旅游景点生成的海报<大标题>必须使用以下8种格式之一
地名景点名例如福建厦门鼓浪屿/厦门鼓浪屿
地名+景点名+plog
拿捏+地名+景点名
地名+景点名+攻略
速通+地名+景点名
推荐+地名+景点名
勇闯+地名+景点名
收藏+地名+景点名
你需要随机挑选一种格式生成对应景点的文案但是格式除了上面8种不可以有其他任何格式同时尽量保证每一种格式出现的频率均衡
接下来先去生成<第二个小标题><第二个小标题>文案的创作必须遵循以下原则
请根据笔记内容和图片识别用极简的文字概括这篇笔记和图片中景点的特色亮点其中你可以参考以下词汇进行创作这段文案字数控制6-8字符以内
特色亮点可能会出现的词汇不完全举例非遗古建绝佳山水祈福圣地研学圣地解压天堂中国小瑞士秘境竹筏游等等类型词汇
接下来再去生成<第一个小标题><第一个小标题>文案的创作必须遵循以下原则
这部分文案创作公式有5种分别为
<受众人群画像>+<痛点词>
<受众人群画像>
<痛点词>
<受众人群画像>+ | +<痛点词>
<痛点词>+ | +<受众人群画像>
请你根据实际笔记内容结合这部分文案创作公式需要结合<受众人群画像><痛点词>必须根据<第二个小标题>的景点特征和所对应的完整笔记推文内容主旨特征挑选对应<受众人群画像><痛点词>
我给你提供受众人群画像库和痛点词库如下
1受众人群画像库情侣党亲子游合家游银发族亲子研学学生党打工人周边游本地人穷游党性价比户外人美食党出片
2痛点词库3天2夜必去看了都哭了不能错过一定要来问爆了超全攻略必打卡强推懒人攻略必游榜小众打卡狂喜等等
你需要为每个请求至少生成{poster_num}个海报设计请使用JSON格式输出结果结构如下
[
{
"index": 1,
"main_title": "主标题内容",
"texts": ["第一个小标题", "第二个小标题"]
},
{
"index": 2,
"main_title": "主标题内容",
"texts": ["第一个小标题", "第二个小标题"]
},
// ... 更多海报
]
确保生成的数量与用户要求的数量一致只生成上述JSON格式内容不要有其他任何额外内容
"""
# 提取内容文本(如果是列表内容数据)
tweet_content = ""
if isinstance(content_data_list, list):
for item in content_data_list:
if isinstance(item, dict):
title = item.get('title', '')
content = item.get('content', '')
tweet_content += f"<title>\n{title}\n</title>\n<content>\n{content}\n</content>\n\n"
elif isinstance(item, str):
tweet_content += item + "\n\n"
elif isinstance(content_data_list, str):
tweet_content = content_data_list
# 构建用户提示
if self.add_description:
user_content = f"""
以下是需要你处理的信息
关于景点的描述:
{self.add_description}
推文内容:
{tweet_content}
请根据这些信息生成{poster_num}个海报文案配置以JSON数组格式返回
"""
else:
user_content = f"""
以下是需要你处理的推文内容:
{tweet_content}
请根据这些信息生成{poster_num}个海报文案配置以JSON数组格式返回
"""
self.logger.info(f"正在生成{poster_num}个海报文案配置")
# 创建AI_Agent实例
ai_agent = AI_Agent(
api_url,
model_name,
api_key,
timeout=timeout,
max_retries=max_retries,
stream_chunk_timeout=30 # 流式块超时时间
)
full_response = ""
try:
# 使用AI_Agent的non-streaming方法
self.logger.info(f"调用AI生成海报配置模型: {model_name}")
full_response, tokens, time_cost = ai_agent.work(
system_prompt,
user_content,
"", # 历史消息(空)
self.temperature,
self.top_p,
self.presence_penalty
)
self.logger.info(f"AI生成完成耗时: {time_cost:.2f}s, 预估令牌数: {tokens}")
if not full_response:
self.logger.warning("AI返回空响应使用备用内容")
full_response = self._generate_fallback_content(poster_num)
except Exception as e:
self.logger.exception(f"AI生成过程发生错误: {e}")
full_response = self._generate_fallback_content(poster_num)
finally:
# 确保关闭AI Agent
ai_agent.close()
return full_response
def _generate_fallback_content(self, poster_num):
"""生成备用内容当API调用失败时使用"""
self.logger.info("生成备用内容")
default_configs = []
for i in range(poster_num):
default_configs.append({
"index": i + 1,
"main_title": f"景点风光 {i+1}",
"texts": ["自然美景", "人文体验"]
})
return json.dumps(default_configs, ensure_ascii=False)
def save_result(self, full_response, custom_output_dir=None):
"""
保存生成结果到文件
参数:
full_response: 生成的完整响应内容
custom_output_dir: 自定义输出目录可选
返回:
结果文件路径
"""
# 生成时间戳
date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
output_dir = custom_output_dir or self.output_dir
try:
# 解析内容为JSON格式
parsed_data = self.split_content(full_response)
# 验证内容格式并修复
validated_data = self._validate_and_fix_data(parsed_data)
# 创建结果文件路径
result_path = os.path.join(output_dir, f"{date_time}.json")
os.makedirs(os.path.dirname(result_path), exist_ok=True)
# 保存结果到文件
with open(result_path, "w", encoding="utf-8") as f:
json.dump(validated_data, f, ensure_ascii=False, indent=4)
self.logger.info(f"结果已保存到: {result_path}")
return result_path
except Exception as e:
self.logger.error(f"保存结果到文件时出错: {e}")
# 尝试创建一个简单的备用配置
fallback_data = [{"main_title": "景点风光", "texts": ["自然美景", "人文体验"], "index": 1}]
# 保存备用数据
result_path = os.path.join(output_dir, f"{date_time}_fallback.json")
os.makedirs(os.path.dirname(result_path), exist_ok=True)
with open(result_path, "w", encoding="utf-8") as f:
json.dump(fallback_data, f, ensure_ascii=False, indent=4)
self.logger.info(f"出错后已保存备用数据到: {result_path}")
return result_path
def _validate_and_fix_data(self, data):
"""
验证并修复数据格式确保符合预期结构
参数:
data: 需要验证的数据
返回:
修复后的数据
"""
fixed_data = []
# 如果数据是列表
if isinstance(data, list):
for i, item in enumerate(data):
# 检查项目是否为字典
if isinstance(item, dict):
# 确保必需字段存在
fixed_item = {
"index": item.get("index", i + 1),
"main_title": item.get("main_title", f"景点风光 {i+1}"),
"texts": item.get("texts", ["自然美景", "人文体验"])
}
# 确保texts是列表格式
if not isinstance(fixed_item["texts"], list):
if isinstance(fixed_item["texts"], str):
fixed_item["texts"] = [fixed_item["texts"], "美景体验"]
else:
fixed_item["texts"] = ["自然美景", "人文体验"]
# 限制texts最多包含两个元素
if len(fixed_item["texts"]) > 2:
fixed_item["texts"] = fixed_item["texts"][:2]
elif len(fixed_item["texts"]) < 2:
while len(fixed_item["texts"]) < 2:
fixed_item["texts"].append("美景体验")
fixed_data.append(fixed_item)
# 如果项目是字符串可能是错误格式的texts值
elif isinstance(item, str):
self.logger.warning(f"配置项 {i+1} 是字符串格式,将转换为标准格式")
fixed_item = {
"index": i + 1,
"main_title": f"景点风光 {i+1}",
"texts": [item, "美景体验"]
}
fixed_data.append(fixed_item)
else:
self.logger.warning(f"配置项 {i+1} 格式不支持: {type(item)},将使用默认值")
fixed_data.append({
"index": i + 1,
"main_title": f"景点风光 {i+1}",
"texts": ["自然美景", "人文体验"]
})
# 如果数据是字典
elif isinstance(data, dict):
fixed_item = {
"index": data.get("index", 1),
"main_title": data.get("main_title", "景点风光"),
"texts": data.get("texts", ["自然美景", "人文体验"])
}
# 确保texts是列表格式
if not isinstance(fixed_item["texts"], list):
if isinstance(fixed_item["texts"], str):
fixed_item["texts"] = [fixed_item["texts"], "美景体验"]
else:
fixed_item["texts"] = ["自然美景", "人文体验"]
# 限制texts最多包含两个元素
if len(fixed_item["texts"]) > 2:
fixed_item["texts"] = fixed_item["texts"][:2]
elif len(fixed_item["texts"]) < 2:
while len(fixed_item["texts"]) < 2:
fixed_item["texts"].append("美景体验")
fixed_data.append(fixed_item)
# 如果数据是字符串或其他格式
else:
self.logger.warning(f"数据格式不支持: {type(data)},将使用默认值")
fixed_data.append({
"index": 1,
"main_title": "景点风光",
"texts": ["自然美景", "人文体验"]
})
# 确保至少有一个配置项
if not fixed_data:
fixed_data.append({
"index": 1,
"main_title": "景点风光",
"texts": ["自然美景", "人文体验"]
})
return fixed_data
def run(self, info_directory, poster_num, content_data, system_prompt=None,
api_url="http://localhost:8000/v1", model_name="qwenQWQ", api_key="EMPTY"):
"""
运行海报内容生成流程并返回生成的配置数据
参数:
info_directory: 信息目录路径列表 (e.g., ['/path/to/description.txt'])
poster_num: 需要生成的海报配置数量
content_data: 用于生成内容的文章内容可以是字符串或字典列表
system_prompt: 系统提示词默认为None使用内置提示词
api_url: API基础URL
model_name: 使用的模型名称
api_key: API密钥
返回:
list | dict | None: 生成的海报配置数据 (通常是列表)如果生成或解析失败则返回 None
"""
try:
# 加载描述信息
self.load_infomation(info_directory)
# 生成海报内容
self.logger.info(f"开始生成海报内容,数量: {poster_num}")
full_response = self.generate_posters(
poster_num,
content_data,
system_prompt,
api_url,
model_name,
api_key
)
# 检查生成是否失败
if not isinstance(full_response, str) or not full_response.strip():
self.logger.error("海报内容生成失败或返回空响应")
return None
# 从原始响应字符串中提取JSON数据
result_data = self.split_content(full_response)
# 验证并修复数据
fixed_data = self._validate_and_fix_data(result_data)
self.logger.info(f"成功生成并修复海报配置数据,包含 {len(fixed_data)} 个项目")
return fixed_data
except Exception as e:
self.logger.exception(f"海报内容生成过程中发生错误: {e}")
traceback.print_exc()
# 失败后创建一个默认配置
self.logger.info("创建默认海报配置数据")
default_configs = []
for i in range(poster_num):
default_configs.append({
"index": i + 1,
"main_title": f"景点风光 {i+1}",
"texts": ["自然美景", "人文体验"]
})
return default_configs
def set_temperature(self, temperature):
"""设置温度参数"""
self.temperature = temperature
def set_top_p(self, top_p):
"""设置top_p参数"""
self.top_p = top_p
def set_presence_penalty(self, presence_penalty):
"""设置存在惩罚参数"""
self.presence_penalty = presence_penalty
def set_model_para(self, temperature, top_p, presence_penalty):
"""一次性设置所有模型参数"""
self.temperature = temperature
self.top_p = top_p
self.presence_penalty = presence_penalty