TravelContentCreator/examples/test_poster_notes.py

226 lines
7.2 KiB
Python
Raw Normal View History

2025-04-26 14:53:54 +08:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import sys
import logging
import argparse
import json
import random
from pathlib import Path
from datetime import datetime
from PIL import Image
# 添加项目根目录到路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.output_handler import OutputHandler
from utils.poster_notes_creator import process_poster_for_notes
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class SimpleOutputHandler(OutputHandler):
"""简单的输出处理器,用于测试目的"""
def __init__(self, output_dir):
"""初始化输出处理器
Args:
output_dir: 输出目录路径
"""
self.output_dir = output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"已初始化输出处理器,输出目录: {output_dir}")
def handle_generated_image(self, run_id, topic_index, variant_index,
image_type, image_data, output_filename, metadata=None):
"""处理生成的图像
Args:
run_id: 运行ID
topic_index: 主题索引
variant_index: 变体索引
image_type: 图像类型
image_data: 图像数据
output_filename: 输出文件名
metadata: 图像元数据
Returns:
str: 保存的图像路径
"""
# 创建目录结构
output_dir = ""
if image_type == 'note':
# 笔记图像保存在image目录
output_dir = os.path.join(self.output_dir, f"{run_id}/{topic_index}_{variant_index}/image")
else:
# 其他类型图像保存在各自的类型目录
output_dir = os.path.join(self.output_dir, image_type)
os.makedirs(output_dir, exist_ok=True)
# 保存图像
image_path = os.path.join(output_dir, output_filename)
image_data.save(image_path)
# 如果有元数据,保存它
if metadata:
metadata_path = os.path.splitext(image_path)[0] + '.json'
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, ensure_ascii=False, indent=2)
logger.info(f"已保存图像: {image_path}")
return image_path
def handle_poster_configs(self, run_id, topic_index, variant_index,
poster_configs, output_filename):
"""处理海报配置
Args:
run_id: 运行ID
topic_index: 主题索引
variant_index: 变体索引
poster_configs: 海报配置
output_filename: 输出文件名
Returns:
str: 保存的配置路径
"""
# 创建目录
config_dir = os.path.join(self.output_dir, 'configs')
os.makedirs(config_dir, exist_ok=True)
# 保存配置
config_path = os.path.join(config_dir, output_filename)
with open(config_path, 'w', encoding='utf-8') as f:
json.dump(poster_configs, f, ensure_ascii=False, indent=2)
logger.info(f"已保存配置: {config_path}")
return config_path
def create_sample_metadata(image_dir, num_images=3):
"""创建示例海报元数据
Args:
image_dir: 图像目录
num_images: 使用的图像数量
Returns:
dict: 海报元数据
"""
# 获取目录中的图像
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp')
available_images = [
f for f in os.listdir(image_dir)
if os.path.isfile(os.path.join(image_dir, f)) and
f.lower().endswith(image_extensions)
]
if len(available_images) < num_images:
logger.warning(f"可用图像数量({len(available_images)})少于请求数量({num_images}),将使用所有可用图像")
selected_images = available_images
else:
selected_images = random.sample(available_images, num_images)
# 创建元数据
metadata = {
"title": "示例海报",
"description": "这是一个用于测试的示例海报",
"collage_images": selected_images,
"generation_time": datetime.now().isoformat(),
"style": "modern"
}
return metadata
def main():
# 解析命令行参数
parser = argparse.ArgumentParser(description='测试海报笔记创建器')
parser.add_argument('--image_dir', type=str, required=True,
help='包含图像的目录路径')
parser.add_argument('--output_dir', type=str, default='./output',
help='输出目录路径')
parser.add_argument('--num_notes', type=int, default=3,
help='要创建的笔记数量')
args = parser.parse_args()
# 验证图像目录
if not os.path.exists(args.image_dir) or not os.path.isdir(args.image_dir):
logger.error(f"图像目录不存在: {args.image_dir}")
return 1
# 创建输出目录
os.makedirs(args.output_dir, exist_ok=True)
# 初始化输出处理器
output_handler = SimpleOutputHandler(args.output_dir)
# 创建一个示例海报和元数据
run_id = datetime.now().strftime("%Y%m%d_%H%M%S")
topic_index = 0
variant_index = 0
# 从图像目录中选择一张图像作为"海报"
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp')
available_images = [
f for f in os.listdir(args.image_dir)
if os.path.isfile(os.path.join(args.image_dir, f)) and
f.lower().endswith(image_extensions)
]
if not available_images:
logger.error(f"图像目录中没有找到图像: {args.image_dir}")
return 1
# 选择第一张图像作为海报
poster_image_name = available_images[0]
poster_image_path = os.path.join(args.image_dir, poster_image_name)
# 创建示例元数据
poster_metadata = create_sample_metadata(args.image_dir, 2)
# 保存海报图像
poster_image = Image.open(poster_image_path)
saved_poster_path = output_handler.handle_generated_image(
run_id,
topic_index,
variant_index,
'poster',
poster_image,
'sample_poster.jpg',
poster_metadata
)
# 保存海报元数据
poster_metadata_path = os.path.splitext(saved_poster_path)[0] + '.json'
logger.info(f"已创建示例海报: {saved_poster_path}")
logger.info(f"元数据路径: {poster_metadata_path}")
# 处理海报笔记
logger.info("开始创建海报笔记...")
note_paths = process_poster_for_notes(
run_id,
topic_index,
variant_index,
saved_poster_path,
poster_metadata_path,
args.image_dir,
args.num_notes,
output_handler
)
logger.info(f"创建了 {len(note_paths)} 个笔记图像:")
for i, path in enumerate(note_paths):
logger.info(f" {i+1}. {path}")
return 0
if __name__ == "__main__":
sys.exit(main())