TravelContentCreator/scripts/generate_vibrant_poster.py

360 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
使用LLM和Vibrant模板从源文件自动生成海报的脚本。
"""
import os
import sys
import json
import json_repair
import random
import asyncio
import argparse
import logging
import re
from datetime import datetime
# 将项目根目录添加到Python路径中
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from poster.templates.vibrant_template import VibrantTemplate
from core.ai.ai_agent import AIAgent
from core.config import AIModelConfig
from multiprocessing.pool import ThreadPool
# 配置日志记录
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def read_data_file(file_path: str):
"""安全地读取并解析各种格式的文件。"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
# 尝试解析为标准JSON
try:
return json_repair.loads(content)
except json.JSONDecodeError:
# 如果不是标准JSON尝试解析为类JSON格式
logger.info(f"文件 {file_path} 不是标准JSON格式尝试其他解析方式...")
# 检查是否是类JSON格式带有大括号的键值对
if content.strip().startswith('{') and content.strip().endswith('}'):
try:
# 尝试将内容解析为字典
result = {}
# 使用正则表达式提取键值对
pattern = r'"([^"]+)":\s*(.+?)(?=,\s*"[^"]+":|$)'
matches = re.findall(pattern, content, re.DOTALL)
for key, value in matches:
# 清理值中的引号和多余空格
value = value.strip()
if value.startswith('"') and value.endswith('"'):
# 字符串值
result[key] = value[1:-1]
elif value.lower() in ['true', 'false']:
# 布尔值
result[key] = value.lower() == 'true'
elif value.replace('.', '', 1).isdigit():
# 数字值
result[key] = float(value) if '.' in value else int(value)
else:
# 其他值作为字符串处理
result[key] = value
if result:
logger.info(f"成功解析文件 {file_path} 为字典")
return result
except Exception as e:
logger.error(f"尝试解析类JSON格式失败: {e}")
# 如果上述方法失败,尝试直接提取文本内容
logger.info(f"尝试将文件 {file_path} 作为纯文本处理")
# 简单解析键值对格式的文本文件
result = {}
lines = content.split('\n')
current_key = None
current_value = []
for line in lines:
line = line.strip()
if not line:
continue
# 检查是否是新的键值对
if ':' in line and not line.startswith(' ') and not line.startswith('\t'):
# 保存之前的键值对
if current_key and current_value:
result[current_key] = '\n'.join(current_value)
current_value = []
# 提取新的键值对
parts = line.split(':', 1)
current_key = parts[0].strip().strip('"').strip("'")
if len(parts) > 1 and parts[1].strip():
current_value = [parts[1].strip()]
else:
current_value = []
elif current_key:
# 继续添加到当前值
current_value.append(line)
# 保存最后一个键值对
if current_key and current_value:
result[current_key] = '\n'.join(current_value)
if result:
logger.info(f"成功将文件 {file_path} 解析为键值对")
return result
# 如果所有方法都失败,返回文本内容
logger.warning(f"无法解析文件 {file_path} 为结构化数据,返回原始文本")
return {"content": content}
except FileNotFoundError:
logger.error(f"文件未找到: {file_path}")
except Exception as e:
logger.error(f"读取文件时发生未知错误 {file_path}: {e}")
return None
async def generate_info_with_llm(ai_agent: AIAgent, scenic_info: dict, product_info: dict, tweet_info: dict):
"""
调用LLM根据输入信息生成海报所需的vibrant_content。
Args:
ai_agent: AIAgent实例
scenic_info: 景区信息字典
product_info: 产品信息字典
tweet_info: 推文信息字典
Returns:
一个符合VibrantTemplate格式的内容字典
"""
logger.info("正在调用LLM以生成海报内容...")
# 构建系统提示
system_prompt = """你是一名专业的海报设计师,专门设计宣传海报。你现在要根据用户提供的图片描述和推文内容,生成适合模板的海报信息。
模板特点:
- 单图背景,毛玻璃渐变效果
- 两栏布局(左栏内容,右栏价格)
- 适合展示套餐内容和价格信息
你需要生成的数据结构包含以下字段:
**必填字段:**
1. `title`: 主标题8-12字符体现产品特色
2. `slogan`: 副标题/宣传语10-20字符吸引人的描述
3. `price`: 价格数字(纯数字,不含符号)
4. `ticket_type`: 票种类型(如"成人票""套餐票""夜场票"等)
5. `content_button`: 内容按钮文字(通常为"套餐内容""包含项目"等)
6. `content_items`: 套餐内容列表3-5个项目每项5-15字符不要只包含项目名称要做合适的美化可以适当省略
**可选字段:**
7. `remarks`: 备注信息1-3条每条10-20字符
8. `tag`: 标签1条, 如"#限时优惠"""等)
9. `pagination`: 分页信息(如"1/3",可为空)
**内容创作要求:**
1. 套餐内容要具体实用:明确说明包含的服务、时间、数量
2. 价格要有吸引力:突出性价比和优惠信息
**输出格式:**
请严格按照以下JSON格式输出不要有任何额外内容
```json
{
"title": "海洋奇幻世界",
"slogan": "探索深海秘境,感受蓝色奇迹的无限魅力",
"price": "299",
"ticket_type": "成人票",
"content_button": "套餐内容",
"content_items": [
"海洋馆门票1张含所有展区",
"海豚表演VIP座位",
"鲨鱼隧道特别体验",
"专业摄影服务"
],
"remarks": [
"工作日可直接入园",
"周末需提前预约"
],
"tag": "#海洋特惠",
"pagination": "1/2"
}
```
"""
# 构建用户提示
user_prompt = f"""请根据以下信息,生成适合在旅游海报上展示的文案:
## 景区信息
{json.dumps(scenic_info, ensure_ascii=False, indent=2)}
## 产品信息
{json.dumps(product_info, ensure_ascii=False, indent=2)}
## 推文信息
{json.dumps(tweet_info, ensure_ascii=False, indent=2)}
请提取关键信息并整合成一个JSON对象包含title、slogan、price、ticket_type、content_items、remarks和tag字段。
"""
try:
# 调用AIAgent生成文本
response, _, _, _ = await ai_agent.generate_text(
system_prompt=system_prompt,
user_prompt=user_prompt,
temperature=0.7,
stage="海报文案生成"
)
# 解析JSON响应
try:
# 尝试找到JSON对象的开始和结束位置
json_start = response.find('{')
json_end = response.rfind('}') + 1
if json_start >= 0 and json_end > json_start:
json_str = response[json_start:json_end]
content_dict = json_repair.loads(json_str)
logger.info(f"LLM成功生成内容: {content_dict}")
# 添加默认的按钮文本和分页信息
content_dict["content_button"] = content_dict.get("content_button", "套餐内容")
content_dict["pagination"] = content_dict.get("pagination", "")
# 确保所有值都是字符串类型
for key, value in content_dict.items():
if isinstance(value, (int, float)):
content_dict[key] = str(value)
elif isinstance(value, list):
content_dict[key] = [str(item) if isinstance(item, (int, float)) else item for item in value]
logger.info(f"转换类型后的内容: {content_dict}")
return content_dict
else:
logger.error(f"无法在响应中找到JSON对象: {response}")
return None
except json.JSONDecodeError:
logger.error(f"无法解析LLM响应为JSON: {response}")
return None
except Exception as e:
logger.error(f"调用LLM时发生错误: {e}")
return None
async def generate_poster(img_dir, output_dir, scenic_info_file, product_info_file, tweet_info_file, model, api_key, api_url):
"""
根据提供的参数生成海报。
"""
# --- 1. 读取所有源文件 ---
scenic_info = read_data_file(scenic_info_file)
product_info = read_data_file(product_info_file)
tweet_info = read_data_file(tweet_info_file)
if not all([scenic_info, product_info, tweet_info]):
logger.error("一个或多个源文件读取失败,终止执行。")
return
# --- 2. 初始化AI代理 ---
ai_config = AIModelConfig(
model=model,
api_key=api_key,
api_url=api_url
)
ai_agent = AIAgent(ai_config)
# --- 3. 调用LLM生成海报内容 ---
vibrant_content = await generate_info_with_llm(ai_agent, scenic_info, product_info, tweet_info)
if not vibrant_content:
logger.error("未能从LLM获取有效内容终止执行。")
return
# --- 4. 随机选择图片 ---
try:
image_files = [f for f in os.listdir(img_dir) if f.lower().endswith(('png', 'jpg', 'jpeg', 'webp'))]
if not image_files:
logger.error(f"在目录 {img_dir} 中未找到任何图片文件。")
return
random_image_name = random.choice(image_files)
image_path = os.path.join(img_dir, random_image_name)
logger.info(f"随机选择图片: {image_path}")
except FileNotFoundError:
logger.error(f"图片目录不存在: {img_dir}")
return
# --- 5. 调用模板生成海报 ---
try:
logger.info("正在初始化 VibrantTemplate...")
template = VibrantTemplate()
poster = template.generate(image_path=image_path, content=vibrant_content)
if not poster:
logger.error("海报生成失败,模板返回了 None。")
return
except Exception as e:
logger.error(f"生成海报时发生未知错误: {e}", exc_info=True)
return
# --- 6. 保存图片 ---
try:
os.makedirs(output_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
scenic_name = scenic_info.get('主体名称', 'poster')
output_filename = f"vibrant_{scenic_name}_{timestamp}.png"
output_path = os.path.join(output_dir, output_filename)
poster.save(output_path, 'PNG')
logger.info(f"海报已成功生成并保存至: {output_path}")
except Exception as e:
logger.error(f"保存海报失败: {e}", exc_info=True)
def run_async(func, *args): # 同步包装器
return asyncio.run(func(*args))
async def main():
parser = argparse.ArgumentParser(
description="通过LLM从源文件生成信息并使用Vibrant模板创建海报。",
formatter_class=argparse.RawTextHelpFormatter
)
# 设置默认值
img_dir = "/root/autodl-tmp/TCC_RESTRUCT/resource/data/images/天津冒险湾"
output_dir = "result/天津冒险湾"
scenic_info_file = "resource/data/Object/天津冒险湾.txt"
product_info_file = "resource/data/Product/天津冒险湾-2大2小套票.txt"
tweet_info_file = "/root/autodl-tmp/TCC_RESTRUCT/result/run_20250709_160942/topic_1/article_judged.json"
model = "qwen-plus"
api_key = "sk-bd5ee62703bc41fc9b8a55d748dc1eb8"
api_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
# 确保输出目录存在
os.makedirs(output_dir, exist_ok=True)
tasks = [
(img_dir, output_dir, scenic_info_file, product_info_file, tweet_info_file, model, api_key, api_url),
(img_dir, output_dir, scenic_info_file, product_info_file, tweet_info_file, model, api_key, api_url),
(img_dir, output_dir, scenic_info_file, product_info_file, tweet_info_file, model, api_key, api_url),
(img_dir, output_dir, scenic_info_file, product_info_file, tweet_info_file, model, api_key, api_url),
(img_dir, output_dir, scenic_info_file, product_info_file, tweet_info_file, model, api_key, api_url),
]
with ThreadPool(5) as pool:
# 直接传递任务参数,不需要额外包装
results = pool.starmap(
run_async,
[(generate_poster, *task_args) for task_args in tasks]
)
pool.close()
pool.join()
print("所有海报生成完成。")
if __name__ == "__main__":
asyncio.run(main())