261 lines
11 KiB
Python
261 lines
11 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
测试选题生成和文章生成的脚本
|
||
"""
|
||
|
||
import os
|
||
import time
|
||
import argparse
|
||
import json
|
||
import logging
|
||
import sys
|
||
from datetime import datetime
|
||
|
||
from core.ai_agent import AI_Agent
|
||
from utils.prompt_manager import PromptManager
|
||
from utils.tweet_generator import run_topic_generation_pipeline, generate_content_for_topic
|
||
from utils.output_handler import FileSystemOutputHandler
|
||
|
||
def load_config(config_path="topic_content_config.json"):
|
||
"""从JSON文件加载配置"""
|
||
if not os.path.exists(config_path):
|
||
print(f"错误:配置文件 '{config_path}' 未找到。")
|
||
sys.exit(1)
|
||
try:
|
||
with open(config_path, 'r', encoding='utf-8') as f:
|
||
config = json.load(f)
|
||
|
||
# 基本验证
|
||
required_keys = ["api_url", "model", "api_key", "resource_dir", "output_dir",
|
||
"num", "variants", "topic_system_prompt", "topic_user_prompt",
|
||
"content_system_prompt"]
|
||
|
||
if not all(key in config for key in required_keys):
|
||
missing_keys = [key for key in required_keys if key not in config]
|
||
print(f"错误:配置文件 '{config_path}' 缺少必需的键:{missing_keys}")
|
||
sys.exit(1)
|
||
|
||
# 验证prompts_dir或prompts_config至少有一个存在
|
||
if not ("prompts_dir" in config or "prompts_config" in config):
|
||
print(f"错误:配置文件 '{config_path}' 必须包含 'prompts_dir' 或 'prompts_config'")
|
||
sys.exit(1)
|
||
|
||
return config
|
||
except json.JSONDecodeError:
|
||
print(f"错误:无法从 '{config_path}' 解码JSON。请检查文件格式。")
|
||
sys.exit(1)
|
||
except Exception as e:
|
||
print(f"从 '{config_path}' 加载配置时出错:{e}")
|
||
sys.exit(1)
|
||
|
||
def main():
|
||
# 设置日志记录
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
|
||
datefmt='%Y-%m-%d %H:%M:%S'
|
||
)
|
||
|
||
# 解析命令行参数
|
||
parser = argparse.ArgumentParser(description="测试选题和文章生成")
|
||
parser.add_argument(
|
||
"--config",
|
||
type=str,
|
||
default="topic_content_config.json",
|
||
help="配置文件路径(例如,topic_content_config.json)"
|
||
)
|
||
parser.add_argument(
|
||
"--run_id",
|
||
type=str,
|
||
default=None,
|
||
help="可选的指定运行ID(例如,'test_run_01')。如果未提供,将生成基于时间戳的ID。"
|
||
)
|
||
parser.add_argument(
|
||
"--topics_file",
|
||
type=str,
|
||
default=None,
|
||
help="可选的预生成选题JSON文件路径。如果提供,则跳过选题生成。"
|
||
)
|
||
parser.add_argument(
|
||
"--debug",
|
||
action='store_true',
|
||
help="启用调试级别日志记录。"
|
||
)
|
||
args = parser.parse_args()
|
||
|
||
# 调整日志级别(如果启用了调试)
|
||
if args.debug:
|
||
logging.getLogger().setLevel(logging.DEBUG)
|
||
logging.info("已启用调试日志记录。")
|
||
|
||
logging.info("启动选题和文章生成测试脚本...")
|
||
logging.info(f"使用配置文件:{args.config}")
|
||
if args.run_id:
|
||
logging.info(f"使用指定的run_id:{args.run_id}")
|
||
if args.topics_file:
|
||
logging.info(f"使用现有选题文件:{args.topics_file}")
|
||
|
||
# 加载配置
|
||
config = load_config(args.config)
|
||
if config is None:
|
||
logging.critical("无法加载配置。退出。")
|
||
sys.exit(1)
|
||
|
||
# 初始化输出处理器
|
||
output_handler = FileSystemOutputHandler(config.get("output_dir", "result"))
|
||
logging.info(f"使用输出处理器:{output_handler.__class__.__name__}")
|
||
|
||
run_id = args.run_id
|
||
topics_list = None
|
||
system_prompt = None
|
||
user_prompt = None
|
||
pipeline_start_time = time.time()
|
||
|
||
# 步骤1:选题生成(或加载现有选题)
|
||
if args.topics_file:
|
||
from core.topic_parser import TopicParser
|
||
logging.info(f"跳过选题生成(步骤1)- 从以下位置加载选题:{args.topics_file}")
|
||
topics_list = TopicParser.load_topics_from_json(args.topics_file)
|
||
if topics_list:
|
||
# 如果未提供run_id,尝试从文件名推断
|
||
if not run_id:
|
||
try:
|
||
base = os.path.basename(args.topics_file)
|
||
# 假设格式为"tweet_topic_{run_id}.json"或"tweet_topic.json"
|
||
if base.startswith("tweet_topic_") and base.endswith(".json"):
|
||
run_id = base[len("tweet_topic_"): -len(".json")]
|
||
logging.info(f"从选题文件名推断的run_id:{run_id}")
|
||
elif base == "tweet_topic.json":
|
||
logging.warning(f"从默认文件名'{base}'加载选题。未推断run_id。")
|
||
else:
|
||
logging.warning(f"无法从选题文件名推断run_id:{base}")
|
||
except Exception as e:
|
||
logging.warning(f"尝试从选题文件名推断run_id时出错:{e}")
|
||
|
||
# 如果尝试推断后run_id仍为None,则生成一个
|
||
if run_id is None:
|
||
run_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S_loaded")
|
||
logging.info(f"为加载的选题生成的run_id:{run_id}")
|
||
|
||
# 加载文件时缺少提示词
|
||
system_prompt = ""
|
||
user_prompt = ""
|
||
logging.info(f"成功加载{len(topics_list)}个选题,run_id:{run_id}。提示词不可用。")
|
||
else:
|
||
logging.error(f"无法从{args.topics_file}加载选题。无法继续。")
|
||
sys.exit(1)
|
||
else:
|
||
logging.info("执行选题生成(步骤1)...")
|
||
step1_start = time.time()
|
||
# 调用更新后的函数,接收原始数据
|
||
run_id, topics_list, system_prompt, user_prompt = run_topic_generation_pipeline(config, args.run_id)
|
||
step1_end = time.time()
|
||
if run_id is not None and topics_list is not None:
|
||
logging.info(f"步骤1成功完成,耗时{step1_end - step1_start:.2f}秒。运行ID:{run_id}")
|
||
# 使用输出处理器保存结果
|
||
output_handler.handle_topic_results(run_id, topics_list, system_prompt, user_prompt)
|
||
else:
|
||
logging.critical("选题生成(步骤1)失败。退出。")
|
||
sys.exit(1)
|
||
|
||
# 步骤2:内容生成
|
||
if run_id is not None and topics_list is not None:
|
||
logging.info("执行内容生成(步骤2)...")
|
||
step2_start = time.time()
|
||
|
||
# 创建PromptManager实例
|
||
try:
|
||
prompt_manager = PromptManager(
|
||
topic_system_prompt_path=config.get("topic_system_prompt"),
|
||
topic_user_prompt_path=config.get("topic_user_prompt"),
|
||
content_system_prompt_path=config.get("content_system_prompt"),
|
||
prompts_config=config.get("prompts_config"),
|
||
prompts_dir=config.get("prompts_dir"),
|
||
resource_dir_config=config.get("resource_dir", []),
|
||
topic_gen_num=config.get("num", 1),
|
||
topic_gen_date=config.get("date", "")
|
||
)
|
||
logging.info("已为步骤2创建PromptManager实例。")
|
||
except KeyError as e:
|
||
logging.error(f"创建PromptManager时配置错误:缺少键'{e}'。无法继续内容生成。")
|
||
sys.exit(1)
|
||
|
||
# 初始化AI Agent
|
||
ai_agent = None
|
||
content_success = False
|
||
try:
|
||
request_timeout = config.get("request_timeout", 30)
|
||
max_retries = config.get("max_retries", 3)
|
||
stream_chunk_timeout = config.get("stream_chunk_timeout", 60)
|
||
ai_agent = AI_Agent(
|
||
config["api_url"],
|
||
config["model"],
|
||
config["api_key"],
|
||
timeout=request_timeout,
|
||
max_retries=max_retries,
|
||
stream_chunk_timeout=stream_chunk_timeout
|
||
)
|
||
logging.info("已初始化用于内容生成的AI Agent。")
|
||
|
||
# 遍历选题
|
||
for i, topic_item in enumerate(topics_list):
|
||
topic_index = topic_item.get('index', i + 1)
|
||
logging.info(f"--- 处理选题 {topic_index}/{len(topics_list)}: {topic_item.get('object', 'N/A')} ---")
|
||
|
||
# 读取内容生成所需的参数
|
||
content_variants = config.get("variants", 1)
|
||
content_temp = config.get("content_temperature", 0.3)
|
||
content_top_p = config.get("content_top_p", 0.4)
|
||
content_presence_penalty = config.get("content_presence_penalty", 1.5)
|
||
|
||
# 调用generate_content_for_topic
|
||
topic_success = generate_content_for_topic(
|
||
ai_agent,
|
||
prompt_manager,
|
||
topic_item,
|
||
run_id,
|
||
topic_index,
|
||
output_handler,
|
||
variants=content_variants,
|
||
temperature=content_temp,
|
||
top_p=content_top_p,
|
||
presence_penalty=content_presence_penalty
|
||
)
|
||
|
||
if topic_success:
|
||
logging.info(f"选题{topic_index}的内容生成成功。")
|
||
content_success = True
|
||
else:
|
||
logging.warning(f"选题{topic_index}的内容生成失败或未产生有效结果。")
|
||
logging.info(f"--- 完成选题 {topic_index} ---")
|
||
|
||
except KeyError as e:
|
||
logging.error(f"内容生成过程中的配置错误:缺少键'{e}'")
|
||
traceback.print_exc()
|
||
except Exception as e:
|
||
logging.exception("内容生成过程中发生意外错误:")
|
||
finally:
|
||
# 确保AI agent已关闭
|
||
if ai_agent:
|
||
logging.info("关闭内容生成AI Agent...")
|
||
ai_agent.close()
|
||
|
||
step2_end = time.time()
|
||
if content_success:
|
||
logging.info(f"步骤2完成,耗时{step2_end - step2_start:.2f}秒。")
|
||
else:
|
||
logging.warning("步骤2完成,但可能遇到错误或未生成输出。")
|
||
else:
|
||
logging.error("无法进行步骤2:步骤1的run_id或topics_list无效。")
|
||
|
||
# 最终化输出
|
||
if run_id:
|
||
output_handler.finalize(run_id)
|
||
|
||
pipeline_end_time = time.time()
|
||
logging.info(f"流程完成。总执行时间:{pipeline_end_time - pipeline_start_time:.2f}秒。")
|
||
logging.info(f"运行ID'{run_id}'的结果位于:{os.path.join(config.get('output_dir', 'result'), run_id)}")
|
||
|
||
if __name__ == "__main__":
|
||
main() |