124 lines
5.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
内容生成流程管理器
"""
import logging
import json
from typing import List, Dict, Any, Optional
from datetime import datetime
from core.config import (
ConfigManager,
GenerateTopicConfig,
GenerateContentConfig,
SystemConfig,
ResourceConfig,
AIModelConfig
)
from core.ai import AIAgent
from utils.file_io import OutputManager
from utils.tweet.topic_generator import TopicGenerator
from utils.tweet.content_generator import ContentGenerator
from utils.tweet.content_judger import ContentJudger
# from utils.poster_generator import PosterGenerator # 待实现
logger = logging.getLogger(__name__)
class PipelineManager:
"""
负责协调整个内容生成流程
"""
def __init__(self, config_dir: str, run_id: Optional[str] = None):
"""
初始化管道管理器
Args:
config_dir: 配置目录
run_id: 运行ID如果为None则自动生成
"""
# 1. 加载配置
self.config_manager = ConfigManager()
self.config_manager.load_from_directory(config_dir)
# 2. 获取各模块配置
self.ai_config = self.config_manager.get_config('ai_model', AIModelConfig)
self.system_config = self.config_manager.get_config('system', SystemConfig)
self.resource_config = self.config_manager.get_config('resource', ResourceConfig)
self.topic_config = self.config_manager.get_config('topic_gen', GenerateTopicConfig)
self.content_config = self.config_manager.get_config('content_gen', GenerateContentConfig)
# 3. 初始化运行ID
self.run_id = run_id or f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
# 4. 初始化组件
self.ai_agent = AIAgent(self.ai_config)
# 修复:只传递 base_dir 字符串,而不是整个 OutputConfig 对象
self.output_manager = OutputManager(self.resource_config.output_dir.base_dir, self.run_id)
self.topic_generator = TopicGenerator(self.ai_agent, self.config_manager, self.output_manager)
self.content_generator = ContentGenerator(self.ai_agent, self.config_manager, self.output_manager)
self.content_judger = ContentJudger(self.ai_agent, self.config_manager, self.output_manager)
# self.poster_generator = PosterGenerator(...)
async def process_content_generation(self, topics: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""处理内容生成"""
logger.info("\n--- 步骤 2: 开始为每个选题生成内容 ---")
generated_contents = []
for topic in topics:
topic_index = topic.get('index', 'N/A')
logger.info(f"--- 正在处理选题 {topic_index} ---")
try:
content_data = await self.content_generator.generate_content_for_topic(topic)
if "error" not in content_data:
generated_contents.append({"topic": topic, "content": content_data})
else:
logger.error(f"为选题 {topic_index} 生成内容失败: {content_data['error']}")
except Exception as e:
logger.critical(f"为选题 {topic_index} 处理内容生成时发生意外错误: {e}", exc_info=True)
return generated_contents
async def process_content_judging(self, generated_contents: List[Dict[str, Any]]):
"""处理内容审核"""
if not self.content_config.enable_content_judge:
logger.info("内容审核已禁用,跳过此步骤。")
return
for item in generated_contents:
topic = item["topic"]
content = item["content"]
topic_index = topic.get('index', 'N/A')
logger.info(f"--- 步骤 3: 开始审核选题 {topic_index} 的内容 ---")
try:
# content 可能已经是字符串或字典judger内部会处理
judged_data = await self.content_judger.judge_content(content, topic)
if "error" not in judged_data:
# judged_data 通常包含分析和修改后的内容
# 这里可以根据需要保存或进一步处理
pass
except Exception as e:
logger.critical(f"为选题 {topic_index} 处理内容审核时发生意外错误: {e}", exc_info=True)
async def run_pipeline(self):
"""按顺序执行整个流程"""
logger.info("--- 步骤 1: 开始生成选题 ---")
topics = await self.topic_generator.generate_topics()
if not topics:
logger.error("未能生成任何选题,流程终止。")
return
# 步骤 2: 内容生成
generated_contents = await self.process_content_generation(topics)
# 步骤 3: 内容审核
if generated_contents:
await self.process_content_judging(generated_contents)
else:
logger.warning("没有成功生成任何内容,审核步骤将跳过。")
logger.info("--- 所有任务已完成 ---")