diff --git a/scripts/regenerate_missing.py b/scripts/regenerate_missing.py new file mode 100644 index 0000000..33369ae --- /dev/null +++ b/scripts/regenerate_missing.py @@ -0,0 +1,504 @@ +import os +import sys +import json +import argparse +import logging +import shutil +from typing import List, Dict, Any, Tuple, Optional, Set +# python scripts/regenerate_missing.py --run_dir /root/autodl-tmp/TravelContentCreator/result/2025-05-12_21-36-33 --config poster_gen_config.json +# 添加项目根目录到系统路径 +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# 导入所需的模块 +from utils.output_handler import FileSystemOutputHandler +from utils.poster_notes_creator import select_additional_images +from utils.tweet_generator import generate_posters_for_topic, generate_content_for_topic +from core.topic_parser import TopicParser +from core.ai_agent import AI_Agent +from utils.prompt_manager import PromptManager + +def load_config(config_path="poster_gen_config.json"): + """加载配置文件""" + if not os.path.exists(config_path): + logging.error(f"错误:配置文件 '{config_path}' 不存在") + return None + + try: + with open(config_path, 'r', encoding='utf-8') as f: + config = json.load(f) + return config + except Exception as e: + logging.error(f"加载配置文件时出错: {e}") + return None + +def check_content_status(run_dir: str, topic_indices: List[int], variants_count: int) -> Dict: + """ + 检查所有主题和变体的内容状态 + + Args: + run_dir: 运行目录路径 + topic_indices: 所有主题索引列表 + variants_count: 每个主题应有的变体数量 + + Returns: + Dict: 包含所有主题变体内容状态的字典 + """ + status = {} + + for topic_index in topic_indices: + status[topic_index] = {} + + for variant_index in range(1, variants_count + 1): + variant_key = f"{topic_index}_{variant_index}" + topic_dir = os.path.join(run_dir, variant_key) + + # 如果变体目录不存在,则标记为完全缺失 + if not os.path.exists(topic_dir): + status[topic_index][variant_index] = { + "exists": False, + "has_content": False, + "has_poster": False, + "has_additional": False + } + continue + + # 变体目录存在,检查内容 + content_file = os.path.join(topic_dir, "tweet_content.json") + poster_dir = os.path.join(topic_dir, "poster") + + # 检查海报图片和元数据 + has_poster_image = False + has_poster_metadata = False + has_additional_images = False + + if os.path.exists(poster_dir): + poster_files = [f for f in os.listdir(poster_dir) + if f.endswith(".jpg") and not f.startswith("additional_")] + has_poster_image = len(poster_files) > 0 + + metadata_files = [f for f in os.listdir(poster_dir) + if f.endswith("_metadata.json")] + has_poster_metadata = len(metadata_files) > 0 + + additional_files = [f for f in os.listdir(poster_dir) + if f.startswith("additional_") and f.endswith(".jpg")] + has_additional_images = len(additional_files) > 0 + + # 记录状态 + status[topic_index][variant_index] = { + "exists": True, + "has_content": os.path.exists(content_file), + "has_poster": has_poster_image and has_poster_metadata, + "has_additional": has_additional_images + } + + return status + +def regenerate_missing_content(run_dir: str, config: Dict): + """重新生成缺失的海报和配图,包括完全缺失的变体""" + run_id = os.path.basename(run_dir) + logging.info(f"处理运行ID: {run_id}") + + # 创建输出处理器 + output_handler = FileSystemOutputHandler(config.get("output_dir", "result")) + + # 加载主题文件 + topics_file = os.path.join(run_dir, f"tweet_topic_{run_id}.json") + if not os.path.exists(topics_file): + alternative_topics_file = os.path.join(run_dir, "tweet_topic.json") + if not os.path.exists(alternative_topics_file): + logging.error(f"主题文件不存在: {topics_file} 或 {alternative_topics_file}") + return + topics_file = alternative_topics_file + + logging.info(f"使用主题文件: {topics_file}") + topics_list = TopicParser.load_topics_from_json(topics_file) + if not topics_list: + logging.error("无法加载主题列表") + return + + # 获取所有主题索引和应有的变体数量 + topic_indices = [topic.get('index') for topic in topics_list if topic.get('index')] + variants_count = config.get("variants", 1) + + logging.info(f"找到 {len(topic_indices)} 个主题,每个主题应有 {variants_count} 个变体") + + # 检查所有主题和变体的内容状态 + content_status = check_content_status(run_dir, topic_indices, variants_count) + + # 统计需要处理的内容 + missing_variants = [] + incomplete_variants = [] + + for topic_index in content_status: + for variant_index in content_status[topic_index]: + status = content_status[topic_index][variant_index] + variant_key = f"{topic_index}_{variant_index}" + + if not status["exists"]: + missing_variants.append((topic_index, variant_index)) + logging.info(f"变体 {variant_key} 完全不存在,需要创建") + elif not (status["has_poster"] and status["has_additional"]): + incomplete_variants.append((topic_index, variant_index)) + logging.info(f"变体 {variant_key} 不完整:海报={status['has_poster']} 配图={status['has_additional']}") + + logging.info(f"需要创建 {len(missing_variants)} 个完全缺失的变体,补充 {len(incomplete_variants)} 个不完整的变体") + + # 如果需要创建内容,则初始化AI代理 + if missing_variants: + # 创建PromptManager实例 + try: + prompt_manager = PromptManager( + topic_system_prompt_path=config.get("topic_system_prompt"), + topic_user_prompt_path=config.get("topic_user_prompt"), + content_system_prompt_path=config.get("content_system_prompt"), + prompts_config=config.get("prompts_config"), + resource_dir_config=config.get("resource_dir", []), + topic_gen_num=config.get("num", 1), + topic_gen_date=config.get("date", ""), + content_judger_system_prompt_path=config.get("content_judger_system_prompt") + ) + logging.info("PromptManager实例创建成功") + except Exception as e: + logging.error(f"创建PromptManager实例失败: {e}") + return + + # 创建AI代理 + try: + request_timeout = config.get("request_timeout", 180) + max_retries = config.get("max_retries", 3) + stream_chunk_timeout = config.get("stream_chunk_timeout", 30) + + ai_agent = AI_Agent( + config.get("api_url"), + config.get("model"), + config.get("api_key"), + timeout=request_timeout, + max_retries=max_retries, + stream_chunk_timeout=stream_chunk_timeout, + ) + logging.info("AI代理创建成功") + except Exception as e: + logging.error(f"创建AI代理失败: {e}") + return + else: + ai_agent = None + prompt_manager = None + + # 读取配置参数 + poster_assets_dir = config.get("poster_assets_base_dir") + img_base_dir = config.get("image_base_dir") + res_dir_config = config.get("resource_dir", []) + poster_size = tuple(config.get("poster_target_size", [900, 1200])) + txt_possibility = config.get("text_possibility", 0.3) + img_frame_possibility = config.get("img_frame_possibility", 0.7) + text_bg_possibility = config.get("text_bg_possibility", 0) + collage_subdir = config.get("output_collage_subdir", "collage_img") + poster_subdir = config.get("output_poster_subdir", "poster") + poster_filename = config.get("output_poster_filename", "poster.jpg") + poster_content_system_prompt_path = config.get("poster_content_system_prompt") + collage_style = config.get("collage_style") + title_possibility = config.get("title_possibility", 0.3) + request_timeout = config.get("request_timeout", 180) + additional_images_count = config.get("additional_images_count", 3) + + # 获取图像选择配置参数 + image_selection_config = config.get("image_selection", {}) + variation_strength = image_selection_config.get("variation_strength", "medium") + extra_effects = image_selection_config.get("extra_effects", True) + + # 检查关键配置 + if not poster_assets_dir or not img_base_dir or not poster_content_system_prompt_path: + logging.error("缺少关键配置参数") + return + + # 读取海报内容系统提示词 + with open(poster_content_system_prompt_path, "r", encoding="utf-8") as f: + poster_content_system_prompt = f.read() + + # 构建主题索引到主题项的映射 + topic_map = {topic.get('index'): topic for topic in topics_list if topic.get('index')} + + # 1. 处理完全缺失的变体 + if missing_variants and ai_agent and prompt_manager: + for topic_index, variant_index in missing_variants: + variant_key = f"{topic_index}_{variant_index}" + logging.info(f"创建完全缺失的变体: {variant_key}") + + topic_item = topic_map.get(topic_index) + if not topic_item: + logging.error(f"找不到主题 {topic_index} 的数据,跳过") + continue + + # 1.1 复制已有变体的内容文件(如果存在) + content_copied = False + for existing_variant in range(1, variants_count + 1): + if existing_variant == variant_index: + continue + + existing_dir = os.path.join(run_dir, f"{topic_index}_{existing_variant}") + existing_content = os.path.join(existing_dir, "tweet_content.json") + + if os.path.exists(existing_content): + # 创建变体目录 + new_variant_dir = os.path.join(run_dir, variant_key) + os.makedirs(new_variant_dir, exist_ok=True) + + # 复制内容文件 + new_content_file = os.path.join(new_variant_dir, "tweet_content.json") + shutil.copy2(existing_content, new_content_file) + logging.info(f"已从变体 {topic_index}_{existing_variant} 复制内容文件到 {variant_key}") + content_copied = True + break + + # 1.2 如果没有可复制的内容,则生成新内容 + if not content_copied: + logging.info(f"为变体 {variant_key} 生成新内容...") + + # 只为当前变体生成内容 + content_success = generate_content_for_topic( + ai_agent, + prompt_manager, + topic_item, + run_id, + topic_index, + output_handler, + variants=1, + variant_start_index=variant_index, + temperature=config.get("content_temperature", 0.3), + top_p=config.get("content_top_p", 0.4), + presence_penalty=config.get("content_presence_penalty", 1.5), + enable_content_judge=config.get("enable_content_judge", False) + ) + + if not content_success: + logging.error(f"为变体 {variant_key} 生成内容失败,跳过后续处理") + continue + + logging.info(f"已为变体 {variant_key} 生成内容") + + # 1.3 生成海报 + logging.info(f"为变体 {variant_key} 生成海报...") + + try: + posters_attempted = generate_posters_for_topic( + topic_item=topic_item, + output_dir=config["output_dir"], + run_id=run_id, + topic_index=topic_index, + output_handler=output_handler, + model_name=config["model"], + base_url=config["api_url"], + api_key=config["api_key"], + variants=1, # 只生成一个变体 + variant_start_index=variant_index, # 从指定变体索引开始 + title_possibility=title_possibility, + poster_assets_base_dir=poster_assets_dir, + image_base_dir=img_base_dir, + resource_dir_config=res_dir_config, + poster_target_size=poster_size, + text_possibility=txt_possibility, + img_frame_possibility=img_frame_possibility, + text_bg_possibility=text_bg_possibility, + output_collage_subdir=collage_subdir, + output_poster_subdir=poster_subdir, + output_poster_filename=poster_filename, + system_prompt=poster_content_system_prompt, + collage_style=collage_style, + timeout=request_timeout + ) + + if posters_attempted: + logging.info(f"变体 {variant_key} 海报生成完成") + else: + logging.warning(f"变体 {variant_key} 海报生成失败") + continue + + except Exception as e: + logging.exception(f"生成海报时出错: {e}") + continue + + # 1.4 生成额外配图 + topic_dir = os.path.join(run_dir, variant_key) + poster_dir = os.path.join(topic_dir, poster_subdir) + + if not os.path.exists(poster_dir): + logging.warning(f"海报目录不存在: {poster_dir},无法生成配图") + continue + + try: + metadata_files = [f for f in os.listdir(poster_dir) + if f.endswith("_metadata.json") and os.path.isfile(os.path.join(poster_dir, f))] + + if metadata_files: + poster_metadata_path = os.path.join(poster_dir, metadata_files[0]) + logging.info(f"为变体 {variant_key} 生成额外配图,使用元数据: {poster_metadata_path}") + + try: + object_name = topic_item.get("object", "").split(".")[0].replace("景点信息-", "").strip() + source_image_dir = os.path.join(img_base_dir, object_name) + + if os.path.exists(source_image_dir) and os.path.isdir(source_image_dir): + additional_paths = select_additional_images( + run_id=run_id, + topic_index=topic_index, + variant_index=variant_index, + poster_metadata_path=poster_metadata_path, + source_image_dir=source_image_dir, + num_additional_images=additional_images_count, + output_handler=output_handler, + variation_strength=variation_strength, + extra_effects=extra_effects + ) + + if additional_paths: + logging.info(f"已为变体 {variant_key} 生成 {len(additional_paths)} 张额外配图") + else: + logging.warning(f"未能为变体 {variant_key} 生成任何额外配图") + else: + logging.warning(f"源图像目录不存在: {source_image_dir}") + except Exception as e: + logging.exception(f"生成额外配图时出错: {e}") + else: + logging.warning(f"未找到海报元数据文件,无法生成额外配图") + + except Exception as e: + logging.warning(f"访问海报目录时出错: {e}") + + # 2. 处理已存在但不完整的变体 + for topic_index, variant_index in incomplete_variants: + variant_key = f"{topic_index}_{variant_index}" + topic_dir = os.path.join(run_dir, variant_key) + topic_item = topic_map.get(topic_index) + + if not topic_item: + logging.error(f"找不到主题 {topic_index} 的数据,跳过") + continue + + status = content_status[topic_index][variant_index] + + # 2.1 如果缺少海报,生成海报 + if not status["has_poster"]: + logging.info(f"为变体 {variant_key} 补充生成海报...") + + try: + posters_attempted = generate_posters_for_topic( + topic_item=topic_item, + output_dir=config["output_dir"], + run_id=run_id, + topic_index=topic_index, + output_handler=output_handler, + model_name=config["model"], + base_url=config["api_url"], + api_key=config["api_key"], + variants=1, + variant_start_index=variant_index, + title_possibility=title_possibility, + poster_assets_base_dir=poster_assets_dir, + image_base_dir=img_base_dir, + resource_dir_config=res_dir_config, + poster_target_size=poster_size, + text_possibility=txt_possibility, + img_frame_possibility=img_frame_possibility, + text_bg_possibility=text_bg_possibility, + output_collage_subdir=collage_subdir, + output_poster_subdir=poster_subdir, + output_poster_filename=poster_filename, + system_prompt=poster_content_system_prompt, + collage_style=collage_style, + timeout=request_timeout + ) + + if posters_attempted: + logging.info(f"变体 {variant_key} 海报生成完成") + # 更新状态 + status["has_poster"] = True + else: + logging.warning(f"变体 {variant_key} 海报生成失败") + + except Exception as e: + logging.exception(f"生成海报时出错: {e}") + + # 2.2 如果有海报但缺少额外配图,生成配图 + if status["has_poster"] and not status["has_additional"]: + poster_dir = os.path.join(topic_dir, poster_subdir) + if not os.path.exists(poster_dir): + logging.warning(f"海报目录不存在: {poster_dir},无法生成配图") + continue + + try: + metadata_files = [f for f in os.listdir(poster_dir) + if f.endswith("_metadata.json") and os.path.isfile(os.path.join(poster_dir, f))] + + if metadata_files: + poster_metadata_path = os.path.join(poster_dir, metadata_files[0]) + logging.info(f"为变体 {variant_key} 补充生成额外配图,使用元数据: {poster_metadata_path}") + + try: + object_name = topic_item.get("object", "").split(".")[0].replace("景点信息-", "").strip() + source_image_dir = os.path.join(img_base_dir, object_name) + + if os.path.exists(source_image_dir) and os.path.isdir(source_image_dir): + additional_paths = select_additional_images( + run_id=run_id, + topic_index=topic_index, + variant_index=variant_index, + poster_metadata_path=poster_metadata_path, + source_image_dir=source_image_dir, + num_additional_images=additional_images_count, + output_handler=output_handler, + variation_strength=variation_strength, + extra_effects=extra_effects + ) + + if additional_paths: + logging.info(f"已为变体 {variant_key} 生成 {len(additional_paths)} 张额外配图") + else: + logging.warning(f"未能为变体 {variant_key} 生成任何额外配图") + else: + logging.warning(f"源图像目录不存在: {source_image_dir}") + except Exception as e: + logging.exception(f"生成额外配图时出错: {e}") + else: + logging.warning(f"未找到海报元数据文件,无法生成额外配图") + + except Exception as e: + logging.warning(f"访问海报目录时出错: {e}") + + # 关闭AI代理 + if ai_agent: + ai_agent.close() + logging.info("AI代理已关闭") + +def main(): + parser = argparse.ArgumentParser(description="补充生成丢失的海报和配图") + parser.add_argument("--run_dir", required=True, help="之前运行结果的目录") + parser.add_argument("--config", default="poster_gen_config.json", help="配置文件路径") + parser.add_argument("--debug", action="store_true", help="启用调试日志") + args = parser.parse_args() + + # 设置日志级别 + log_level = logging.DEBUG if args.debug else logging.INFO + logging.basicConfig( + level=log_level, + format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + + # 加载配置 + config = load_config(args.config) + if not config: + sys.exit(1) + + # 检查目录是否存在 + if not os.path.exists(args.run_dir): + logging.error(f"指定的运行目录不存在: {args.run_dir}") + sys.exit(1) + + # 开始处理 + logging.info(f"开始为 {args.run_dir} 补充生成内容") + regenerate_missing_content(args.run_dir, config) + logging.info("处理完成") + +if __name__ == "__main__": + main() \ No newline at end of file