236 lines
7.1 KiB
Python
236 lines
7.1 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
选题生成引擎
|
|
负责生成营销选题
|
|
"""
|
|
|
|
import logging
|
|
from typing import Dict, Any, Optional, List, Tuple
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TopicEngine:
|
|
"""
|
|
选题生成引擎
|
|
|
|
职责:
|
|
- 构建选题生成提示词
|
|
- 调用 LLM 生成选题
|
|
- 解析和增强选题结果
|
|
"""
|
|
|
|
def __init__(self, llm_client=None, prompt_builder=None, db_accessor=None):
|
|
"""
|
|
初始化选题引擎
|
|
|
|
Args:
|
|
llm_client: LLM 客户端
|
|
prompt_builder: 提示词构建器
|
|
db_accessor: 数据库访问器
|
|
"""
|
|
self._llm = llm_client
|
|
self._prompt = prompt_builder
|
|
self._db = db_accessor
|
|
self.logger = logging.getLogger(f"{__name__}.TopicEngine")
|
|
|
|
async def generate(
|
|
self,
|
|
scenic_spot_id: Optional[int] = None,
|
|
product_id: Optional[int] = None,
|
|
style_id: Optional[int] = None,
|
|
audience_id: Optional[int] = None,
|
|
num_topics: int = 5,
|
|
month: Optional[str] = None,
|
|
**kwargs
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
生成选题
|
|
|
|
Args:
|
|
scenic_spot_id: 景区 ID
|
|
product_id: 产品 ID
|
|
style_id: 风格 ID
|
|
audience_id: 受众 ID
|
|
num_topics: 生成数量
|
|
month: 目标月份
|
|
|
|
Returns:
|
|
选题列表
|
|
"""
|
|
try:
|
|
self.logger.info(f"开始生成选题: num={num_topics}")
|
|
|
|
# 1. 获取相关数据
|
|
context = await self._build_context(
|
|
scenic_spot_id, product_id, style_id, audience_id
|
|
)
|
|
|
|
# 2. 构建提示词
|
|
system_prompt, user_prompt = self._build_prompts(
|
|
context, num_topics, month
|
|
)
|
|
|
|
# 3. 调用 LLM
|
|
response = await self._llm.generate(
|
|
prompt=user_prompt,
|
|
system_prompt=system_prompt
|
|
)
|
|
|
|
# 4. 解析结果
|
|
topics = self._parse_topics(response)
|
|
|
|
# 5. 增强选题
|
|
enhanced_topics = await self._enhance_topics(topics, context)
|
|
|
|
self.logger.info(f"选题生成完成: {len(enhanced_topics)} 个")
|
|
return enhanced_topics
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"选题生成失败: {e}")
|
|
raise
|
|
|
|
async def _build_context(
|
|
self,
|
|
scenic_spot_id: Optional[int],
|
|
product_id: Optional[int],
|
|
style_id: Optional[int],
|
|
audience_id: Optional[int]
|
|
) -> Dict[str, Any]:
|
|
"""构建上下文信息"""
|
|
context = {}
|
|
|
|
if self._db:
|
|
if scenic_spot_id:
|
|
spot = await self._db.scenic_spot.find_by_id(scenic_spot_id)
|
|
if spot:
|
|
context['scenic_spot'] = spot
|
|
|
|
if product_id:
|
|
product = await self._db.product.find_by_id(product_id)
|
|
if product:
|
|
context['product'] = product
|
|
|
|
if style_id:
|
|
style = await self._db.style.find_by_id(style_id)
|
|
if style:
|
|
context['style'] = style
|
|
|
|
if audience_id:
|
|
audience = await self._db.audience.find_by_id(audience_id)
|
|
if audience:
|
|
context['audience'] = audience
|
|
|
|
return context
|
|
|
|
def _build_prompts(
|
|
self,
|
|
context: Dict[str, Any],
|
|
num_topics: int,
|
|
month: Optional[str]
|
|
) -> Tuple[str, str]:
|
|
"""构建提示词"""
|
|
if self._prompt:
|
|
system_prompt = self._prompt.get_system_prompt(
|
|
"topic_generate",
|
|
**context
|
|
)
|
|
user_prompt = self._prompt.get_user_prompt(
|
|
"topic_generate",
|
|
num_topics=num_topics,
|
|
month=month or "当月",
|
|
**context
|
|
)
|
|
return system_prompt, user_prompt
|
|
|
|
# 默认提示词
|
|
system_prompt = """你是一个专业的旅游内容营销专家。
|
|
请根据提供的信息生成吸引人的营销选题。
|
|
每个选题应该包含:标题、描述、关键词、目标受众。
|
|
输出格式为 JSON 数组。"""
|
|
|
|
spot_info = ""
|
|
if context.get('scenic_spot'):
|
|
spot = context['scenic_spot']
|
|
spot_info = f"景区:{spot.get('name', '')}\n描述:{spot.get('description', '')}"
|
|
|
|
user_prompt = f"""请生成 {num_topics} 个营销选题。
|
|
|
|
{spot_info}
|
|
|
|
目标月份:{month or '当月'}
|
|
|
|
请以 JSON 数组格式输出,每个选题包含:
|
|
- title: 选题标题
|
|
- description: 选题描述
|
|
- keywords: 关键词数组
|
|
- target_audience: 目标受众"""
|
|
|
|
return system_prompt, user_prompt
|
|
|
|
def _parse_topics(self, response: str) -> List[Dict[str, Any]]:
|
|
"""解析 LLM 响应"""
|
|
import json
|
|
import re
|
|
|
|
# 尝试直接解析
|
|
try:
|
|
return json.loads(response)
|
|
except:
|
|
pass
|
|
|
|
# 尝试提取 JSON 数组
|
|
patterns = [
|
|
r'\[[\s\S]*\]',
|
|
r'```json\s*([\s\S]*?)\s*```',
|
|
]
|
|
|
|
for pattern in patterns:
|
|
match = re.search(pattern, response)
|
|
if match:
|
|
try:
|
|
json_str = match.group(1) if '```' in pattern else match.group(0)
|
|
return json.loads(json_str)
|
|
except:
|
|
continue
|
|
|
|
self.logger.warning("无法解析选题响应")
|
|
return []
|
|
|
|
async def _enhance_topics(
|
|
self,
|
|
topics: List[Dict[str, Any]],
|
|
context: Dict[str, Any]
|
|
) -> List[Dict[str, Any]]:
|
|
"""增强选题信息"""
|
|
enhanced = []
|
|
|
|
for i, topic in enumerate(topics):
|
|
enhanced_topic = topic.copy()
|
|
|
|
# 添加索引
|
|
enhanced_topic['index'] = i
|
|
|
|
# 添加上下文信息
|
|
if context.get('scenic_spot'):
|
|
enhanced_topic['scenic_spot_id'] = context['scenic_spot'].get('id')
|
|
enhanced_topic['scenic_spot_name'] = context['scenic_spot'].get('name')
|
|
|
|
if context.get('product'):
|
|
enhanced_topic['product_id'] = context['product'].get('id')
|
|
enhanced_topic['product_name'] = context['product'].get('name')
|
|
|
|
if context.get('style'):
|
|
enhanced_topic['style_id'] = context['style'].get('id')
|
|
enhanced_topic['style_name'] = context['style'].get('name')
|
|
|
|
if context.get('audience'):
|
|
enhanced_topic['audience_id'] = context['audience'].get('id')
|
|
enhanced_topic['audience_name'] = context['audience'].get('name')
|
|
|
|
enhanced.append(enhanced_topic)
|
|
|
|
return enhanced
|