TravelContentCreator/utils/tweet/topic_generator.py

111 lines
4.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
选题生成模块
"""
import logging
import json
import re
from typing import List, Dict, Any, Optional
from json_repair import loads as json_repair_loads
from core.ai import AIAgent
from core.config import GenerateTopicConfig, ConfigManager
from utils.prompts import TopicPromptBuilder
from utils.file_io import OutputManager
from .topic_parser import TopicParser
logger = logging.getLogger(__name__)
class TopicGenerator:
"""
负责生成、解析和保存主题
"""
def __init__(self, ai_agent: AIAgent, config_manager: ConfigManager, output_manager: OutputManager):
"""
初始化主题生成器
Args:
ai_agent: AI Agent
config_manager: 配置管理器
output_manager: 输出管理器
"""
self.ai_agent = ai_agent
self.config: GenerateTopicConfig = config_manager.get_config('topic_gen', GenerateTopicConfig)
self.output_manager = output_manager
self.parser = TopicParser()
self.prompt_builder = TopicPromptBuilder(config_manager)
logger.info(f"选题生成配置: {self.config}")
async def generate_topics(self) -> Optional[List[Dict[str, Any]]]:
"""
执行完整的选题生成流程:构建提示 -> 调用AI -> 解析结果 -> 保存产物
"""
logger.info("开始执行选题生成流程...")
# 1. 构建提示
system_prompt = self.prompt_builder.get_system_prompt()
user_prompt = self.prompt_builder.build_user_prompt(
num_topics=self.config.topic.num,
month=self.config.topic.date
)
self.output_manager.save_text(system_prompt, "topic_system_prompt.txt")
self.output_manager.save_text(user_prompt, "topic_user_prompt.txt")
# 获取模型参数
model_params = {}
if hasattr(self.config, 'model') and isinstance(self.config.model, dict):
model_params = {
'temperature': self.config.model.get('temperature'),
'top_p': self.config.model.get('top_p'),
'presence_penalty': self.config.model.get('presence_penalty')
}
# 移除None值
model_params = {k: v for k, v in model_params.items() if v is not None}
# 2. 调用AI生成
try:
raw_result, _, _, _ = await self.ai_agent.generate_text(
system_prompt=system_prompt,
user_prompt=user_prompt,
use_stream=False, # 选题生成通常不需要流式输出
stage="选题生成",
**model_params
)
self.output_manager.save_text(raw_result, "topics_raw_response.txt")
except Exception as e:
logger.critical(f"AI调用失败无法生成选题: {e}", exc_info=True)
return None
# 3. 解析结果
topics = self.parser.parse(raw_result)
if not topics:
logger.error("未能从AI响应中解析出任何有效选题")
return None
# 4. 保存解析后的结果
self.output_manager.save_json(topics, "topics_generated.json")
# 5. (可选)保存为易于阅读的 .txt 格式
topics_text = self._format_topics_to_text(topics)
self.output_manager.save_text(topics_text, "topics_generated.txt")
logger.info(f"选题生成流程成功完成,共生成 {len(topics)} 个选题。")
return topics
def _format_topics_to_text(self, topics: List[Dict[str, Any]]) -> str:
"""将选题列表格式化为人类可读的文本"""
text_parts = [f"# 选题列表 (Run ID: {self.output_manager.run_id})\n"]
for topic in topics:
text_parts.append(f"## 选题 {topic.get('index', 'N/A')}")
text_parts.append(f"- 日期: {topic.get('date', 'N/A')}")
text_parts.append(f"- 对象: {topic.get('object', 'N/A')}")
text_parts.append(f"- 产品: {topic.get('product', 'N/A')}")
text_parts.append(f"- 风格: {topic.get('style', 'N/A')}")
text_parts.append(f"- 目标受众: {topic.get('target_audience', 'N/A')}")
text_parts.append(f"- 逻辑: {topic.get('logic', 'N/A')}")
return "\n".join(text_parts)