TravelContentCreator/test_topic_content.py

261 lines
11 KiB
Python
Raw Normal View History

2025-04-25 10:11:45 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
测试选题生成和文章生成的脚本
"""
import os
import time
import argparse
import json
import logging
import sys
from datetime import datetime
from core.ai_agent import AI_Agent
from utils.prompt_manager import PromptManager
from utils.tweet_generator import run_topic_generation_pipeline, generate_content_for_topic
from utils.output_handler import FileSystemOutputHandler
def load_config(config_path="topic_content_config.json"):
"""从JSON文件加载配置"""
if not os.path.exists(config_path):
print(f"错误:配置文件 '{config_path}' 未找到。")
sys.exit(1)
try:
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
# 基本验证
required_keys = ["api_url", "model", "api_key", "resource_dir", "output_dir",
"num", "variants", "topic_system_prompt", "topic_user_prompt",
"content_system_prompt"]
if not all(key in config for key in required_keys):
missing_keys = [key for key in required_keys if key not in config]
print(f"错误:配置文件 '{config_path}' 缺少必需的键:{missing_keys}")
sys.exit(1)
# 验证prompts_dir或prompts_config至少有一个存在
if not ("prompts_dir" in config or "prompts_config" in config):
print(f"错误:配置文件 '{config_path}' 必须包含 'prompts_dir''prompts_config'")
sys.exit(1)
return config
except json.JSONDecodeError:
print(f"错误:无法从 '{config_path}' 解码JSON。请检查文件格式。")
sys.exit(1)
except Exception as e:
print(f"'{config_path}' 加载配置时出错:{e}")
sys.exit(1)
def main():
# 设置日志记录
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# 解析命令行参数
parser = argparse.ArgumentParser(description="测试选题和文章生成")
parser.add_argument(
"--config",
type=str,
default="topic_content_config.json",
help="配置文件路径例如topic_content_config.json"
)
parser.add_argument(
"--run_id",
type=str,
default=None,
help="可选的指定运行ID例如'test_run_01'。如果未提供将生成基于时间戳的ID。"
)
parser.add_argument(
"--topics_file",
type=str,
default=None,
help="可选的预生成选题JSON文件路径。如果提供则跳过选题生成。"
)
parser.add_argument(
"--debug",
action='store_true',
help="启用调试级别日志记录。"
)
args = parser.parse_args()
# 调整日志级别(如果启用了调试)
if args.debug:
logging.getLogger().setLevel(logging.DEBUG)
logging.info("已启用调试日志记录。")
logging.info("启动选题和文章生成测试脚本...")
logging.info(f"使用配置文件:{args.config}")
if args.run_id:
logging.info(f"使用指定的run_id{args.run_id}")
if args.topics_file:
logging.info(f"使用现有选题文件:{args.topics_file}")
# 加载配置
config = load_config(args.config)
if config is None:
logging.critical("无法加载配置。退出。")
sys.exit(1)
# 初始化输出处理器
output_handler = FileSystemOutputHandler(config.get("output_dir", "result"))
logging.info(f"使用输出处理器:{output_handler.__class__.__name__}")
run_id = args.run_id
topics_list = None
system_prompt = None
user_prompt = None
pipeline_start_time = time.time()
# 步骤1选题生成或加载现有选题
if args.topics_file:
from core.topic_parser import TopicParser
logging.info(f"跳过选题生成步骤1- 从以下位置加载选题:{args.topics_file}")
topics_list = TopicParser.load_topics_from_json(args.topics_file)
if topics_list:
# 如果未提供run_id尝试从文件名推断
if not run_id:
try:
base = os.path.basename(args.topics_file)
# 假设格式为"tweet_topic_{run_id}.json"或"tweet_topic.json"
if base.startswith("tweet_topic_") and base.endswith(".json"):
run_id = base[len("tweet_topic_"): -len(".json")]
logging.info(f"从选题文件名推断的run_id{run_id}")
elif base == "tweet_topic.json":
logging.warning(f"从默认文件名'{base}'加载选题。未推断run_id。")
else:
logging.warning(f"无法从选题文件名推断run_id{base}")
except Exception as e:
logging.warning(f"尝试从选题文件名推断run_id时出错{e}")
# 如果尝试推断后run_id仍为None则生成一个
if run_id is None:
run_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S_loaded")
logging.info(f"为加载的选题生成的run_id{run_id}")
# 加载文件时缺少提示词
system_prompt = ""
user_prompt = ""
logging.info(f"成功加载{len(topics_list)}个选题run_id{run_id}。提示词不可用。")
else:
logging.error(f"无法从{args.topics_file}加载选题。无法继续。")
sys.exit(1)
else:
logging.info("执行选题生成步骤1...")
step1_start = time.time()
# 调用更新后的函数,接收原始数据
run_id, topics_list, system_prompt, user_prompt = run_topic_generation_pipeline(config, args.run_id)
step1_end = time.time()
if run_id is not None and topics_list is not None:
logging.info(f"步骤1成功完成耗时{step1_end - step1_start:.2f}秒。运行ID{run_id}")
# 使用输出处理器保存结果
output_handler.handle_topic_results(run_id, topics_list, system_prompt, user_prompt)
else:
logging.critical("选题生成步骤1失败。退出。")
sys.exit(1)
# 步骤2内容生成
if run_id is not None and topics_list is not None:
logging.info("执行内容生成步骤2...")
step2_start = time.time()
# 创建PromptManager实例
try:
prompt_manager = PromptManager(
topic_system_prompt_path=config.get("topic_system_prompt"),
topic_user_prompt_path=config.get("topic_user_prompt"),
content_system_prompt_path=config.get("content_system_prompt"),
prompts_config=config.get("prompts_config"),
prompts_dir=config.get("prompts_dir"),
resource_dir_config=config.get("resource_dir", []),
topic_gen_num=config.get("num", 1),
topic_gen_date=config.get("date", "")
)
logging.info("已为步骤2创建PromptManager实例。")
except KeyError as e:
logging.error(f"创建PromptManager时配置错误缺少键'{e}'。无法继续内容生成。")
sys.exit(1)
# 初始化AI Agent
ai_agent = None
content_success = False
try:
request_timeout = config.get("request_timeout", 30)
max_retries = config.get("max_retries", 3)
stream_chunk_timeout = config.get("stream_chunk_timeout", 60)
ai_agent = AI_Agent(
config["api_url"],
config["model"],
config["api_key"],
timeout=request_timeout,
max_retries=max_retries,
stream_chunk_timeout=stream_chunk_timeout
)
logging.info("已初始化用于内容生成的AI Agent。")
# 遍历选题
for i, topic_item in enumerate(topics_list):
topic_index = topic_item.get('index', i + 1)
logging.info(f"--- 处理选题 {topic_index}/{len(topics_list)}: {topic_item.get('object', 'N/A')} ---")
# 读取内容生成所需的参数
content_variants = config.get("variants", 1)
content_temp = config.get("content_temperature", 0.3)
content_top_p = config.get("content_top_p", 0.4)
content_presence_penalty = config.get("content_presence_penalty", 1.5)
# 调用generate_content_for_topic
topic_success = generate_content_for_topic(
ai_agent,
prompt_manager,
topic_item,
run_id,
topic_index,
output_handler,
variants=content_variants,
temperature=content_temp,
top_p=content_top_p,
presence_penalty=content_presence_penalty
)
if topic_success:
logging.info(f"选题{topic_index}的内容生成成功。")
content_success = True
else:
logging.warning(f"选题{topic_index}的内容生成失败或未产生有效结果。")
logging.info(f"--- 完成选题 {topic_index} ---")
except KeyError as e:
logging.error(f"内容生成过程中的配置错误:缺少键'{e}'")
traceback.print_exc()
except Exception as e:
logging.exception("内容生成过程中发生意外错误:")
finally:
# 确保AI agent已关闭
if ai_agent:
logging.info("关闭内容生成AI Agent...")
ai_agent.close()
step2_end = time.time()
if content_success:
logging.info(f"步骤2完成耗时{step2_end - step2_start:.2f}秒。")
else:
logging.warning("步骤2完成但可能遇到错误或未生成输出。")
else:
logging.error("无法进行步骤2步骤1的run_id或topics_list无效。")
# 最终化输出
if run_id:
output_handler.finalize(run_id)
pipeline_end_time = time.time()
logging.info(f"流程完成。总执行时间:{pipeline_end_time - pipeline_start_time:.2f}秒。")
logging.info(f"运行ID'{run_id}'的结果位于:{os.path.join(config.get('output_dir', 'result'), run_id)}")
if __name__ == "__main__":
main()