TravelContentCreator/tweet/topic_generator.py

149 lines
5.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
选题生成模块
"""
import logging
from typing import Dict, Any, List, Optional, Tuple
from core.ai import AIAgent
from core.config import ConfigManager, GenerateTopicConfig
from utils.prompts import TopicPromptBuilder
from utils.file_io import OutputManager, process_llm_json_text
from .topic_parser import TopicParser
logger = logging.getLogger(__name__)
class TopicGenerator:
"""
选题生成器
负责生成旅游相关的选题
"""
def __init__(self, ai_agent: AIAgent, config_manager: ConfigManager, output_manager: OutputManager):
"""
初始化选题生成器
Args:
ai_agent: AI代理
config_manager: 配置管理器
output_manager: 输出管理器
"""
self.ai_agent = ai_agent
self.config_manager = config_manager
self.config = config_manager.get_config('topic_gen', GenerateTopicConfig)
self.output_manager = output_manager
self.prompt_builder = TopicPromptBuilder(config_manager)
self.parser = TopicParser()
async def generate_topics(self) -> Optional[List[Dict[str, Any]]]:
"""
执行完整的选题生成流程:构建提示 -> 调用AI -> 解析结果 -> 保存产物
"""
logger.info("开始执行选题生成流程...")
# 1. 构建提示
system_prompt = self.prompt_builder.get_system_prompt()
user_prompt = self.prompt_builder.build_user_prompt(
numTopics=self.config.topic.num,
month=self.config.topic.date
)
self.output_manager.save_text(system_prompt, "topic_system_prompt.txt")
self.output_manager.save_text(user_prompt, "topic_user_prompt.txt")
# 获取模型参数
model_params = {}
if hasattr(self.config, 'model') and isinstance(self.config.model, dict):
model_params = {
'temperature': self.config.model.get('temperature'),
'top_p': self.config.model.get('top_p'),
'presence_penalty': self.config.model.get('presence_penalty')
}
# 移除None值
model_params = {k: v for k, v in model_params.items() if v is not None}
# 2. 调用AI生成
try:
raw_result, _, _, _ = await self.ai_agent.generate_text(
system_prompt=system_prompt,
user_prompt=user_prompt,
use_stream=True, # 选题生成通常不需要流式输出
stage="选题生成",
**model_params
)
self.output_manager.save_text(raw_result, "topics_raw_response.txt")
except Exception as e:
logger.critical(f"AI调用失败无法生成选题: {e}", exc_info=True)
return None
# 3. 解析结果
topics = self.parser.parse(raw_result)
if not topics:
logger.error("未能从AI响应中解析出任何有效选题")
return None
# 4. 保存结果
self.output_manager.save_json(topics, "topics.json")
logger.info(f"成功生成并保存 {len(topics)} 个选题")
return topics
async def generate_topics_with_prompt(self, system_prompt: str, user_prompt: str) -> Optional[List[Dict[str, Any]]]:
"""
使用预构建的提示词生成选题
Args:
system_prompt: 已构建好的系统提示词
user_prompt: 已构建好的用户提示词
Returns:
生成的选题列表如果失败则返回None
"""
logger.info("使用预构建提示词开始执行选题生成流程...")
# 保存提示以供调试
self.output_manager.save_text(system_prompt, "topic_system_prompt.txt")
self.output_manager.save_text(user_prompt, "topic_user_prompt.txt")
# 获取模型参数
model_params = {}
if hasattr(self.config, 'model') and isinstance(self.config.model, dict):
model_params = {
'temperature': self.config.model.get('temperature'),
'top_p': self.config.model.get('top_p'),
'presence_penalty': self.config.model.get('presence_penalty')
}
# 移除None值
model_params = {k: v for k, v in model_params.items() if v is not None}
# 调用AI生成
try:
raw_result, _, _, _ = await self.ai_agent.generate_text(
system_prompt=system_prompt,
user_prompt=user_prompt,
use_stream=True,
stage="选题生成",
**model_params
)
self.output_manager.save_text(raw_result, "topics_raw_response.txt")
except Exception as e:
logger.critical(f"AI调用失败无法生成选题: {e}", exc_info=True)
return None
# 解析结果
topics = self.parser.parse(raw_result)
if not topics:
logger.error("未能从AI响应中解析出任何有效选题")
return None
# 保存结果
self.output_manager.save_json(topics, "topics.json")
logger.info(f"成功生成并保存 {len(topics)} 个选题")
return topics