TravelContentCreator/core/contentGen.py

671 lines
27 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from openai import OpenAI
import pandas as pd
from datetime import datetime
import cv2
import time
import random
import json
import logging
class ContentGenerator:
def __init__(self,
model_name="qwenQWQ",
api_base_url="http://localhost:8000/v1",
api_key="EMPTY",
output_dir="/root/autodl-tmp/poster_generate_result",
):
"""
初始化海报生成器
参数:
csv_path: CSV文件路径
img_base_dir: 图片基础目录
output_dir: 输出结果保存目录
model_name: 使用的模型名称
api_base_url: API基础URL
api_key: API密钥
"""
self.output_dir = output_dir
self.model_name = model_name
self.api_base_url = api_base_url
self.api_key = api_key
# 不在初始化时创建OpenAI客户端而是在需要时临时创建
self.client = None
# 初始化数据
self.df = None
self.all_images_info = []
self.structured_prompt = ""
self.current_img_info = None
self.add_description = ""
self.temperature = 0.7
self.top_p = 0.8
self.presence_penalty = 1.2
# 设置日志
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
self.logger = logging.getLogger(__name__)
def load_infomation(self, info_directory_path):
"""
加载额外描述文件
参数:
info_directory_path: 信息目录路径
"""
## 读取路径下的所有文件
for path in info_directory_path:
# file_extend_path = os.path.join(self.img_base_dir, info_directory, "description.txt")
try:
with open(path, "r") as f:
self.add_description += f.read()
except:
self.add_description = ""
def _create_temp_client(self):
"""
创建临时OpenAI客户端
返回:
OpenAI客户端实例
"""
try:
import gc
# 强制垃圾回收
gc.collect()
# 创建新的客户端实例
print(f"创建临时OpenAI客户端API URL: {self.api_base_url}")
client = OpenAI(
base_url=self.api_base_url,
api_key=self.api_key
)
return client
except Exception as e:
print(f"创建OpenAI客户端失败: {str(e)}")
return None
def _close_client(self, client):
"""
关闭并清理OpenAI客户端
参数:
client: 需要关闭的客户端实例
"""
try:
# OpenAI客户端可能没有显式的close方法
# 将引用设为None让Python垃圾回收处理
client = None
import gc
gc.collect()
print("OpenAI客户端资源已释放")
except Exception as e:
print(f"关闭客户端失败: {str(e)}")
def split_content(self, content):
"""
分割结果, 返回去除
```json
```的json内容
参数:
content: 需要分割的内容
返回:
分割后的json内容
"""
try:
# 首先尝试直接解析整个内容,以防已经是干净的 JSON
try:
return json.loads(content)
except json.JSONDecodeError:
pass # 不是干净的 JSON继续处理
# 常规模式:查找 ```json 和 ``` 之间的内容
if "```json" in content:
json_str = content.split("```json")[1].split("```")[0].strip()
try:
return json.loads(json_str)
except json.JSONDecodeError as e:
print(f"常规格式解析失败: {e}, 尝试其他方法")
# 备用模式1查找连续的 { 开头和 } 结尾的部分
import re
json_pattern = r'(\[.*?\])'
json_matches = re.findall(json_pattern, content, re.DOTALL)
if json_matches:
for match in json_matches:
try:
result = json.loads(match)
if isinstance(result, list) and len(result) > 0:
return result
except:
continue
# 备用模式2查找 { 开头 和 } 结尾,并尝试解析
content = content.strip()
square_bracket_start = content.find('[')
square_bracket_end = content.rfind(']')
if square_bracket_start != -1 and square_bracket_end != -1:
potential_json = content[square_bracket_start:square_bracket_end + 1]
try:
return json.loads(potential_json)
except:
print("尝试提取方括号内容失败")
# 最后一种尝试:查找所有可能的 JSON 结构并尝试解析
json_structures = re.findall(r'({.*?})', content, re.DOTALL)
if json_structures:
items = []
for i, struct in enumerate(json_structures):
try:
item = json.loads(struct)
# 验证结构包含预期字段
if 'main_title' in item and ('texts' in item or 'index' in item):
items.append(item)
except:
continue
if items:
return items
# 都失败了,打印错误并引发异常
print(f"无法解析内容,返回原始文本: {content[:200]}...")
raise ValueError("无法从响应中提取有效的 JSON 格式")
except Exception as e:
print(f"解析内容时出错: {e}")
print(f"原始内容: {content[:200]}...") # 仅显示前200个字符
raise e
def generate_posters(self, poster_num, tweet_content, system_prompt=None, max_retries=3):
"""
生成海报内容
参数:
poster_num: 海报数量
tweet_content: 推文内容
system_prompt: 系统提示默认为None则使用预设提示
max_retries: 最大重试次数
返回:
生成的海报内容
"""
full_response = ""
timeout = 60 # 请求超时时间(秒)
if not system_prompt:
# 使用默认系统提示词
system_prompt = f"""
你是一个专业的文案处理专家,擅长从文章中提取关键信息并生成吸引人的标题和简短描述。
现在,我需要你根据提供的文章内容,生成{poster_num}个海报的文案配置。
每个配置包含:
1. main_title主标题简短有力突出景点特点
2. texts两句简短文本每句不超过15字描述景点特色或游玩体验
以JSON数组格式返回配置示例
[
{{
"main_title": "泰宁古城",
"texts": ["千年古韵","匠心独运"]
}},
...
]
仅返回JSON数据不需要任何额外解释。确保生成的标题和文本能够准确反映文章提到的景点特色。
"""
if self.add_description:
# 创建用户内容包括info信息和tweet_content
user_content = f"""
以下是需要你处理的信息:
关于景点的描述:
{self.add_description}
推文内容:
{tweet_content}
请根据这些信息,生成{poster_num}个海报文案配置以JSON数组格式返回。
确保主标题(main_title)简短有力每个text不超过15字并能准确反映景点特色。
"""
else:
# 仅使用tweet_content
user_content = f"""
以下是需要你处理的推文内容:
{tweet_content}
请根据这些信息,生成{poster_num}个海报文案配置以JSON数组格式返回。
确保主标题(main_title)简短有力每个text不超过15字并能准确反映景点特色。
"""
self.logger.info(f"正在生成{poster_num}个海报文案配置")
# 创建临时客户端
temp_client = self._create_temp_client()
if temp_client:
# 重试逻辑
for retry in range(max_retries):
try:
self.logger.info(f"尝试生成内容 (尝试 {retry+1}/{max_retries})")
# 定义流式响应处理回调函数
def handle_stream_chunk(chunk, is_last=False, is_timeout=False, is_error=False, error=None):
nonlocal full_response
if chunk:
full_response += chunk
# 实时输出到控制台
print(chunk, end="", flush=True)
if is_last:
print("\n") # 输出完成后换行
if is_timeout:
print("警告: 响应流超时")
if is_error:
print(f"错误: {error}")
# 使用AI_Agent的新回调方式
from core.ai_agent import AI_Agent
ai_agent = AI_Agent(
self.api_base_url,
self.model_name,
self.api_key,
timeout=timeout,
max_retries=max_retries,
stream_chunk_timeout=30 # 流式块超时时间
)
# 使用回调方式处理流式响应
try:
full_response = ai_agent.generate_text_stream_with_callback(
system_prompt,
user_content,
handle_stream_chunk,
temperature=self.temperature,
top_p=self.top_p,
presence_penalty=self.presence_penalty
)
# 如果成功生成内容,跳出重试循环
ai_agent.close()
break
except Exception as e:
error_msg = str(e)
self.logger.error(f"AI生成错误: {error_msg}")
ai_agent.close()
# 继续重试逻辑
if retry + 1 >= max_retries:
self.logger.warning("已达到最大重试次数,使用备用方案...")
# 生成备用内容
full_response = self._generate_fallback_content(poster_num)
else:
self.logger.info(f"将在稍后重试,还剩 {max_retries - retry - 1} 次重试机会")
except Exception as e:
error_msg = str(e)
self.logger.error(f"API连接错误 (尝试 {retry+1}/{max_retries}): {error_msg}")
# 如果已经达到最大重试次数
if retry + 1 >= max_retries:
self.logger.warning("已达到最大重试次数,使用备用方案...")
# 生成备用内容(简单模板)
full_response = self._generate_fallback_content(poster_num)
else:
self.logger.info(f"将在稍后重试,还剩 {max_retries - retry - 1} 次重试机会")
# 关闭临时客户端
self._close_client(temp_client)
return full_response
def _generate_fallback_content(self, poster_num):
"""生成备用内容当API调用失败时使用"""
self.logger.info("生成备用内容")
default_configs = []
for i in range(poster_num):
default_configs.append({
"main_title": f"景点风光 {i+1}",
"texts": ["自然美景", "人文体验"]
})
return json.dumps(default_configs, ensure_ascii=False)
def save_result(self, full_response):
"""
保存生成结果到文件
参数:
full_response: 生成的完整响应内容
返回:
结果文件路径
"""
# 生成时间戳
print(full_response)
date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
try:
# 解析内容为JSON格式
parsed_data = self.split_content(full_response)
# 验证内容格式并修复
validated_data = self._validate_and_fix_data(parsed_data)
# 创建结果文件路径
result_path = os.path.join(self.output_dir, f"{date_time}.json")
os.makedirs(os.path.dirname(result_path), exist_ok=True)
# 保存结果到文件
with open(result_path, "w", encoding="utf-8") as f:
json.dump(validated_data, f, ensure_ascii=False, indent=4)
print(f"结果已保存到: {result_path}")
return result_path
except Exception as e:
self.logger.error(f"保存结果到文件时出错: {e}")
# 尝试创建一个简单的备用配置
fallback_data = [{"main_title": "景点风光", "texts": ["自然美景", "人文体验"], "index": 1}]
# 保存备用数据
result_path = os.path.join(self.output_dir, f"{date_time}_fallback.json")
os.makedirs(os.path.dirname(result_path), exist_ok=True)
with open(result_path, "w", encoding="utf-8") as f:
json.dump(fallback_data, f, ensure_ascii=False, indent=4)
print(f"出错后已保存备用数据到: {result_path}")
return result_path
def _validate_and_fix_data(self, data):
"""
验证并修复数据格式,确保符合预期结构
参数:
data: 需要验证的数据
返回:
修复后的数据
"""
fixed_data = []
# 如果数据是列表
if isinstance(data, list):
for i, item in enumerate(data):
# 检查项目是否为字典
if isinstance(item, dict):
# 确保必需字段存在
fixed_item = {
"index": item.get("index", i + 1),
"main_title": item.get("main_title", f"景点风光 {i+1}"),
"texts": item.get("texts", ["自然美景", "人文体验"])
}
# 确保texts是列表格式
if not isinstance(fixed_item["texts"], list):
if isinstance(fixed_item["texts"], str):
fixed_item["texts"] = [fixed_item["texts"], "美景体验"]
else:
fixed_item["texts"] = ["自然美景", "人文体验"]
# 限制texts最多包含两个元素
if len(fixed_item["texts"]) > 2:
fixed_item["texts"] = fixed_item["texts"][:2]
elif len(fixed_item["texts"]) < 2:
while len(fixed_item["texts"]) < 2:
fixed_item["texts"].append("美景体验")
fixed_data.append(fixed_item)
# 如果项目是字符串可能是错误格式的texts值
elif isinstance(item, str):
self.logger.warning(f"配置项 {i+1} 是字符串格式,将转换为标准格式")
fixed_item = {
"index": i + 1,
"main_title": f"景点风光 {i+1}",
"texts": [item, "美景体验"]
}
fixed_data.append(fixed_item)
else:
self.logger.warning(f"配置项 {i+1} 格式不支持: {type(item)},将使用默认值")
fixed_data.append({
"index": i + 1,
"main_title": f"景点风光 {i+1}",
"texts": ["自然美景", "人文体验"]
})
# 如果数据是字典
elif isinstance(data, dict):
fixed_item = {
"index": data.get("index", 1),
"main_title": data.get("main_title", "景点风光"),
"texts": data.get("texts", ["自然美景", "人文体验"])
}
# 确保texts是列表格式
if not isinstance(fixed_item["texts"], list):
if isinstance(fixed_item["texts"], str):
fixed_item["texts"] = [fixed_item["texts"], "美景体验"]
else:
fixed_item["texts"] = ["自然美景", "人文体验"]
# 限制texts最多包含两个元素
if len(fixed_item["texts"]) > 2:
fixed_item["texts"] = fixed_item["texts"][:2]
elif len(fixed_item["texts"]) < 2:
while len(fixed_item["texts"]) < 2:
fixed_item["texts"].append("美景体验")
fixed_data.append(fixed_item)
# 如果数据是字符串或其他格式
else:
self.logger.warning(f"数据格式不支持: {type(data)},将使用默认值")
fixed_data.append({
"index": 1,
"main_title": "景点风光",
"texts": ["自然美景", "人文体验"]
})
# 确保至少有一个配置项
if not fixed_data:
fixed_data.append({
"index": 1,
"main_title": "景点风光",
"texts": ["自然美景", "人文体验"]
})
return fixed_data
def run(self, info_directory, poster_num, tweet_content):
"""
运行海报内容生成流程,并返回生成的配置数据。
参数:
info_directory: 信息目录路径列表 (e.g., ['/path/to/description.txt'])
poster_num: 需要生成的海报配置数量
tweet_content: 用于生成内容的推文/文章内容
返回:
list | dict | None: 生成的海报配置数据 (通常是列表),如果生成或解析失败则返回 None。
"""
self.load_infomation(info_directory)
# Generate the raw string response from AI
full_response = self.generate_posters(poster_num, tweet_content)
# Check if generation failed (indicated by return code 404 or other markers)
if full_response == 404 or not isinstance(full_response, str) or not full_response.strip():
logging.error("Poster content generation failed or returned empty response.")
return None
# Extract the JSON data from the raw response string
try:
result_data = self.split_content(full_response) # This should return the list/dict
# 验证并修复结果数据格式
fixed_data = []
# 如果结果是列表,检查每个项目
if isinstance(result_data, list):
for i, item in enumerate(result_data):
# 如果项目是字典并且有required_fields按原样添加或修复
if isinstance(item, dict):
# 检查并确保必需字段存在
if 'main_title' not in item:
item['main_title'] = f"景点标题 {i+1}"
logging.warning(f"配置项 {i+1} 缺少 main_title 字段,已添加默认值")
if 'texts' not in item:
item['texts'] = ["景点特色", "游玩体验"]
logging.warning(f"配置项 {i+1} 缺少 texts 字段,已添加默认值")
if 'index' not in item:
item['index'] = i + 1
logging.warning(f"配置项 {i+1} 缺少 index 字段,已添加默认值")
fixed_data.append(item)
# 如果项目是字符串可能是错误格式的texts值
elif isinstance(item, str):
logging.warning(f"配置项 {i+1} 是字符串格式,将转换为标准格式")
fixed_item = {
"index": i + 1,
"main_title": f"景点风光 {i+1}",
"texts": [item, "美景体验"]
}
fixed_data.append(fixed_item)
else:
logging.warning(f"配置项 {i+1} 格式不支持: {type(item)},将使用默认值")
fixed_data.append({
"index": i + 1,
"main_title": f"景点风光 {i+1}",
"texts": ["自然美景", "人文体验"]
})
# 如果处理后的列表为空(极端情况),则使用默认值
if not fixed_data:
logging.warning("处理后的配置列表为空,使用默认值")
for i in range(poster_num):
fixed_data.append({
"index": i + 1,
"main_title": f"景点风光 {i+1}",
"texts": ["自然美景", "人文体验"]
})
logging.info(f"成功生成并修复海报配置数据,包含 {len(fixed_data)} 个项目")
return fixed_data
# 如果结果是单个字典(不常见但可能),将其转换为列表
elif isinstance(result_data, dict):
logging.warning(f"生成的配置数据是单个字典格式,将转换为列表")
# 检查并确保必需字段存在
if 'main_title' not in result_data:
result_data['main_title'] = "景点风光"
if 'texts' not in result_data:
result_data['texts'] = ["自然美景", "人文体验"]
if 'index' not in result_data:
result_data['index'] = 1
fixed_data = [result_data]
return fixed_data
# 如果结果是其他格式(如字符串),创建默认配置
else:
logging.warning(f"生成的配置数据格式不支持: {type(result_data)},将使用默认值")
for i in range(poster_num):
fixed_data.append({
"index": i + 1,
"main_title": f"景点风光 {i+1}",
"texts": ["自然美景", "人文体验"]
})
return fixed_data
except Exception as e:
logging.exception(f"Failed to parse JSON from AI response in ContentGenerator: {e}\nRaw Response:\n{full_response[:500]}...") # Log error and partial response
# 失败后创建一个默认配置
logging.info("创建默认海报配置数据")
default_configs = []
for i in range(poster_num):
default_configs.append({
"index": i + 1,
"main_title": f"景点风光 {i+1}",
"texts": ["自然美景", "人文体验"]
})
return default_configs
def set_temperature(self, temperature):
self.temperature = temperature
def set_top_p(self, top_p):
self.top_p = top_p
def set_presence_penalty(self, presence_penalty):
self.presence_penalty = presence_penalty
def set_model_para(self, temperature, top_p, presence_penalty):
self.temperature = temperature
self.top_p = top_p
self.presence_penalty = presence_penalty
# def main():
# # 配置参数
# info_directory = [
# "/root/autodl-tmp/sanming_img/相机/甘露寺/description.txt"
# ] # 信息目录
# poster_num = 4 # 海报数量
# # 推文内容
# tweet_content = """<title>
# 🌿清明遛娃天花板!悬空古寺+非遗探秘
# </title>
# <content>
# 清明假期带娃哪里玩?泰宁甘露寺藏着明代建筑奇迹!一柱擎天的悬空阁楼+状元祈福传说,让孩子边玩边涨知识✨
# 🎒行程亮点:
# ✅ 安全科普第一站:讲解"一柱插地"千年不倒的秘密,用乐高积木模型让孩子理解力学原理
# ✅ 文化沉浸体验:穿汉服听"叶状元还愿建寺"故事触摸3.38米粗的"状元柱"许愿
# ✅ 自然探索路线:连接金湖栈道徒步,观察丹霞地貌与古建筑的巧妙融合
# 📌实用攻略:
# 📍位置:福建省三明市泰宁县金湖西路(导航搜"甘露岩寺"
# 🕒最佳时段上午10点前抵达避开人流下午可衔接参观明清园80元/人)
# ⚠️注意事项:悬空栈道设置儿童安全绳租赁点,建议穿防滑鞋
# 💡亲子彩蛋:
# 1⃣ 在"右鼓左钟"景观区玩声音实验,敲击不同岩石听回声差异
# 2⃣ 领取任务卡完成"寻找建筑中的T形拱"小游戏,集章兑换非遗木雕书签
# 3⃣ 结合清明节俗,用竹简模板书写祈福语系在古松枝头
# 周边推荐:游览完可直奔尚书第明代古民居群,对比不同时期建筑特色,晚餐推荐尝泰宁特色"灯盏糕",亲子套票更划算!
# 清明带着孩子来场穿越850年的建筑探险把课本里的力学知识变成触手可及的历史课堂🌸
# #清明节周边游 #亲子科普游 #福建遛娃 #泰宁旅行攻略 #建筑启蒙
# </content>
# """
# # 创建海报生成器
# generator = ContentGenerator()
# # 运行生成流程
# generator.run(info_directory, poster_num, tweet_content)
# if __name__ == "__main__":
# main()