#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 使用LLM和Vibrant模板从源文件自动生成海报的脚本。 """ import os import sys import json import json_repair import random import asyncio import argparse import logging import re from datetime import datetime # 将项目根目录添加到Python路径中 sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from poster.templates.vibrant_template import VibrantTemplate from core.ai.ai_agent import AIAgent from core.config import AIModelConfig from multiprocessing.pool import ThreadPool # 配置日志记录 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) def read_data_file(file_path: str): """安全地读取并解析各种格式的文件。""" try: with open(file_path, 'r', encoding='utf-8') as f: content = f.read() # 尝试解析为标准JSON try: return json_repair.loads(content) except json.JSONDecodeError: # 如果不是标准JSON,尝试解析为类JSON格式 logger.info(f"文件 {file_path} 不是标准JSON格式,尝试其他解析方式...") # 检查是否是类JSON格式(带有大括号的键值对) if content.strip().startswith('{') and content.strip().endswith('}'): try: # 尝试将内容解析为字典 result = {} # 使用正则表达式提取键值对 pattern = r'"([^"]+)":\s*(.+?)(?=,\s*"[^"]+":|$)' matches = re.findall(pattern, content, re.DOTALL) for key, value in matches: # 清理值中的引号和多余空格 value = value.strip() if value.startswith('"') and value.endswith('"'): # 字符串值 result[key] = value[1:-1] elif value.lower() in ['true', 'false']: # 布尔值 result[key] = value.lower() == 'true' elif value.replace('.', '', 1).isdigit(): # 数字值 result[key] = float(value) if '.' in value else int(value) else: # 其他值作为字符串处理 result[key] = value if result: logger.info(f"成功解析文件 {file_path} 为字典") return result except Exception as e: logger.error(f"尝试解析类JSON格式失败: {e}") # 如果上述方法失败,尝试直接提取文本内容 logger.info(f"尝试将文件 {file_path} 作为纯文本处理") # 简单解析键值对格式的文本文件 result = {} lines = content.split('\n') current_key = None current_value = [] for line in lines: line = line.strip() if not line: continue # 检查是否是新的键值对 if ':' in line and not line.startswith(' ') and not line.startswith('\t'): # 保存之前的键值对 if current_key and current_value: result[current_key] = '\n'.join(current_value) current_value = [] # 提取新的键值对 parts = line.split(':', 1) current_key = parts[0].strip().strip('"').strip("'") if len(parts) > 1 and parts[1].strip(): current_value = [parts[1].strip()] else: current_value = [] elif current_key: # 继续添加到当前值 current_value.append(line) # 保存最后一个键值对 if current_key and current_value: result[current_key] = '\n'.join(current_value) if result: logger.info(f"成功将文件 {file_path} 解析为键值对") return result # 如果所有方法都失败,返回文本内容 logger.warning(f"无法解析文件 {file_path} 为结构化数据,返回原始文本") return {"content": content} except FileNotFoundError: logger.error(f"文件未找到: {file_path}") except Exception as e: logger.error(f"读取文件时发生未知错误 {file_path}: {e}") return None async def generate_info_with_llm(ai_agent: AIAgent, scenic_info: dict, product_info: dict, tweet_info: dict): """ 调用LLM,根据输入信息生成海报所需的vibrant_content。 Args: ai_agent: AIAgent实例 scenic_info: 景区信息字典 product_info: 产品信息字典 tweet_info: 推文信息字典 Returns: 一个符合VibrantTemplate格式的内容字典 """ logger.info("正在调用LLM以生成海报内容...") # 构建系统提示 system_prompt = """你是一名专业的海报设计师,专门设计宣传海报。你现在要根据用户提供的图片描述和推文内容,生成适合模板的海报信息。 模板特点: - 单图背景,毛玻璃渐变效果 - 两栏布局(左栏内容,右栏价格) - 适合展示套餐内容和价格信息 你需要生成的数据结构包含以下字段: **必填字段:** 1. `title`: 主标题(8-12字符,体现产品特色) 2. `slogan`: 副标题/宣传语(10-20字符,吸引人的描述) 3. `price`: 价格数字(纯数字,不含符号) 4. `ticket_type`: 票种类型(如"成人票"、"套餐票"、"夜场票"等) 5. `content_button`: 内容按钮文字(通常为"套餐内容"、"包含项目"等) 6. `content_items`: 套餐内容列表(3-5个项目,每项5-15字符,不要只包含项目名称,要做合适的美化,可以适当省略) **可选字段:** 7. `remarks`: 备注信息(1-3条,每条10-20字符) 8. `tag`: 标签(1条, 如"#限时优惠"、""等) 9. `pagination`: 分页信息(如"1/3",可为空) **内容创作要求:** 1. 套餐内容要具体实用:明确说明包含的服务、时间、数量 2. 价格要有吸引力:突出性价比和优惠信息 **输出格式:** 请严格按照以下JSON格式输出,不要有任何额外内容: ```json { "title": "海洋奇幻世界", "slogan": "探索深海秘境,感受蓝色奇迹的无限魅力", "price": "299", "ticket_type": "成人票", "content_button": "套餐内容", "content_items": [ "海洋馆门票1张(含所有展区)", "海豚表演VIP座位", "鲨鱼隧道特别体验", "专业摄影服务" ], "remarks": [ "工作日可直接入园", "周末需提前预约" ], "tag": "#海洋特惠", "pagination": "1/2" } ``` """ # 构建用户提示 user_prompt = f"""请根据以下信息,生成适合在旅游海报上展示的文案: ## 景区信息 {json.dumps(scenic_info, ensure_ascii=False, indent=2)} ## 产品信息 {json.dumps(product_info, ensure_ascii=False, indent=2)} ## 推文信息 {json.dumps(tweet_info, ensure_ascii=False, indent=2)} 请提取关键信息并整合成一个JSON对象,包含title、slogan、price、ticket_type、content_items、remarks和tag字段。 """ try: # 调用AIAgent生成文本 response, _, _, _ = await ai_agent.generate_text( system_prompt=system_prompt, user_prompt=user_prompt, temperature=0.7, stage="海报文案生成" ) # 解析JSON响应 try: # 尝试找到JSON对象的开始和结束位置 json_start = response.find('{') json_end = response.rfind('}') + 1 if json_start >= 0 and json_end > json_start: json_str = response[json_start:json_end] content_dict = json_repair.loads(json_str) logger.info(f"LLM成功生成内容: {content_dict}") # 添加默认的按钮文本和分页信息 content_dict["content_button"] = content_dict.get("content_button", "套餐内容") content_dict["pagination"] = content_dict.get("pagination", "") # 确保所有值都是字符串类型 for key, value in content_dict.items(): if isinstance(value, (int, float)): content_dict[key] = str(value) elif isinstance(value, list): content_dict[key] = [str(item) if isinstance(item, (int, float)) else item for item in value] logger.info(f"转换类型后的内容: {content_dict}") return content_dict else: logger.error(f"无法在响应中找到JSON对象: {response}") return None except json.JSONDecodeError: logger.error(f"无法解析LLM响应为JSON: {response}") return None except Exception as e: logger.error(f"调用LLM时发生错误: {e}") return None async def generate_poster(img_dir, output_dir, scenic_info_file, product_info_file, tweet_info_file, model, api_key, api_url): """ 根据提供的参数生成海报。 """ # --- 1. 读取所有源文件 --- scenic_info = read_data_file(scenic_info_file) product_info = read_data_file(product_info_file) tweet_info = read_data_file(tweet_info_file) if not all([scenic_info, product_info, tweet_info]): logger.error("一个或多个源文件读取失败,终止执行。") return # --- 2. 初始化AI代理 --- ai_config = AIModelConfig( model=model, api_key=api_key, api_url=api_url ) ai_agent = AIAgent(ai_config) # --- 3. 调用LLM生成海报内容 --- vibrant_content = await generate_info_with_llm(ai_agent, scenic_info, product_info, tweet_info) if not vibrant_content: logger.error("未能从LLM获取有效内容,终止执行。") return # --- 4. 随机选择图片 --- try: image_files = [f for f in os.listdir(img_dir) if f.lower().endswith(('png', 'jpg', 'jpeg', 'webp'))] if not image_files: logger.error(f"在目录 {img_dir} 中未找到任何图片文件。") return random_image_name = random.choice(image_files) image_path = os.path.join(img_dir, random_image_name) logger.info(f"随机选择图片: {image_path}") except FileNotFoundError: logger.error(f"图片目录不存在: {img_dir}") return # --- 5. 调用模板生成海报 --- try: logger.info("正在初始化 VibrantTemplate...") template = VibrantTemplate() poster = template.generate(image_path=image_path, content=vibrant_content) if not poster: logger.error("海报生成失败,模板返回了 None。") return except Exception as e: logger.error(f"生成海报时发生未知错误: {e}", exc_info=True) return # --- 6. 保存图片 --- try: os.makedirs(output_dir, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") scenic_name = scenic_info.get('主体名称', 'poster') output_filename = f"vibrant_{scenic_name}_{timestamp}.png" output_path = os.path.join(output_dir, output_filename) poster.save(output_path, 'PNG') logger.info(f"海报已成功生成并保存至: {output_path}") except Exception as e: logger.error(f"保存海报失败: {e}", exc_info=True) def run_async(func, *args): # 同步包装器 return asyncio.run(func(*args)) async def main(): parser = argparse.ArgumentParser( description="通过LLM从源文件生成信息,并使用Vibrant模板创建海报。", formatter_class=argparse.RawTextHelpFormatter ) # 设置默认值 img_dir = "/root/autodl-tmp/TCC_RESTRUCT/resource/data/images/天津冒险湾" output_dir = "result/天津冒险湾" scenic_info_file = "resource/data/Object/天津冒险湾.txt" product_info_file = "resource/data/Product/天津冒险湾-2大2小套票.txt" tweet_info_file = "/root/autodl-tmp/TCC_RESTRUCT/result/run_20250709_160942/topic_1/article_judged.json" model = "qwen-plus" api_key = "sk-bd5ee62703bc41fc9b8a55d748dc1eb8" api_url = "https://dashscope.aliyuncs.com/compatible-mode/v1" # 确保输出目录存在 os.makedirs(output_dir, exist_ok=True) tasks = [ (img_dir, output_dir, scenic_info_file, product_info_file, tweet_info_file, model, api_key, api_url), (img_dir, output_dir, scenic_info_file, product_info_file, tweet_info_file, model, api_key, api_url), (img_dir, output_dir, scenic_info_file, product_info_file, tweet_info_file, model, api_key, api_url), (img_dir, output_dir, scenic_info_file, product_info_file, tweet_info_file, model, api_key, api_url), (img_dir, output_dir, scenic_info_file, product_info_file, tweet_info_file, model, api_key, api_url), ] with ThreadPool(5) as pool: # 直接传递任务参数,不需要额外包装 results = pool.starmap( run_async, [(generate_poster, *task_args) for task_args in tasks] ) pool.close() pool.join() print("所有海报生成完成。") if __name__ == "__main__": asyncio.run(main())