TravelContentCreator/scripts/generate_vibrant_poster.py

360 lines
14 KiB
Python
Raw Permalink Normal View History

2025-07-10 11:52:44 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
使用LLM和Vibrant模板从源文件自动生成海报的脚本
"""
import os
import sys
import json
import json_repair
2025-07-10 11:52:44 +08:00
import random
import asyncio
import argparse
import logging
import re
from datetime import datetime
# 将项目根目录添加到Python路径中
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
2025-07-10 17:01:39 +08:00
from poster.templates.vibrant_template import VibrantTemplate
2025-07-10 11:52:44 +08:00
from core.ai.ai_agent import AIAgent
from core.config import AIModelConfig
from multiprocessing.pool import ThreadPool
# 配置日志记录
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def read_data_file(file_path: str):
"""安全地读取并解析各种格式的文件。"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
# 尝试解析为标准JSON
try:
return json_repair.loads(content)
2025-07-10 11:52:44 +08:00
except json.JSONDecodeError:
# 如果不是标准JSON尝试解析为类JSON格式
logger.info(f"文件 {file_path} 不是标准JSON格式尝试其他解析方式...")
# 检查是否是类JSON格式带有大括号的键值对
if content.strip().startswith('{') and content.strip().endswith('}'):
try:
# 尝试将内容解析为字典
result = {}
# 使用正则表达式提取键值对
pattern = r'"([^"]+)":\s*(.+?)(?=,\s*"[^"]+":|$)'
matches = re.findall(pattern, content, re.DOTALL)
for key, value in matches:
# 清理值中的引号和多余空格
value = value.strip()
if value.startswith('"') and value.endswith('"'):
# 字符串值
result[key] = value[1:-1]
elif value.lower() in ['true', 'false']:
# 布尔值
result[key] = value.lower() == 'true'
elif value.replace('.', '', 1).isdigit():
# 数字值
result[key] = float(value) if '.' in value else int(value)
else:
# 其他值作为字符串处理
result[key] = value
if result:
logger.info(f"成功解析文件 {file_path} 为字典")
return result
except Exception as e:
logger.error(f"尝试解析类JSON格式失败: {e}")
# 如果上述方法失败,尝试直接提取文本内容
logger.info(f"尝试将文件 {file_path} 作为纯文本处理")
# 简单解析键值对格式的文本文件
result = {}
lines = content.split('\n')
current_key = None
current_value = []
for line in lines:
line = line.strip()
if not line:
continue
# 检查是否是新的键值对
if ':' in line and not line.startswith(' ') and not line.startswith('\t'):
# 保存之前的键值对
if current_key and current_value:
result[current_key] = '\n'.join(current_value)
current_value = []
# 提取新的键值对
parts = line.split(':', 1)
current_key = parts[0].strip().strip('"').strip("'")
if len(parts) > 1 and parts[1].strip():
current_value = [parts[1].strip()]
else:
current_value = []
elif current_key:
# 继续添加到当前值
current_value.append(line)
# 保存最后一个键值对
if current_key and current_value:
result[current_key] = '\n'.join(current_value)
if result:
logger.info(f"成功将文件 {file_path} 解析为键值对")
return result
# 如果所有方法都失败,返回文本内容
logger.warning(f"无法解析文件 {file_path} 为结构化数据,返回原始文本")
return {"content": content}
except FileNotFoundError:
logger.error(f"文件未找到: {file_path}")
except Exception as e:
logger.error(f"读取文件时发生未知错误 {file_path}: {e}")
return None
async def generate_info_with_llm(ai_agent: AIAgent, scenic_info: dict, product_info: dict, tweet_info: dict):
"""
调用LLM根据输入信息生成海报所需的vibrant_content
Args:
ai_agent: AIAgent实例
scenic_info: 景区信息字典
product_info: 产品信息字典
tweet_info: 推文信息字典
Returns:
一个符合VibrantTemplate格式的内容字典
"""
logger.info("正在调用LLM以生成海报内容...")
# 构建系统提示
system_prompt = """你是一名专业的海报设计师,专门设计宣传海报。你现在要根据用户提供的图片描述和推文内容,生成适合模板的海报信息。
模板特点
- 单图背景毛玻璃渐变效果
- 两栏布局左栏内容右栏价格
- 适合展示套餐内容和价格信息
你需要生成的数据结构包含以下字段
**必填字段**
1. `title`: 主标题8-12字符体现产品特色
2. `slogan`: 副标题/宣传语10-20字符吸引人的描述
3. `price`: 价格数字纯数字不含符号
4. `ticket_type`: 票种类型"成人票""套餐票""夜场票"
5. `content_button`: 内容按钮文字通常为"套餐内容""包含项目"
6. `content_items`: 套餐内容列表3-5个项目每项5-15字符不要只包含项目名称要做合适的美化可以适当省略
**可选字段**
7. `remarks`: 备注信息1-3每条10-20字符
8. `tag`: 标签1, "#限时优惠"""
9. `pagination`: 分页信息"1/3"可为空
**内容创作要求**
1. 套餐内容要具体实用明确说明包含的服务时间数量
2. 价格要有吸引力突出性价比和优惠信息
**输出格式**
请严格按照以下JSON格式输出不要有任何额外内容
```json
{
"title": "海洋奇幻世界",
"slogan": "探索深海秘境,感受蓝色奇迹的无限魅力",
"price": "299",
"ticket_type": "成人票",
"content_button": "套餐内容",
"content_items": [
"海洋馆门票1张含所有展区",
"海豚表演VIP座位",
"鲨鱼隧道特别体验",
"专业摄影服务"
],
"remarks": [
"工作日可直接入园",
"周末需提前预约"
],
"tag": "#海洋特惠",
"pagination": "1/2"
}
```
"""
# 构建用户提示
user_prompt = f"""请根据以下信息,生成适合在旅游海报上展示的文案:
## 景区信息
{json.dumps(scenic_info, ensure_ascii=False, indent=2)}
## 产品信息
{json.dumps(product_info, ensure_ascii=False, indent=2)}
## 推文信息
{json.dumps(tweet_info, ensure_ascii=False, indent=2)}
请提取关键信息并整合成一个JSON对象包含titlesloganpriceticket_typecontent_itemsremarks和tag字段
"""
try:
# 调用AIAgent生成文本
response, _, _, _ = await ai_agent.generate_text(
system_prompt=system_prompt,
user_prompt=user_prompt,
temperature=0.7,
stage="海报文案生成"
)
# 解析JSON响应
try:
# 尝试找到JSON对象的开始和结束位置
json_start = response.find('{')
json_end = response.rfind('}') + 1
if json_start >= 0 and json_end > json_start:
json_str = response[json_start:json_end]
content_dict = json_repair.loads(json_str)
2025-07-10 11:52:44 +08:00
logger.info(f"LLM成功生成内容: {content_dict}")
# 添加默认的按钮文本和分页信息
content_dict["content_button"] = content_dict.get("content_button", "套餐内容")
content_dict["pagination"] = content_dict.get("pagination", "")
# 确保所有值都是字符串类型
for key, value in content_dict.items():
if isinstance(value, (int, float)):
content_dict[key] = str(value)
elif isinstance(value, list):
content_dict[key] = [str(item) if isinstance(item, (int, float)) else item for item in value]
logger.info(f"转换类型后的内容: {content_dict}")
return content_dict
else:
logger.error(f"无法在响应中找到JSON对象: {response}")
return None
except json.JSONDecodeError:
logger.error(f"无法解析LLM响应为JSON: {response}")
return None
except Exception as e:
logger.error(f"调用LLM时发生错误: {e}")
return None
async def generate_poster(img_dir, output_dir, scenic_info_file, product_info_file, tweet_info_file, model, api_key, api_url):
"""
根据提供的参数生成海报
"""
# --- 1. 读取所有源文件 ---
scenic_info = read_data_file(scenic_info_file)
product_info = read_data_file(product_info_file)
tweet_info = read_data_file(tweet_info_file)
if not all([scenic_info, product_info, tweet_info]):
logger.error("一个或多个源文件读取失败,终止执行。")
return
# --- 2. 初始化AI代理 ---
ai_config = AIModelConfig(
model=model,
api_key=api_key,
api_url=api_url
)
ai_agent = AIAgent(ai_config)
# --- 3. 调用LLM生成海报内容 ---
vibrant_content = await generate_info_with_llm(ai_agent, scenic_info, product_info, tweet_info)
if not vibrant_content:
logger.error("未能从LLM获取有效内容终止执行。")
return
# --- 4. 随机选择图片 ---
try:
image_files = [f for f in os.listdir(img_dir) if f.lower().endswith(('png', 'jpg', 'jpeg', 'webp'))]
if not image_files:
logger.error(f"在目录 {img_dir} 中未找到任何图片文件。")
return
random_image_name = random.choice(image_files)
image_path = os.path.join(img_dir, random_image_name)
logger.info(f"随机选择图片: {image_path}")
except FileNotFoundError:
logger.error(f"图片目录不存在: {img_dir}")
return
# --- 5. 调用模板生成海报 ---
try:
logger.info("正在初始化 VibrantTemplate...")
template = VibrantTemplate()
poster = template.generate(image_path=image_path, content=vibrant_content)
if not poster:
logger.error("海报生成失败,模板返回了 None。")
return
except Exception as e:
logger.error(f"生成海报时发生未知错误: {e}", exc_info=True)
return
# --- 6. 保存图片 ---
try:
os.makedirs(output_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
scenic_name = scenic_info.get('主体名称', 'poster')
output_filename = f"vibrant_{scenic_name}_{timestamp}.png"
output_path = os.path.join(output_dir, output_filename)
poster.save(output_path, 'PNG')
logger.info(f"海报已成功生成并保存至: {output_path}")
except Exception as e:
logger.error(f"保存海报失败: {e}", exc_info=True)
def run_async(func, *args): # 同步包装器
return asyncio.run(func(*args))
async def main():
parser = argparse.ArgumentParser(
description="通过LLM从源文件生成信息并使用Vibrant模板创建海报。",
formatter_class=argparse.RawTextHelpFormatter
)
# 设置默认值
img_dir = "/root/autodl-tmp/TCC_RESTRUCT/resource/data/images/天津冒险湾"
output_dir = "result/天津冒险湾"
scenic_info_file = "resource/data/Object/天津冒险湾.txt"
product_info_file = "resource/data/Product/天津冒险湾-2大2小套票.txt"
tweet_info_file = "/root/autodl-tmp/TCC_RESTRUCT/result/run_20250709_160942/topic_1/article_judged.json"
model = "qwen-plus"
api_key = "sk-bd5ee62703bc41fc9b8a55d748dc1eb8"
api_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
# 确保输出目录存在
os.makedirs(output_dir, exist_ok=True)
tasks = [
(img_dir, output_dir, scenic_info_file, product_info_file, tweet_info_file, model, api_key, api_url),
(img_dir, output_dir, scenic_info_file, product_info_file, tweet_info_file, model, api_key, api_url),
(img_dir, output_dir, scenic_info_file, product_info_file, tweet_info_file, model, api_key, api_url),
(img_dir, output_dir, scenic_info_file, product_info_file, tweet_info_file, model, api_key, api_url),
(img_dir, output_dir, scenic_info_file, product_info_file, tweet_info_file, model, api_key, api_url),
]
with ThreadPool(5) as pool:
# 直接传递任务参数,不需要额外包装
results = pool.starmap(
run_async,
[(generate_poster, *task_args) for task_args in tasks]
)
pool.close()
pool.join()
print("所有海报生成完成。")
if __name__ == "__main__":
asyncio.run(main())