diff --git a/core/__pycache__/contentGen.cpython-312.pyc b/core/__pycache__/contentGen.cpython-312.pyc index c5755bf..aa6f51f 100644 Binary files a/core/__pycache__/contentGen.cpython-312.pyc and b/core/__pycache__/contentGen.cpython-312.pyc differ diff --git a/core/contentGen.py b/core/contentGen.py index 79bb14a..ba99eea 100644 --- a/core/contentGen.py +++ b/core/contentGen.py @@ -1,704 +1,107 @@ import os -from openai import OpenAI -import pandas as pd -from datetime import datetime -import cv2 -import time -import random -import json import logging +from utils.content_generator import ContentGenerator as NewContentGenerator +# 为了向后兼容,我们保留ContentGenerator类,但内部使用新的实现 class ContentGenerator: def __init__(self, model_name="qwenQWQ", api_base_url="http://localhost:8000/v1", api_key="EMPTY", - output_dir="/root/autodl-tmp/poster_generate_result", -): + output_dir="/root/autodl-tmp/poster_generate_result"): """ - 初始化海报生成器 + 初始化海报生成器,内部使用新的实现 参数: - csv_path: CSV文件路径 - img_base_dir: 图片基础目录 output_dir: 输出结果保存目录 model_name: 使用的模型名称 api_base_url: API基础URL api_key: API密钥 """ - self.output_dir = output_dir + # 创建新的实现实例 + self._impl = NewContentGenerator( + output_dir=output_dir, + temperature=0.7, + top_p=0.8, + presence_penalty=1.2 + ) + + # 存储API参数 self.model_name = model_name self.api_base_url = api_base_url self.api_key = api_key - # 不在初始化时创建OpenAI客户端,而是在需要时临时创建 - self.client = None + # 设置日志 + logging.basicConfig(level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') + self.logger = logging.getLogger(__name__) - # 初始化数据 + # 为向后兼容保留的字段 self.df = None self.all_images_info = [] self.structured_prompt = "" self.current_img_info = None self.add_description = "" - - self.temperature = 0.7 - self.top_p = 0.8 - self.presence_penalty = 1.2 - # 设置日志 - logging.basicConfig(level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') - self.logger = logging.getLogger(__name__) - - + # 将方法委托给新实现 def load_infomation(self, info_directory_path): - """ - 加载额外描述文件 - - 参数: - info_directory_path: 信息目录路径 - """ - ## 读取路径下的所有文件 - for path in info_directory_path: - # file_extend_path = os.path.join(self.img_base_dir, info_directory, "description.txt") - try: - with open(path, "r") as f: - self.add_description += f.read() - except: - self.add_description = "" + """加载额外描述文件""" + self._impl.load_infomation(info_directory_path) + self.add_description = self._impl.add_description - def _create_temp_client(self): - """ - 创建临时OpenAI客户端 - - 返回: - OpenAI客户端实例 - """ - try: - import gc - # 强制垃圾回收 - gc.collect() - - # 创建新的客户端实例 - print(f"创建临时OpenAI客户端,API URL: {self.api_base_url}") - client = OpenAI( - base_url=self.api_base_url, - api_key=self.api_key - ) - return client - except Exception as e: - print(f"创建OpenAI客户端失败: {str(e)}") - return None - - def _close_client(self, client): - """ - 关闭并清理OpenAI客户端 - - 参数: - client: 需要关闭的客户端实例 - """ - try: - # OpenAI客户端可能没有显式的close方法 - # 将引用设为None,让Python垃圾回收处理 - client = None - import gc - gc.collect() - print("OpenAI客户端资源已释放") - except Exception as e: - print(f"关闭客户端失败: {str(e)}") - - def split_content(self, content): - """ - 分割结果, 返回去除 - ```json - ```的json内容 - - 参数: - content: 需要分割的内容 - - 返回: - 分割后的json内容 - """ - try: - # 首先尝试直接解析整个内容,以防已经是干净的 JSON - try: - return json.loads(content) - except json.JSONDecodeError: - pass # 不是干净的 JSON,继续处理 - - # 常规模式:查找 ```json 和 ``` 之间的内容 - if "```json" in content: - json_str = content.split("```json")[1].split("```")[0].strip() - try: - return json.loads(json_str) - except json.JSONDecodeError as e: - print(f"常规格式解析失败: {e}, 尝试其他方法") - - # 备用模式1:查找连续的 { 开头和 } 结尾的部分 - import re - json_pattern = r'(\[.*?\])' - json_matches = re.findall(json_pattern, content, re.DOTALL) - if json_matches: - for match in json_matches: - try: - result = json.loads(match) - if isinstance(result, list) and len(result) > 0: - return result - except: - continue - - # 备用模式2:查找 { 开头 和 } 结尾,并尝试解析 - content = content.strip() - square_bracket_start = content.find('[') - square_bracket_end = content.rfind(']') - - if square_bracket_start != -1 and square_bracket_end != -1: - potential_json = content[square_bracket_start:square_bracket_end + 1] - try: - return json.loads(potential_json) - except: - print("尝试提取方括号内容失败") - - # 最后一种尝试:查找所有可能的 JSON 结构并尝试解析 - json_structures = re.findall(r'({.*?})', content, re.DOTALL) - if json_structures: - items = [] - for i, struct in enumerate(json_structures): - try: - item = json.loads(struct) - # 验证结构包含预期字段 - if 'main_title' in item and ('texts' in item or 'index' in item): - items.append(item) - except: - continue - - if items: - return items - - # 都失败了,打印错误并引发异常 - print(f"无法解析内容,返回原始文本: {content[:200]}...") - raise ValueError("无法从响应中提取有效的 JSON 格式") - - except Exception as e: - print(f"解析内容时出错: {e}") - print(f"原始内容: {content[:200]}...") # 仅显示前200个字符 - raise e + """分割JSON内容""" + return self._impl.split_content(content) def generate_posters(self, poster_num, tweet_content, system_prompt=None, max_retries=3): - """ - 生成海报内容 - - 参数: - poster_num: 海报数量 - tweet_content: 推文内容 - system_prompt: 系统提示,默认为None则使用预设提示 - max_retries: 最大重试次数 - - 返回: - 生成的海报内容 - """ - full_response = "" - timeout = 60 # 请求超时时间(秒) - - if not system_prompt: - # 使用默认系统提示词 - system_prompt = """ - 你是一名资深海报设计师,有丰富的爆款海报设计经验,你现在要为旅游景点做宣传,在小红书上发布大量宣传海报。你的主要工作目标有2个: - 1、你要根据我给你的图片描述和笔记推文内容,设计图文匹配的海报。 - 2、为海报设计文案,文案的<第一个小标题>和<第二个小标题>之间你需要检查是否逻辑关系合理,你将通过先去生成<第二个小标题>关于景区亮点的部分,再去综合判断<第一个小标题>应该如何搭配组合更符合两个小标题的逻辑再生成<第一个小标题>。 - - 其中,生成三类标题文案的通用性要求如下: - 1、生成的<大标题>字数必须小于8个字符 - 2、生成的<第一个小标题>字数和<第二个小标题>字数,两者都必须小8个字符 - 3、标题和文案都应符合中国社会主义核心价值观 - - 接下来先开始生成<大标题>部分,由于海报是用来宣传旅游景点,生成的海报<大标题>必须使用以下8种格式之一: - ①地名+景点名(例如福建厦门鼓浪屿/厦门鼓浪屿); - ②地名+景点名+plog; - ③拿捏+地名+景点名; - ④地名+景点名+攻略; - ⑤速通+地名+景点名 - ⑥推荐!+地名+景点名 - ⑦勇闯!+地名+景点名 - ⑧收藏!+地名+景点名 - 你需要随机挑选一种格式生成对应景点的文案,但是格式除了上面8种不可以有其他任何格式;同时尽量保证每一种格式出现的频率均衡。 - 接下来先去生成<第二个小标题>,<第二个小标题>文案的创作必须遵循以下原则: - 请根据笔记内容和图片识别,用极简的文字概括这篇笔记和图片中景点的特色亮点,其中你可以参考以下词汇进行创作,这段文案字数控制6-8字符以内; - - 特色亮点可能会出现的词汇不完全举例:非遗、古建、绝佳山水、祈福圣地、研学圣地、解压天堂、中国小瑞士、秘境竹筏游等等类型词汇 - - 接下来再去生成<第一个小标题>,<第一个小标题>文案的创作必须遵循以下原则: - 这部分文案创作公式有5种,分别为: - ①<受众人群画像>+<痛点词> - ②<受众人群画像> - ③<痛点词> - ④<受众人群画像>+ | +<痛点词> - ⑤<痛点词>+ | +<受众人群画像> - 请你根据实际笔记内容,结合这部分文案创作公式,需要结合<受众人群画像>和<痛点词>时,必须根据<第二个小标题>的景点特征和所对应的完整笔记推文内容主旨,特征挑选对应<受众人群画像>和<痛点词>。 - - 我给你提供受众人群画像库和痛点词库如下: - 1、受众人群画像库:情侣党、亲子游、合家游、银发族、亲子研学、学生党、打工人、周边游、本地人、穷游党、性价比、户外人、美食党、出片 - 2、痛点词库:3天2夜、必去、看了都哭了、不能错过、一定要来、问爆了、超全攻略、必打卡、强推、懒人攻略、必游榜、小众打卡、狂喜等等。 - - 你需要为每个请求至少生成{poster_num}个海报设计。请使用JSON格式输出结果,结构如下: - [ - { - "index": 1, - "main_title": "主标题内容", - "texts": ["第一个小标题", "第二个小标题"] - }, - { - "index": 2, - "main_title": "主标题内容", - "texts": ["第一个小标题", "第二个小标题"] - }, - // ... 更多海报 - ] - 确保生成的数量与用户要求的数量一致。只生成上述JSON格式内容,不要有其他任何额外内容。 - - """ - - if self.add_description: - # 创建用户内容,包括info信息和tweet_content - user_content = f""" - 以下是需要你处理的信息: - - 关于景点的描述: - {self.add_description} - - 推文内容: - {tweet_content} - - 请根据这些信息,生成{poster_num}个海报文案配置,以JSON数组格式返回。 - """ - else: - # 仅使用tweet_content - user_content = f""" - 以下是需要你处理的推文内容: - {tweet_content} - - 请根据这些信息,生成{poster_num}个海报文案配置,以JSON数组格式返回。 - """ - - self.logger.info(f"正在生成{poster_num}个海报文案配置") - - # 创建临时客户端 - temp_client = self._create_temp_client() - - if temp_client: - # 重试逻辑 - for retry in range(max_retries): - try: - self.logger.info(f"尝试生成内容 (尝试 {retry+1}/{max_retries})") - - # 定义流式响应处理回调函数 - def handle_stream_chunk(chunk, is_last=False, is_timeout=False, is_error=False, error=None): - nonlocal full_response - - if chunk: - full_response += chunk - # 实时输出到控制台 - print(chunk, end="", flush=True) - - if is_last: - print("\n") # 输出完成后换行 - if is_timeout: - print("警告: 响应流超时") - if is_error: - print(f"错误: {error}") - - # 使用AI_Agent的新回调方式 - from core.ai_agent import AI_Agent - ai_agent = AI_Agent( - self.api_base_url, - self.model_name, - self.api_key, - timeout=timeout, - max_retries=max_retries, - stream_chunk_timeout=30 # 流式块超时时间 - ) - - # 使用回调方式处理流式响应 - try: - full_response = ai_agent.generate_text_stream_with_callback( - system_prompt, - user_content, - callback=handle_stream_chunk, - temperature=self.temperature, - top_p=self.top_p, - presence_penalty=self.presence_penalty - ) - - # 如果成功生成内容,跳出重试循环 - ai_agent.close() - break - - except Exception as e: - error_msg = str(e) - self.logger.error(f"AI生成错误: {error_msg}") - ai_agent.close() - - # 继续重试逻辑 - if retry + 1 >= max_retries: - self.logger.warning("已达到最大重试次数,使用备用方案...") - # 生成备用内容 - full_response = self._generate_fallback_content(poster_num) - else: - self.logger.info(f"将在稍后重试,还剩 {max_retries - retry - 1} 次重试机会") - - except Exception as e: - error_msg = str(e) - self.logger.error(f"API连接错误 (尝试 {retry+1}/{max_retries}): {error_msg}") - - # 如果已经达到最大重试次数 - if retry + 1 >= max_retries: - self.logger.warning("已达到最大重试次数,使用备用方案...") - # 生成备用内容(简单模板) - full_response = self._generate_fallback_content(poster_num) - else: - self.logger.info(f"将在稍后重试,还剩 {max_retries - retry - 1} 次重试机会") - - # 关闭临时客户端 - self._close_client(temp_client) - - return full_response + """生成海报内容""" + return self._impl.generate_posters( + poster_num, + tweet_content, + system_prompt, + api_url=self.api_base_url, + model_name=self.model_name, + api_key=self.api_key, + timeout=60, + max_retries=max_retries + ) def _generate_fallback_content(self, poster_num): - """生成备用内容,当API调用失败时使用""" - self.logger.info("生成备用内容") - default_configs = [] - for i in range(poster_num): - default_configs.append({ - "main_title": f"景点风光 {i+1}", - "texts": ["自然美景", "人文体验"] - }) - return json.dumps(default_configs, ensure_ascii=False) - + """生成备用内容""" + return self._impl._generate_fallback_content(poster_num) + def save_result(self, full_response): - """ - 保存生成结果到文件 - - 参数: - full_response: 生成的完整响应内容 - - 返回: - 结果文件路径 - """ - # 生成时间戳 - print(full_response) - date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - - try: - # 解析内容为JSON格式 - parsed_data = self.split_content(full_response) - - # 验证内容格式并修复 - validated_data = self._validate_and_fix_data(parsed_data) - - # 创建结果文件路径 - result_path = os.path.join(self.output_dir, f"{date_time}.json") - os.makedirs(os.path.dirname(result_path), exist_ok=True) - - # 保存结果到文件 - with open(result_path, "w", encoding="utf-8") as f: - json.dump(validated_data, f, ensure_ascii=False, indent=4) - - print(f"结果已保存到: {result_path}") - return result_path - - except Exception as e: - self.logger.error(f"保存结果到文件时出错: {e}") - # 尝试创建一个简单的备用配置 - fallback_data = [{"main_title": "景点风光", "texts": ["自然美景", "人文体验"], "index": 1}] - - # 保存备用数据 - result_path = os.path.join(self.output_dir, f"{date_time}_fallback.json") - os.makedirs(os.path.dirname(result_path), exist_ok=True) - - with open(result_path, "w", encoding="utf-8") as f: - json.dump(fallback_data, f, ensure_ascii=False, indent=4) - - print(f"出错后已保存备用数据到: {result_path}") - return result_path - + """保存生成结果到文件""" + return self._impl.save_result(full_response) + def _validate_and_fix_data(self, data): - """ - 验证并修复数据格式,确保符合预期结构 - - 参数: - data: 需要验证的数据 - - 返回: - 修复后的数据 - """ - fixed_data = [] - - # 如果数据是列表 - if isinstance(data, list): - for i, item in enumerate(data): - # 检查项目是否为字典 - if isinstance(item, dict): - # 确保必需字段存在 - fixed_item = { - "index": item.get("index", i + 1), - "main_title": item.get("main_title", f"景点风光 {i+1}"), - "texts": item.get("texts", ["自然美景", "人文体验"]) - } - - # 确保texts是列表格式 - if not isinstance(fixed_item["texts"], list): - if isinstance(fixed_item["texts"], str): - fixed_item["texts"] = [fixed_item["texts"], "美景体验"] - else: - fixed_item["texts"] = ["自然美景", "人文体验"] - - # 限制texts最多包含两个元素 - if len(fixed_item["texts"]) > 2: - fixed_item["texts"] = fixed_item["texts"][:2] - elif len(fixed_item["texts"]) < 2: - while len(fixed_item["texts"]) < 2: - fixed_item["texts"].append("美景体验") - - fixed_data.append(fixed_item) - - # 如果项目是字符串(可能是错误格式的texts值) - elif isinstance(item, str): - self.logger.warning(f"配置项 {i+1} 是字符串格式,将转换为标准格式") - fixed_item = { - "index": i + 1, - "main_title": f"景点风光 {i+1}", - "texts": [item, "美景体验"] - } - fixed_data.append(fixed_item) - else: - self.logger.warning(f"配置项 {i+1} 格式不支持: {type(item)},将使用默认值") - fixed_data.append({ - "index": i + 1, - "main_title": f"景点风光 {i+1}", - "texts": ["自然美景", "人文体验"] - }) - - # 如果数据是字典 - elif isinstance(data, dict): - fixed_item = { - "index": data.get("index", 1), - "main_title": data.get("main_title", "景点风光"), - "texts": data.get("texts", ["自然美景", "人文体验"]) - } - - # 确保texts是列表格式 - if not isinstance(fixed_item["texts"], list): - if isinstance(fixed_item["texts"], str): - fixed_item["texts"] = [fixed_item["texts"], "美景体验"] - else: - fixed_item["texts"] = ["自然美景", "人文体验"] - - # 限制texts最多包含两个元素 - if len(fixed_item["texts"]) > 2: - fixed_item["texts"] = fixed_item["texts"][:2] - elif len(fixed_item["texts"]) < 2: - while len(fixed_item["texts"]) < 2: - fixed_item["texts"].append("美景体验") - - fixed_data.append(fixed_item) - - # 如果数据是字符串或其他格式 - else: - self.logger.warning(f"数据格式不支持: {type(data)},将使用默认值") - fixed_data.append({ - "index": 1, - "main_title": "景点风光", - "texts": ["自然美景", "人文体验"] - }) - - # 确保至少有一个配置项 - if not fixed_data: - fixed_data.append({ - "index": 1, - "main_title": "景点风光", - "texts": ["自然美景", "人文体验"] - }) - - return fixed_data - + """验证并修复数据格式""" + return self._impl._validate_and_fix_data(data) + def run(self, info_directory, poster_num, tweet_content, system_prompt=None): - """ - 运行海报内容生成流程,并返回生成的配置数据。 - - 参数: - info_directory: 信息目录路径列表 (e.g., ['/path/to/description.txt']) - poster_num: 需要生成的海报配置数量 - tweet_content: 用于生成内容的推文/文章内容 - - 返回: - list | dict | None: 生成的海报配置数据 (通常是列表),如果生成或解析失败则返回 None。 - """ - self.load_infomation(info_directory) - - # Generate the raw string response from AI - full_response = self.generate_posters(poster_num, tweet_content, system_prompt) - - # Check if generation failed (indicated by return code 404 or other markers) - if full_response == 404 or not isinstance(full_response, str) or not full_response.strip(): - logging.error("Poster content generation failed or returned empty response.") - return None - - # Extract the JSON data from the raw response string - try: - result_data = self.split_content(full_response) # This should return the list/dict - - # 验证并修复结果数据格式 - fixed_data = [] - - # 如果结果是列表,检查每个项目 - if isinstance(result_data, list): - for i, item in enumerate(result_data): - # 如果项目是字典并且有required_fields,按原样添加或修复 - if isinstance(item, dict): - # 检查并确保必需字段存在 - if 'main_title' not in item: - item['main_title'] = f"景点标题 {i+1}" - logging.warning(f"配置项 {i+1} 缺少 main_title 字段,已添加默认值") - - if 'texts' not in item: - item['texts'] = ["景点特色", "游玩体验"] - logging.warning(f"配置项 {i+1} 缺少 texts 字段,已添加默认值") - - if 'index' not in item: - item['index'] = i + 1 - logging.warning(f"配置项 {i+1} 缺少 index 字段,已添加默认值") - - fixed_data.append(item) - # 如果项目是字符串(可能是错误格式的texts值) - elif isinstance(item, str): - logging.warning(f"配置项 {i+1} 是字符串格式,将转换为标准格式") - fixed_item = { - "index": i + 1, - "main_title": f"景点风光 {i+1}", - "texts": [item, "美景体验"] - } - fixed_data.append(fixed_item) - else: - logging.warning(f"配置项 {i+1} 格式不支持: {type(item)},将使用默认值") - fixed_data.append({ - "index": i + 1, - "main_title": f"景点风光 {i+1}", - "texts": ["自然美景", "人文体验"] - }) - - # 如果处理后的列表为空(极端情况),则使用默认值 - if not fixed_data: - logging.warning("处理后的配置列表为空,使用默认值") - for i in range(poster_num): - fixed_data.append({ - "index": i + 1, - "main_title": f"景点风光 {i+1}", - "texts": ["自然美景", "人文体验"] - }) - - logging.info(f"成功生成并修复海报配置数据,包含 {len(fixed_data)} 个项目") - return fixed_data - - # 如果结果是单个字典(不常见但可能),将其转换为列表 - elif isinstance(result_data, dict): - logging.warning(f"生成的配置数据是单个字典格式,将转换为列表") - - # 检查并确保必需字段存在 - if 'main_title' not in result_data: - result_data['main_title'] = "景点风光" - - if 'texts' not in result_data: - result_data['texts'] = ["自然美景", "人文体验"] - - if 'index' not in result_data: - result_data['index'] = 1 - - fixed_data = [result_data] - return fixed_data - - # 如果结果是其他格式(如字符串),创建默认配置 - else: - logging.warning(f"生成的配置数据格式不支持: {type(result_data)},将使用默认值") - for i in range(poster_num): - fixed_data.append({ - "index": i + 1, - "main_title": f"景点风光 {i+1}", - "texts": ["自然美景", "人文体验"] - }) - return fixed_data - - except Exception as e: - logging.exception(f"Failed to parse JSON from AI response in ContentGenerator: {e}\nRaw Response:\n{full_response[:500]}...") # Log error and partial response - - # 失败后创建一个默认配置 - logging.info("创建默认海报配置数据") - default_configs = [] - for i in range(poster_num): - default_configs.append({ - "index": i + 1, - "main_title": f"景点风光 {i+1}", - "texts": ["自然美景", "人文体验"] - }) - return default_configs - + """运行海报内容生成流程""" + return self._impl.run( + info_directory, + poster_num, + tweet_content, + system_prompt, + api_url=self.api_base_url, + model_name=self.model_name, + api_key=self.api_key + ) + def set_temperature(self, temperature): - self.temperature = temperature - + """设置温度参数""" + self._impl.set_temperature(temperature) + def set_top_p(self, top_p): - self.top_p = top_p - + """设置top_p参数""" + self._impl.set_top_p(top_p) + def set_presence_penalty(self, presence_penalty): - self.presence_penalty = presence_penalty + """设置存在惩罚参数""" + self._impl.set_presence_penalty(presence_penalty) def set_model_para(self, temperature, top_p, presence_penalty): - self.temperature = temperature - self.top_p = top_p - self.presence_penalty = presence_penalty - -# def main(): -# # 配置参数 -# info_directory = [ -# "/root/autodl-tmp/sanming_img/相机/甘露寺/description.txt" -# ] # 信息目录 -# poster_num = 4 # 海报数量 - -# # 推文内容 -# tweet_content = """ -# 🌿清明遛娃天花板!悬空古寺+非遗探秘 -# -# -# 清明假期带娃哪里玩?泰宁甘露寺藏着明代建筑奇迹!一柱擎天的悬空阁楼+状元祈福传说,让孩子边玩边涨知识✨ - -# 🎒行程亮点: -# ✅ 安全科普第一站:讲解"一柱插地"千年不倒的秘密,用乐高积木模型让孩子理解力学原理 -# ✅ 文化沉浸体验:穿汉服听"叶状元还愿建寺"故事,触摸3.38米粗的"状元柱"许愿 -# ✅ 自然探索路线:连接金湖栈道徒步,观察丹霞地貌与古建筑的巧妙融合 - -# 📌实用攻略: -# 📍位置:福建省三明市泰宁县金湖西路(导航搜"甘露岩寺") -# 🕒最佳时段:上午10点前抵达避开人流,下午可衔接参观明清园(80元/人) -# ⚠️注意事项:悬空栈道设置儿童安全绳租赁点,建议穿防滑鞋 - -# 💡亲子彩蛋: -# 1️⃣ 在"右鼓左钟"景观区玩声音实验,敲击不同岩石听回声差异 -# 2️⃣ 领取任务卡完成"寻找建筑中的T形拱"小游戏,集章兑换非遗木雕书签 -# 3️⃣ 结合清明节俗,用竹简模板书写祈福语系在古松枝头 - -# 周边推荐:游览完可直奔尚书第明代古民居群,对比不同时期建筑特色,晚餐推荐尝泰宁特色"灯盏糕",亲子套票更划算! - -# 清明带着孩子来场穿越850年的建筑探险,把课本里的力学知识变成触手可及的历史课堂!🌸 - -# #清明节周边游 #亲子科普游 #福建遛娃 #泰宁旅行攻略 #建筑启蒙 -# -# """ - -# # 创建海报生成器 -# generator = ContentGenerator() - -# # 运行生成流程 -# generator.run(info_directory, poster_num, tweet_content) - - -# if __name__ == "__main__": -# main() + """一次性设置所有模型参数""" + self._impl.set_model_para(temperature, top_p, presence_penalty) diff --git a/output/2025-04-24_19-44-15.json b/output/2025-04-24_19-44-15.json new file mode 100644 index 0000000..4368846 --- /dev/null +++ b/output/2025-04-24_19-44-15.json @@ -0,0 +1,26 @@ +[ + { + "index": 1, + "main_title": "泰宁甘露寺攻略", + "texts": [ + "悬空古建探秘", + "亲子研学首选" + ] + }, + { + "index": 2, + "main_title": "推荐!泰宁甘露寺", + "texts": [ + "千年古刹秘境", + "自驾打卡胜地" + ] + }, + { + "index": 3, + "main_title": "泰宁甘露寺", + "texts": [ + "非遗文化沉浸", + "状元祈福之旅" + ] + } +] \ No newline at end of file diff --git a/test_content_generator.py b/test_content_generator.py new file mode 100644 index 0000000..3d929fb --- /dev/null +++ b/test_content_generator.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# 导入新旧实现 +from core.contentGen import ContentGenerator as OldContentGenerator +from utils.content_generator import ContentGenerator as NewContentGenerator + +def test_both_implementations(): + """测试新旧实现的兼容性""" + # 创建测试文本内容 + test_content = """ + 🌿清明遛娃天花板!悬空古寺+非遗探秘 + + + 清明假期带娃哪里玩?泰宁甘露寺藏着明代建筑奇迹!一柱擎天的悬空阁楼+状元祈福传说,让孩子边玩边涨知识✨ + + 🎒行程亮点: + ✅ 安全科普第一站:讲解"一柱插地"千年不倒的秘密,用乐高积木模型让孩子理解力学原理 + ✅ 文化沉浸体验:穿汉服听"叶状元还愿建寺"故事,触摸3.38米粗的"状元柱"许愿 + + """ + + print("=" * 50) + print("测试新旧ContentGenerator实现") + print("=" * 50) + + # 创建输出目录 + import os + os.makedirs("./test_output", exist_ok=True) + + # 测试参数 + api_url = "http://localhost:8000/v1" # 替换为实际URL + model_name = "qwenQWQ" # 替换为实际模型 + api_key = "EMPTY" # 替换为实际密钥 + poster_num = 2 # 生成2个海报配置 + + # 1. 测试旧实现(现在委托给新实现) + print("\n1. 测试旧实现 (core.contentGen.ContentGenerator)") + old_generator = OldContentGenerator( + model_name=model_name, + api_base_url=api_url, + api_key=api_key, + output_dir="./test_output/old" + ) + + # 设置生成参数 + old_generator.set_model_para(0.7, 0.8, 1.2) + + # 运行生成 + print("正在使用旧实现生成海报配置...") + old_result = old_generator.run([], poster_num, test_content) + + if old_result: + print(f"旧实现成功生成 {len(old_result)} 个配置项") + for i, config in enumerate(old_result): + print(f" 配置 {i+1}: {config.get('main_title')} - {config.get('texts')}") + else: + print("旧实现生成失败") + + # 2. 测试新实现 + print("\n2. 测试新实现 (utils.content_generator.ContentGenerator)") + new_generator = NewContentGenerator( + output_dir="./test_output/new", + temperature=0.7, + top_p=0.8, + presence_penalty=1.2 + ) + + # 运行生成 + print("正在使用新实现生成海报配置...") + new_result = new_generator.run( + [], + poster_num, + test_content, + api_url=api_url, + model_name=model_name, + api_key=api_key + ) + + if new_result: + print(f"新实现成功生成 {len(new_result)} 个配置项") + for i, config in enumerate(new_result): + print(f" 配置 {i+1}: {config.get('main_title')} - {config.get('texts')}") + else: + print("新实现生成失败") + + print("\n3. 比较结果") + if old_result and new_result: + import json + + # 格式化输出结果比较 + print("\n旧实现结果:") + print(json.dumps(old_result, ensure_ascii=False, indent=2)) + + print("\n新实现结果:") + print(json.dumps(new_result, ensure_ascii=False, indent=2)) + + # 结构比较 + old_format = [type(config) for config in old_result] + new_format = [type(config) for config in new_result] + + print(f"\n结构比较: 旧实现: {old_format}, 新实现: {new_format}") + print(f"数量比较: 旧实现: {len(old_result)}, 新实现: {len(new_result)}") + + # 检查所有必要的字段 + field_check = True + for result in [old_result, new_result]: + for config in result: + if not all(key in config for key in ["index", "main_title", "texts"]): + field_check = False + break + + print(f"字段检查: {'通过' if field_check else '失败'}") + + print("\n测试完成!") + +if __name__ == "__main__": + test_both_implementations() \ No newline at end of file diff --git a/utils/__pycache__/content_generator.cpython-312.pyc b/utils/__pycache__/content_generator.cpython-312.pyc new file mode 100644 index 0000000..4afd604 Binary files /dev/null and b/utils/__pycache__/content_generator.cpython-312.pyc differ diff --git a/utils/__pycache__/tweet_generator.cpython-312.pyc b/utils/__pycache__/tweet_generator.cpython-312.pyc index d54f51e..7a62d6a 100644 Binary files a/utils/__pycache__/tweet_generator.cpython-312.pyc and b/utils/__pycache__/tweet_generator.cpython-312.pyc differ diff --git a/utils/content_generator.py b/utils/content_generator.py new file mode 100644 index 0000000..530c96b --- /dev/null +++ b/utils/content_generator.py @@ -0,0 +1,529 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import os +import json +import logging +import traceback +from datetime import datetime +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from core.ai_agent import AI_Agent + +class ContentGenerator: + """ + 海报文本内容生成器 + 使用AI_Agent代替直接管理OpenAI客户端,简化代码结构 + """ + def __init__(self, + output_dir="/root/autodl-tmp/poster_generate_result", + temperature=0.7, + top_p=0.8, + presence_penalty=1.2): + """ + 初始化内容生成器 + + 参数: + output_dir: 输出结果保存目录 + temperature: 生成温度参数 + top_p: top_p参数 + presence_penalty: 惩罚参数 + """ + self.output_dir = output_dir + self.temperature = temperature + self.top_p = top_p + self.presence_penalty = presence_penalty + self.add_description = "" + + # 设置日志 + logging.basicConfig(level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') + self.logger = logging.getLogger(__name__) + + def load_infomation(self, info_directory_path): + """ + 加载额外描述文件 + + 参数: + info_directory_path: 信息目录路径列表 + """ + self.add_description = "" # 重置描述文本 + for path in info_directory_path: + try: + with open(path, "r", encoding="utf-8") as f: + self.add_description += f.read() + self.logger.info(f"成功加载描述文件: {path}") + except Exception as e: + self.logger.warning(f"加载描述文件失败: {path}, 错误: {e}") + self.add_description = "" + + def split_content(self, content): + """ + 分割结果, 返回去除 + ```json + ```的json内容 + + 参数: + content: 需要分割的内容 + + 返回: + 分割后的json内容 + """ + try: + # 首先尝试直接解析整个内容,以防已经是干净的 JSON + try: + return json.loads(content) + except json.JSONDecodeError: + pass # 不是干净的 JSON,继续处理 + + # 常规模式:查找 ```json 和 ``` 之间的内容 + if "```json" in content: + json_str = content.split("```json")[1].split("```")[0].strip() + try: + return json.loads(json_str) + except json.JSONDecodeError as e: + self.logger.warning(f"常规格式解析失败: {e}, 尝试其他方法") + + # 备用模式1:查找连续的 [ 开头和 ] 结尾的部分 + import re + json_pattern = r'(\[.*?\])' + json_matches = re.findall(json_pattern, content, re.DOTALL) + if json_matches: + for match in json_matches: + try: + result = json.loads(match) + if isinstance(result, list) and len(result) > 0: + return result + except: + continue + + # 备用模式2:查找 [ 开头 和 ] 结尾,并尝试解析 + content = content.strip() + square_bracket_start = content.find('[') + square_bracket_end = content.rfind(']') + + if square_bracket_start != -1 and square_bracket_end != -1: + potential_json = content[square_bracket_start:square_bracket_end + 1] + try: + return json.loads(potential_json) + except: + self.logger.warning("尝试提取方括号内容失败") + + # 最后一种尝试:查找所有可能的 JSON 结构并尝试解析 + json_structures = re.findall(r'({.*?})', content, re.DOTALL) + if json_structures: + items = [] + for i, struct in enumerate(json_structures): + try: + item = json.loads(struct) + # 验证结构包含预期字段 + if 'main_title' in item and ('texts' in item or 'index' in item): + items.append(item) + except: + continue + + if items: + return items + + # 都失败了,打印错误并引发异常 + self.logger.error(f"无法解析内容,返回原始文本: {content[:200]}...") + raise ValueError("无法从响应中提取有效的 JSON 格式") + + except Exception as e: + self.logger.error(f"解析内容时出错: {e}") + self.logger.debug(f"原始内容: {content[:200]}...") # 仅显示前200个字符 + raise e + + def generate_posters(self, + poster_num, + content_data_list, + system_prompt=None, + api_url="http://localhost:8000/v1", + model_name="qwenQWQ", + api_key="EMPTY", + timeout=60, + max_retries=3): + """ + 生成海报内容 + + 参数: + poster_num: 海报数量 + content_data_list: 内容数据列表(字典或字符串) + system_prompt: 系统提示,默认为None则使用预设提示 + api_url: API基础URL + model_name: 使用的模型名称 + api_key: API密钥 + timeout: 请求超时时间 + max_retries: 最大重试次数 + + 返回: + 生成的海报内容 + """ + # 构建默认系统提示词 + if not system_prompt: + system_prompt = """ + 你是一名资深海报设计师,有丰富的爆款海报设计经验,你现在要为旅游景点做宣传,在小红书上发布大量宣传海报。你的主要工作目标有2个: + 1、你要根据我给你的图片描述和笔记推文内容,设计图文匹配的海报。 + 2、为海报设计文案,文案的<第一个小标题>和<第二个小标题>之间你需要检查是否逻辑关系合理,你将通过先去生成<第二个小标题>关于景区亮点的部分,再去综合判断<第一个小标题>应该如何搭配组合更符合两个小标题的逻辑再生成<第一个小标题>。 + + 其中,生成三类标题文案的通用性要求如下: + 1、生成的<大标题>字数必须小于8个字符 + 2、生成的<第一个小标题>字数和<第二个小标题>字数,两者都必须小8个字符 + 3、标题和文案都应符合中国社会主义核心价值观 + + 接下来先开始生成<大标题>部分,由于海报是用来宣传旅游景点,生成的海报<大标题>必须使用以下8种格式之一: + ①地名+景点名(例如福建厦门鼓浪屿/厦门鼓浪屿); + ②地名+景点名+plog; + ③拿捏+地名+景点名; + ④地名+景点名+攻略; + ⑤速通+地名+景点名 + ⑥推荐!+地名+景点名 + ⑦勇闯!+地名+景点名 + ⑧收藏!+地名+景点名 + 你需要随机挑选一种格式生成对应景点的文案,但是格式除了上面8种不可以有其他任何格式;同时尽量保证每一种格式出现的频率均衡。 + 接下来先去生成<第二个小标题>,<第二个小标题>文案的创作必须遵循以下原则: + 请根据笔记内容和图片识别,用极简的文字概括这篇笔记和图片中景点的特色亮点,其中你可以参考以下词汇进行创作,这段文案字数控制6-8字符以内; + + 特色亮点可能会出现的词汇不完全举例:非遗、古建、绝佳山水、祈福圣地、研学圣地、解压天堂、中国小瑞士、秘境竹筏游等等类型词汇 + + 接下来再去生成<第一个小标题>,<第一个小标题>文案的创作必须遵循以下原则: + 这部分文案创作公式有5种,分别为: + ①<受众人群画像>+<痛点词> + ②<受众人群画像> + ③<痛点词> + ④<受众人群画像>+ | +<痛点词> + ⑤<痛点词>+ | +<受众人群画像> + 请你根据实际笔记内容,结合这部分文案创作公式,需要结合<受众人群画像>和<痛点词>时,必须根据<第二个小标题>的景点特征和所对应的完整笔记推文内容主旨,特征挑选对应<受众人群画像>和<痛点词>。 + + 我给你提供受众人群画像库和痛点词库如下: + 1、受众人群画像库:情侣党、亲子游、合家游、银发族、亲子研学、学生党、打工人、周边游、本地人、穷游党、性价比、户外人、美食党、出片 + 2、痛点词库:3天2夜、必去、看了都哭了、不能错过、一定要来、问爆了、超全攻略、必打卡、强推、懒人攻略、必游榜、小众打卡、狂喜等等。 + + 你需要为每个请求至少生成{poster_num}个海报设计。请使用JSON格式输出结果,结构如下: + [ + { + "index": 1, + "main_title": "主标题内容", + "texts": ["第一个小标题", "第二个小标题"] + }, + { + "index": 2, + "main_title": "主标题内容", + "texts": ["第一个小标题", "第二个小标题"] + }, + // ... 更多海报 + ] + 确保生成的数量与用户要求的数量一致。只生成上述JSON格式内容,不要有其他任何额外内容。 + """ + + # 提取内容文本(如果是列表内容数据) + tweet_content = "" + if isinstance(content_data_list, list): + for item in content_data_list: + if isinstance(item, dict): + title = item.get('title', '') + content = item.get('content', '') + tweet_content += f"\n{title}\n\n\n{content}\n\n\n" + elif isinstance(item, str): + tweet_content += item + "\n\n" + elif isinstance(content_data_list, str): + tweet_content = content_data_list + + # 构建用户提示 + if self.add_description: + user_content = f""" + 以下是需要你处理的信息: + + 关于景点的描述: + {self.add_description} + + 推文内容: + {tweet_content} + + 请根据这些信息,生成{poster_num}个海报文案配置,以JSON数组格式返回。 + """ + else: + user_content = f""" + 以下是需要你处理的推文内容: + {tweet_content} + + 请根据这些信息,生成{poster_num}个海报文案配置,以JSON数组格式返回。 + """ + + self.logger.info(f"正在生成{poster_num}个海报文案配置") + + # 创建AI_Agent实例 + ai_agent = AI_Agent( + api_url, + model_name, + api_key, + timeout=timeout, + max_retries=max_retries, + stream_chunk_timeout=30 # 流式块超时时间 + ) + + full_response = "" + try: + # 使用AI_Agent的non-streaming方法 + self.logger.info(f"调用AI生成海报配置,模型: {model_name}") + full_response, tokens, time_cost = ai_agent.work( + system_prompt, + user_content, + "", # 历史消息(空) + self.temperature, + self.top_p, + self.presence_penalty + ) + + self.logger.info(f"AI生成完成,耗时: {time_cost:.2f}s, 预估令牌数: {tokens}") + + if not full_response: + self.logger.warning("AI返回空响应,使用备用内容") + full_response = self._generate_fallback_content(poster_num) + except Exception as e: + self.logger.exception(f"AI生成过程发生错误: {e}") + full_response = self._generate_fallback_content(poster_num) + finally: + # 确保关闭AI Agent + ai_agent.close() + + return full_response + + def _generate_fallback_content(self, poster_num): + """生成备用内容,当API调用失败时使用""" + self.logger.info("生成备用内容") + default_configs = [] + for i in range(poster_num): + default_configs.append({ + "index": i + 1, + "main_title": f"景点风光 {i+1}", + "texts": ["自然美景", "人文体验"] + }) + return json.dumps(default_configs, ensure_ascii=False) + + def save_result(self, full_response, custom_output_dir=None): + """ + 保存生成结果到文件 + + 参数: + full_response: 生成的完整响应内容 + custom_output_dir: 自定义输出目录(可选) + + 返回: + 结果文件路径 + """ + # 生成时间戳 + date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + output_dir = custom_output_dir or self.output_dir + + try: + # 解析内容为JSON格式 + parsed_data = self.split_content(full_response) + + # 验证内容格式并修复 + validated_data = self._validate_and_fix_data(parsed_data) + + # 创建结果文件路径 + result_path = os.path.join(output_dir, f"{date_time}.json") + os.makedirs(os.path.dirname(result_path), exist_ok=True) + + # 保存结果到文件 + with open(result_path, "w", encoding="utf-8") as f: + json.dump(validated_data, f, ensure_ascii=False, indent=4) + + self.logger.info(f"结果已保存到: {result_path}") + return result_path + + except Exception as e: + self.logger.error(f"保存结果到文件时出错: {e}") + # 尝试创建一个简单的备用配置 + fallback_data = [{"main_title": "景点风光", "texts": ["自然美景", "人文体验"], "index": 1}] + + # 保存备用数据 + result_path = os.path.join(output_dir, f"{date_time}_fallback.json") + os.makedirs(os.path.dirname(result_path), exist_ok=True) + + with open(result_path, "w", encoding="utf-8") as f: + json.dump(fallback_data, f, ensure_ascii=False, indent=4) + + self.logger.info(f"出错后已保存备用数据到: {result_path}") + return result_path + + def _validate_and_fix_data(self, data): + """ + 验证并修复数据格式,确保符合预期结构 + + 参数: + data: 需要验证的数据 + + 返回: + 修复后的数据 + """ + fixed_data = [] + + # 如果数据是列表 + if isinstance(data, list): + for i, item in enumerate(data): + # 检查项目是否为字典 + if isinstance(item, dict): + # 确保必需字段存在 + fixed_item = { + "index": item.get("index", i + 1), + "main_title": item.get("main_title", f"景点风光 {i+1}"), + "texts": item.get("texts", ["自然美景", "人文体验"]) + } + + # 确保texts是列表格式 + if not isinstance(fixed_item["texts"], list): + if isinstance(fixed_item["texts"], str): + fixed_item["texts"] = [fixed_item["texts"], "美景体验"] + else: + fixed_item["texts"] = ["自然美景", "人文体验"] + + # 限制texts最多包含两个元素 + if len(fixed_item["texts"]) > 2: + fixed_item["texts"] = fixed_item["texts"][:2] + elif len(fixed_item["texts"]) < 2: + while len(fixed_item["texts"]) < 2: + fixed_item["texts"].append("美景体验") + + fixed_data.append(fixed_item) + + # 如果项目是字符串(可能是错误格式的texts值) + elif isinstance(item, str): + self.logger.warning(f"配置项 {i+1} 是字符串格式,将转换为标准格式") + fixed_item = { + "index": i + 1, + "main_title": f"景点风光 {i+1}", + "texts": [item, "美景体验"] + } + fixed_data.append(fixed_item) + else: + self.logger.warning(f"配置项 {i+1} 格式不支持: {type(item)},将使用默认值") + fixed_data.append({ + "index": i + 1, + "main_title": f"景点风光 {i+1}", + "texts": ["自然美景", "人文体验"] + }) + + # 如果数据是字典 + elif isinstance(data, dict): + fixed_item = { + "index": data.get("index", 1), + "main_title": data.get("main_title", "景点风光"), + "texts": data.get("texts", ["自然美景", "人文体验"]) + } + + # 确保texts是列表格式 + if not isinstance(fixed_item["texts"], list): + if isinstance(fixed_item["texts"], str): + fixed_item["texts"] = [fixed_item["texts"], "美景体验"] + else: + fixed_item["texts"] = ["自然美景", "人文体验"] + + # 限制texts最多包含两个元素 + if len(fixed_item["texts"]) > 2: + fixed_item["texts"] = fixed_item["texts"][:2] + elif len(fixed_item["texts"]) < 2: + while len(fixed_item["texts"]) < 2: + fixed_item["texts"].append("美景体验") + + fixed_data.append(fixed_item) + + # 如果数据是字符串或其他格式 + else: + self.logger.warning(f"数据格式不支持: {type(data)},将使用默认值") + fixed_data.append({ + "index": 1, + "main_title": "景点风光", + "texts": ["自然美景", "人文体验"] + }) + + # 确保至少有一个配置项 + if not fixed_data: + fixed_data.append({ + "index": 1, + "main_title": "景点风光", + "texts": ["自然美景", "人文体验"] + }) + + return fixed_data + + def run(self, info_directory, poster_num, content_data, system_prompt=None, + api_url="http://localhost:8000/v1", model_name="qwenQWQ", api_key="EMPTY"): + """ + 运行海报内容生成流程,并返回生成的配置数据。 + + 参数: + info_directory: 信息目录路径列表 (e.g., ['/path/to/description.txt']) + poster_num: 需要生成的海报配置数量 + content_data: 用于生成内容的文章内容(可以是字符串或字典列表) + system_prompt: 系统提示词,默认为None使用内置提示词 + api_url: API基础URL + model_name: 使用的模型名称 + api_key: API密钥 + + 返回: + list | dict | None: 生成的海报配置数据 (通常是列表),如果生成或解析失败则返回 None。 + """ + try: + # 加载描述信息 + self.load_infomation(info_directory) + + # 生成海报内容 + self.logger.info(f"开始生成海报内容,数量: {poster_num}") + full_response = self.generate_posters( + poster_num, + content_data, + system_prompt, + api_url, + model_name, + api_key + ) + + # 检查生成是否失败 + if not isinstance(full_response, str) or not full_response.strip(): + self.logger.error("海报内容生成失败或返回空响应") + return None + + # 从原始响应字符串中提取JSON数据 + result_data = self.split_content(full_response) + + # 验证并修复数据 + fixed_data = self._validate_and_fix_data(result_data) + + self.logger.info(f"成功生成并修复海报配置数据,包含 {len(fixed_data)} 个项目") + return fixed_data + + except Exception as e: + self.logger.exception(f"海报内容生成过程中发生错误: {e}") + traceback.print_exc() + + # 失败后创建一个默认配置 + self.logger.info("创建默认海报配置数据") + default_configs = [] + for i in range(poster_num): + default_configs.append({ + "index": i + 1, + "main_title": f"景点风光 {i+1}", + "texts": ["自然美景", "人文体验"] + }) + return default_configs + + def set_temperature(self, temperature): + """设置温度参数""" + self.temperature = temperature + + def set_top_p(self, top_p): + """设置top_p参数""" + self.top_p = top_p + + def set_presence_penalty(self, presence_penalty): + """设置存在惩罚参数""" + self.presence_penalty = presence_penalty + + def set_model_para(self, temperature, top_p, presence_penalty): + """一次性设置所有模型参数""" + self.temperature = temperature + self.top_p = top_p + self.presence_penalty = presence_penalty \ No newline at end of file