#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 测试海报生成的脚本 """ import os import time import argparse import json import logging import sys import traceback from datetime import datetime from core.ai_agent import AI_Agent from utils.tweet_generator import generate_posters_for_topic from utils.output_handler import FileSystemOutputHandler from core.topic_parser import TopicParser def load_config(config_path="poster_config.json"): """从JSON文件加载配置""" if not os.path.exists(config_path): print(f"错误:配置文件 '{config_path}' 未找到。") sys.exit(1) try: with open(config_path, 'r', encoding='utf-8') as f: config = json.load(f) # 基本验证 required_keys = ["api_url", "model", "api_key", "resource_dir", "output_dir", "image_base_dir", "poster_assets_base_dir", "poster_content_system_prompt"] if not all(key in config for key in required_keys): missing_keys = [key for key in required_keys if key not in config] print(f"错误:配置文件 '{config_path}' 缺少必需的键:{missing_keys}") sys.exit(1) return config except json.JSONDecodeError: print(f"错误:无法从 '{config_path}' 解码JSON。请检查文件格式。") sys.exit(1) except Exception as e: print(f"从 '{config_path}' 加载配置时出错:{e}") sys.exit(1) def main(): # 设置日志记录 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) # 解析命令行参数 parser = argparse.ArgumentParser(description="测试海报生成") parser.add_argument( "--config", type=str, default="/root/autodl-tmp/TravelContentCreator/poster_gen_config.json", help="配置文件路径(例如,poster_config.json)" ) parser.add_argument( "--topics_file", type=str, required=True, default="/root/autodl-tmp/TravelContentCreator/result/2025-04-25_17-27-03/tweet_topic_2025-04-25_17-27-03.json", help="必需的选题JSON文件路径,用于获取海报生成的主题数据。" ) parser.add_argument( "--topic_index", type=int, default=1, help="要生成海报的特定选题索引。如果未提供,将为所有选题生成海报。" ) parser.add_argument( "--run_id", type=str, default=None, help="可选的指定运行ID。如果未提供,将生成基于时间戳的ID。" ) parser.add_argument( "--debug", action='store_true', help="启用调试级别日志记录。" ) args = parser.parse_args() # 调整日志级别(如果启用了调试) if args.debug: logging.getLogger().setLevel(logging.DEBUG) logging.info("已启用调试日志记录。") logging.info("启动海报生成测试脚本...") logging.info(f"使用配置文件:{args.config}") logging.info(f"使用选题文件:{args.topics_file}") if args.topic_index is not None: logging.info(f"将仅处理选题索引:{args.topic_index}") if args.run_id: logging.info(f"使用指定的run_id:{args.run_id}") # 加载配置 config = load_config(args.config) if config is None: logging.critical("无法加载配置。退出。") sys.exit(1) # 初始化输出处理器 output_handler = FileSystemOutputHandler(config.get("output_dir", "result")) logging.info(f"使用输出处理器:{output_handler.__class__.__name__}") # 加载选题数据 logging.info(f"从以下位置加载选题:{args.topics_file}") topics_list = TopicParser.load_topics_from_json(args.topics_file) if not topics_list: logging.error(f"无法从{args.topics_file}加载选题。无法继续。") sys.exit(1) logging.info(f"成功加载{len(topics_list)}个选题。") # 设置run_id run_id = args.run_id if run_id is None: # 尝试从文件名推断run_id try: base = os.path.basename(args.topics_file) if base.startswith("tweet_topic_") and base.endswith(".json"): run_id = base[len("tweet_topic_"): -len(".json")] logging.info(f"从选题文件名推断的run_id:{run_id}") else: run_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S_poster") logging.info(f"为海报生成的run_id:{run_id}") except Exception as e: run_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S_poster") logging.info(f"生成的run_id:{run_id}") # 加载海报内容系统提示词 poster_content_system_prompt_path = config.get("poster_content_system_prompt") if not os.path.exists(poster_content_system_prompt_path): logging.error(f"海报内容系统提示词文件不存在:{poster_content_system_prompt_path}") sys.exit(1) with open(poster_content_system_prompt_path, "r", encoding="utf-8") as f: poster_content_system_prompt = f.read() # 准备海报生成参数 poster_variants = config.get("variants", 1) poster_assets_dir = config.get("poster_assets_base_dir") img_base_dir = config.get("image_base_dir") res_dir_config = config.get("resource_dir", []) poster_size = tuple(config.get("poster_target_size", [900, 1200])) txt_possibility = config.get("text_possibility", 0.3) img_frame_possibility = config.get("img_frame_possibility", 0.7) text_bg_possibility = config.get("text_bg_possibility", 0) collage_subdir = config.get("output_collage_subdir", "collage_img") poster_subdir = config.get("output_poster_subdir", "poster") poster_filename = config.get("output_poster_filename", "poster.jpg") title_possibility = config.get("title_possibility", 0.3) collage_style = config.get("collage_style", None) timeout = config.get("request_timeout", 180) # 检查关键路径 if not poster_assets_dir or not img_base_dir: logging.error("配置中缺少关键路径(poster_assets_base_dir或image_base_dir)。无法继续。") sys.exit(1) # 开始海报生成 pipeline_start_time = time.time() logging.info("开始执行海报生成...") poster_success = False # 如果指定了topic_index,只处理该选题 if args.topic_index is not None: topics_to_process = [] for topic in topics_list: if topic.get('index') == args.topic_index or (topic.get('index') is None and int(args.topic_index) == 1): topics_to_process.append(topic) break if not topics_to_process: logging.error(f"未找到索引为{args.topic_index}的选题。") sys.exit(1) else: topics_to_process = topics_list # 逐个处理选题 for i, topic_item in enumerate(topics_to_process): topic_index = topic_item.get('index', i + 1) logging.info(f"--- 处理选题 {topic_index}: {topic_item.get('object', 'N/A')} ---") try: posters_attempted = generate_posters_for_topic( topic_item=topic_item, output_dir=config["output_dir"], run_id=run_id, topic_index=topic_index, output_handler=output_handler, variants=poster_variants, poster_assets_base_dir=poster_assets_dir, image_base_dir=img_base_dir, resource_dir_config=res_dir_config, poster_target_size=poster_size, text_possibility=txt_possibility, img_frame_possibility=img_frame_possibility, text_bg_possibility=text_bg_possibility, output_collage_subdir=collage_subdir, output_poster_subdir=poster_subdir, output_poster_filename=poster_filename, system_prompt=poster_content_system_prompt, collage_style=collage_style, title_possibility=title_possibility, timeout=timeout ) if posters_attempted: logging.info(f"选题{topic_index}的海报生成过程已完成。") poster_success = True else: logging.warning(f"选题{topic_index}的海报生成被跳过或在早期失败。") except Exception as e: logging.exception(f"处理选题{topic_index}的海报生成时出错:") logging.error(f"错误信息: {e}") logging.info(f"--- 完成选题 {topic_index} ---") # 最终化输出 if run_id: output_handler.finalize(run_id) pipeline_end_time = time.time() if poster_success: logging.info(f"海报生成完成,耗时{pipeline_end_time - pipeline_start_time:.2f}秒。") else: logging.warning("海报生成完成,但可能遇到错误或未生成输出。") logging.info(f"运行ID'{run_id}'的结果位于:{os.path.join(config.get('output_dir', 'result'), run_id)}") if __name__ == "__main__": main()