TravelContentCreator/scripts/test_poster_gen.py

202 lines
6.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import sys
import argparse
from PIL import Image
import random
import json
from datetime import datetime
# 导入海报生成模块
from core.poster_gen import PosterGenerator, PosterConfig
def parse_arguments():
"""解析命令行参数"""
parser = argparse.ArgumentParser(description='测试海报生成模块')
# 必需参数
parser.add_argument('--assets-dir', type=str, required=True,
help='海报资源目录包含frames, font, stickers等子目录')
# 可选参数
parser.add_argument('--image', type=str, default=None,
help='输入图片路径(单张图片测试)')
parser.add_argument('--image-dir', type=str, default=None,
help='输入图片目录(批量测试)')
parser.add_argument('--config', type=str, default=None,
help='海报配置文件路径')
parser.add_argument('--output-dir', type=str, default='poster_test_output',
help='输出目录')
parser.add_argument('--count', type=int, default=1,
help='生成海报数量')
parser.add_argument('--frame-possibility', type=float, default=0.7,
help='添加边框的概率')
parser.add_argument('--text-bg-possibility', type=float, default=0.0,
help='添加文本背景的概率')
return parser.parse_args()
def create_sample_config():
"""创建示例配置"""
return [
{
"index": 0,
"main_title": "美丽的海滩度假胜地",
"texts": [
"阳光沙滩,椰林树影,尽享热带风情",
"给自己一次难忘的海岛旅行"
]
},
{
"index": 1,
"main_title": "壮丽的山川河流",
"texts": [
"峰峦叠嶂,溪水潺潺,领略自然之美",
"放慢脚步,感受大自然的馈赠"
]
},
{
"index": 2,
"main_title": "古老的城市文化",
"texts": [
"千年古迹,历史遗存,品味人文气息",
"穿越时空,探寻文明的足迹"
]
}
]
def load_config(config_path):
"""加载海报配置"""
if not config_path:
return create_sample_config()
try:
if os.path.exists(config_path):
with open(config_path, 'r', encoding='utf-8') as f:
return json.load(f)
else:
print(f"配置文件不存在: {config_path},使用默认配置")
return create_sample_config()
except Exception as e:
print(f"加载配置失败: {e},使用默认配置")
return create_sample_config()
def get_image_files(image_dir):
"""获取目录中的图片文件"""
if not os.path.exists(image_dir):
print(f"图片目录不存在: {image_dir}")
return []
extensions = ('.jpg', '.jpeg', '.png')
return [os.path.join(image_dir, f) for f in os.listdir(image_dir)
if f.lower().endswith(extensions)]
def main():
# 解析命令行参数
args = parse_arguments()
# 确保资源目录存在
if not os.path.exists(args.assets_dir):
print(f"错误: 海报资源目录不存在: {args.assets_dir}")
return
# 创建输出目录
os.makedirs(args.output_dir, exist_ok=True)
# 加载配置
configs = load_config(args.config)
print(f"加载了 {len(configs)} 个海报配置")
# 创建海报生成器
poster_generator = PosterGenerator(args.assets_dir, args.output_dir)
poster_generator.set_img_frame_possibility(args.frame_possibility)
poster_generator.set_text_bg_possibility(args.text_bg_possibility)
# 环境设置
os.environ['USE_TEXT_BG'] = 'True' if args.text_bg_possibility > 0 else 'False'
# 获取图片列表
image_list = []
if args.image:
if os.path.exists(args.image):
image_list = [args.image]
else:
print(f"错误: 指定的图片不存在: {args.image}")
return
elif args.image_dir:
image_list = get_image_files(args.image_dir)
if not image_list:
print(f"错误: 在目录中没有找到图片: {args.image_dir}")
return
else:
print("错误: 必须提供单张图片路径或图片目录")
return
print(f"找到 {len(image_list)} 张图片用于测试")
# 生成海报
generated_count = 0
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
for i in range(args.count):
# 选择图片
if not image_list:
print("错误: 没有可用的图片")
break
image_path = random.choice(image_list)
# 选择配置
config = random.choice(configs)
# 构建文本数据
text_data = {
"title": config["main_title"],
"subtitle": "",
"additional_texts": []
}
# 添加额外文本
if "texts" in config and config["texts"]:
for idx, text in enumerate(config["texts"]):
if text:
position = "bottom" if idx > 0 else "middle"
text_data["additional_texts"].append({
"text": text,
"position": position,
"size_factor": 0.6 if idx > 0 else 0.8
})
try:
# 加载图片
print(f"\n生成海报 {i+1}/{args.count}")
print(f"使用图片: {image_path}")
print(f"标题: {text_data['title']}")
# 尝试先加载PIL图像对象
img = Image.open(image_path).convert('RGBA')
# 生成海报
poster = poster_generator.create_poster(img, text_data)
if poster:
# 保存结果
output_filename = f"poster_{timestamp}_{i+1}.jpg"
output_path = os.path.join(args.output_dir, output_filename)
poster.save(output_path)
print(f"成功生成海报: {output_path}")
generated_count += 1
else:
print(f"生成海报失败")
except Exception as e:
print(f"处理图片时出错: {e}")
import traceback
traceback.print_exc()
print(f"\n测试完成! 成功生成 {generated_count}/{args.count} 张海报")
print(f"输出目录: {args.output_dir}")
if __name__ == "__main__":
main()