TravelContentCreator/scripts/reimage/regenerate_images.py

314 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import sys
import json
import argparse
import logging
from typing import List, Dict
# 添加项目根目录到系统路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
# 导入所需的模块
from utils.output_handler import FileSystemOutputHandler
from utils.poster_notes_creator import select_additional_images
from utils.tweet_generator import generate_posters_for_topic
from core.topic_parser import TopicParser
def load_config(config_path="poster_gen_config.json"):
"""加载配置文件"""
if not os.path.exists(config_path):
logging.error(f"错误:配置文件 '{config_path}' 不存在")
return None
try:
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
return config
except Exception as e:
logging.error(f"加载配置文件时出错: {e}")
return None
def regenerate_posters_and_images(run_dir: str, config: Dict, target_topic: int = None,
ignore_metadata: bool = False, poster_config_path: str = None):
"""为所有主题重新生成海报和配图
Args:
run_dir: 运行目录路径
config: 配置字典
target_topic: 只处理特定主题ID
ignore_metadata: 是否忽略现有的元数据
poster_config_path: 海报配置文件路径
"""
run_id = os.path.basename(run_dir)
logging.info(f"处理运行ID: {run_id}")
# 创建输出处理器
output_handler = FileSystemOutputHandler(config.get("output_dir", "result"))
# 加载主题文件
topics_file = os.path.join(run_dir, f"tweet_topic_{run_id}.json")
if not os.path.exists(topics_file):
alternative_topics_file = os.path.join(run_dir, "tweet_topic.json")
if not os.path.exists(alternative_topics_file):
logging.error(f"主题文件不存在: {topics_file}{alternative_topics_file}")
return
topics_file = alternative_topics_file
logging.info(f"使用主题文件: {topics_file}")
topics_list = TopicParser.load_topics_from_json(topics_file)
if not topics_list:
logging.error("无法加载主题列表")
return
# 如果指定了poster_config_path加载自定义海报配置
poster_configs = None
if poster_config_path:
try:
with open(poster_config_path, 'r', encoding='utf-8') as f:
poster_configs = json.load(f)
logging.info(f"已加载自定义海报配置,包含 {len(poster_configs)} 个配置项")
except Exception as e:
logging.error(f"加载海报配置文件时出错: {e}")
return
# 加载配置参数
poster_variants = config.get("variants", 1)
poster_assets_dir = config.get("poster_assets_base_dir")
img_base_dir = config.get("image_base_dir")
res_dir_config = config.get("resource_dir", [])
poster_size = tuple(config.get("poster_target_size", [900, 1200]))
txt_possibility = config.get("text_possibility", 0.3)
img_frame_possibility = config.get("img_frame_possibility", 0.7)
text_bg_possibility = config.get("text_bg_possibility", 0)
collage_subdir = config.get("output_collage_subdir", "collage_img")
poster_subdir = config.get("output_poster_subdir", "poster")
poster_filename = config.get("output_poster_filename", "poster.jpg")
poster_content_system_prompt_path = config.get("poster_content_system_prompt")
collage_style = config.get("collage_style")
title_possibility = config.get("title_possibility", 0.3)
request_timeout = config.get("request_timeout", 180)
additional_images_count = config.get("additional_images_count", 3)
# 获取图像处理参数
image_selection_config = config.get("image_selection", {})
variation_strength = image_selection_config.get("variation_strength", "medium")
extra_effects = image_selection_config.get("extra_effects", True)
# 检查关键配置
if not poster_assets_dir or not img_base_dir or not poster_content_system_prompt_path:
logging.error("缺少关键配置参数")
return
# 读取海报内容系统提示词
with open(poster_content_system_prompt_path, "r", encoding="utf-8") as f:
poster_content_system_prompt = f.read()
# 处理每个主题
for topic_item in topics_list:
topic_index = topic_item.get('index')
if not topic_index:
logging.warning(f"主题缺少索引,跳过: {topic_item}")
continue
# 如果指定了target_topic且不匹配当前主题则跳过
if target_topic is not None and topic_index != target_topic:
continue
# 检查是否有自定义海报配置
custom_config = None
if poster_configs:
for config_item in poster_configs:
if config_item.get('index') == topic_index:
custom_config = config_item
break
if custom_config:
# 更新主题标题
if 'main_title' in custom_config:
topic_item['title'] = custom_config['main_title']
logging.info(f"使用自定义标题: {custom_config['main_title']}")
logging.info(f"处理主题 {topic_index}: {topic_item.get('object', 'N/A')}")
# 如果忽略元数据,先清空所有现有的海报和配图文件
if ignore_metadata:
for variant_index in range(1, poster_variants + 1):
variant_key = f"{topic_index}_{variant_index}"
topic_dir = os.path.join(run_dir, variant_key)
if os.path.exists(topic_dir):
poster_dir = os.path.join(topic_dir, poster_subdir)
if os.path.exists(poster_dir):
logging.info(f"清空变体 {variant_key} 的海报目录")
for file_name in os.listdir(poster_dir):
file_path = os.path.join(poster_dir, file_name)
if os.path.isfile(file_path):
os.remove(file_path)
logging.debug(f"已删除文件: {file_path}")
else:
os.makedirs(poster_dir, exist_ok=True)
else:
os.makedirs(topic_dir, exist_ok=True)
os.makedirs(os.path.join(topic_dir, poster_subdir), exist_ok=True)
logging.info(f"创建变体目录: {topic_dir}")
# 1. 为此主题生成所有变体的海报(一次性调用)
logging.info(f"为主题 {topic_index} 生成所有变体的海报...")
try:
posters_generated = generate_posters_for_topic(
topic_item=topic_item,
output_dir=config["output_dir"],
run_id=run_id,
topic_index=topic_index,
output_handler=output_handler,
model_name=config["model"],
base_url=config["api_url"],
api_key=config["api_key"],
variants=poster_variants, # 生成所有变体
title_possibility=title_possibility,
poster_assets_base_dir=poster_assets_dir,
image_base_dir=img_base_dir,
resource_dir_config=res_dir_config,
poster_target_size=poster_size,
text_possibility=txt_possibility,
img_frame_possibility=img_frame_possibility,
text_bg_possibility=text_bg_possibility,
output_collage_subdir=collage_subdir,
output_poster_subdir=poster_subdir,
output_poster_filename=poster_filename,
system_prompt=poster_content_system_prompt,
collage_style=collage_style,
timeout=request_timeout
)
if posters_generated:
logging.info(f"主题 {topic_index} 的所有海报生成完成")
else:
logging.warning(f"主题 {topic_index} 的海报生成失败或未返回成功标志")
continue
except Exception as e:
logging.exception(f"生成主题 {topic_index} 海报时出错: {e}")
continue
# 2. 为每个变体生成额外配图
for variant_index in range(1, poster_variants + 1):
variant_key = f"{topic_index}_{variant_index}"
topic_dir = os.path.join(run_dir, variant_key)
if not os.path.exists(topic_dir):
logging.warning(f"变体目录不存在: {topic_dir},跳过生成配图")
continue
# 生成额外配图
logging.info(f"为变体 {variant_key} 生成额外配图...")
# 获取海报元数据
poster_dir = os.path.join(topic_dir, poster_subdir)
if not os.path.exists(poster_dir):
logging.warning(f"变体 {variant_key} 海报目录不存在,跳过生成配图")
continue
try:
# 删除现有的额外配图文件
for file_name in os.listdir(poster_dir):
if file_name.startswith("additional_") and file_name.endswith(".jpg"):
file_path = os.path.join(poster_dir, file_name)
os.remove(file_path)
logging.debug(f"已删除额外配图: {file_path}")
# 查找元数据文件
metadata_files = [f for f in os.listdir(poster_dir) if f.endswith("_metadata.json")]
if not metadata_files:
logging.warning(f"变体 {variant_key} 未找到海报元数据文件,跳过生成配图")
continue
poster_metadata_path = os.path.join(poster_dir, metadata_files[0])
# 获取源图像目录
object_name = topic_item.get("object", "").split(".")[0].replace("景点信息-", "").strip()
source_image_dir = os.path.join(img_base_dir, object_name)
if not os.path.exists(source_image_dir) or not os.path.isdir(source_image_dir):
logging.warning(f"变体 {variant_key} 源图像目录不存在: {source_image_dir}")
continue
# 生成额外配图
additional_paths = select_additional_images(
run_id=run_id,
topic_index=topic_index,
variant_index=variant_index,
poster_metadata_path=poster_metadata_path,
source_image_dir=source_image_dir,
num_additional_images=additional_images_count,
output_handler=output_handler,
variation_strength=variation_strength,
extra_effects=extra_effects
)
if additional_paths:
logging.info(f"已为变体 {variant_key} 生成 {len(additional_paths)} 张额外配图")
else:
logging.warning(f"未能为变体 {variant_key} 生成任何额外配图")
except Exception as e:
logging.exception(f"生成变体 {variant_key} 额外配图时出错: {e}")
# 如果指定了target_topic但没有找到
if target_topic is not None and all(topic.get('index') != target_topic for topic in topics_list):
logging.warning(f"未找到指定的主题ID: {target_topic}")
def main():
parser = argparse.ArgumentParser(description="为所有主题重新生成海报和配图")
parser.add_argument("--run_dir", required=True, help="之前运行结果的目录")
parser.add_argument("--config", default="poster_gen_config.json", help="配置文件路径")
parser.add_argument("--debug", action="store_true", help="启用调试日志")
parser.add_argument("--topic", type=int, help="只处理指定主题索引")
parser.add_argument("--ignore-metadata", action="store_true", help="忽略现有的元数据,重新生成所有海报和图片")
parser.add_argument("--poster-config", help="指定海报配置文件路径,用于自定义海报内容")
args = parser.parse_args()
# 设置日志级别
log_level = logging.DEBUG if args.debug else logging.INFO
logging.basicConfig(
level=log_level,
format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# 加载配置
config = load_config(args.config)
if not config:
sys.exit(1)
# 检查目录是否存在
if not os.path.exists(args.run_dir):
logging.error(f"指定的运行目录不存在: {args.run_dir}")
sys.exit(1)
# 修改regenerate_posters_and_images函数添加新参数
if args.ignore_metadata:
logging.info("将忽略现有元数据,完全重新生成所有海报和配图")
# 如果指定了poster-config
if args.poster_config:
logging.info(f"使用自定义海报配置文件: {args.poster_config}")
# 检查文件是否存在
if not os.path.exists(args.poster_config):
logging.error(f"指定的海报配置文件不存在: {args.poster_config}")
sys.exit(1)
# 开始处理
if args.topic:
logging.info(f"只处理主题 {args.topic} 的海报和配图")
else:
logging.info(f"开始为 {args.run_dir} 重新生成海报和配图")
regenerate_posters_and_images(args.run_dir, config, args.topic, args.ignore_metadata, args.poster_config)
logging.info("处理完成")
if __name__ == "__main__":
main()