TravelContentCreator/core/contentGen.py

705 lines
30 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from openai import OpenAI
import pandas as pd
from datetime import datetime
import cv2
import time
import random
import json
import logging
class ContentGenerator:
def __init__(self,
model_name="qwenQWQ",
api_base_url="http://localhost:8000/v1",
api_key="EMPTY",
output_dir="/root/autodl-tmp/poster_generate_result",
):
"""
初始化海报生成器
参数:
csv_path: CSV文件路径
img_base_dir: 图片基础目录
output_dir: 输出结果保存目录
model_name: 使用的模型名称
api_base_url: API基础URL
api_key: API密钥
"""
self.output_dir = output_dir
self.model_name = model_name
self.api_base_url = api_base_url
self.api_key = api_key
# 不在初始化时创建OpenAI客户端而是在需要时临时创建
self.client = None
# 初始化数据
self.df = None
self.all_images_info = []
self.structured_prompt = ""
self.current_img_info = None
self.add_description = ""
self.temperature = 0.7
self.top_p = 0.8
self.presence_penalty = 1.2
# 设置日志
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
self.logger = logging.getLogger(__name__)
def load_infomation(self, info_directory_path):
"""
加载额外描述文件
参数:
info_directory_path: 信息目录路径
"""
## 读取路径下的所有文件
for path in info_directory_path:
# file_extend_path = os.path.join(self.img_base_dir, info_directory, "description.txt")
try:
with open(path, "r") as f:
self.add_description += f.read()
except:
self.add_description = ""
def _create_temp_client(self):
"""
创建临时OpenAI客户端
返回:
OpenAI客户端实例
"""
try:
import gc
# 强制垃圾回收
gc.collect()
# 创建新的客户端实例
print(f"创建临时OpenAI客户端API URL: {self.api_base_url}")
client = OpenAI(
base_url=self.api_base_url,
api_key=self.api_key
)
return client
except Exception as e:
print(f"创建OpenAI客户端失败: {str(e)}")
return None
def _close_client(self, client):
"""
关闭并清理OpenAI客户端
参数:
client: 需要关闭的客户端实例
"""
try:
# OpenAI客户端可能没有显式的close方法
# 将引用设为None让Python垃圾回收处理
client = None
import gc
gc.collect()
print("OpenAI客户端资源已释放")
except Exception as e:
print(f"关闭客户端失败: {str(e)}")
def split_content(self, content):
"""
分割结果, 返回去除
```json
```的json内容
参数:
content: 需要分割的内容
返回:
分割后的json内容
"""
try:
# 首先尝试直接解析整个内容,以防已经是干净的 JSON
try:
return json.loads(content)
except json.JSONDecodeError:
pass # 不是干净的 JSON继续处理
# 常规模式:查找 ```json 和 ``` 之间的内容
if "```json" in content:
json_str = content.split("```json")[1].split("```")[0].strip()
try:
return json.loads(json_str)
except json.JSONDecodeError as e:
print(f"常规格式解析失败: {e}, 尝试其他方法")
# 备用模式1查找连续的 { 开头和 } 结尾的部分
import re
json_pattern = r'(\[.*?\])'
json_matches = re.findall(json_pattern, content, re.DOTALL)
if json_matches:
for match in json_matches:
try:
result = json.loads(match)
if isinstance(result, list) and len(result) > 0:
return result
except:
continue
# 备用模式2查找 { 开头 和 } 结尾,并尝试解析
content = content.strip()
square_bracket_start = content.find('[')
square_bracket_end = content.rfind(']')
if square_bracket_start != -1 and square_bracket_end != -1:
potential_json = content[square_bracket_start:square_bracket_end + 1]
try:
return json.loads(potential_json)
except:
print("尝试提取方括号内容失败")
# 最后一种尝试:查找所有可能的 JSON 结构并尝试解析
json_structures = re.findall(r'({.*?})', content, re.DOTALL)
if json_structures:
items = []
for i, struct in enumerate(json_structures):
try:
item = json.loads(struct)
# 验证结构包含预期字段
if 'main_title' in item and ('texts' in item or 'index' in item):
items.append(item)
except:
continue
if items:
return items
# 都失败了,打印错误并引发异常
print(f"无法解析内容,返回原始文本: {content[:200]}...")
raise ValueError("无法从响应中提取有效的 JSON 格式")
except Exception as e:
print(f"解析内容时出错: {e}")
print(f"原始内容: {content[:200]}...") # 仅显示前200个字符
raise e
def generate_posters(self, poster_num, tweet_content, system_prompt=None, max_retries=3):
"""
生成海报内容
参数:
poster_num: 海报数量
tweet_content: 推文内容
system_prompt: 系统提示默认为None则使用预设提示
max_retries: 最大重试次数
返回:
生成的海报内容
"""
full_response = ""
timeout = 60 # 请求超时时间(秒)
if not system_prompt:
# 使用默认系统提示词
system_prompt = """
你是一名资深海报设计师有丰富的爆款海报设计经验你现在要为旅游景点做宣传在小红书上发布大量宣传海报。你的主要工作目标有2个
1、你要根据我给你的图片描述和笔记推文内容设计图文匹配的海报。
2、为海报设计文案文案的<第一个小标题>和<第二个小标题>之间你需要检查是否逻辑关系合理,你将通过先去生成<第二个小标题>关于景区亮点的部分,再去综合判断<第一个小标题>应该如何搭配组合更符合两个小标题的逻辑再生成<第一个小标题>。
其中,生成三类标题文案的通用性要求如下:
1、生成的<大标题>字数必须小于8个字符
2、生成的<第一个小标题>字数和<第二个小标题>字数两者都必须小8个字符
3、标题和文案都应符合中国社会主义核心价值观
接下来先开始生成<大标题>部分,由于海报是用来宣传旅游景点,生成的海报<大标题>必须使用以下8种格式之一
①地名+景点名(例如福建厦门鼓浪屿/厦门鼓浪屿);
②地名+景点名+plog
③拿捏+地名+景点名;
④地名+景点名+攻略;
⑤速通+地名+景点名
⑥推荐!+地名+景点名
⑦勇闯!+地名+景点名
⑧收藏!+地名+景点名
你需要随机挑选一种格式生成对应景点的文案但是格式除了上面8种不可以有其他任何格式同时尽量保证每一种格式出现的频率均衡。
接下来先去生成<第二个小标题><第二个小标题>文案的创作必须遵循以下原则:
请根据笔记内容和图片识别用极简的文字概括这篇笔记和图片中景点的特色亮点其中你可以参考以下词汇进行创作这段文案字数控制6-8字符以内
特色亮点可能会出现的词汇不完全举例:非遗、古建、绝佳山水、祈福圣地、研学圣地、解压天堂、中国小瑞士、秘境竹筏游等等类型词汇
接下来再去生成<第一个小标题><第一个小标题>文案的创作必须遵循以下原则:
这部分文案创作公式有5种分别为
①<受众人群画像>+<痛点词>
②<受众人群画像>
③<痛点词>
④<受众人群画像>+ | +<痛点词>
⑤<痛点词>+ | +<受众人群画像>
请你根据实际笔记内容,结合这部分文案创作公式,需要结合<受众人群画像>和<痛点词>时,必须根据<第二个小标题>的景点特征和所对应的完整笔记推文内容主旨,特征挑选对应<受众人群画像>和<痛点词>。
我给你提供受众人群画像库和痛点词库如下:
1、受众人群画像库情侣党、亲子游、合家游、银发族、亲子研学、学生党、打工人、周边游、本地人、穷游党、性价比、户外人、美食党、出片
2、痛点词库3天2夜、必去、看了都哭了、不能错过、一定要来、问爆了、超全攻略、必打卡、强推、懒人攻略、必游榜、小众打卡、狂喜等等。
你需要为每个请求至少生成{poster_num}个海报设计。请使用JSON格式输出结果结构如下
[
{
"index": 1,
"main_title": "主标题内容",
"texts": ["第一个小标题", "第二个小标题"]
},
{
"index": 2,
"main_title": "主标题内容",
"texts": ["第一个小标题", "第二个小标题"]
},
// ... 更多海报
]
确保生成的数量与用户要求的数量一致。只生成上述JSON格式内容不要有其他任何额外内容。
"""
if self.add_description:
# 创建用户内容包括info信息和tweet_content
user_content = f"""
以下是需要你处理的信息:
关于景点的描述:
{self.add_description}
推文内容:
{tweet_content}
请根据这些信息,生成{poster_num}个海报文案配置以JSON数组格式返回。
"""
else:
# 仅使用tweet_content
user_content = f"""
以下是需要你处理的推文内容:
{tweet_content}
请根据这些信息,生成{poster_num}个海报文案配置以JSON数组格式返回。
"""
self.logger.info(f"正在生成{poster_num}个海报文案配置")
# 创建临时客户端
temp_client = self._create_temp_client()
if temp_client:
# 重试逻辑
for retry in range(max_retries):
try:
self.logger.info(f"尝试生成内容 (尝试 {retry+1}/{max_retries})")
# 定义流式响应处理回调函数
def handle_stream_chunk(chunk, is_last=False, is_timeout=False, is_error=False, error=None):
nonlocal full_response
if chunk:
full_response += chunk
# 实时输出到控制台
print(chunk, end="", flush=True)
if is_last:
print("\n") # 输出完成后换行
if is_timeout:
print("警告: 响应流超时")
if is_error:
print(f"错误: {error}")
# 使用AI_Agent的新回调方式
from core.ai_agent import AI_Agent
ai_agent = AI_Agent(
self.api_base_url,
self.model_name,
self.api_key,
timeout=timeout,
max_retries=max_retries,
stream_chunk_timeout=30 # 流式块超时时间
)
# 使用回调方式处理流式响应
try:
full_response = ai_agent.generate_text_stream_with_callback(
system_prompt,
user_content,
callback=handle_stream_chunk,
temperature=self.temperature,
top_p=self.top_p,
presence_penalty=self.presence_penalty
)
# 如果成功生成内容,跳出重试循环
ai_agent.close()
break
except Exception as e:
error_msg = str(e)
self.logger.error(f"AI生成错误: {error_msg}")
ai_agent.close()
# 继续重试逻辑
if retry + 1 >= max_retries:
self.logger.warning("已达到最大重试次数,使用备用方案...")
# 生成备用内容
full_response = self._generate_fallback_content(poster_num)
else:
self.logger.info(f"将在稍后重试,还剩 {max_retries - retry - 1} 次重试机会")
except Exception as e:
error_msg = str(e)
self.logger.error(f"API连接错误 (尝试 {retry+1}/{max_retries}): {error_msg}")
# 如果已经达到最大重试次数
if retry + 1 >= max_retries:
self.logger.warning("已达到最大重试次数,使用备用方案...")
# 生成备用内容(简单模板)
full_response = self._generate_fallback_content(poster_num)
else:
self.logger.info(f"将在稍后重试,还剩 {max_retries - retry - 1} 次重试机会")
# 关闭临时客户端
self._close_client(temp_client)
return full_response
def _generate_fallback_content(self, poster_num):
"""生成备用内容当API调用失败时使用"""
self.logger.info("生成备用内容")
default_configs = []
for i in range(poster_num):
default_configs.append({
"main_title": f"景点风光 {i+1}",
"texts": ["自然美景", "人文体验"]
})
return json.dumps(default_configs, ensure_ascii=False)
def save_result(self, full_response):
"""
保存生成结果到文件
参数:
full_response: 生成的完整响应内容
返回:
结果文件路径
"""
# 生成时间戳
print(full_response)
date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
try:
# 解析内容为JSON格式
parsed_data = self.split_content(full_response)
# 验证内容格式并修复
validated_data = self._validate_and_fix_data(parsed_data)
# 创建结果文件路径
result_path = os.path.join(self.output_dir, f"{date_time}.json")
os.makedirs(os.path.dirname(result_path), exist_ok=True)
# 保存结果到文件
with open(result_path, "w", encoding="utf-8") as f:
json.dump(validated_data, f, ensure_ascii=False, indent=4)
print(f"结果已保存到: {result_path}")
return result_path
except Exception as e:
self.logger.error(f"保存结果到文件时出错: {e}")
# 尝试创建一个简单的备用配置
fallback_data = [{"main_title": "景点风光", "texts": ["自然美景", "人文体验"], "index": 1}]
# 保存备用数据
result_path = os.path.join(self.output_dir, f"{date_time}_fallback.json")
os.makedirs(os.path.dirname(result_path), exist_ok=True)
with open(result_path, "w", encoding="utf-8") as f:
json.dump(fallback_data, f, ensure_ascii=False, indent=4)
print(f"出错后已保存备用数据到: {result_path}")
return result_path
def _validate_and_fix_data(self, data):
"""
验证并修复数据格式,确保符合预期结构
参数:
data: 需要验证的数据
返回:
修复后的数据
"""
fixed_data = []
# 如果数据是列表
if isinstance(data, list):
for i, item in enumerate(data):
# 检查项目是否为字典
if isinstance(item, dict):
# 确保必需字段存在
fixed_item = {
"index": item.get("index", i + 1),
"main_title": item.get("main_title", f"景点风光 {i+1}"),
"texts": item.get("texts", ["自然美景", "人文体验"])
}
# 确保texts是列表格式
if not isinstance(fixed_item["texts"], list):
if isinstance(fixed_item["texts"], str):
fixed_item["texts"] = [fixed_item["texts"], "美景体验"]
else:
fixed_item["texts"] = ["自然美景", "人文体验"]
# 限制texts最多包含两个元素
if len(fixed_item["texts"]) > 2:
fixed_item["texts"] = fixed_item["texts"][:2]
elif len(fixed_item["texts"]) < 2:
while len(fixed_item["texts"]) < 2:
fixed_item["texts"].append("美景体验")
fixed_data.append(fixed_item)
# 如果项目是字符串可能是错误格式的texts值
elif isinstance(item, str):
self.logger.warning(f"配置项 {i+1} 是字符串格式,将转换为标准格式")
fixed_item = {
"index": i + 1,
"main_title": f"景点风光 {i+1}",
"texts": [item, "美景体验"]
}
fixed_data.append(fixed_item)
else:
self.logger.warning(f"配置项 {i+1} 格式不支持: {type(item)},将使用默认值")
fixed_data.append({
"index": i + 1,
"main_title": f"景点风光 {i+1}",
"texts": ["自然美景", "人文体验"]
})
# 如果数据是字典
elif isinstance(data, dict):
fixed_item = {
"index": data.get("index", 1),
"main_title": data.get("main_title", "景点风光"),
"texts": data.get("texts", ["自然美景", "人文体验"])
}
# 确保texts是列表格式
if not isinstance(fixed_item["texts"], list):
if isinstance(fixed_item["texts"], str):
fixed_item["texts"] = [fixed_item["texts"], "美景体验"]
else:
fixed_item["texts"] = ["自然美景", "人文体验"]
# 限制texts最多包含两个元素
if len(fixed_item["texts"]) > 2:
fixed_item["texts"] = fixed_item["texts"][:2]
elif len(fixed_item["texts"]) < 2:
while len(fixed_item["texts"]) < 2:
fixed_item["texts"].append("美景体验")
fixed_data.append(fixed_item)
# 如果数据是字符串或其他格式
else:
self.logger.warning(f"数据格式不支持: {type(data)},将使用默认值")
fixed_data.append({
"index": 1,
"main_title": "景点风光",
"texts": ["自然美景", "人文体验"]
})
# 确保至少有一个配置项
if not fixed_data:
fixed_data.append({
"index": 1,
"main_title": "景点风光",
"texts": ["自然美景", "人文体验"]
})
return fixed_data
def run(self, info_directory, poster_num, tweet_content):
"""
运行海报内容生成流程,并返回生成的配置数据。
参数:
info_directory: 信息目录路径列表 (e.g., ['/path/to/description.txt'])
poster_num: 需要生成的海报配置数量
tweet_content: 用于生成内容的推文/文章内容
返回:
list | dict | None: 生成的海报配置数据 (通常是列表),如果生成或解析失败则返回 None。
"""
self.load_infomation(info_directory)
# Generate the raw string response from AI
full_response = self.generate_posters(poster_num, tweet_content)
# Check if generation failed (indicated by return code 404 or other markers)
if full_response == 404 or not isinstance(full_response, str) or not full_response.strip():
logging.error("Poster content generation failed or returned empty response.")
return None
# Extract the JSON data from the raw response string
try:
result_data = self.split_content(full_response) # This should return the list/dict
# 验证并修复结果数据格式
fixed_data = []
# 如果结果是列表,检查每个项目
if isinstance(result_data, list):
for i, item in enumerate(result_data):
# 如果项目是字典并且有required_fields按原样添加或修复
if isinstance(item, dict):
# 检查并确保必需字段存在
if 'main_title' not in item:
item['main_title'] = f"景点标题 {i+1}"
logging.warning(f"配置项 {i+1} 缺少 main_title 字段,已添加默认值")
if 'texts' not in item:
item['texts'] = ["景点特色", "游玩体验"]
logging.warning(f"配置项 {i+1} 缺少 texts 字段,已添加默认值")
if 'index' not in item:
item['index'] = i + 1
logging.warning(f"配置项 {i+1} 缺少 index 字段,已添加默认值")
fixed_data.append(item)
# 如果项目是字符串可能是错误格式的texts值
elif isinstance(item, str):
logging.warning(f"配置项 {i+1} 是字符串格式,将转换为标准格式")
fixed_item = {
"index": i + 1,
"main_title": f"景点风光 {i+1}",
"texts": [item, "美景体验"]
}
fixed_data.append(fixed_item)
else:
logging.warning(f"配置项 {i+1} 格式不支持: {type(item)},将使用默认值")
fixed_data.append({
"index": i + 1,
"main_title": f"景点风光 {i+1}",
"texts": ["自然美景", "人文体验"]
})
# 如果处理后的列表为空(极端情况),则使用默认值
if not fixed_data:
logging.warning("处理后的配置列表为空,使用默认值")
for i in range(poster_num):
fixed_data.append({
"index": i + 1,
"main_title": f"景点风光 {i+1}",
"texts": ["自然美景", "人文体验"]
})
logging.info(f"成功生成并修复海报配置数据,包含 {len(fixed_data)} 个项目")
return fixed_data
# 如果结果是单个字典(不常见但可能),将其转换为列表
elif isinstance(result_data, dict):
logging.warning(f"生成的配置数据是单个字典格式,将转换为列表")
# 检查并确保必需字段存在
if 'main_title' not in result_data:
result_data['main_title'] = "景点风光"
if 'texts' not in result_data:
result_data['texts'] = ["自然美景", "人文体验"]
if 'index' not in result_data:
result_data['index'] = 1
fixed_data = [result_data]
return fixed_data
# 如果结果是其他格式(如字符串),创建默认配置
else:
logging.warning(f"生成的配置数据格式不支持: {type(result_data)},将使用默认值")
for i in range(poster_num):
fixed_data.append({
"index": i + 1,
"main_title": f"景点风光 {i+1}",
"texts": ["自然美景", "人文体验"]
})
return fixed_data
except Exception as e:
logging.exception(f"Failed to parse JSON from AI response in ContentGenerator: {e}\nRaw Response:\n{full_response[:500]}...") # Log error and partial response
# 失败后创建一个默认配置
logging.info("创建默认海报配置数据")
default_configs = []
for i in range(poster_num):
default_configs.append({
"index": i + 1,
"main_title": f"景点风光 {i+1}",
"texts": ["自然美景", "人文体验"]
})
return default_configs
def set_temperature(self, temperature):
self.temperature = temperature
def set_top_p(self, top_p):
self.top_p = top_p
def set_presence_penalty(self, presence_penalty):
self.presence_penalty = presence_penalty
def set_model_para(self, temperature, top_p, presence_penalty):
self.temperature = temperature
self.top_p = top_p
self.presence_penalty = presence_penalty
# def main():
# # 配置参数
# info_directory = [
# "/root/autodl-tmp/sanming_img/相机/甘露寺/description.txt"
# ] # 信息目录
# poster_num = 4 # 海报数量
# # 推文内容
# tweet_content = """<title>
# 🌿清明遛娃天花板!悬空古寺+非遗探秘
# </title>
# <content>
# 清明假期带娃哪里玩?泰宁甘露寺藏着明代建筑奇迹!一柱擎天的悬空阁楼+状元祈福传说,让孩子边玩边涨知识✨
# 🎒行程亮点:
# ✅ 安全科普第一站:讲解"一柱插地"千年不倒的秘密,用乐高积木模型让孩子理解力学原理
# ✅ 文化沉浸体验:穿汉服听"叶状元还愿建寺"故事触摸3.38米粗的"状元柱"许愿
# ✅ 自然探索路线:连接金湖栈道徒步,观察丹霞地貌与古建筑的巧妙融合
# 📌实用攻略:
# 📍位置:福建省三明市泰宁县金湖西路(导航搜"甘露岩寺"
# 🕒最佳时段上午10点前抵达避开人流下午可衔接参观明清园80元/人)
# ⚠️注意事项:悬空栈道设置儿童安全绳租赁点,建议穿防滑鞋
# 💡亲子彩蛋:
# 1⃣ 在"右鼓左钟"景观区玩声音实验,敲击不同岩石听回声差异
# 2⃣ 领取任务卡完成"寻找建筑中的T形拱"小游戏,集章兑换非遗木雕书签
# 3⃣ 结合清明节俗,用竹简模板书写祈福语系在古松枝头
# 周边推荐:游览完可直奔尚书第明代古民居群,对比不同时期建筑特色,晚餐推荐尝泰宁特色"灯盏糕",亲子套票更划算!
# 清明带着孩子来场穿越850年的建筑探险把课本里的力学知识变成触手可及的历史课堂🌸
# #清明节周边游 #亲子科普游 #福建遛娃 #泰宁旅行攻略 #建筑启蒙
# </content>
# """
# # 创建海报生成器
# generator = ContentGenerator()
# # 运行生成流程
# generator.run(info_directory, poster_num, tweet_content)
# if __name__ == "__main__":
# main()