226 lines
7.2 KiB
Python
226 lines
7.2 KiB
Python
|
|
#!/usr/bin/env python
|
||
|
|
# -*- coding: utf-8 -*-
|
||
|
|
|
||
|
|
import os
|
||
|
|
import sys
|
||
|
|
import logging
|
||
|
|
import argparse
|
||
|
|
import json
|
||
|
|
import random
|
||
|
|
from pathlib import Path
|
||
|
|
from datetime import datetime
|
||
|
|
from PIL import Image
|
||
|
|
|
||
|
|
# 添加项目根目录到路径
|
||
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
|
|
||
|
|
from utils.output_handler import OutputHandler
|
||
|
|
from utils.poster_notes_creator import process_poster_for_notes
|
||
|
|
|
||
|
|
# 配置日志
|
||
|
|
logging.basicConfig(
|
||
|
|
level=logging.INFO,
|
||
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||
|
|
)
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
class SimpleOutputHandler(OutputHandler):
|
||
|
|
"""简单的输出处理器,用于测试目的"""
|
||
|
|
|
||
|
|
def __init__(self, output_dir):
|
||
|
|
"""初始化输出处理器
|
||
|
|
|
||
|
|
Args:
|
||
|
|
output_dir: 输出目录路径
|
||
|
|
"""
|
||
|
|
self.output_dir = output_dir
|
||
|
|
os.makedirs(output_dir, exist_ok=True)
|
||
|
|
logger.info(f"已初始化输出处理器,输出目录: {output_dir}")
|
||
|
|
|
||
|
|
def handle_generated_image(self, run_id, topic_index, variant_index,
|
||
|
|
image_type, image_data, output_filename, metadata=None):
|
||
|
|
"""处理生成的图像
|
||
|
|
|
||
|
|
Args:
|
||
|
|
run_id: 运行ID
|
||
|
|
topic_index: 主题索引
|
||
|
|
variant_index: 变体索引
|
||
|
|
image_type: 图像类型
|
||
|
|
image_data: 图像数据
|
||
|
|
output_filename: 输出文件名
|
||
|
|
metadata: 图像元数据
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
str: 保存的图像路径
|
||
|
|
"""
|
||
|
|
# 创建目录结构
|
||
|
|
output_dir = ""
|
||
|
|
if image_type == 'note':
|
||
|
|
# 笔记图像保存在image目录
|
||
|
|
output_dir = os.path.join(self.output_dir, f"{run_id}/{topic_index}_{variant_index}/image")
|
||
|
|
else:
|
||
|
|
# 其他类型图像保存在各自的类型目录
|
||
|
|
output_dir = os.path.join(self.output_dir, image_type)
|
||
|
|
|
||
|
|
os.makedirs(output_dir, exist_ok=True)
|
||
|
|
|
||
|
|
# 保存图像
|
||
|
|
image_path = os.path.join(output_dir, output_filename)
|
||
|
|
image_data.save(image_path)
|
||
|
|
|
||
|
|
# 如果有元数据,保存它
|
||
|
|
if metadata:
|
||
|
|
metadata_path = os.path.splitext(image_path)[0] + '.json'
|
||
|
|
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||
|
|
json.dump(metadata, f, ensure_ascii=False, indent=2)
|
||
|
|
|
||
|
|
logger.info(f"已保存图像: {image_path}")
|
||
|
|
return image_path
|
||
|
|
|
||
|
|
def handle_poster_configs(self, run_id, topic_index, variant_index,
|
||
|
|
poster_configs, output_filename):
|
||
|
|
"""处理海报配置
|
||
|
|
|
||
|
|
Args:
|
||
|
|
run_id: 运行ID
|
||
|
|
topic_index: 主题索引
|
||
|
|
variant_index: 变体索引
|
||
|
|
poster_configs: 海报配置
|
||
|
|
output_filename: 输出文件名
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
str: 保存的配置路径
|
||
|
|
"""
|
||
|
|
# 创建目录
|
||
|
|
config_dir = os.path.join(self.output_dir, 'configs')
|
||
|
|
os.makedirs(config_dir, exist_ok=True)
|
||
|
|
|
||
|
|
# 保存配置
|
||
|
|
config_path = os.path.join(config_dir, output_filename)
|
||
|
|
with open(config_path, 'w', encoding='utf-8') as f:
|
||
|
|
json.dump(poster_configs, f, ensure_ascii=False, indent=2)
|
||
|
|
|
||
|
|
logger.info(f"已保存配置: {config_path}")
|
||
|
|
return config_path
|
||
|
|
|
||
|
|
def create_sample_metadata(image_dir, num_images=3):
|
||
|
|
"""创建示例海报元数据
|
||
|
|
|
||
|
|
Args:
|
||
|
|
image_dir: 图像目录
|
||
|
|
num_images: 使用的图像数量
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
dict: 海报元数据
|
||
|
|
"""
|
||
|
|
# 获取目录中的图像
|
||
|
|
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp')
|
||
|
|
available_images = [
|
||
|
|
f for f in os.listdir(image_dir)
|
||
|
|
if os.path.isfile(os.path.join(image_dir, f)) and
|
||
|
|
f.lower().endswith(image_extensions)
|
||
|
|
]
|
||
|
|
|
||
|
|
if len(available_images) < num_images:
|
||
|
|
logger.warning(f"可用图像数量({len(available_images)})少于请求数量({num_images}),将使用所有可用图像")
|
||
|
|
selected_images = available_images
|
||
|
|
else:
|
||
|
|
selected_images = random.sample(available_images, num_images)
|
||
|
|
|
||
|
|
# 创建元数据
|
||
|
|
metadata = {
|
||
|
|
"title": "示例海报",
|
||
|
|
"description": "这是一个用于测试的示例海报",
|
||
|
|
"collage_images": selected_images,
|
||
|
|
"generation_time": datetime.now().isoformat(),
|
||
|
|
"style": "modern"
|
||
|
|
}
|
||
|
|
|
||
|
|
return metadata
|
||
|
|
|
||
|
|
def main():
|
||
|
|
# 解析命令行参数
|
||
|
|
parser = argparse.ArgumentParser(description='测试海报笔记创建器')
|
||
|
|
parser.add_argument('--image_dir', type=str, required=True,
|
||
|
|
help='包含图像的目录路径')
|
||
|
|
parser.add_argument('--output_dir', type=str, default='./output',
|
||
|
|
help='输出目录路径')
|
||
|
|
parser.add_argument('--num_notes', type=int, default=3,
|
||
|
|
help='要创建的笔记数量')
|
||
|
|
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
# 验证图像目录
|
||
|
|
if not os.path.exists(args.image_dir) or not os.path.isdir(args.image_dir):
|
||
|
|
logger.error(f"图像目录不存在: {args.image_dir}")
|
||
|
|
return 1
|
||
|
|
|
||
|
|
# 创建输出目录
|
||
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
||
|
|
|
||
|
|
# 初始化输出处理器
|
||
|
|
output_handler = SimpleOutputHandler(args.output_dir)
|
||
|
|
|
||
|
|
# 创建一个示例海报和元数据
|
||
|
|
run_id = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
|
|
topic_index = 0
|
||
|
|
variant_index = 0
|
||
|
|
|
||
|
|
# 从图像目录中选择一张图像作为"海报"
|
||
|
|
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp')
|
||
|
|
available_images = [
|
||
|
|
f for f in os.listdir(args.image_dir)
|
||
|
|
if os.path.isfile(os.path.join(args.image_dir, f)) and
|
||
|
|
f.lower().endswith(image_extensions)
|
||
|
|
]
|
||
|
|
|
||
|
|
if not available_images:
|
||
|
|
logger.error(f"图像目录中没有找到图像: {args.image_dir}")
|
||
|
|
return 1
|
||
|
|
|
||
|
|
# 选择第一张图像作为海报
|
||
|
|
poster_image_name = available_images[0]
|
||
|
|
poster_image_path = os.path.join(args.image_dir, poster_image_name)
|
||
|
|
|
||
|
|
# 创建示例元数据
|
||
|
|
poster_metadata = create_sample_metadata(args.image_dir, 2)
|
||
|
|
|
||
|
|
# 保存海报图像
|
||
|
|
poster_image = Image.open(poster_image_path)
|
||
|
|
saved_poster_path = output_handler.handle_generated_image(
|
||
|
|
run_id,
|
||
|
|
topic_index,
|
||
|
|
variant_index,
|
||
|
|
'poster',
|
||
|
|
poster_image,
|
||
|
|
'sample_poster.jpg',
|
||
|
|
poster_metadata
|
||
|
|
)
|
||
|
|
|
||
|
|
# 保存海报元数据
|
||
|
|
poster_metadata_path = os.path.splitext(saved_poster_path)[0] + '.json'
|
||
|
|
|
||
|
|
logger.info(f"已创建示例海报: {saved_poster_path}")
|
||
|
|
logger.info(f"元数据路径: {poster_metadata_path}")
|
||
|
|
|
||
|
|
# 处理海报笔记
|
||
|
|
logger.info("开始创建海报笔记...")
|
||
|
|
note_paths = process_poster_for_notes(
|
||
|
|
run_id,
|
||
|
|
topic_index,
|
||
|
|
variant_index,
|
||
|
|
saved_poster_path,
|
||
|
|
poster_metadata_path,
|
||
|
|
args.image_dir,
|
||
|
|
args.num_notes,
|
||
|
|
output_handler
|
||
|
|
)
|
||
|
|
|
||
|
|
logger.info(f"创建了 {len(note_paths)} 个笔记图像:")
|
||
|
|
for i, path in enumerate(note_paths):
|
||
|
|
logger.info(f" {i+1}. {path}")
|
||
|
|
|
||
|
|
return 0
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
sys.exit(main())
|