From b5ddb84cdd9ceb37c9820af66f04129ff305fedf Mon Sep 17 00:00:00 2001 From: jinye_huang Date: Thu, 10 Jul 2025 11:52:44 +0800 Subject: [PATCH] =?UTF-8?q?vibrant=E6=A8=A1=E7=89=88=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E6=88=90=E5=8A=9F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 3 +- config/ai_model.json | 2 +- scripts/generate_vibrant_poster.py | 360 +++++++++++++++++++++++++++++ 3 files changed, 363 insertions(+), 2 deletions(-) create mode 100644 scripts/generate_vibrant_poster.py diff --git a/.gitignore b/.gitignore index c308743..109245c 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ /log/ /__pycache__/ /assets/ -/resource/data/ \ No newline at end of file +/resource/data/ +/tests/output/ \ No newline at end of file diff --git a/config/ai_model.json b/config/ai_model.json index 5ede62a..66e8356 100644 --- a/config/ai_model.json +++ b/config/ai_model.json @@ -1,5 +1,5 @@ { - "model": "qwen-plus", + "model": "qwq-plus", "api_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", "api_key": "sk-bd5ee62703bc41fc9b8a55d748dc1eb8", "temperature": 0.3, diff --git a/scripts/generate_vibrant_poster.py b/scripts/generate_vibrant_poster.py new file mode 100644 index 0000000..50e4544 --- /dev/null +++ b/scripts/generate_vibrant_poster.py @@ -0,0 +1,360 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +使用LLM和Vibrant模板从源文件自动生成海报的脚本。 +""" + +import os +import sys +import json +import random +import asyncio +import argparse +import logging +import re +from datetime import datetime + +# 将项目根目录添加到Python路径中 +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from utils.poster.templates.vibrant_template import VibrantTemplate +from core.ai.ai_agent import AIAgent +from core.config import AIModelConfig +from multiprocessing.pool import ThreadPool +# 配置日志记录 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +def read_data_file(file_path: str): + """安全地读取并解析各种格式的文件。""" + try: + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + + # 尝试解析为标准JSON + try: + return json.loads(content) + except json.JSONDecodeError: + # 如果不是标准JSON,尝试解析为类JSON格式 + logger.info(f"文件 {file_path} 不是标准JSON格式,尝试其他解析方式...") + + # 检查是否是类JSON格式(带有大括号的键值对) + if content.strip().startswith('{') and content.strip().endswith('}'): + try: + # 尝试将内容解析为字典 + result = {} + # 使用正则表达式提取键值对 + pattern = r'"([^"]+)":\s*(.+?)(?=,\s*"[^"]+":|$)' + matches = re.findall(pattern, content, re.DOTALL) + + for key, value in matches: + # 清理值中的引号和多余空格 + value = value.strip() + if value.startswith('"') and value.endswith('"'): + # 字符串值 + result[key] = value[1:-1] + elif value.lower() in ['true', 'false']: + # 布尔值 + result[key] = value.lower() == 'true' + elif value.replace('.', '', 1).isdigit(): + # 数字值 + result[key] = float(value) if '.' in value else int(value) + else: + # 其他值作为字符串处理 + result[key] = value + + if result: + logger.info(f"成功解析文件 {file_path} 为字典") + return result + except Exception as e: + logger.error(f"尝试解析类JSON格式失败: {e}") + + # 如果上述方法失败,尝试直接提取文本内容 + logger.info(f"尝试将文件 {file_path} 作为纯文本处理") + + # 简单解析键值对格式的文本文件 + result = {} + lines = content.split('\n') + current_key = None + current_value = [] + + for line in lines: + line = line.strip() + if not line: + continue + + # 检查是否是新的键值对 + if ':' in line and not line.startswith(' ') and not line.startswith('\t'): + # 保存之前的键值对 + if current_key and current_value: + result[current_key] = '\n'.join(current_value) + current_value = [] + + # 提取新的键值对 + parts = line.split(':', 1) + current_key = parts[0].strip().strip('"').strip("'") + if len(parts) > 1 and parts[1].strip(): + current_value = [parts[1].strip()] + else: + current_value = [] + elif current_key: + # 继续添加到当前值 + current_value.append(line) + + # 保存最后一个键值对 + if current_key and current_value: + result[current_key] = '\n'.join(current_value) + + if result: + logger.info(f"成功将文件 {file_path} 解析为键值对") + return result + + # 如果所有方法都失败,返回文本内容 + logger.warning(f"无法解析文件 {file_path} 为结构化数据,返回原始文本") + return {"content": content} + + except FileNotFoundError: + logger.error(f"文件未找到: {file_path}") + except Exception as e: + logger.error(f"读取文件时发生未知错误 {file_path}: {e}") + return None + +async def generate_info_with_llm(ai_agent: AIAgent, scenic_info: dict, product_info: dict, tweet_info: dict): + """ + 调用LLM,根据输入信息生成海报所需的vibrant_content。 + + Args: + ai_agent: AIAgent实例 + scenic_info: 景区信息字典 + product_info: 产品信息字典 + tweet_info: 推文信息字典 + + Returns: + 一个符合VibrantTemplate格式的内容字典 + """ + logger.info("正在调用LLM以生成海报内容...") + + # 构建系统提示 + system_prompt = """你是一名专业的海报设计师,专门设计宣传海报。你现在要根据用户提供的图片描述和推文内容,生成适合模板的海报信息。 + +模板特点: +- 单图背景,毛玻璃渐变效果 +- 两栏布局(左栏内容,右栏价格) +- 适合展示套餐内容和价格信息 + +你需要生成的数据结构包含以下字段: + +**必填字段:** +1. `title`: 主标题(8-12字符,体现产品特色) +2. `slogan`: 副标题/宣传语(10-20字符,吸引人的描述) +3. `price`: 价格数字(纯数字,不含符号) +4. `ticket_type`: 票种类型(如"成人票"、"套餐票"、"夜场票"等) +5. `content_button`: 内容按钮文字(通常为"套餐内容"、"包含项目"等) +6. `content_items`: 套餐内容列表(3-5个项目,每项5-15字符,不要只包含项目名称,要做合适的美化,可以适当省略) + +**可选字段:** +7. `remarks`: 备注信息(1-3条,每条10-20字符) +8. `tag`: 标签(1条, 如"#限时优惠"、""等) +9. `pagination`: 分页信息(如"1/3",可为空) + +**内容创作要求:** +1. 套餐内容要具体实用:明确说明包含的服务、时间、数量 +2. 价格要有吸引力:突出性价比和优惠信息 + +**输出格式:** +请严格按照以下JSON格式输出,不要有任何额外内容: + +```json +{ + "title": "海洋奇幻世界", + "slogan": "探索深海秘境,感受蓝色奇迹的无限魅力", + "price": "299", + "ticket_type": "成人票", + "content_button": "套餐内容", + "content_items": [ + "海洋馆门票1张(含所有展区)", + "海豚表演VIP座位", + "鲨鱼隧道特别体验", + "专业摄影服务" + ], + "remarks": [ + "工作日可直接入园", + "周末需提前预约" + ], + "tag": "#海洋特惠", + "pagination": "1/2" +} +``` +""" + + # 构建用户提示 + user_prompt = f"""请根据以下信息,生成适合在旅游海报上展示的文案: + +## 景区信息 +{json.dumps(scenic_info, ensure_ascii=False, indent=2)} + +## 产品信息 +{json.dumps(product_info, ensure_ascii=False, indent=2)} + +## 推文信息 +{json.dumps(tweet_info, ensure_ascii=False, indent=2)} + +请提取关键信息并整合成一个JSON对象,包含title、slogan、price、ticket_type、content_items、remarks和tag字段。 +""" + + try: + # 调用AIAgent生成文本 + response, _, _, _ = await ai_agent.generate_text( + system_prompt=system_prompt, + user_prompt=user_prompt, + temperature=0.7, + stage="海报文案生成" + ) + + # 解析JSON响应 + try: + # 尝试找到JSON对象的开始和结束位置 + json_start = response.find('{') + json_end = response.rfind('}') + 1 + + if json_start >= 0 and json_end > json_start: + json_str = response[json_start:json_end] + content_dict = json.loads(json_str) + logger.info(f"LLM成功生成内容: {content_dict}") + + # 添加默认的按钮文本和分页信息 + content_dict["content_button"] = content_dict.get("content_button", "套餐内容") + content_dict["pagination"] = content_dict.get("pagination", "") + + # 确保所有值都是字符串类型 + for key, value in content_dict.items(): + if isinstance(value, (int, float)): + content_dict[key] = str(value) + elif isinstance(value, list): + content_dict[key] = [str(item) if isinstance(item, (int, float)) else item for item in value] + + logger.info(f"转换类型后的内容: {content_dict}") + return content_dict + else: + logger.error(f"无法在响应中找到JSON对象: {response}") + return None + except json.JSONDecodeError: + logger.error(f"无法解析LLM响应为JSON: {response}") + return None + + except Exception as e: + logger.error(f"调用LLM时发生错误: {e}") + return None + +async def generate_poster(img_dir, output_dir, scenic_info_file, product_info_file, tweet_info_file, model, api_key, api_url): + """ + 根据提供的参数生成海报。 + """ + # --- 1. 读取所有源文件 --- + scenic_info = read_data_file(scenic_info_file) + product_info = read_data_file(product_info_file) + tweet_info = read_data_file(tweet_info_file) + + if not all([scenic_info, product_info, tweet_info]): + logger.error("一个或多个源文件读取失败,终止执行。") + return + + # --- 2. 初始化AI代理 --- + ai_config = AIModelConfig( + model=model, + api_key=api_key, + api_url=api_url + ) + ai_agent = AIAgent(ai_config) + + # --- 3. 调用LLM生成海报内容 --- + vibrant_content = await generate_info_with_llm(ai_agent, scenic_info, product_info, tweet_info) + if not vibrant_content: + logger.error("未能从LLM获取有效内容,终止执行。") + return + + # --- 4. 随机选择图片 --- + try: + image_files = [f for f in os.listdir(img_dir) if f.lower().endswith(('png', 'jpg', 'jpeg', 'webp'))] + if not image_files: + logger.error(f"在目录 {img_dir} 中未找到任何图片文件。") + return + random_image_name = random.choice(image_files) + image_path = os.path.join(img_dir, random_image_name) + logger.info(f"随机选择图片: {image_path}") + except FileNotFoundError: + logger.error(f"图片目录不存在: {img_dir}") + return + + # --- 5. 调用模板生成海报 --- + try: + logger.info("正在初始化 VibrantTemplate...") + template = VibrantTemplate() + poster = template.generate(image_path=image_path, content=vibrant_content) + + if not poster: + logger.error("海报生成失败,模板返回了 None。") + return + + except Exception as e: + logger.error(f"生成海报时发生未知错误: {e}", exc_info=True) + return + + # --- 6. 保存图片 --- + try: + os.makedirs(output_dir, exist_ok=True) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + scenic_name = scenic_info.get('主体名称', 'poster') + output_filename = f"vibrant_{scenic_name}_{timestamp}.png" + output_path = os.path.join(output_dir, output_filename) + + poster.save(output_path, 'PNG') + logger.info(f"海报已成功生成并保存至: {output_path}") + + except Exception as e: + logger.error(f"保存海报失败: {e}", exc_info=True) + +def run_async(func, *args): # 同步包装器 + return asyncio.run(func(*args)) + +async def main(): + parser = argparse.ArgumentParser( + description="通过LLM从源文件生成信息,并使用Vibrant模板创建海报。", + formatter_class=argparse.RawTextHelpFormatter + ) + + # 设置默认值 + img_dir = "/root/autodl-tmp/TCC_RESTRUCT/resource/data/images/天津冒险湾" + output_dir = "result/天津冒险湾" + scenic_info_file = "resource/data/Object/天津冒险湾.txt" + product_info_file = "resource/data/Product/天津冒险湾-2大2小套票.txt" + tweet_info_file = "/root/autodl-tmp/TCC_RESTRUCT/result/run_20250709_160942/topic_1/article_judged.json" + + model = "qwen-plus" + api_key = "sk-bd5ee62703bc41fc9b8a55d748dc1eb8" + api_url = "https://dashscope.aliyuncs.com/compatible-mode/v1" + + # 确保输出目录存在 + os.makedirs(output_dir, exist_ok=True) + tasks = [ + (img_dir, output_dir, scenic_info_file, product_info_file, tweet_info_file, model, api_key, api_url), + (img_dir, output_dir, scenic_info_file, product_info_file, tweet_info_file, model, api_key, api_url), + (img_dir, output_dir, scenic_info_file, product_info_file, tweet_info_file, model, api_key, api_url), + (img_dir, output_dir, scenic_info_file, product_info_file, tweet_info_file, model, api_key, api_url), + (img_dir, output_dir, scenic_info_file, product_info_file, tweet_info_file, model, api_key, api_url), + ] + with ThreadPool(5) as pool: + # 直接传递任务参数,不需要额外包装 + results = pool.starmap( + run_async, + [(generate_poster, *task_args) for task_args in tasks] + ) + pool.close() + pool.join() + +print("所有海报生成完成。") + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file