314 lines
14 KiB
Python
314 lines
14 KiB
Python
import os
|
||
import sys
|
||
import json
|
||
import argparse
|
||
import logging
|
||
from typing import List, Dict
|
||
|
||
# 添加项目根目录到系统路径
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||
# 导入所需的模块
|
||
from utils.output_handler import FileSystemOutputHandler
|
||
from utils.poster_notes_creator import select_additional_images
|
||
from utils.tweet_generator import generate_posters_for_topic
|
||
from core.topic_parser import TopicParser
|
||
|
||
def load_config(config_path="poster_gen_config.json"):
|
||
"""加载配置文件"""
|
||
if not os.path.exists(config_path):
|
||
logging.error(f"错误:配置文件 '{config_path}' 不存在")
|
||
return None
|
||
|
||
try:
|
||
with open(config_path, 'r', encoding='utf-8') as f:
|
||
config = json.load(f)
|
||
return config
|
||
except Exception as e:
|
||
logging.error(f"加载配置文件时出错: {e}")
|
||
return None
|
||
|
||
def regenerate_posters_and_images(run_dir: str, config: Dict, target_topic: int = None,
|
||
ignore_metadata: bool = False, poster_config_path: str = None):
|
||
"""为所有主题重新生成海报和配图
|
||
|
||
Args:
|
||
run_dir: 运行目录路径
|
||
config: 配置字典
|
||
target_topic: 只处理特定主题ID
|
||
ignore_metadata: 是否忽略现有的元数据
|
||
poster_config_path: 海报配置文件路径
|
||
"""
|
||
run_id = os.path.basename(run_dir)
|
||
logging.info(f"处理运行ID: {run_id}")
|
||
|
||
# 创建输出处理器
|
||
output_handler = FileSystemOutputHandler(config.get("output_dir", "result"))
|
||
|
||
# 加载主题文件
|
||
topics_file = os.path.join(run_dir, f"tweet_topic_{run_id}.json")
|
||
if not os.path.exists(topics_file):
|
||
alternative_topics_file = os.path.join(run_dir, "tweet_topic.json")
|
||
if not os.path.exists(alternative_topics_file):
|
||
logging.error(f"主题文件不存在: {topics_file} 或 {alternative_topics_file}")
|
||
return
|
||
topics_file = alternative_topics_file
|
||
|
||
logging.info(f"使用主题文件: {topics_file}")
|
||
topics_list = TopicParser.load_topics_from_json(topics_file)
|
||
if not topics_list:
|
||
logging.error("无法加载主题列表")
|
||
return
|
||
|
||
# 如果指定了poster_config_path,加载自定义海报配置
|
||
poster_configs = None
|
||
if poster_config_path:
|
||
try:
|
||
with open(poster_config_path, 'r', encoding='utf-8') as f:
|
||
poster_configs = json.load(f)
|
||
logging.info(f"已加载自定义海报配置,包含 {len(poster_configs)} 个配置项")
|
||
except Exception as e:
|
||
logging.error(f"加载海报配置文件时出错: {e}")
|
||
return
|
||
|
||
# 加载配置参数
|
||
poster_variants = config.get("variants", 1)
|
||
poster_assets_dir = config.get("poster_assets_base_dir")
|
||
img_base_dir = config.get("image_base_dir")
|
||
res_dir_config = config.get("resource_dir", [])
|
||
poster_size = tuple(config.get("poster_target_size", [900, 1200]))
|
||
txt_possibility = config.get("text_possibility", 0.3)
|
||
img_frame_possibility = config.get("img_frame_possibility", 0.7)
|
||
text_bg_possibility = config.get("text_bg_possibility", 0)
|
||
collage_subdir = config.get("output_collage_subdir", "collage_img")
|
||
poster_subdir = config.get("output_poster_subdir", "poster")
|
||
poster_filename = config.get("output_poster_filename", "poster.jpg")
|
||
poster_content_system_prompt_path = config.get("poster_content_system_prompt")
|
||
collage_style = config.get("collage_style")
|
||
title_possibility = config.get("title_possibility", 0.3)
|
||
request_timeout = config.get("request_timeout", 180)
|
||
additional_images_count = config.get("additional_images_count", 3)
|
||
|
||
# 获取图像处理参数
|
||
image_selection_config = config.get("image_selection", {})
|
||
variation_strength = image_selection_config.get("variation_strength", "medium")
|
||
extra_effects = image_selection_config.get("extra_effects", True)
|
||
|
||
# 检查关键配置
|
||
if not poster_assets_dir or not img_base_dir or not poster_content_system_prompt_path:
|
||
logging.error("缺少关键配置参数")
|
||
return
|
||
|
||
# 读取海报内容系统提示词
|
||
with open(poster_content_system_prompt_path, "r", encoding="utf-8") as f:
|
||
poster_content_system_prompt = f.read()
|
||
|
||
# 处理每个主题
|
||
for topic_item in topics_list:
|
||
topic_index = topic_item.get('index')
|
||
if not topic_index:
|
||
logging.warning(f"主题缺少索引,跳过: {topic_item}")
|
||
continue
|
||
|
||
# 如果指定了target_topic且不匹配当前主题,则跳过
|
||
if target_topic is not None and topic_index != target_topic:
|
||
continue
|
||
|
||
# 检查是否有自定义海报配置
|
||
custom_config = None
|
||
if poster_configs:
|
||
for config_item in poster_configs:
|
||
if config_item.get('index') == topic_index:
|
||
custom_config = config_item
|
||
break
|
||
|
||
if custom_config:
|
||
# 更新主题标题
|
||
if 'main_title' in custom_config:
|
||
topic_item['title'] = custom_config['main_title']
|
||
logging.info(f"使用自定义标题: {custom_config['main_title']}")
|
||
|
||
logging.info(f"处理主题 {topic_index}: {topic_item.get('object', 'N/A')}")
|
||
|
||
# 如果忽略元数据,先清空所有现有的海报和配图文件
|
||
if ignore_metadata:
|
||
for variant_index in range(1, poster_variants + 1):
|
||
variant_key = f"{topic_index}_{variant_index}"
|
||
topic_dir = os.path.join(run_dir, variant_key)
|
||
|
||
if os.path.exists(topic_dir):
|
||
poster_dir = os.path.join(topic_dir, poster_subdir)
|
||
if os.path.exists(poster_dir):
|
||
logging.info(f"清空变体 {variant_key} 的海报目录")
|
||
for file_name in os.listdir(poster_dir):
|
||
file_path = os.path.join(poster_dir, file_name)
|
||
if os.path.isfile(file_path):
|
||
os.remove(file_path)
|
||
logging.debug(f"已删除文件: {file_path}")
|
||
else:
|
||
os.makedirs(poster_dir, exist_ok=True)
|
||
else:
|
||
os.makedirs(topic_dir, exist_ok=True)
|
||
os.makedirs(os.path.join(topic_dir, poster_subdir), exist_ok=True)
|
||
logging.info(f"创建变体目录: {topic_dir}")
|
||
|
||
# 1. 为此主题生成所有变体的海报(一次性调用)
|
||
logging.info(f"为主题 {topic_index} 生成所有变体的海报...")
|
||
|
||
try:
|
||
posters_generated = generate_posters_for_topic(
|
||
topic_item=topic_item,
|
||
output_dir=config["output_dir"],
|
||
run_id=run_id,
|
||
topic_index=topic_index,
|
||
output_handler=output_handler,
|
||
model_name=config["model"],
|
||
base_url=config["api_url"],
|
||
api_key=config["api_key"],
|
||
variants=poster_variants, # 生成所有变体
|
||
title_possibility=title_possibility,
|
||
poster_assets_base_dir=poster_assets_dir,
|
||
image_base_dir=img_base_dir,
|
||
resource_dir_config=res_dir_config,
|
||
poster_target_size=poster_size,
|
||
text_possibility=txt_possibility,
|
||
img_frame_possibility=img_frame_possibility,
|
||
text_bg_possibility=text_bg_possibility,
|
||
output_collage_subdir=collage_subdir,
|
||
output_poster_subdir=poster_subdir,
|
||
output_poster_filename=poster_filename,
|
||
system_prompt=poster_content_system_prompt,
|
||
collage_style=collage_style,
|
||
timeout=request_timeout
|
||
)
|
||
|
||
if posters_generated:
|
||
logging.info(f"主题 {topic_index} 的所有海报生成完成")
|
||
else:
|
||
logging.warning(f"主题 {topic_index} 的海报生成失败或未返回成功标志")
|
||
continue
|
||
|
||
except Exception as e:
|
||
logging.exception(f"生成主题 {topic_index} 海报时出错: {e}")
|
||
continue
|
||
|
||
# 2. 为每个变体生成额外配图
|
||
for variant_index in range(1, poster_variants + 1):
|
||
variant_key = f"{topic_index}_{variant_index}"
|
||
topic_dir = os.path.join(run_dir, variant_key)
|
||
|
||
if not os.path.exists(topic_dir):
|
||
logging.warning(f"变体目录不存在: {topic_dir},跳过生成配图")
|
||
continue
|
||
|
||
# 生成额外配图
|
||
logging.info(f"为变体 {variant_key} 生成额外配图...")
|
||
|
||
# 获取海报元数据
|
||
poster_dir = os.path.join(topic_dir, poster_subdir)
|
||
|
||
if not os.path.exists(poster_dir):
|
||
logging.warning(f"变体 {variant_key} 海报目录不存在,跳过生成配图")
|
||
continue
|
||
|
||
try:
|
||
# 删除现有的额外配图文件
|
||
for file_name in os.listdir(poster_dir):
|
||
if file_name.startswith("additional_") and file_name.endswith(".jpg"):
|
||
file_path = os.path.join(poster_dir, file_name)
|
||
os.remove(file_path)
|
||
logging.debug(f"已删除额外配图: {file_path}")
|
||
|
||
# 查找元数据文件
|
||
metadata_files = [f for f in os.listdir(poster_dir) if f.endswith("_metadata.json")]
|
||
|
||
if not metadata_files:
|
||
logging.warning(f"变体 {variant_key} 未找到海报元数据文件,跳过生成配图")
|
||
continue
|
||
|
||
poster_metadata_path = os.path.join(poster_dir, metadata_files[0])
|
||
|
||
# 获取源图像目录
|
||
object_name = topic_item.get("object", "").split(".")[0].replace("景点信息-", "").strip()
|
||
source_image_dir = os.path.join(img_base_dir, object_name)
|
||
|
||
if not os.path.exists(source_image_dir) or not os.path.isdir(source_image_dir):
|
||
logging.warning(f"变体 {variant_key} 源图像目录不存在: {source_image_dir}")
|
||
continue
|
||
|
||
# 生成额外配图
|
||
additional_paths = select_additional_images(
|
||
run_id=run_id,
|
||
topic_index=topic_index,
|
||
variant_index=variant_index,
|
||
poster_metadata_path=poster_metadata_path,
|
||
source_image_dir=source_image_dir,
|
||
num_additional_images=additional_images_count,
|
||
output_handler=output_handler,
|
||
variation_strength=variation_strength,
|
||
extra_effects=extra_effects
|
||
)
|
||
|
||
if additional_paths:
|
||
logging.info(f"已为变体 {variant_key} 生成 {len(additional_paths)} 张额外配图")
|
||
else:
|
||
logging.warning(f"未能为变体 {variant_key} 生成任何额外配图")
|
||
|
||
except Exception as e:
|
||
logging.exception(f"生成变体 {variant_key} 额外配图时出错: {e}")
|
||
|
||
# 如果指定了target_topic但没有找到
|
||
if target_topic is not None and all(topic.get('index') != target_topic for topic in topics_list):
|
||
logging.warning(f"未找到指定的主题ID: {target_topic}")
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="为所有主题重新生成海报和配图")
|
||
parser.add_argument("--run_dir", required=True, help="之前运行结果的目录")
|
||
parser.add_argument("--config", default="poster_gen_config.json", help="配置文件路径")
|
||
parser.add_argument("--debug", action="store_true", help="启用调试日志")
|
||
parser.add_argument("--topic", type=int, help="只处理指定主题索引")
|
||
parser.add_argument("--ignore-metadata", action="store_true", help="忽略现有的元数据,重新生成所有海报和图片")
|
||
parser.add_argument("--poster-config", help="指定海报配置文件路径,用于自定义海报内容")
|
||
args = parser.parse_args()
|
||
|
||
# 设置日志级别
|
||
log_level = logging.DEBUG if args.debug else logging.INFO
|
||
logging.basicConfig(
|
||
level=log_level,
|
||
format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
|
||
datefmt='%Y-%m-%d %H:%M:%S'
|
||
)
|
||
|
||
# 加载配置
|
||
config = load_config(args.config)
|
||
if not config:
|
||
sys.exit(1)
|
||
|
||
# 检查目录是否存在
|
||
if not os.path.exists(args.run_dir):
|
||
logging.error(f"指定的运行目录不存在: {args.run_dir}")
|
||
sys.exit(1)
|
||
|
||
# 修改regenerate_posters_and_images函数,添加新参数
|
||
if args.ignore_metadata:
|
||
logging.info("将忽略现有元数据,完全重新生成所有海报和配图")
|
||
|
||
# 如果指定了poster-config
|
||
if args.poster_config:
|
||
logging.info(f"使用自定义海报配置文件: {args.poster_config}")
|
||
# 检查文件是否存在
|
||
if not os.path.exists(args.poster_config):
|
||
logging.error(f"指定的海报配置文件不存在: {args.poster_config}")
|
||
sys.exit(1)
|
||
|
||
# 开始处理
|
||
if args.topic:
|
||
logging.info(f"只处理主题 {args.topic} 的海报和配图")
|
||
else:
|
||
logging.info(f"开始为 {args.run_dir} 重新生成海报和配图")
|
||
|
||
regenerate_posters_and_images(args.run_dir, config, args.topic, args.ignore_metadata, args.poster_config)
|
||
logging.info("处理完成")
|
||
|
||
if __name__ == "__main__":
|
||
main() |