TravelContentCreator/test_topic_content.py

261 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
测试选题生成和文章生成的脚本
"""
import os
import time
import argparse
import json
import logging
import sys
from datetime import datetime
from core.ai_agent import AI_Agent
from utils.prompt_manager import PromptManager
from utils.tweet_generator import run_topic_generation_pipeline, generate_content_for_topic
from utils.output_handler import FileSystemOutputHandler
def load_config(config_path="topic_content_config.json"):
"""从JSON文件加载配置"""
if not os.path.exists(config_path):
print(f"错误:配置文件 '{config_path}' 未找到。")
sys.exit(1)
try:
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
# 基本验证
required_keys = ["api_url", "model", "api_key", "resource_dir", "output_dir",
"num", "variants", "topic_system_prompt", "topic_user_prompt",
"content_system_prompt"]
if not all(key in config for key in required_keys):
missing_keys = [key for key in required_keys if key not in config]
print(f"错误:配置文件 '{config_path}' 缺少必需的键:{missing_keys}")
sys.exit(1)
# 验证prompts_dir或prompts_config至少有一个存在
if not ("prompts_dir" in config or "prompts_config" in config):
print(f"错误:配置文件 '{config_path}' 必须包含 'prompts_dir''prompts_config'")
sys.exit(1)
return config
except json.JSONDecodeError:
print(f"错误:无法从 '{config_path}' 解码JSON。请检查文件格式。")
sys.exit(1)
except Exception as e:
print(f"'{config_path}' 加载配置时出错:{e}")
sys.exit(1)
def main():
# 设置日志记录
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# 解析命令行参数
parser = argparse.ArgumentParser(description="测试选题和文章生成")
parser.add_argument(
"--config",
type=str,
default="topic_content_config.json",
help="配置文件路径例如topic_content_config.json"
)
parser.add_argument(
"--run_id",
type=str,
default=None,
help="可选的指定运行ID例如'test_run_01'。如果未提供将生成基于时间戳的ID。"
)
parser.add_argument(
"--topics_file",
type=str,
default=None,
help="可选的预生成选题JSON文件路径。如果提供则跳过选题生成。"
)
parser.add_argument(
"--debug",
action='store_true',
help="启用调试级别日志记录。"
)
args = parser.parse_args()
# 调整日志级别(如果启用了调试)
if args.debug:
logging.getLogger().setLevel(logging.DEBUG)
logging.info("已启用调试日志记录。")
logging.info("启动选题和文章生成测试脚本...")
logging.info(f"使用配置文件:{args.config}")
if args.run_id:
logging.info(f"使用指定的run_id{args.run_id}")
if args.topics_file:
logging.info(f"使用现有选题文件:{args.topics_file}")
# 加载配置
config = load_config(args.config)
if config is None:
logging.critical("无法加载配置。退出。")
sys.exit(1)
# 初始化输出处理器
output_handler = FileSystemOutputHandler(config.get("output_dir", "result"))
logging.info(f"使用输出处理器:{output_handler.__class__.__name__}")
run_id = args.run_id
topics_list = None
system_prompt = None
user_prompt = None
pipeline_start_time = time.time()
# 步骤1选题生成或加载现有选题
if args.topics_file:
from core.topic_parser import TopicParser
logging.info(f"跳过选题生成步骤1- 从以下位置加载选题:{args.topics_file}")
topics_list = TopicParser.load_topics_from_json(args.topics_file)
if topics_list:
# 如果未提供run_id尝试从文件名推断
if not run_id:
try:
base = os.path.basename(args.topics_file)
# 假设格式为"tweet_topic_{run_id}.json"或"tweet_topic.json"
if base.startswith("tweet_topic_") and base.endswith(".json"):
run_id = base[len("tweet_topic_"): -len(".json")]
logging.info(f"从选题文件名推断的run_id{run_id}")
elif base == "tweet_topic.json":
logging.warning(f"从默认文件名'{base}'加载选题。未推断run_id。")
else:
logging.warning(f"无法从选题文件名推断run_id{base}")
except Exception as e:
logging.warning(f"尝试从选题文件名推断run_id时出错{e}")
# 如果尝试推断后run_id仍为None则生成一个
if run_id is None:
run_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S_loaded")
logging.info(f"为加载的选题生成的run_id{run_id}")
# 加载文件时缺少提示词
system_prompt = ""
user_prompt = ""
logging.info(f"成功加载{len(topics_list)}个选题run_id{run_id}。提示词不可用。")
else:
logging.error(f"无法从{args.topics_file}加载选题。无法继续。")
sys.exit(1)
else:
logging.info("执行选题生成步骤1...")
step1_start = time.time()
# 调用更新后的函数,接收原始数据
run_id, topics_list, system_prompt, user_prompt = run_topic_generation_pipeline(config, args.run_id)
step1_end = time.time()
if run_id is not None and topics_list is not None:
logging.info(f"步骤1成功完成耗时{step1_end - step1_start:.2f}秒。运行ID{run_id}")
# 使用输出处理器保存结果
output_handler.handle_topic_results(run_id, topics_list, system_prompt, user_prompt)
else:
logging.critical("选题生成步骤1失败。退出。")
sys.exit(1)
# 步骤2内容生成
if run_id is not None and topics_list is not None:
logging.info("执行内容生成步骤2...")
step2_start = time.time()
# 创建PromptManager实例
try:
prompt_manager = PromptManager(
topic_system_prompt_path=config.get("topic_system_prompt"),
topic_user_prompt_path=config.get("topic_user_prompt"),
content_system_prompt_path=config.get("content_system_prompt"),
prompts_config=config.get("prompts_config"),
prompts_dir=config.get("prompts_dir"),
resource_dir_config=config.get("resource_dir", []),
topic_gen_num=config.get("num", 1),
topic_gen_date=config.get("date", "")
)
logging.info("已为步骤2创建PromptManager实例。")
except KeyError as e:
logging.error(f"创建PromptManager时配置错误缺少键'{e}'。无法继续内容生成。")
sys.exit(1)
# 初始化AI Agent
ai_agent = None
content_success = False
try:
request_timeout = config.get("request_timeout", 30)
max_retries = config.get("max_retries", 3)
stream_chunk_timeout = config.get("stream_chunk_timeout", 60)
ai_agent = AI_Agent(
config["api_url"],
config["model"],
config["api_key"],
timeout=request_timeout,
max_retries=max_retries,
stream_chunk_timeout=stream_chunk_timeout
)
logging.info("已初始化用于内容生成的AI Agent。")
# 遍历选题
for i, topic_item in enumerate(topics_list):
topic_index = topic_item.get('index', i + 1)
logging.info(f"--- 处理选题 {topic_index}/{len(topics_list)}: {topic_item.get('object', 'N/A')} ---")
# 读取内容生成所需的参数
content_variants = config.get("variants", 1)
content_temp = config.get("content_temperature", 0.3)
content_top_p = config.get("content_top_p", 0.4)
content_presence_penalty = config.get("content_presence_penalty", 1.5)
# 调用generate_content_for_topic
topic_success = generate_content_for_topic(
ai_agent,
prompt_manager,
topic_item,
run_id,
topic_index,
output_handler,
variants=content_variants,
temperature=content_temp,
top_p=content_top_p,
presence_penalty=content_presence_penalty
)
if topic_success:
logging.info(f"选题{topic_index}的内容生成成功。")
content_success = True
else:
logging.warning(f"选题{topic_index}的内容生成失败或未产生有效结果。")
logging.info(f"--- 完成选题 {topic_index} ---")
except KeyError as e:
logging.error(f"内容生成过程中的配置错误:缺少键'{e}'")
traceback.print_exc()
except Exception as e:
logging.exception("内容生成过程中发生意外错误:")
finally:
# 确保AI agent已关闭
if ai_agent:
logging.info("关闭内容生成AI Agent...")
ai_agent.close()
step2_end = time.time()
if content_success:
logging.info(f"步骤2完成耗时{step2_end - step2_start:.2f}秒。")
else:
logging.warning("步骤2完成但可能遇到错误或未生成输出。")
else:
logging.error("无法进行步骤2步骤1的run_id或topics_list无效。")
# 最终化输出
if run_id:
output_handler.finalize(run_id)
pipeline_end_time = time.time()
logging.info(f"流程完成。总执行时间:{pipeline_end_time - pipeline_start_time:.2f}秒。")
logging.info(f"运行ID'{run_id}'的结果位于:{os.path.join(config.get('output_dir', 'result'), run_id)}")
if __name__ == "__main__":
main()