111 lines
4.3 KiB
Python
111 lines
4.3 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
选题生成模块
|
||
"""
|
||
|
||
import logging
|
||
import json
|
||
import re
|
||
from typing import List, Dict, Any, Optional
|
||
from json_repair import loads as json_repair_loads
|
||
|
||
from core.ai import AIAgent
|
||
from core.config import GenerateTopicConfig, ConfigManager
|
||
from utils.prompts import TopicPromptBuilder
|
||
from utils.file_io import OutputManager
|
||
from .topic_parser import TopicParser
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class TopicGenerator:
|
||
"""
|
||
负责生成、解析和保存主题
|
||
"""
|
||
|
||
def __init__(self, ai_agent: AIAgent, config_manager: ConfigManager, output_manager: OutputManager):
|
||
"""
|
||
初始化主题生成器
|
||
|
||
Args:
|
||
ai_agent: AI Agent
|
||
config_manager: 配置管理器
|
||
output_manager: 输出管理器
|
||
"""
|
||
self.ai_agent = ai_agent
|
||
self.config: GenerateTopicConfig = config_manager.get_config('topic_gen', GenerateTopicConfig)
|
||
self.output_manager = output_manager
|
||
self.parser = TopicParser()
|
||
self.prompt_builder = TopicPromptBuilder(config_manager)
|
||
logger.info(f"选题生成配置: {self.config}")
|
||
async def generate_topics(self) -> Optional[List[Dict[str, Any]]]:
|
||
"""
|
||
执行完整的选题生成流程:构建提示 -> 调用AI -> 解析结果 -> 保存产物
|
||
"""
|
||
logger.info("开始执行选题生成流程...")
|
||
|
||
# 1. 构建提示
|
||
system_prompt = self.prompt_builder.get_system_prompt()
|
||
user_prompt = self.prompt_builder.build_user_prompt(
|
||
num_topics=self.config.topic.num,
|
||
month=self.config.topic.date
|
||
)
|
||
self.output_manager.save_text(system_prompt, "topic_system_prompt.txt")
|
||
self.output_manager.save_text(user_prompt, "topic_user_prompt.txt")
|
||
|
||
# 获取模型参数
|
||
model_params = {}
|
||
if hasattr(self.config, 'model') and isinstance(self.config.model, dict):
|
||
model_params = {
|
||
'temperature': self.config.model.get('temperature'),
|
||
'top_p': self.config.model.get('top_p'),
|
||
'presence_penalty': self.config.model.get('presence_penalty')
|
||
}
|
||
# 移除None值
|
||
model_params = {k: v for k, v in model_params.items() if v is not None}
|
||
|
||
# 2. 调用AI生成
|
||
try:
|
||
raw_result, _, _, _ = await self.ai_agent.generate_text(
|
||
system_prompt=system_prompt,
|
||
user_prompt=user_prompt,
|
||
use_stream=True, # 选题生成通常不需要流式输出
|
||
stage="选题生成",
|
||
**model_params
|
||
)
|
||
self.output_manager.save_text(raw_result, "topics_raw_response.txt")
|
||
except Exception as e:
|
||
logger.critical(f"AI调用失败,无法生成选题: {e}", exc_info=True)
|
||
return None
|
||
|
||
# 3. 解析结果
|
||
topics = self.parser.parse(raw_result)
|
||
if not topics:
|
||
logger.error("未能从AI响应中解析出任何有效选题")
|
||
return None
|
||
|
||
# 4. 保存解析后的结果
|
||
self.output_manager.save_json(topics, "topics_generated.json")
|
||
|
||
# 5. (可选)保存为易于阅读的 .txt 格式
|
||
topics_text = self._format_topics_to_text(topics)
|
||
self.output_manager.save_text(topics_text, "topics_generated.txt")
|
||
|
||
logger.info(f"选题生成流程成功完成,共生成 {len(topics)} 个选题。")
|
||
return topics
|
||
|
||
def _format_topics_to_text(self, topics: List[Dict[str, Any]]) -> str:
|
||
"""将选题列表格式化为人类可读的文本"""
|
||
text_parts = [f"# 选题列表 (Run ID: {self.output_manager.run_id})\n"]
|
||
for topic in topics:
|
||
text_parts.append(f"## 选题 {topic.get('index', 'N/A')}")
|
||
text_parts.append(f"- 日期: {topic.get('date', 'N/A')}")
|
||
text_parts.append(f"- 对象: {topic.get('object', 'N/A')}")
|
||
text_parts.append(f"- 产品: {topic.get('product', 'N/A')}")
|
||
text_parts.append(f"- 风格: {topic.get('style', 'N/A')}")
|
||
text_parts.append(f"- 目标受众: {topic.get('target_audience', 'N/A')}")
|
||
text_parts.append(f"- 逻辑: {topic.get('logic', 'N/A')}")
|
||
return "\n".join(text_parts)
|