360 lines
14 KiB
Python
360 lines
14 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
使用LLM和Vibrant模板从源文件自动生成海报的脚本。
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import json
|
||
import json_repair
|
||
import random
|
||
import asyncio
|
||
import argparse
|
||
import logging
|
||
import re
|
||
from datetime import datetime
|
||
# 将项目根目录添加到Python路径中
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from utils.poster.templates.vibrant_template import VibrantTemplate
|
||
from core.ai.ai_agent import AIAgent
|
||
from core.config import AIModelConfig
|
||
from multiprocessing.pool import ThreadPool
|
||
# 配置日志记录
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||
logger = logging.getLogger(__name__)
|
||
|
||
def read_data_file(file_path: str):
|
||
"""安全地读取并解析各种格式的文件。"""
|
||
try:
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
|
||
# 尝试解析为标准JSON
|
||
try:
|
||
return json_repair.loads(content)
|
||
except json.JSONDecodeError:
|
||
# 如果不是标准JSON,尝试解析为类JSON格式
|
||
logger.info(f"文件 {file_path} 不是标准JSON格式,尝试其他解析方式...")
|
||
|
||
# 检查是否是类JSON格式(带有大括号的键值对)
|
||
if content.strip().startswith('{') and content.strip().endswith('}'):
|
||
try:
|
||
# 尝试将内容解析为字典
|
||
result = {}
|
||
# 使用正则表达式提取键值对
|
||
pattern = r'"([^"]+)":\s*(.+?)(?=,\s*"[^"]+":|$)'
|
||
matches = re.findall(pattern, content, re.DOTALL)
|
||
|
||
for key, value in matches:
|
||
# 清理值中的引号和多余空格
|
||
value = value.strip()
|
||
if value.startswith('"') and value.endswith('"'):
|
||
# 字符串值
|
||
result[key] = value[1:-1]
|
||
elif value.lower() in ['true', 'false']:
|
||
# 布尔值
|
||
result[key] = value.lower() == 'true'
|
||
elif value.replace('.', '', 1).isdigit():
|
||
# 数字值
|
||
result[key] = float(value) if '.' in value else int(value)
|
||
else:
|
||
# 其他值作为字符串处理
|
||
result[key] = value
|
||
|
||
if result:
|
||
logger.info(f"成功解析文件 {file_path} 为字典")
|
||
return result
|
||
except Exception as e:
|
||
logger.error(f"尝试解析类JSON格式失败: {e}")
|
||
|
||
# 如果上述方法失败,尝试直接提取文本内容
|
||
logger.info(f"尝试将文件 {file_path} 作为纯文本处理")
|
||
|
||
# 简单解析键值对格式的文本文件
|
||
result = {}
|
||
lines = content.split('\n')
|
||
current_key = None
|
||
current_value = []
|
||
|
||
for line in lines:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
|
||
# 检查是否是新的键值对
|
||
if ':' in line and not line.startswith(' ') and not line.startswith('\t'):
|
||
# 保存之前的键值对
|
||
if current_key and current_value:
|
||
result[current_key] = '\n'.join(current_value)
|
||
current_value = []
|
||
|
||
# 提取新的键值对
|
||
parts = line.split(':', 1)
|
||
current_key = parts[0].strip().strip('"').strip("'")
|
||
if len(parts) > 1 and parts[1].strip():
|
||
current_value = [parts[1].strip()]
|
||
else:
|
||
current_value = []
|
||
elif current_key:
|
||
# 继续添加到当前值
|
||
current_value.append(line)
|
||
|
||
# 保存最后一个键值对
|
||
if current_key and current_value:
|
||
result[current_key] = '\n'.join(current_value)
|
||
|
||
if result:
|
||
logger.info(f"成功将文件 {file_path} 解析为键值对")
|
||
return result
|
||
|
||
# 如果所有方法都失败,返回文本内容
|
||
logger.warning(f"无法解析文件 {file_path} 为结构化数据,返回原始文本")
|
||
return {"content": content}
|
||
|
||
except FileNotFoundError:
|
||
logger.error(f"文件未找到: {file_path}")
|
||
except Exception as e:
|
||
logger.error(f"读取文件时发生未知错误 {file_path}: {e}")
|
||
return None
|
||
|
||
async def generate_info_with_llm(ai_agent: AIAgent, scenic_info: dict, product_info: dict, tweet_info: dict):
|
||
"""
|
||
调用LLM,根据输入信息生成海报所需的vibrant_content。
|
||
|
||
Args:
|
||
ai_agent: AIAgent实例
|
||
scenic_info: 景区信息字典
|
||
product_info: 产品信息字典
|
||
tweet_info: 推文信息字典
|
||
|
||
Returns:
|
||
一个符合VibrantTemplate格式的内容字典
|
||
"""
|
||
logger.info("正在调用LLM以生成海报内容...")
|
||
|
||
# 构建系统提示
|
||
system_prompt = """你是一名专业的海报设计师,专门设计宣传海报。你现在要根据用户提供的图片描述和推文内容,生成适合模板的海报信息。
|
||
|
||
模板特点:
|
||
- 单图背景,毛玻璃渐变效果
|
||
- 两栏布局(左栏内容,右栏价格)
|
||
- 适合展示套餐内容和价格信息
|
||
|
||
你需要生成的数据结构包含以下字段:
|
||
|
||
**必填字段:**
|
||
1. `title`: 主标题(8-12字符,体现产品特色)
|
||
2. `slogan`: 副标题/宣传语(10-20字符,吸引人的描述)
|
||
3. `price`: 价格数字(纯数字,不含符号)
|
||
4. `ticket_type`: 票种类型(如"成人票"、"套餐票"、"夜场票"等)
|
||
5. `content_button`: 内容按钮文字(通常为"套餐内容"、"包含项目"等)
|
||
6. `content_items`: 套餐内容列表(3-5个项目,每项5-15字符,不要只包含项目名称,要做合适的美化,可以适当省略)
|
||
|
||
**可选字段:**
|
||
7. `remarks`: 备注信息(1-3条,每条10-20字符)
|
||
8. `tag`: 标签(1条, 如"#限时优惠"、""等)
|
||
9. `pagination`: 分页信息(如"1/3",可为空)
|
||
|
||
**内容创作要求:**
|
||
1. 套餐内容要具体实用:明确说明包含的服务、时间、数量
|
||
2. 价格要有吸引力:突出性价比和优惠信息
|
||
|
||
**输出格式:**
|
||
请严格按照以下JSON格式输出,不要有任何额外内容:
|
||
|
||
```json
|
||
{
|
||
"title": "海洋奇幻世界",
|
||
"slogan": "探索深海秘境,感受蓝色奇迹的无限魅力",
|
||
"price": "299",
|
||
"ticket_type": "成人票",
|
||
"content_button": "套餐内容",
|
||
"content_items": [
|
||
"海洋馆门票1张(含所有展区)",
|
||
"海豚表演VIP座位",
|
||
"鲨鱼隧道特别体验",
|
||
"专业摄影服务"
|
||
],
|
||
"remarks": [
|
||
"工作日可直接入园",
|
||
"周末需提前预约"
|
||
],
|
||
"tag": "#海洋特惠",
|
||
"pagination": "1/2"
|
||
}
|
||
```
|
||
"""
|
||
|
||
# 构建用户提示
|
||
user_prompt = f"""请根据以下信息,生成适合在旅游海报上展示的文案:
|
||
|
||
## 景区信息
|
||
{json.dumps(scenic_info, ensure_ascii=False, indent=2)}
|
||
|
||
## 产品信息
|
||
{json.dumps(product_info, ensure_ascii=False, indent=2)}
|
||
|
||
## 推文信息
|
||
{json.dumps(tweet_info, ensure_ascii=False, indent=2)}
|
||
|
||
请提取关键信息并整合成一个JSON对象,包含title、slogan、price、ticket_type、content_items、remarks和tag字段。
|
||
"""
|
||
|
||
try:
|
||
# 调用AIAgent生成文本
|
||
response, _, _, _ = await ai_agent.generate_text(
|
||
system_prompt=system_prompt,
|
||
user_prompt=user_prompt,
|
||
temperature=0.7,
|
||
stage="海报文案生成"
|
||
)
|
||
|
||
# 解析JSON响应
|
||
try:
|
||
# 尝试找到JSON对象的开始和结束位置
|
||
json_start = response.find('{')
|
||
json_end = response.rfind('}') + 1
|
||
|
||
if json_start >= 0 and json_end > json_start:
|
||
json_str = response[json_start:json_end]
|
||
content_dict = json_repair.loads(json_str)
|
||
logger.info(f"LLM成功生成内容: {content_dict}")
|
||
|
||
# 添加默认的按钮文本和分页信息
|
||
content_dict["content_button"] = content_dict.get("content_button", "套餐内容")
|
||
content_dict["pagination"] = content_dict.get("pagination", "")
|
||
|
||
# 确保所有值都是字符串类型
|
||
for key, value in content_dict.items():
|
||
if isinstance(value, (int, float)):
|
||
content_dict[key] = str(value)
|
||
elif isinstance(value, list):
|
||
content_dict[key] = [str(item) if isinstance(item, (int, float)) else item for item in value]
|
||
|
||
logger.info(f"转换类型后的内容: {content_dict}")
|
||
return content_dict
|
||
else:
|
||
logger.error(f"无法在响应中找到JSON对象: {response}")
|
||
return None
|
||
except json.JSONDecodeError:
|
||
logger.error(f"无法解析LLM响应为JSON: {response}")
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"调用LLM时发生错误: {e}")
|
||
return None
|
||
|
||
async def generate_poster(img_dir, output_dir, scenic_info_file, product_info_file, tweet_info_file, model, api_key, api_url):
|
||
"""
|
||
根据提供的参数生成海报。
|
||
"""
|
||
# --- 1. 读取所有源文件 ---
|
||
scenic_info = read_data_file(scenic_info_file)
|
||
product_info = read_data_file(product_info_file)
|
||
tweet_info = read_data_file(tweet_info_file)
|
||
|
||
if not all([scenic_info, product_info, tweet_info]):
|
||
logger.error("一个或多个源文件读取失败,终止执行。")
|
||
return
|
||
|
||
# --- 2. 初始化AI代理 ---
|
||
ai_config = AIModelConfig(
|
||
model=model,
|
||
api_key=api_key,
|
||
api_url=api_url
|
||
)
|
||
ai_agent = AIAgent(ai_config)
|
||
|
||
# --- 3. 调用LLM生成海报内容 ---
|
||
vibrant_content = await generate_info_with_llm(ai_agent, scenic_info, product_info, tweet_info)
|
||
if not vibrant_content:
|
||
logger.error("未能从LLM获取有效内容,终止执行。")
|
||
return
|
||
|
||
# --- 4. 随机选择图片 ---
|
||
try:
|
||
image_files = [f for f in os.listdir(img_dir) if f.lower().endswith(('png', 'jpg', 'jpeg', 'webp'))]
|
||
if not image_files:
|
||
logger.error(f"在目录 {img_dir} 中未找到任何图片文件。")
|
||
return
|
||
random_image_name = random.choice(image_files)
|
||
image_path = os.path.join(img_dir, random_image_name)
|
||
logger.info(f"随机选择图片: {image_path}")
|
||
except FileNotFoundError:
|
||
logger.error(f"图片目录不存在: {img_dir}")
|
||
return
|
||
|
||
# --- 5. 调用模板生成海报 ---
|
||
try:
|
||
logger.info("正在初始化 VibrantTemplate...")
|
||
template = VibrantTemplate()
|
||
poster = template.generate(image_path=image_path, content=vibrant_content)
|
||
|
||
if not poster:
|
||
logger.error("海报生成失败,模板返回了 None。")
|
||
return
|
||
|
||
except Exception as e:
|
||
logger.error(f"生成海报时发生未知错误: {e}", exc_info=True)
|
||
return
|
||
|
||
# --- 6. 保存图片 ---
|
||
try:
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
scenic_name = scenic_info.get('主体名称', 'poster')
|
||
output_filename = f"vibrant_{scenic_name}_{timestamp}.png"
|
||
output_path = os.path.join(output_dir, output_filename)
|
||
|
||
poster.save(output_path, 'PNG')
|
||
logger.info(f"海报已成功生成并保存至: {output_path}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"保存海报失败: {e}", exc_info=True)
|
||
|
||
def run_async(func, *args): # 同步包装器
|
||
return asyncio.run(func(*args))
|
||
|
||
async def main():
|
||
parser = argparse.ArgumentParser(
|
||
description="通过LLM从源文件生成信息,并使用Vibrant模板创建海报。",
|
||
formatter_class=argparse.RawTextHelpFormatter
|
||
)
|
||
|
||
# 设置默认值
|
||
img_dir = "/root/autodl-tmp/TCC_RESTRUCT/resource/data/images/天津冒险湾"
|
||
output_dir = "result/天津冒险湾"
|
||
scenic_info_file = "resource/data/Object/天津冒险湾.txt"
|
||
product_info_file = "resource/data/Product/天津冒险湾-2大2小套票.txt"
|
||
tweet_info_file = "/root/autodl-tmp/TCC_RESTRUCT/result/run_20250709_160942/topic_1/article_judged.json"
|
||
|
||
model = "qwen-plus"
|
||
api_key = "sk-bd5ee62703bc41fc9b8a55d748dc1eb8"
|
||
api_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||
|
||
# 确保输出目录存在
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
tasks = [
|
||
(img_dir, output_dir, scenic_info_file, product_info_file, tweet_info_file, model, api_key, api_url),
|
||
(img_dir, output_dir, scenic_info_file, product_info_file, tweet_info_file, model, api_key, api_url),
|
||
(img_dir, output_dir, scenic_info_file, product_info_file, tweet_info_file, model, api_key, api_url),
|
||
(img_dir, output_dir, scenic_info_file, product_info_file, tweet_info_file, model, api_key, api_url),
|
||
(img_dir, output_dir, scenic_info_file, product_info_file, tweet_info_file, model, api_key, api_url),
|
||
]
|
||
with ThreadPool(5) as pool:
|
||
# 直接传递任务参数,不需要额外包装
|
||
results = pool.starmap(
|
||
run_async,
|
||
[(generate_poster, *task_args) for task_args in tasks]
|
||
)
|
||
pool.close()
|
||
pool.join()
|
||
|
||
print("所有海报生成完成。")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main()) |