增加了配图模块
This commit is contained in:
parent
218a91659b
commit
b85d98b95a
129
examples/test_additional_images.py
Normal file
129
examples/test_additional_images.py
Normal file
@ -0,0 +1,129 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import argparse
|
||||
import json
|
||||
import random
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from PIL import Image
|
||||
|
||||
# 添加项目根目录到路径
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from utils.output_handler import FileSystemOutputHandler
|
||||
from utils.poster_notes_creator import select_additional_images
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def create_sample_metadata(image_dir, num_images=2):
|
||||
"""创建示例海报元数据
|
||||
|
||||
Args:
|
||||
image_dir: 图像目录
|
||||
num_images: 使用的图像数量
|
||||
|
||||
Returns:
|
||||
dict: 海报元数据
|
||||
"""
|
||||
# 获取目录中的图像
|
||||
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp')
|
||||
available_images = [
|
||||
f for f in os.listdir(image_dir)
|
||||
if os.path.isfile(os.path.join(image_dir, f)) and
|
||||
f.lower().endswith(image_extensions)
|
||||
]
|
||||
|
||||
if len(available_images) < num_images:
|
||||
logger.warning(f"可用图像数量({len(available_images)})少于请求数量({num_images}),将使用所有可用图像")
|
||||
selected_images = available_images
|
||||
else:
|
||||
selected_images = random.sample(available_images, num_images)
|
||||
|
||||
# 创建元数据
|
||||
metadata = {
|
||||
"title": "示例海报",
|
||||
"description": "这是一个用于测试额外配图功能的示例海报",
|
||||
"collage_images": selected_images,
|
||||
"generation_time": datetime.now().isoformat(),
|
||||
"style": "modern"
|
||||
}
|
||||
|
||||
return metadata
|
||||
|
||||
def main():
|
||||
# 解析命令行参数
|
||||
parser = argparse.ArgumentParser(description='测试额外配图功能')
|
||||
parser.add_argument('--image_dir', type=str, required=True,
|
||||
help='包含图像的目录路径')
|
||||
parser.add_argument('--output_dir', type=str, default='./output',
|
||||
help='输出目录路径')
|
||||
parser.add_argument('--num_images', type=int, default=3,
|
||||
help='要选择的额外配图数量')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 验证图像目录
|
||||
if not os.path.exists(args.image_dir) or not os.path.isdir(args.image_dir):
|
||||
logger.error(f"图像目录不存在: {args.image_dir}")
|
||||
return 1
|
||||
|
||||
# 创建输出目录
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# 初始化输出处理器
|
||||
output_handler = FileSystemOutputHandler(args.output_dir)
|
||||
|
||||
# 创建测试运行ID和主题/变体索引
|
||||
run_id = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
topic_index = 1
|
||||
variant_index = 1
|
||||
|
||||
# 创建一个模拟的海报元数据文件
|
||||
metadata = create_sample_metadata(args.image_dir, 2)
|
||||
metadata_path = os.path.join(args.output_dir, f"{run_id}_{topic_index}_{variant_index}_metadata.json")
|
||||
|
||||
# 保存元数据到文件
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata, f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"模拟海报元数据已保存: {metadata_path}")
|
||||
|
||||
# 打印将要使用的"已用"图像
|
||||
logger.info(f"模拟海报使用的图像: {', '.join(metadata['collage_images'])}")
|
||||
|
||||
# 调用额外配图选择函数
|
||||
logger.info(f"开始选择额外配图...")
|
||||
image_paths = select_additional_images(
|
||||
run_id=run_id,
|
||||
topic_index=topic_index,
|
||||
variant_index=variant_index,
|
||||
poster_metadata_path=metadata_path,
|
||||
source_image_dir=args.image_dir,
|
||||
num_additional_images=args.num_images,
|
||||
output_handler=output_handler
|
||||
)
|
||||
|
||||
# 打印结果
|
||||
if image_paths:
|
||||
logger.info(f"已选择并保存 {len(image_paths)} 张额外配图:")
|
||||
for i, path in enumerate(image_paths):
|
||||
logger.info(f" {i+1}. {path}")
|
||||
|
||||
logger.info(f"额外配图已保存到: {os.path.join(args.output_dir, run_id, f'{topic_index}_{variant_index}', 'image')}")
|
||||
logger.info(f"测试成功完成!")
|
||||
else:
|
||||
logger.error("未能选择任何额外配图.")
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
226
examples/test_poster_notes.py
Normal file
226
examples/test_poster_notes.py
Normal file
@ -0,0 +1,226 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import argparse
|
||||
import json
|
||||
import random
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from PIL import Image
|
||||
|
||||
# 添加项目根目录到路径
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from utils.output_handler import OutputHandler
|
||||
from utils.poster_notes_creator import process_poster_for_notes
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SimpleOutputHandler(OutputHandler):
|
||||
"""简单的输出处理器,用于测试目的"""
|
||||
|
||||
def __init__(self, output_dir):
|
||||
"""初始化输出处理器
|
||||
|
||||
Args:
|
||||
output_dir: 输出目录路径
|
||||
"""
|
||||
self.output_dir = output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
logger.info(f"已初始化输出处理器,输出目录: {output_dir}")
|
||||
|
||||
def handle_generated_image(self, run_id, topic_index, variant_index,
|
||||
image_type, image_data, output_filename, metadata=None):
|
||||
"""处理生成的图像
|
||||
|
||||
Args:
|
||||
run_id: 运行ID
|
||||
topic_index: 主题索引
|
||||
variant_index: 变体索引
|
||||
image_type: 图像类型
|
||||
image_data: 图像数据
|
||||
output_filename: 输出文件名
|
||||
metadata: 图像元数据
|
||||
|
||||
Returns:
|
||||
str: 保存的图像路径
|
||||
"""
|
||||
# 创建目录结构
|
||||
output_dir = ""
|
||||
if image_type == 'note':
|
||||
# 笔记图像保存在image目录
|
||||
output_dir = os.path.join(self.output_dir, f"{run_id}/{topic_index}_{variant_index}/image")
|
||||
else:
|
||||
# 其他类型图像保存在各自的类型目录
|
||||
output_dir = os.path.join(self.output_dir, image_type)
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 保存图像
|
||||
image_path = os.path.join(output_dir, output_filename)
|
||||
image_data.save(image_path)
|
||||
|
||||
# 如果有元数据,保存它
|
||||
if metadata:
|
||||
metadata_path = os.path.splitext(image_path)[0] + '.json'
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info(f"已保存图像: {image_path}")
|
||||
return image_path
|
||||
|
||||
def handle_poster_configs(self, run_id, topic_index, variant_index,
|
||||
poster_configs, output_filename):
|
||||
"""处理海报配置
|
||||
|
||||
Args:
|
||||
run_id: 运行ID
|
||||
topic_index: 主题索引
|
||||
variant_index: 变体索引
|
||||
poster_configs: 海报配置
|
||||
output_filename: 输出文件名
|
||||
|
||||
Returns:
|
||||
str: 保存的配置路径
|
||||
"""
|
||||
# 创建目录
|
||||
config_dir = os.path.join(self.output_dir, 'configs')
|
||||
os.makedirs(config_dir, exist_ok=True)
|
||||
|
||||
# 保存配置
|
||||
config_path = os.path.join(config_dir, output_filename)
|
||||
with open(config_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(poster_configs, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info(f"已保存配置: {config_path}")
|
||||
return config_path
|
||||
|
||||
def create_sample_metadata(image_dir, num_images=3):
|
||||
"""创建示例海报元数据
|
||||
|
||||
Args:
|
||||
image_dir: 图像目录
|
||||
num_images: 使用的图像数量
|
||||
|
||||
Returns:
|
||||
dict: 海报元数据
|
||||
"""
|
||||
# 获取目录中的图像
|
||||
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp')
|
||||
available_images = [
|
||||
f for f in os.listdir(image_dir)
|
||||
if os.path.isfile(os.path.join(image_dir, f)) and
|
||||
f.lower().endswith(image_extensions)
|
||||
]
|
||||
|
||||
if len(available_images) < num_images:
|
||||
logger.warning(f"可用图像数量({len(available_images)})少于请求数量({num_images}),将使用所有可用图像")
|
||||
selected_images = available_images
|
||||
else:
|
||||
selected_images = random.sample(available_images, num_images)
|
||||
|
||||
# 创建元数据
|
||||
metadata = {
|
||||
"title": "示例海报",
|
||||
"description": "这是一个用于测试的示例海报",
|
||||
"collage_images": selected_images,
|
||||
"generation_time": datetime.now().isoformat(),
|
||||
"style": "modern"
|
||||
}
|
||||
|
||||
return metadata
|
||||
|
||||
def main():
|
||||
# 解析命令行参数
|
||||
parser = argparse.ArgumentParser(description='测试海报笔记创建器')
|
||||
parser.add_argument('--image_dir', type=str, required=True,
|
||||
help='包含图像的目录路径')
|
||||
parser.add_argument('--output_dir', type=str, default='./output',
|
||||
help='输出目录路径')
|
||||
parser.add_argument('--num_notes', type=int, default=3,
|
||||
help='要创建的笔记数量')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 验证图像目录
|
||||
if not os.path.exists(args.image_dir) or not os.path.isdir(args.image_dir):
|
||||
logger.error(f"图像目录不存在: {args.image_dir}")
|
||||
return 1
|
||||
|
||||
# 创建输出目录
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# 初始化输出处理器
|
||||
output_handler = SimpleOutputHandler(args.output_dir)
|
||||
|
||||
# 创建一个示例海报和元数据
|
||||
run_id = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
topic_index = 0
|
||||
variant_index = 0
|
||||
|
||||
# 从图像目录中选择一张图像作为"海报"
|
||||
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp')
|
||||
available_images = [
|
||||
f for f in os.listdir(args.image_dir)
|
||||
if os.path.isfile(os.path.join(args.image_dir, f)) and
|
||||
f.lower().endswith(image_extensions)
|
||||
]
|
||||
|
||||
if not available_images:
|
||||
logger.error(f"图像目录中没有找到图像: {args.image_dir}")
|
||||
return 1
|
||||
|
||||
# 选择第一张图像作为海报
|
||||
poster_image_name = available_images[0]
|
||||
poster_image_path = os.path.join(args.image_dir, poster_image_name)
|
||||
|
||||
# 创建示例元数据
|
||||
poster_metadata = create_sample_metadata(args.image_dir, 2)
|
||||
|
||||
# 保存海报图像
|
||||
poster_image = Image.open(poster_image_path)
|
||||
saved_poster_path = output_handler.handle_generated_image(
|
||||
run_id,
|
||||
topic_index,
|
||||
variant_index,
|
||||
'poster',
|
||||
poster_image,
|
||||
'sample_poster.jpg',
|
||||
poster_metadata
|
||||
)
|
||||
|
||||
# 保存海报元数据
|
||||
poster_metadata_path = os.path.splitext(saved_poster_path)[0] + '.json'
|
||||
|
||||
logger.info(f"已创建示例海报: {saved_poster_path}")
|
||||
logger.info(f"元数据路径: {poster_metadata_path}")
|
||||
|
||||
# 处理海报笔记
|
||||
logger.info("开始创建海报笔记...")
|
||||
note_paths = process_poster_for_notes(
|
||||
run_id,
|
||||
topic_index,
|
||||
variant_index,
|
||||
saved_poster_path,
|
||||
poster_metadata_path,
|
||||
args.image_dir,
|
||||
args.num_notes,
|
||||
output_handler
|
||||
)
|
||||
|
||||
logger.info(f"创建了 {len(note_paths)} 个笔记图像:")
|
||||
for i, path in enumerate(note_paths):
|
||||
logger.info(f" {i+1}. {path}")
|
||||
|
||||
return 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
50
main.py
50
main.py
@ -24,6 +24,8 @@ import random
|
||||
from utils.output_handler import FileSystemOutputHandler, OutputHandler
|
||||
from core.topic_parser import TopicParser
|
||||
from utils.tweet_generator import tweetTopicRecord # Needed only if loading old topics files?
|
||||
# 导入额外配图选择函数
|
||||
from utils.poster_notes_creator import select_additional_images
|
||||
|
||||
def load_config(config_path="poster_gen_config.json"):
|
||||
"""Loads configuration from a JSON file."""
|
||||
@ -200,6 +202,54 @@ def generate_content_and_posters_step(config, run_id, topics_list, output_handle
|
||||
)
|
||||
if posters_attempted:
|
||||
logging.info(f"Poster generation process completed for Topic {topic_index}.")
|
||||
|
||||
# --- 为每个变体添加额外配图 ---
|
||||
logging.info(f"开始为主题 {topic_index} 添加额外配图...")
|
||||
additional_images_count = config.get("additional_images_count", 3)
|
||||
|
||||
# 循环处理每个变体
|
||||
for j in range(poster_variants):
|
||||
variant_index = j + 1
|
||||
variant_dir = os.path.join(config["output_dir"], run_id, f"{topic_index}_{variant_index}")
|
||||
|
||||
# 获取海报元数据路径
|
||||
poster_dir = os.path.join(variant_dir, poster_subdir)
|
||||
metadata_files = [f for f in os.listdir(poster_dir)
|
||||
if f.endswith("_metadata.json") and os.path.isfile(os.path.join(poster_dir, f))]
|
||||
|
||||
if metadata_files:
|
||||
poster_metadata_path = os.path.join(poster_dir, metadata_files[0])
|
||||
logging.info(f"为变体 {variant_index} 选择额外配图,使用元数据: {poster_metadata_path}")
|
||||
|
||||
# 调用额外配图选择函数
|
||||
try:
|
||||
object_name = topic_item.get("object", "").split(".")[0].replace("景点信息-", "").strip()
|
||||
source_image_dir = os.path.join(img_base_dir, object_name)
|
||||
|
||||
if os.path.exists(source_image_dir) and os.path.isdir(source_image_dir):
|
||||
additional_paths = select_additional_images(
|
||||
run_id=run_id,
|
||||
topic_index=topic_index,
|
||||
variant_index=variant_index,
|
||||
poster_metadata_path=poster_metadata_path,
|
||||
source_image_dir=source_image_dir,
|
||||
num_additional_images=additional_images_count,
|
||||
output_handler=output_handler
|
||||
)
|
||||
|
||||
if additional_paths:
|
||||
logging.info(f"已为变体 {variant_index} 选择 {len(additional_paths)} 张额外配图")
|
||||
else:
|
||||
logging.warning(f"未能为变体 {variant_index} 选择任何额外配图")
|
||||
else:
|
||||
logging.warning(f"源图像目录不存在: {source_image_dir}")
|
||||
except Exception as e:
|
||||
logging.exception(f"选择额外配图时发生错误: {e}")
|
||||
else:
|
||||
logging.warning(f"未找到变体 {variant_index} 的海报元数据文件,跳过额外配图选择")
|
||||
|
||||
# --- 结束添加额外配图 ---
|
||||
|
||||
success_flag = True # Mark overall success if content AND poster attempts were made
|
||||
else:
|
||||
logging.warning(f"Poster generation skipped or failed early for Topic {topic_index}.")
|
||||
|
||||
@ -76,5 +76,8 @@
|
||||
"text_possibility": 0.3,
|
||||
"img_frame_possibility": 0,
|
||||
"text_bg_possibility": 0,
|
||||
"collage_style": ["grid_2x2", "overlap", "mosaic", "fullscreen", "vertical_stack"]
|
||||
"collage_style": ["grid_2x2", "overlap", "mosaic", "fullscreen", "vertical_stack"],
|
||||
|
||||
"additional_images_count": 3,
|
||||
"additional_images_enabled": true
|
||||
}
|
||||
80
poster_gen_config.json.bak
Normal file
80
poster_gen_config.json.bak
Normal file
@ -0,0 +1,80 @@
|
||||
{
|
||||
"date": "4月30日, 4月28日, 5月1日",
|
||||
"num": 1,
|
||||
"variants": 3,
|
||||
"topic_temperature": 0.2,
|
||||
"topic_top_p": 0.3,
|
||||
"topic_presence_penalty": 1.5,
|
||||
"content_temperature": 0.3,
|
||||
"content_top_p": 0.4,
|
||||
"content_presence_penalty": 1.5,
|
||||
"model": "qwenQWQ",
|
||||
"api_url": "http://localhost:8000/v1/",
|
||||
"api_key": "EMPTY",
|
||||
"topic_system_prompt": "./SelectPrompt/systemPrompt.txt",
|
||||
"topic_user_prompt": "./SelectPrompt/userPrompt.txt",
|
||||
"content_system_prompt": "./genPrompts/systemPrompt.txt",
|
||||
"poster_content_system_prompt": "./genPrompts/poster_content_systemPrompt.txt",
|
||||
"prompts_config": [
|
||||
{
|
||||
"type": "Style",
|
||||
"file_path": [
|
||||
"./genPrompts/Style/攻略风文案提示词.txt"
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "Demand",
|
||||
"file_path": [
|
||||
"./genPrompts/Demand/亲子向文旅需求.txt",
|
||||
"./genPrompts/Demand/周边游文旅需求.txt",
|
||||
"./genPrompts/Demand/情侣向文旅需求.txt"
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "Refer",
|
||||
"file_path": [
|
||||
"./genPrompts/Refer/标题参考格式.txt",
|
||||
"./genPrompts/Refer/正文开头引入段落参考.txt"
|
||||
]
|
||||
}
|
||||
],
|
||||
"resource_dir": [
|
||||
{
|
||||
"type": "Object",
|
||||
"file_path": [
|
||||
"./resource/Object/笔架山居森林度假酒店.txt",
|
||||
"./resource/Object/乌镇民宿.txt"
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "Description",
|
||||
"file_path": [
|
||||
"./resource/Object/笔架山居森林度假酒店.txt",
|
||||
"./resource/Object/乌镇民宿.txt"
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "Product",
|
||||
"file_path": [
|
||||
]
|
||||
}
|
||||
],
|
||||
"output_dir": "./result",
|
||||
"image_base_dir": "/root/autodl-tmp/TravelContentCreator/hotel_img",
|
||||
"poster_assets_base_dir": "/root/autodl-tmp/poster_baseboard_0403",
|
||||
"request_timeout": 210,
|
||||
"max_retries": 3,
|
||||
"description_filename": "description.txt",
|
||||
"output_collage_subdir": "collage_img",
|
||||
"output_poster_subdir": "poster",
|
||||
"output_poster_filename": "poster.jpg",
|
||||
"poster_target_size": [
|
||||
900,
|
||||
1200
|
||||
],
|
||||
"title_possibility": 0.3,
|
||||
"text_possibility": 0.3,
|
||||
"img_frame_possibility": 0,
|
||||
"text_bg_possibility": 0,
|
||||
"collage_style": ["grid_2x2", "overlap", "mosaic", "fullscreen", "vertical_stack"]
|
||||
}
|
||||
88
test_additional_images.py
Normal file
88
test_additional_images.py
Normal file
@ -0,0 +1,88 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import logging
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
|
||||
# 添加项目根目录到路径
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from utils.output_handler import FileSystemOutputHandler
|
||||
from utils.poster_notes_creator import select_additional_images
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='测试额外配图选择功能')
|
||||
parser.add_argument('--run_id', type=str, required=True, help='指定运行ID')
|
||||
parser.add_argument('--topic_index', type=int, required=True, help='指定主题索引')
|
||||
parser.add_argument('--variant_index', type=int, required=True, help='指定变体索引')
|
||||
parser.add_argument('--source_dir', type=str, required=True, help='源图像目录路径')
|
||||
parser.add_argument('--output_dir', type=str, default='./result', help='输出目录路径')
|
||||
parser.add_argument('--num_images', type=int, default=3, help='要选择的配图数量')
|
||||
args = parser.parse_args()
|
||||
|
||||
# 检查源目录是否存在
|
||||
if not os.path.exists(args.source_dir) or not os.path.isdir(args.source_dir):
|
||||
logger.error(f"源图像目录不存在: {args.source_dir}")
|
||||
return 1
|
||||
|
||||
# 检查输出目录结构
|
||||
run_dir = os.path.join(args.output_dir, args.run_id)
|
||||
variant_dir = os.path.join(run_dir, f"{args.topic_index}_{args.variant_index}")
|
||||
poster_dir = os.path.join(variant_dir, "poster")
|
||||
|
||||
if not os.path.exists(poster_dir):
|
||||
logger.error(f"海报目录不存在: {poster_dir}")
|
||||
return 1
|
||||
|
||||
# 查找海报元数据文件
|
||||
metadata_files = [f for f in os.listdir(poster_dir)
|
||||
if f.endswith("_metadata.json") and os.path.isfile(os.path.join(poster_dir, f))]
|
||||
|
||||
if not metadata_files:
|
||||
logger.error(f"未找到海报元数据文件")
|
||||
return 1
|
||||
|
||||
metadata_path = os.path.join(poster_dir, metadata_files[0])
|
||||
logger.info(f"使用元数据文件: {metadata_path}")
|
||||
|
||||
# 初始化输出处理器
|
||||
output_handler = FileSystemOutputHandler(args.output_dir)
|
||||
|
||||
# 调用额外配图选择函数
|
||||
logger.info(f"开始为运行ID {args.run_id} 主题 {args.topic_index} 变体 {args.variant_index} 选择额外配图...")
|
||||
image_paths = select_additional_images(
|
||||
run_id=args.run_id,
|
||||
topic_index=args.topic_index,
|
||||
variant_index=args.variant_index,
|
||||
poster_metadata_path=metadata_path,
|
||||
source_image_dir=args.source_dir,
|
||||
num_additional_images=args.num_images,
|
||||
output_handler=output_handler
|
||||
)
|
||||
|
||||
# 输出结果
|
||||
if image_paths:
|
||||
logger.info(f"已选择并保存 {len(image_paths)} 张额外配图:")
|
||||
for i, path in enumerate(image_paths):
|
||||
logger.info(f" {i+1}. {path}")
|
||||
logger.info(f"额外配图已保存到: {os.path.join(args.output_dir, args.run_id, f'{args.topic_index}_{args.variant_index}', 'image')}")
|
||||
logger.info("测试成功完成!")
|
||||
else:
|
||||
logger.error("未能选择任何额外配图")
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Binary file not shown.
BIN
utils/__pycache__/poster_notes_creator.cpython-312.pyc
Normal file
BIN
utils/__pycache__/poster_notes_creator.cpython-312.pyc
Normal file
Binary file not shown.
@ -121,38 +121,53 @@ class FileSystemOutputHandler(OutputHandler):
|
||||
logging.error(f"Failed to save complete poster configurations for topic {topic_index} to {config_path}: {save_err}")
|
||||
|
||||
def handle_generated_image(self, run_id: str, topic_index: int, variant_index: int, image_type: str, image_data, output_filename: str, metadata: dict = None):
|
||||
"""Saves a generated image (PIL Image) to the appropriate variant subdirectory."""
|
||||
subdir = None
|
||||
if image_type == 'collage':
|
||||
subdir = 'collage_img' # TODO: Make these subdir names configurable?
|
||||
elif image_type == 'poster':
|
||||
subdir = 'poster'
|
||||
"""处理生成的图像,对于笔记图像和额外配图保存到image目录,其他类型保持原有路径结构"""
|
||||
# 根据图像类型确定保存路径
|
||||
if image_type == 'note' or image_type == 'additional': # 笔记图像和额外配图保存到image目录
|
||||
# 创建run_id/i_j/image目录
|
||||
run_dir = self._get_run_dir(run_id)
|
||||
variant_dir = os.path.join(run_dir, f"{topic_index}_{variant_index}")
|
||||
image_dir = os.path.join(variant_dir, "image")
|
||||
os.makedirs(image_dir, exist_ok=True)
|
||||
|
||||
# 在输出文件名前加上图像类型前缀
|
||||
prefixed_filename = f"{image_type}_{output_filename}" if not output_filename.startswith(image_type) else output_filename
|
||||
save_path = os.path.join(image_dir, prefixed_filename)
|
||||
else:
|
||||
logging.warning(f"Unknown image type '{image_type}'. Saving to variant root.")
|
||||
subdir = None # Save directly in variant dir if type is unknown
|
||||
|
||||
target_dir = self._get_variant_dir(run_id, topic_index, variant_index, subdir=subdir)
|
||||
save_path = os.path.join(target_dir, output_filename)
|
||||
# 其他类型图像使用原有的保存路径逻辑
|
||||
subdir = None
|
||||
if image_type == 'collage':
|
||||
subdir = 'collage_img' # 可配置的子目录名称
|
||||
elif image_type == 'poster':
|
||||
subdir = 'poster'
|
||||
else:
|
||||
logging.warning(f"未知图像类型 '{image_type}',保存到变体根目录。")
|
||||
subdir = None # 如果类型未知,直接保存到变体目录
|
||||
|
||||
target_dir = self._get_variant_dir(run_id, topic_index, variant_index, subdir=subdir)
|
||||
save_path = os.path.join(target_dir, output_filename)
|
||||
|
||||
try:
|
||||
# 保存图片
|
||||
# Assuming image_data is a PIL Image object based on posterGen/simple_collage
|
||||
image_data.save(save_path)
|
||||
logging.info(f"Saved {image_type} image to: {save_path}")
|
||||
logging.info(f"保存{image_type}图像到: {save_path}")
|
||||
|
||||
# 保存元数据(如果有)
|
||||
if metadata:
|
||||
metadata_filename = os.path.splitext(output_filename)[0] + "_metadata.json"
|
||||
metadata_path = os.path.join(target_dir, metadata_filename)
|
||||
# 确保元数据文件与图像在同一目录
|
||||
metadata_filename = os.path.splitext(os.path.basename(save_path))[0] + "_metadata.json"
|
||||
metadata_path = os.path.join(os.path.dirname(save_path), metadata_filename)
|
||||
try:
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata, f, ensure_ascii=False, indent=4)
|
||||
logging.info(f"Saved {image_type} metadata to: {metadata_path}")
|
||||
logging.info(f"保存{image_type}元数据到: {metadata_path}")
|
||||
except Exception as me:
|
||||
logging.error(f"Failed to save {image_type} metadata to {metadata_path}: {me}")
|
||||
logging.error(f"无法保存{image_type}元数据到{metadata_path}: {me}")
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(f"Failed to save {image_type} image to {save_path}: {e}")
|
||||
logging.exception(f"无法保存{image_type}图像到{save_path}: {e}")
|
||||
|
||||
return save_path
|
||||
|
||||
def finalize(self, run_id: str):
|
||||
logging.info(f"FileSystemOutputHandler finalizing run: {run_id}. No specific actions needed.")
|
||||
|
||||
372
utils/poster_notes_creator.py
Normal file
372
utils/poster_notes_creator.py
Normal file
@ -0,0 +1,372 @@
|
||||
import os
|
||||
import random
|
||||
import logging
|
||||
import json
|
||||
from PIL import Image
|
||||
import traceback
|
||||
from typing import List, Tuple, Dict, Any, Optional
|
||||
|
||||
from .output_handler import OutputHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class PosterNotesCreator:
|
||||
"""
|
||||
处理原始海报作为主图,并随机选择额外的图片作为笔记图片。
|
||||
确保选择的笔记图片与海报中使用的图片不重复。
|
||||
"""
|
||||
|
||||
def __init__(self, output_handler: OutputHandler):
|
||||
"""
|
||||
初始化 PosterNotesCreator
|
||||
|
||||
Args:
|
||||
output_handler: 可选的 OutputHandler 实例,用于处理输出
|
||||
"""
|
||||
self.output_handler = output_handler
|
||||
logging.info("PosterNotesCreator 初始化完成")
|
||||
|
||||
def create_notes_images(
|
||||
self,
|
||||
run_id: str,
|
||||
topic_index: int,
|
||||
variant_index: int,
|
||||
poster_image_path: str,
|
||||
poster_metadata_path: str,
|
||||
source_image_dir: str,
|
||||
num_additional_images: int,
|
||||
output_filename_template: str = "note_{index}.jpg"
|
||||
) -> List[str]:
|
||||
"""
|
||||
创建笔记图像
|
||||
|
||||
Args:
|
||||
run_id: 运行ID
|
||||
topic_index: 主题索引
|
||||
variant_index: 变体索引
|
||||
poster_image_path: 海报图像路径
|
||||
poster_metadata_path: 海报元数据路径
|
||||
source_image_dir: 源图像目录
|
||||
num_additional_images: 要使用的额外图像数量
|
||||
output_filename_template: 输出文件名模板
|
||||
|
||||
Returns:
|
||||
List[str]: 保存的笔记图像路径列表
|
||||
"""
|
||||
# 检查输入路径是否存在
|
||||
if not os.path.exists(poster_image_path):
|
||||
logger.error(f"海报图像不存在: {poster_image_path}")
|
||||
return []
|
||||
|
||||
if not os.path.exists(poster_metadata_path):
|
||||
logger.error(f"海报元数据不存在: {poster_metadata_path}")
|
||||
return []
|
||||
|
||||
if not os.path.exists(source_image_dir) or not os.path.isdir(source_image_dir):
|
||||
logger.error(f"源图像目录不存在: {source_image_dir}")
|
||||
return []
|
||||
|
||||
# 从元数据文件中读取已使用的图像信息
|
||||
try:
|
||||
with open(poster_metadata_path, 'r', encoding='utf-8') as f:
|
||||
poster_metadata = json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"无法读取海报元数据: {e}")
|
||||
return []
|
||||
|
||||
# 获取已经在海报中使用的图像
|
||||
used_images = []
|
||||
if 'collage_images' in poster_metadata:
|
||||
used_images = poster_metadata['collage_images']
|
||||
logger.info(f"海报中已使用 {len(used_images)} 张图像: {', '.join(used_images)}")
|
||||
|
||||
# 列出源目录中的所有图像文件
|
||||
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp')
|
||||
available_images = [
|
||||
f for f in os.listdir(source_image_dir)
|
||||
if os.path.isfile(os.path.join(source_image_dir, f)) and
|
||||
f.lower().endswith(image_extensions)
|
||||
]
|
||||
|
||||
if not available_images:
|
||||
logger.error(f"源目录中没有找到图像: {source_image_dir}")
|
||||
return []
|
||||
|
||||
logger.info(f"源目录中找到 {len(available_images)} 张图像")
|
||||
|
||||
# 过滤掉已经在海报中使用的图像
|
||||
available_images = [img for img in available_images if img not in used_images]
|
||||
|
||||
if not available_images:
|
||||
logger.warning("所有图像都已在海报中使用,无法创建额外笔记")
|
||||
return []
|
||||
|
||||
logger.info(f"过滤后可用图像数量: {len(available_images)}")
|
||||
|
||||
# 如果可用图像少于请求数量,进行警告但继续处理
|
||||
if len(available_images) < num_additional_images:
|
||||
logger.warning(
|
||||
f"可用图像数量 ({len(available_images)}) 少于请求的笔记数量 ({num_additional_images}),"
|
||||
f"将使用所有可用图像"
|
||||
)
|
||||
selected_images = available_images
|
||||
else:
|
||||
# 随机选择额外图像
|
||||
selected_images = random.sample(available_images, num_additional_images)
|
||||
|
||||
logger.info(f"已选择 {len(selected_images)} 张图像作为笔记")
|
||||
|
||||
# 保存选择的笔记图像
|
||||
saved_paths = []
|
||||
for i, image_filename in enumerate(selected_images):
|
||||
try:
|
||||
# 加载图像
|
||||
image_path = os.path.join(source_image_dir, image_filename)
|
||||
image = Image.open(image_path)
|
||||
|
||||
# 生成输出文件名
|
||||
output_filename = output_filename_template.format(index=i+1)
|
||||
|
||||
# 创建元数据
|
||||
note_metadata = {
|
||||
"original_image": image_filename,
|
||||
"note_index": i + 1,
|
||||
"source_dir": source_image_dir,
|
||||
"associated_poster": os.path.basename(poster_image_path)
|
||||
}
|
||||
|
||||
# 使用输出处理器保存图像
|
||||
saved_path = self.output_handler.handle_generated_image(
|
||||
run_id,
|
||||
topic_index,
|
||||
variant_index,
|
||||
'note', # 图像类型为note
|
||||
image,
|
||||
output_filename,
|
||||
note_metadata
|
||||
)
|
||||
|
||||
saved_paths.append(saved_path)
|
||||
logger.info(f"已保存笔记图像 {i+1}/{len(selected_images)}: {saved_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理图像时出错 '{image_filename}': {e}")
|
||||
|
||||
return saved_paths
|
||||
|
||||
def create_additional_images(
|
||||
self,
|
||||
run_id: str,
|
||||
topic_index: int,
|
||||
variant_index: int,
|
||||
poster_metadata_path: str,
|
||||
source_image_dir: str,
|
||||
num_additional_images: int = 3,
|
||||
output_filename_template: str = "additional_{index}.jpg"
|
||||
) -> List[str]:
|
||||
"""
|
||||
选择未被海报使用的图像作为额外配图
|
||||
|
||||
Args:
|
||||
run_id: 运行ID
|
||||
topic_index: 主题索引
|
||||
variant_index: 变体索引
|
||||
poster_metadata_path: 海报元数据路径
|
||||
source_image_dir: 源图像目录
|
||||
num_additional_images: 要选择的额外图像数量
|
||||
output_filename_template: 输出文件名模板
|
||||
|
||||
Returns:
|
||||
List[str]: 保存的额外配图路径列表
|
||||
"""
|
||||
logger.info(f"开始为主题 {topic_index} 变体 {variant_index} 选择额外配图")
|
||||
|
||||
# 检查输入路径是否存在
|
||||
if not os.path.exists(poster_metadata_path):
|
||||
logger.error(f"海报元数据不存在: {poster_metadata_path}")
|
||||
return []
|
||||
|
||||
if not os.path.exists(source_image_dir) or not os.path.isdir(source_image_dir):
|
||||
logger.error(f"源图像目录不存在: {source_image_dir}")
|
||||
return []
|
||||
|
||||
# 从元数据文件中读取已使用的图像信息
|
||||
try:
|
||||
with open(poster_metadata_path, 'r', encoding='utf-8') as f:
|
||||
poster_metadata = json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"无法读取海报元数据: {e}")
|
||||
return []
|
||||
|
||||
# 获取已经在海报中使用的图像
|
||||
used_images = []
|
||||
if 'collage_images' in poster_metadata:
|
||||
used_images = poster_metadata['collage_images']
|
||||
logger.info(f"海报中已使用 {len(used_images)} 张图像: {', '.join(used_images)}")
|
||||
|
||||
# 列出源目录中的所有图像文件
|
||||
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp')
|
||||
available_images = [
|
||||
f for f in os.listdir(source_image_dir)
|
||||
if os.path.isfile(os.path.join(source_image_dir, f)) and
|
||||
f.lower().endswith(image_extensions)
|
||||
]
|
||||
|
||||
if not available_images:
|
||||
logger.error(f"源目录中没有找到图像: {source_image_dir}")
|
||||
return []
|
||||
|
||||
logger.info(f"源目录中找到 {len(available_images)} 张图像")
|
||||
|
||||
# 过滤掉已经在海报中使用的图像
|
||||
available_images = [img for img in available_images if img not in used_images]
|
||||
|
||||
if not available_images:
|
||||
logger.warning("所有图像都已在海报中使用,无法创建额外配图")
|
||||
return []
|
||||
|
||||
logger.info(f"过滤后可用图像数量: {len(available_images)}")
|
||||
|
||||
# 如果可用图像少于请求数量,进行警告但继续处理
|
||||
if len(available_images) < num_additional_images:
|
||||
logger.warning(
|
||||
f"可用图像数量 ({len(available_images)}) 少于请求的配图数量 ({num_additional_images}),"
|
||||
f"将使用所有可用图像"
|
||||
)
|
||||
selected_images = available_images
|
||||
else:
|
||||
# 随机选择额外图像
|
||||
selected_images = random.sample(available_images, num_additional_images)
|
||||
|
||||
logger.info(f"已选择 {len(selected_images)} 张图像作为额外配图")
|
||||
|
||||
# 保存选择的额外配图
|
||||
saved_paths = []
|
||||
for i, image_filename in enumerate(selected_images):
|
||||
try:
|
||||
# 加载图像
|
||||
image_path = os.path.join(source_image_dir, image_filename)
|
||||
image = Image.open(image_path)
|
||||
|
||||
# 生成输出文件名
|
||||
output_filename = output_filename_template.format(index=i+1)
|
||||
|
||||
# 创建元数据
|
||||
additional_metadata = {
|
||||
"original_image": image_filename,
|
||||
"additional_index": i + 1,
|
||||
"source_dir": source_image_dir,
|
||||
"is_additional_image": True
|
||||
}
|
||||
|
||||
# 使用输出处理器保存图像
|
||||
saved_path = self.output_handler.handle_generated_image(
|
||||
run_id,
|
||||
topic_index,
|
||||
variant_index,
|
||||
'additional', # 图像类型为additional
|
||||
image,
|
||||
output_filename,
|
||||
additional_metadata
|
||||
)
|
||||
|
||||
saved_paths.append(saved_path)
|
||||
logger.info(f"已保存额外配图 {i+1}/{len(selected_images)}: {saved_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理图像时出错 '{image_filename}': {e}")
|
||||
|
||||
return saved_paths
|
||||
|
||||
def process_poster_for_notes(
|
||||
run_id: str,
|
||||
topic_index: int,
|
||||
variant_index: int,
|
||||
poster_image_path: str,
|
||||
poster_metadata_path: str,
|
||||
source_image_dir: str,
|
||||
num_additional_images: int,
|
||||
output_handler: OutputHandler,
|
||||
output_filename_template: str = "note_{index}.jpg"
|
||||
) -> List[str]:
|
||||
"""
|
||||
处理海报并创建笔记图像
|
||||
|
||||
Args:
|
||||
run_id: 运行ID
|
||||
topic_index: 主题索引
|
||||
variant_index: 变体索引
|
||||
poster_image_path: 海报图像路径
|
||||
poster_metadata_path: 海报元数据路径
|
||||
source_image_dir: 源图像目录
|
||||
num_additional_images: 要使用的额外图像数量
|
||||
output_handler: 输出处理器
|
||||
output_filename_template: 输出文件名模板
|
||||
|
||||
Returns:
|
||||
List[str]: 保存的笔记图像路径列表
|
||||
"""
|
||||
logger.info(f"开始为海报创建笔记图像: {poster_image_path}")
|
||||
|
||||
# 验证输入
|
||||
if not os.path.exists(poster_image_path):
|
||||
logger.error(f"海报图像不存在: {poster_image_path}")
|
||||
return []
|
||||
|
||||
# 创建处理器实例并处理
|
||||
creator = PosterNotesCreator(output_handler)
|
||||
return creator.create_notes_images(
|
||||
run_id,
|
||||
topic_index,
|
||||
variant_index,
|
||||
poster_image_path,
|
||||
poster_metadata_path,
|
||||
source_image_dir,
|
||||
num_additional_images,
|
||||
output_filename_template
|
||||
)
|
||||
|
||||
def select_additional_images(
|
||||
run_id: str,
|
||||
topic_index: int,
|
||||
variant_index: int,
|
||||
poster_metadata_path: str,
|
||||
source_image_dir: str,
|
||||
num_additional_images: int,
|
||||
output_handler: OutputHandler,
|
||||
output_filename_template: str = "additional_{index}.jpg"
|
||||
) -> List[str]:
|
||||
"""
|
||||
选择未被海报使用的图像作为额外配图
|
||||
|
||||
Args:
|
||||
run_id: 运行ID
|
||||
topic_index: 主题索引
|
||||
variant_index: 变体索引
|
||||
poster_metadata_path: 海报元数据路径
|
||||
source_image_dir: 源图像目录
|
||||
num_additional_images: 要选择的额外图像数量
|
||||
output_handler: 输出处理器
|
||||
output_filename_template: 输出文件名模板
|
||||
|
||||
Returns:
|
||||
List[str]: 保存的额外配图路径列表
|
||||
"""
|
||||
logger.info(f"开始为主题 {topic_index} 变体 {variant_index} 选择额外配图")
|
||||
|
||||
# 验证输入
|
||||
if not os.path.exists(poster_metadata_path):
|
||||
logger.error(f"海报元数据不存在: {poster_metadata_path}")
|
||||
return []
|
||||
|
||||
# 创建处理器实例并处理
|
||||
creator = PosterNotesCreator(output_handler)
|
||||
return creator.create_additional_images(
|
||||
run_id,
|
||||
topic_index,
|
||||
variant_index,
|
||||
poster_metadata_path,
|
||||
source_image_dir,
|
||||
num_additional_images,
|
||||
output_filename_template
|
||||
)
|
||||
Loading…
x
Reference in New Issue
Block a user