#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 测试选题生成和文章生成的脚本 """ import os import time import argparse import json import logging import sys from datetime import datetime from core.ai_agent import AI_Agent from utils.prompt_manager import PromptManager from utils.tweet_generator import run_topic_generation_pipeline, generate_content_for_topic from utils.output_handler import FileSystemOutputHandler def load_config(config_path="topic_content_config.json"): """从JSON文件加载配置""" if not os.path.exists(config_path): print(f"错误:配置文件 '{config_path}' 未找到。") sys.exit(1) try: with open(config_path, 'r', encoding='utf-8') as f: config = json.load(f) # 基本验证 required_keys = ["api_url", "model", "api_key", "resource_dir", "output_dir", "num", "variants", "topic_system_prompt", "topic_user_prompt", "content_system_prompt"] if not all(key in config for key in required_keys): missing_keys = [key for key in required_keys if key not in config] print(f"错误:配置文件 '{config_path}' 缺少必需的键:{missing_keys}") sys.exit(1) # 验证prompts_dir或prompts_config至少有一个存在 if not ("prompts_dir" in config or "prompts_config" in config): print(f"错误:配置文件 '{config_path}' 必须包含 'prompts_dir' 或 'prompts_config'") sys.exit(1) return config except json.JSONDecodeError: print(f"错误:无法从 '{config_path}' 解码JSON。请检查文件格式。") sys.exit(1) except Exception as e: print(f"从 '{config_path}' 加载配置时出错:{e}") sys.exit(1) def main(): # 设置日志记录 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) # 解析命令行参数 parser = argparse.ArgumentParser(description="测试选题和文章生成") parser.add_argument( "--config", type=str, default="topic_content_config.json", help="配置文件路径(例如,topic_content_config.json)" ) parser.add_argument( "--run_id", type=str, default=None, help="可选的指定运行ID(例如,'test_run_01')。如果未提供,将生成基于时间戳的ID。" ) parser.add_argument( "--topics_file", type=str, default=None, help="可选的预生成选题JSON文件路径。如果提供,则跳过选题生成。" ) parser.add_argument( "--debug", action='store_true', help="启用调试级别日志记录。" ) args = parser.parse_args() # 调整日志级别(如果启用了调试) if args.debug: logging.getLogger().setLevel(logging.DEBUG) logging.info("已启用调试日志记录。") logging.info("启动选题和文章生成测试脚本...") logging.info(f"使用配置文件:{args.config}") if args.run_id: logging.info(f"使用指定的run_id:{args.run_id}") if args.topics_file: logging.info(f"使用现有选题文件:{args.topics_file}") # 加载配置 config = load_config(args.config) if config is None: logging.critical("无法加载配置。退出。") sys.exit(1) # 初始化输出处理器 output_handler = FileSystemOutputHandler(config.get("output_dir", "result")) logging.info(f"使用输出处理器:{output_handler.__class__.__name__}") run_id = args.run_id topics_list = None system_prompt = None user_prompt = None pipeline_start_time = time.time() # 步骤1:选题生成(或加载现有选题) if args.topics_file: from core.topic_parser import TopicParser logging.info(f"跳过选题生成(步骤1)- 从以下位置加载选题:{args.topics_file}") topics_list = TopicParser.load_topics_from_json(args.topics_file) if topics_list: # 如果未提供run_id,尝试从文件名推断 if not run_id: try: base = os.path.basename(args.topics_file) # 假设格式为"tweet_topic_{run_id}.json"或"tweet_topic.json" if base.startswith("tweet_topic_") and base.endswith(".json"): run_id = base[len("tweet_topic_"): -len(".json")] logging.info(f"从选题文件名推断的run_id:{run_id}") elif base == "tweet_topic.json": logging.warning(f"从默认文件名'{base}'加载选题。未推断run_id。") else: logging.warning(f"无法从选题文件名推断run_id:{base}") except Exception as e: logging.warning(f"尝试从选题文件名推断run_id时出错:{e}") # 如果尝试推断后run_id仍为None,则生成一个 if run_id is None: run_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S_loaded") logging.info(f"为加载的选题生成的run_id:{run_id}") # 加载文件时缺少提示词 system_prompt = "" user_prompt = "" logging.info(f"成功加载{len(topics_list)}个选题,run_id:{run_id}。提示词不可用。") else: logging.error(f"无法从{args.topics_file}加载选题。无法继续。") sys.exit(1) else: logging.info("执行选题生成(步骤1)...") step1_start = time.time() # 调用更新后的函数,接收原始数据 run_id, topics_list, system_prompt, user_prompt = run_topic_generation_pipeline(config, args.run_id) step1_end = time.time() if run_id is not None and topics_list is not None: logging.info(f"步骤1成功完成,耗时{step1_end - step1_start:.2f}秒。运行ID:{run_id}") # 使用输出处理器保存结果 output_handler.handle_topic_results(run_id, topics_list, system_prompt, user_prompt) else: logging.critical("选题生成(步骤1)失败。退出。") sys.exit(1) # 步骤2:内容生成 if run_id is not None and topics_list is not None: logging.info("执行内容生成(步骤2)...") step2_start = time.time() # 创建PromptManager实例 try: prompt_manager = PromptManager( topic_system_prompt_path=config.get("topic_system_prompt"), topic_user_prompt_path=config.get("topic_user_prompt"), content_system_prompt_path=config.get("content_system_prompt"), prompts_config=config.get("prompts_config"), prompts_dir=config.get("prompts_dir"), resource_dir_config=config.get("resource_dir", []), topic_gen_num=config.get("num", 1), topic_gen_date=config.get("date", "") ) logging.info("已为步骤2创建PromptManager实例。") except KeyError as e: logging.error(f"创建PromptManager时配置错误:缺少键'{e}'。无法继续内容生成。") sys.exit(1) # 初始化AI Agent ai_agent = None content_success = False try: request_timeout = config.get("request_timeout", 30) max_retries = config.get("max_retries", 3) stream_chunk_timeout = config.get("stream_chunk_timeout", 60) ai_agent = AI_Agent( config["api_url"], config["model"], config["api_key"], timeout=request_timeout, max_retries=max_retries, stream_chunk_timeout=stream_chunk_timeout ) logging.info("已初始化用于内容生成的AI Agent。") # 遍历选题 for i, topic_item in enumerate(topics_list): topic_index = topic_item.get('index', i + 1) logging.info(f"--- 处理选题 {topic_index}/{len(topics_list)}: {topic_item.get('object', 'N/A')} ---") # 读取内容生成所需的参数 content_variants = config.get("variants", 1) content_temp = config.get("content_temperature", 0.3) content_top_p = config.get("content_top_p", 0.4) content_presence_penalty = config.get("content_presence_penalty", 1.5) # 调用generate_content_for_topic topic_success = generate_content_for_topic( ai_agent, prompt_manager, topic_item, run_id, topic_index, output_handler, variants=content_variants, temperature=content_temp, top_p=content_top_p, presence_penalty=content_presence_penalty ) if topic_success: logging.info(f"选题{topic_index}的内容生成成功。") content_success = True else: logging.warning(f"选题{topic_index}的内容生成失败或未产生有效结果。") logging.info(f"--- 完成选题 {topic_index} ---") except KeyError as e: logging.error(f"内容生成过程中的配置错误:缺少键'{e}'") traceback.print_exc() except Exception as e: logging.exception("内容生成过程中发生意外错误:") finally: # 确保AI agent已关闭 if ai_agent: logging.info("关闭内容生成AI Agent...") ai_agent.close() step2_end = time.time() if content_success: logging.info(f"步骤2完成,耗时{step2_end - step2_start:.2f}秒。") else: logging.warning("步骤2完成,但可能遇到错误或未生成输出。") else: logging.error("无法进行步骤2:步骤1的run_id或topics_list无效。") # 最终化输出 if run_id: output_handler.finalize(run_id) pipeline_end_time = time.time() logging.info(f"流程完成。总执行时间:{pipeline_end_time - pipeline_start_time:.2f}秒。") logging.info(f"运行ID'{run_id}'的结果位于:{os.path.join(config.get('output_dir', 'result'), run_id)}") if __name__ == "__main__": main()