236 lines
7.1 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
选题生成引擎
负责生成营销选题
"""
import logging
from typing import Dict, Any, Optional, List, Tuple
logger = logging.getLogger(__name__)
class TopicEngine:
"""
选题生成引擎
职责:
- 构建选题生成提示词
- 调用 LLM 生成选题
- 解析和增强选题结果
"""
def __init__(self, llm_client=None, prompt_builder=None, db_accessor=None):
"""
初始化选题引擎
Args:
llm_client: LLM 客户端
prompt_builder: 提示词构建器
db_accessor: 数据库访问器
"""
self._llm = llm_client
self._prompt = prompt_builder
self._db = db_accessor
self.logger = logging.getLogger(f"{__name__}.TopicEngine")
async def generate(
self,
scenic_spot_id: Optional[int] = None,
product_id: Optional[int] = None,
style_id: Optional[int] = None,
audience_id: Optional[int] = None,
num_topics: int = 5,
month: Optional[str] = None,
**kwargs
) -> List[Dict[str, Any]]:
"""
生成选题
Args:
scenic_spot_id: 景区 ID
product_id: 产品 ID
style_id: 风格 ID
audience_id: 受众 ID
num_topics: 生成数量
month: 目标月份
Returns:
选题列表
"""
try:
self.logger.info(f"开始生成选题: num={num_topics}")
# 1. 获取相关数据
context = await self._build_context(
scenic_spot_id, product_id, style_id, audience_id
)
# 2. 构建提示词
system_prompt, user_prompt = self._build_prompts(
context, num_topics, month
)
# 3. 调用 LLM
response = await self._llm.generate(
prompt=user_prompt,
system_prompt=system_prompt
)
# 4. 解析结果
topics = self._parse_topics(response)
# 5. 增强选题
enhanced_topics = await self._enhance_topics(topics, context)
self.logger.info(f"选题生成完成: {len(enhanced_topics)}")
return enhanced_topics
except Exception as e:
self.logger.error(f"选题生成失败: {e}")
raise
async def _build_context(
self,
scenic_spot_id: Optional[int],
product_id: Optional[int],
style_id: Optional[int],
audience_id: Optional[int]
) -> Dict[str, Any]:
"""构建上下文信息"""
context = {}
if self._db:
if scenic_spot_id:
spot = await self._db.scenic_spot.find_by_id(scenic_spot_id)
if spot:
context['scenic_spot'] = spot
if product_id:
product = await self._db.product.find_by_id(product_id)
if product:
context['product'] = product
if style_id:
style = await self._db.style.find_by_id(style_id)
if style:
context['style'] = style
if audience_id:
audience = await self._db.audience.find_by_id(audience_id)
if audience:
context['audience'] = audience
return context
def _build_prompts(
self,
context: Dict[str, Any],
num_topics: int,
month: Optional[str]
) -> Tuple[str, str]:
"""构建提示词"""
if self._prompt:
system_prompt = self._prompt.get_system_prompt(
"topic_generate",
**context
)
user_prompt = self._prompt.get_user_prompt(
"topic_generate",
num_topics=num_topics,
month=month or "当月",
**context
)
return system_prompt, user_prompt
# 默认提示词
system_prompt = """你是一个专业的旅游内容营销专家。
请根据提供的信息生成吸引人的营销选题。
每个选题应该包含:标题、描述、关键词、目标受众。
输出格式为 JSON 数组。"""
spot_info = ""
if context.get('scenic_spot'):
spot = context['scenic_spot']
spot_info = f"景区:{spot.get('name', '')}\n描述:{spot.get('description', '')}"
user_prompt = f"""请生成 {num_topics} 个营销选题。
{spot_info}
目标月份:{month or '当月'}
请以 JSON 数组格式输出,每个选题包含:
- title: 选题标题
- description: 选题描述
- keywords: 关键词数组
- target_audience: 目标受众"""
return system_prompt, user_prompt
def _parse_topics(self, response: str) -> List[Dict[str, Any]]:
"""解析 LLM 响应"""
import json
import re
# 尝试直接解析
try:
return json.loads(response)
except:
pass
# 尝试提取 JSON 数组
patterns = [
r'\[[\s\S]*\]',
r'```json\s*([\s\S]*?)\s*```',
]
for pattern in patterns:
match = re.search(pattern, response)
if match:
try:
json_str = match.group(1) if '```' in pattern else match.group(0)
return json.loads(json_str)
except:
continue
self.logger.warning("无法解析选题响应")
return []
async def _enhance_topics(
self,
topics: List[Dict[str, Any]],
context: Dict[str, Any]
) -> List[Dict[str, Any]]:
"""增强选题信息"""
enhanced = []
for i, topic in enumerate(topics):
enhanced_topic = topic.copy()
# 添加索引
enhanced_topic['index'] = i
# 添加上下文信息
if context.get('scenic_spot'):
enhanced_topic['scenic_spot_id'] = context['scenic_spot'].get('id')
enhanced_topic['scenic_spot_name'] = context['scenic_spot'].get('name')
if context.get('product'):
enhanced_topic['product_id'] = context['product'].get('id')
enhanced_topic['product_name'] = context['product'].get('name')
if context.get('style'):
enhanced_topic['style_id'] = context['style'].get('id')
enhanced_topic['style_name'] = context['style'].get('name')
if context.get('audience'):
enhanced_topic['audience_id'] = context['audience'].get('id')
enhanced_topic['audience_name'] = context['audience'].get('name')
enhanced.append(enhanced_topic)
return enhanced