import os from openai import OpenAI import pandas as pd from datetime import datetime import cv2 import time import random import json import logging class ContentGenerator: def __init__(self, model_name="qwenQWQ", api_base_url="http://localhost:8000/v1", api_key="EMPTY", output_dir="/root/autodl-tmp/poster_generate_result", ): """ 初始化海报生成器 参数: csv_path: CSV文件路径 img_base_dir: 图片基础目录 output_dir: 输出结果保存目录 model_name: 使用的模型名称 api_base_url: API基础URL api_key: API密钥 """ self.output_dir = output_dir self.model_name = model_name self.api_base_url = api_base_url self.api_key = api_key # 不在初始化时创建OpenAI客户端,而是在需要时临时创建 self.client = None # 初始化数据 self.df = None self.all_images_info = [] self.structured_prompt = "" self.current_img_info = None self.add_description = "" self.temperature = 0.7 self.top_p = 0.8 self.presence_penalty = 1.2 # 设置日志 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') self.logger = logging.getLogger(__name__) def load_infomation(self, info_directory_path): """ 加载额外描述文件 参数: info_directory_path: 信息目录路径 """ ## 读取路径下的所有文件 for path in info_directory_path: # file_extend_path = os.path.join(self.img_base_dir, info_directory, "description.txt") try: with open(path, "r") as f: self.add_description += f.read() except: self.add_description = "" def _create_temp_client(self): """ 创建临时OpenAI客户端 返回: OpenAI客户端实例 """ try: import gc # 强制垃圾回收 gc.collect() # 创建新的客户端实例 print(f"创建临时OpenAI客户端,API URL: {self.api_base_url}") client = OpenAI( base_url=self.api_base_url, api_key=self.api_key ) return client except Exception as e: print(f"创建OpenAI客户端失败: {str(e)}") return None def _close_client(self, client): """ 关闭并清理OpenAI客户端 参数: client: 需要关闭的客户端实例 """ try: # OpenAI客户端可能没有显式的close方法 # 将引用设为None,让Python垃圾回收处理 client = None import gc gc.collect() print("OpenAI客户端资源已释放") except Exception as e: print(f"关闭客户端失败: {str(e)}") def split_content(self, content): """ 分割结果, 返回去除 ```json ```的json内容 参数: content: 需要分割的内容 返回: 分割后的json内容 """ try: # 首先尝试直接解析整个内容,以防已经是干净的 JSON try: return json.loads(content) except json.JSONDecodeError: pass # 不是干净的 JSON,继续处理 # 常规模式:查找 ```json 和 ``` 之间的内容 if "```json" in content: json_str = content.split("```json")[1].split("```")[0].strip() try: return json.loads(json_str) except json.JSONDecodeError as e: print(f"常规格式解析失败: {e}, 尝试其他方法") # 备用模式1:查找连续的 { 开头和 } 结尾的部分 import re json_pattern = r'(\[.*?\])' json_matches = re.findall(json_pattern, content, re.DOTALL) if json_matches: for match in json_matches: try: result = json.loads(match) if isinstance(result, list) and len(result) > 0: return result except: continue # 备用模式2:查找 { 开头 和 } 结尾,并尝试解析 content = content.strip() square_bracket_start = content.find('[') square_bracket_end = content.rfind(']') if square_bracket_start != -1 and square_bracket_end != -1: potential_json = content[square_bracket_start:square_bracket_end + 1] try: return json.loads(potential_json) except: print("尝试提取方括号内容失败") # 最后一种尝试:查找所有可能的 JSON 结构并尝试解析 json_structures = re.findall(r'({.*?})', content, re.DOTALL) if json_structures: items = [] for i, struct in enumerate(json_structures): try: item = json.loads(struct) # 验证结构包含预期字段 if 'main_title' in item and ('texts' in item or 'index' in item): items.append(item) except: continue if items: return items # 都失败了,打印错误并引发异常 print(f"无法解析内容,返回原始文本: {content[:200]}...") raise ValueError("无法从响应中提取有效的 JSON 格式") except Exception as e: print(f"解析内容时出错: {e}") print(f"原始内容: {content[:200]}...") # 仅显示前200个字符 raise e def generate_posters(self, poster_num, tweet_content, system_prompt=None, max_retries=3): """ 生成海报内容 参数: poster_num: 海报数量 tweet_content: 推文内容 system_prompt: 系统提示,默认为None则使用预设提示 max_retries: 最大重试次数 返回: 生成的海报内容 """ full_response = "" timeout = 60 # 请求超时时间(秒) if not system_prompt: # 使用默认系统提示词 system_prompt = f""" 你是一个专业的文案处理专家,擅长从文章中提取关键信息并生成吸引人的标题和简短描述。 现在,我需要你根据提供的文章内容,生成{poster_num}个海报的文案配置。 每个配置包含: 1. main_title:主标题,简短有力,突出景点特点 2. texts:两句简短文本,每句不超过15字,描述景点特色或游玩体验 以JSON数组格式返回配置,示例: [ {{ "main_title": "泰宁古城", "texts": ["千年古韵","匠心独运"] }}, ... ] 仅返回JSON数据,不需要任何额外解释。确保生成的标题和文本能够准确反映文章提到的景点特色。 """ if self.add_description: # 创建用户内容,包括info信息和tweet_content user_content = f""" 以下是需要你处理的信息: 关于景点的描述: {self.add_description} 推文内容: {tweet_content} 请根据这些信息,生成{poster_num}个海报文案配置,以JSON数组格式返回。 确保主标题(main_title)简短有力,每个text不超过15字,并能准确反映景点特色。 """ else: # 仅使用tweet_content user_content = f""" 以下是需要你处理的推文内容: {tweet_content} 请根据这些信息,生成{poster_num}个海报文案配置,以JSON数组格式返回。 确保主标题(main_title)简短有力,每个text不超过15字,并能准确反映景点特色。 """ self.logger.info(f"正在生成{poster_num}个海报文案配置") # 创建临时客户端 temp_client = self._create_temp_client() if temp_client: # 重试逻辑 for retry in range(max_retries): try: self.logger.info(f"尝试生成内容 (尝试 {retry+1}/{max_retries})") # 定义流式响应处理回调函数 def handle_stream_chunk(chunk, is_last=False, is_timeout=False, is_error=False, error=None): nonlocal full_response if chunk: full_response += chunk # 实时输出到控制台 print(chunk, end="", flush=True) if is_last: print("\n") # 输出完成后换行 if is_timeout: print("警告: 响应流超时") if is_error: print(f"错误: {error}") # 使用AI_Agent的新回调方式 from core.ai_agent import AI_Agent ai_agent = AI_Agent( self.api_base_url, self.model_name, self.api_key, timeout=timeout, max_retries=max_retries, stream_chunk_timeout=30 # 流式块超时时间 ) # 使用回调方式处理流式响应 try: full_response = ai_agent.generate_text_stream_with_callback( system_prompt, user_content, handle_stream_chunk, temperature=self.temperature, top_p=self.top_p, presence_penalty=self.presence_penalty ) # 如果成功生成内容,跳出重试循环 ai_agent.close() break except Exception as e: error_msg = str(e) self.logger.error(f"AI生成错误: {error_msg}") ai_agent.close() # 继续重试逻辑 if retry + 1 >= max_retries: self.logger.warning("已达到最大重试次数,使用备用方案...") # 生成备用内容 full_response = self._generate_fallback_content(poster_num) else: self.logger.info(f"将在稍后重试,还剩 {max_retries - retry - 1} 次重试机会") except Exception as e: error_msg = str(e) self.logger.error(f"API连接错误 (尝试 {retry+1}/{max_retries}): {error_msg}") # 如果已经达到最大重试次数 if retry + 1 >= max_retries: self.logger.warning("已达到最大重试次数,使用备用方案...") # 生成备用内容(简单模板) full_response = self._generate_fallback_content(poster_num) else: self.logger.info(f"将在稍后重试,还剩 {max_retries - retry - 1} 次重试机会") # 关闭临时客户端 self._close_client(temp_client) return full_response def _generate_fallback_content(self, poster_num): """生成备用内容,当API调用失败时使用""" self.logger.info("生成备用内容") default_configs = [] for i in range(poster_num): default_configs.append({ "main_title": f"景点风光 {i+1}", "texts": ["自然美景", "人文体验"] }) return json.dumps(default_configs, ensure_ascii=False) def save_result(self, full_response): """ 保存生成结果到文件 参数: full_response: 生成的完整响应内容 返回: 结果文件路径 """ # 生成时间戳 print(full_response) date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") try: # 解析内容为JSON格式 parsed_data = self.split_content(full_response) # 验证内容格式并修复 validated_data = self._validate_and_fix_data(parsed_data) # 创建结果文件路径 result_path = os.path.join(self.output_dir, f"{date_time}.json") os.makedirs(os.path.dirname(result_path), exist_ok=True) # 保存结果到文件 with open(result_path, "w", encoding="utf-8") as f: json.dump(validated_data, f, ensure_ascii=False, indent=4) print(f"结果已保存到: {result_path}") return result_path except Exception as e: self.logger.error(f"保存结果到文件时出错: {e}") # 尝试创建一个简单的备用配置 fallback_data = [{"main_title": "景点风光", "texts": ["自然美景", "人文体验"], "index": 1}] # 保存备用数据 result_path = os.path.join(self.output_dir, f"{date_time}_fallback.json") os.makedirs(os.path.dirname(result_path), exist_ok=True) with open(result_path, "w", encoding="utf-8") as f: json.dump(fallback_data, f, ensure_ascii=False, indent=4) print(f"出错后已保存备用数据到: {result_path}") return result_path def _validate_and_fix_data(self, data): """ 验证并修复数据格式,确保符合预期结构 参数: data: 需要验证的数据 返回: 修复后的数据 """ fixed_data = [] # 如果数据是列表 if isinstance(data, list): for i, item in enumerate(data): # 检查项目是否为字典 if isinstance(item, dict): # 确保必需字段存在 fixed_item = { "index": item.get("index", i + 1), "main_title": item.get("main_title", f"景点风光 {i+1}"), "texts": item.get("texts", ["自然美景", "人文体验"]) } # 确保texts是列表格式 if not isinstance(fixed_item["texts"], list): if isinstance(fixed_item["texts"], str): fixed_item["texts"] = [fixed_item["texts"], "美景体验"] else: fixed_item["texts"] = ["自然美景", "人文体验"] # 限制texts最多包含两个元素 if len(fixed_item["texts"]) > 2: fixed_item["texts"] = fixed_item["texts"][:2] elif len(fixed_item["texts"]) < 2: while len(fixed_item["texts"]) < 2: fixed_item["texts"].append("美景体验") fixed_data.append(fixed_item) # 如果项目是字符串(可能是错误格式的texts值) elif isinstance(item, str): self.logger.warning(f"配置项 {i+1} 是字符串格式,将转换为标准格式") fixed_item = { "index": i + 1, "main_title": f"景点风光 {i+1}", "texts": [item, "美景体验"] } fixed_data.append(fixed_item) else: self.logger.warning(f"配置项 {i+1} 格式不支持: {type(item)},将使用默认值") fixed_data.append({ "index": i + 1, "main_title": f"景点风光 {i+1}", "texts": ["自然美景", "人文体验"] }) # 如果数据是字典 elif isinstance(data, dict): fixed_item = { "index": data.get("index", 1), "main_title": data.get("main_title", "景点风光"), "texts": data.get("texts", ["自然美景", "人文体验"]) } # 确保texts是列表格式 if not isinstance(fixed_item["texts"], list): if isinstance(fixed_item["texts"], str): fixed_item["texts"] = [fixed_item["texts"], "美景体验"] else: fixed_item["texts"] = ["自然美景", "人文体验"] # 限制texts最多包含两个元素 if len(fixed_item["texts"]) > 2: fixed_item["texts"] = fixed_item["texts"][:2] elif len(fixed_item["texts"]) < 2: while len(fixed_item["texts"]) < 2: fixed_item["texts"].append("美景体验") fixed_data.append(fixed_item) # 如果数据是字符串或其他格式 else: self.logger.warning(f"数据格式不支持: {type(data)},将使用默认值") fixed_data.append({ "index": 1, "main_title": "景点风光", "texts": ["自然美景", "人文体验"] }) # 确保至少有一个配置项 if not fixed_data: fixed_data.append({ "index": 1, "main_title": "景点风光", "texts": ["自然美景", "人文体验"] }) return fixed_data def run(self, info_directory, poster_num, tweet_content): """ 运行海报内容生成流程,并返回生成的配置数据。 参数: info_directory: 信息目录路径列表 (e.g., ['/path/to/description.txt']) poster_num: 需要生成的海报配置数量 tweet_content: 用于生成内容的推文/文章内容 返回: list | dict | None: 生成的海报配置数据 (通常是列表),如果生成或解析失败则返回 None。 """ self.load_infomation(info_directory) # Generate the raw string response from AI full_response = self.generate_posters(poster_num, tweet_content) # Check if generation failed (indicated by return code 404 or other markers) if full_response == 404 or not isinstance(full_response, str) or not full_response.strip(): logging.error("Poster content generation failed or returned empty response.") return None # Extract the JSON data from the raw response string try: result_data = self.split_content(full_response) # This should return the list/dict # 验证并修复结果数据格式 fixed_data = [] # 如果结果是列表,检查每个项目 if isinstance(result_data, list): for i, item in enumerate(result_data): # 如果项目是字典并且有required_fields,按原样添加或修复 if isinstance(item, dict): # 检查并确保必需字段存在 if 'main_title' not in item: item['main_title'] = f"景点标题 {i+1}" logging.warning(f"配置项 {i+1} 缺少 main_title 字段,已添加默认值") if 'texts' not in item: item['texts'] = ["景点特色", "游玩体验"] logging.warning(f"配置项 {i+1} 缺少 texts 字段,已添加默认值") if 'index' not in item: item['index'] = i + 1 logging.warning(f"配置项 {i+1} 缺少 index 字段,已添加默认值") fixed_data.append(item) # 如果项目是字符串(可能是错误格式的texts值) elif isinstance(item, str): logging.warning(f"配置项 {i+1} 是字符串格式,将转换为标准格式") fixed_item = { "index": i + 1, "main_title": f"景点风光 {i+1}", "texts": [item, "美景体验"] } fixed_data.append(fixed_item) else: logging.warning(f"配置项 {i+1} 格式不支持: {type(item)},将使用默认值") fixed_data.append({ "index": i + 1, "main_title": f"景点风光 {i+1}", "texts": ["自然美景", "人文体验"] }) # 如果处理后的列表为空(极端情况),则使用默认值 if not fixed_data: logging.warning("处理后的配置列表为空,使用默认值") for i in range(poster_num): fixed_data.append({ "index": i + 1, "main_title": f"景点风光 {i+1}", "texts": ["自然美景", "人文体验"] }) logging.info(f"成功生成并修复海报配置数据,包含 {len(fixed_data)} 个项目") return fixed_data # 如果结果是单个字典(不常见但可能),将其转换为列表 elif isinstance(result_data, dict): logging.warning(f"生成的配置数据是单个字典格式,将转换为列表") # 检查并确保必需字段存在 if 'main_title' not in result_data: result_data['main_title'] = "景点风光" if 'texts' not in result_data: result_data['texts'] = ["自然美景", "人文体验"] if 'index' not in result_data: result_data['index'] = 1 fixed_data = [result_data] return fixed_data # 如果结果是其他格式(如字符串),创建默认配置 else: logging.warning(f"生成的配置数据格式不支持: {type(result_data)},将使用默认值") for i in range(poster_num): fixed_data.append({ "index": i + 1, "main_title": f"景点风光 {i+1}", "texts": ["自然美景", "人文体验"] }) return fixed_data except Exception as e: logging.exception(f"Failed to parse JSON from AI response in ContentGenerator: {e}\nRaw Response:\n{full_response[:500]}...") # Log error and partial response # 失败后创建一个默认配置 logging.info("创建默认海报配置数据") default_configs = [] for i in range(poster_num): default_configs.append({ "index": i + 1, "main_title": f"景点风光 {i+1}", "texts": ["自然美景", "人文体验"] }) return default_configs def set_temperature(self, temperature): self.temperature = temperature def set_top_p(self, top_p): self.top_p = top_p def set_presence_penalty(self, presence_penalty): self.presence_penalty = presence_penalty def set_model_para(self, temperature, top_p, presence_penalty): self.temperature = temperature self.top_p = top_p self.presence_penalty = presence_penalty # def main(): # # 配置参数 # info_directory = [ # "/root/autodl-tmp/sanming_img/相机/甘露寺/description.txt" # ] # 信息目录 # poster_num = 4 # 海报数量 # # 推文内容 # tweet_content = """ # 🌿清明遛娃天花板!悬空古寺+非遗探秘 # # # 清明假期带娃哪里玩?泰宁甘露寺藏着明代建筑奇迹!一柱擎天的悬空阁楼+状元祈福传说,让孩子边玩边涨知识✨ # 🎒行程亮点: # ✅ 安全科普第一站:讲解"一柱插地"千年不倒的秘密,用乐高积木模型让孩子理解力学原理 # ✅ 文化沉浸体验:穿汉服听"叶状元还愿建寺"故事,触摸3.38米粗的"状元柱"许愿 # ✅ 自然探索路线:连接金湖栈道徒步,观察丹霞地貌与古建筑的巧妙融合 # 📌实用攻略: # 📍位置:福建省三明市泰宁县金湖西路(导航搜"甘露岩寺") # 🕒最佳时段:上午10点前抵达避开人流,下午可衔接参观明清园(80元/人) # ⚠️注意事项:悬空栈道设置儿童安全绳租赁点,建议穿防滑鞋 # 💡亲子彩蛋: # 1️⃣ 在"右鼓左钟"景观区玩声音实验,敲击不同岩石听回声差异 # 2️⃣ 领取任务卡完成"寻找建筑中的T形拱"小游戏,集章兑换非遗木雕书签 # 3️⃣ 结合清明节俗,用竹简模板书写祈福语系在古松枝头 # 周边推荐:游览完可直奔尚书第明代古民居群,对比不同时期建筑特色,晚餐推荐尝泰宁特色"灯盏糕",亲子套票更划算! # 清明带着孩子来场穿越850年的建筑探险,把课本里的力学知识变成触手可及的历史课堂!🌸 # #清明节周边游 #亲子科普游 #福建遛娃 #泰宁旅行攻略 #建筑启蒙 # # """ # # 创建海报生成器 # generator = ContentGenerator() # # 运行生成流程 # generator.run(info_directory, poster_num, tweet_content) # if __name__ == "__main__": # main()