2025-12-08 14:58:35 +08:00
|
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
选题生成引擎 V2
|
|
|
|
|
|
- 不访问数据库,接收完整数据
|
|
|
|
|
|
- 使用 PromptRegistry 管理 prompt
|
|
|
|
|
|
- 统一依赖注入
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import logging
|
|
|
|
|
|
from typing import Dict, Any, Optional, List
|
|
|
|
|
|
|
|
|
|
|
|
from .base import BaseAIGCEngine, EngineResult
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TopicGenerateEngineV2(BaseAIGCEngine):
|
|
|
|
|
|
"""
|
|
|
|
|
|
选题生成引擎 V2
|
|
|
|
|
|
|
|
|
|
|
|
改进:
|
|
|
|
|
|
1. 不访问数据库,所有数据由调用方传入
|
|
|
|
|
|
2. 使用 PromptRegistry 管理 prompt
|
|
|
|
|
|
3. 接收完整对象而非 ID
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
2025-12-09 21:16:44 +08:00
|
|
|
|
engine_id = "topic_generate"
|
|
|
|
|
|
engine_name = "选题生成"
|
2025-12-08 14:58:35 +08:00
|
|
|
|
version = "2.0.0"
|
2025-12-09 21:16:44 +08:00
|
|
|
|
description = "根据景区、产品、风格等信息生成营销选题"
|
2025-12-08 14:58:35 +08:00
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
self._prompt_registry = None
|
|
|
|
|
|
|
|
|
|
|
|
def get_param_schema(self) -> Dict[str, Any]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
定义参数结构
|
|
|
|
|
|
|
2025-12-09 21:16:44 +08:00
|
|
|
|
V2.2: 合并 scenic_spot 和 product 为 subject
|
2025-12-08 14:58:35 +08:00
|
|
|
|
"""
|
|
|
|
|
|
return {
|
|
|
|
|
|
# 基础参数
|
|
|
|
|
|
"num_topics": {
|
|
|
|
|
|
"type": "int",
|
|
|
|
|
|
"required": False,
|
|
|
|
|
|
"default": 5,
|
|
|
|
|
|
"desc": "生成选题数量",
|
|
|
|
|
|
},
|
|
|
|
|
|
"month": {
|
|
|
|
|
|
"type": "str",
|
|
|
|
|
|
"required": True,
|
|
|
|
|
|
"desc": "目标日期/月份 (如 '2024-12' 或 '12月5日')",
|
|
|
|
|
|
},
|
|
|
|
|
|
|
2025-12-09 21:16:44 +08:00
|
|
|
|
# 主体信息 (景区+产品合并)
|
|
|
|
|
|
"subject": {
|
|
|
|
|
|
"type": "object",
|
|
|
|
|
|
"required": False,
|
|
|
|
|
|
"desc": "主体信息 {id, name, type, description, location, products: [...]}",
|
|
|
|
|
|
},
|
|
|
|
|
|
# 兼容旧字段
|
2025-12-08 14:58:35 +08:00
|
|
|
|
"scenic_spot": {
|
|
|
|
|
|
"type": "object",
|
|
|
|
|
|
"required": False,
|
2025-12-09 21:16:44 +08:00
|
|
|
|
"desc": "[兼容] 景区信息,建议使用 subject",
|
2025-12-08 14:58:35 +08:00
|
|
|
|
},
|
|
|
|
|
|
"product": {
|
|
|
|
|
|
"type": "object",
|
|
|
|
|
|
"required": False,
|
2025-12-09 21:16:44 +08:00
|
|
|
|
"desc": "[兼容] 产品信息,建议使用 subject.products",
|
2025-12-08 14:58:35 +08:00
|
|
|
|
},
|
2025-12-09 21:16:44 +08:00
|
|
|
|
|
2025-12-08 14:58:35 +08:00
|
|
|
|
"style": {
|
|
|
|
|
|
"type": "object",
|
|
|
|
|
|
"required": False,
|
2025-12-09 21:16:44 +08:00
|
|
|
|
"desc": "风格信息对象 {id, name}",
|
2025-12-08 14:58:35 +08:00
|
|
|
|
},
|
|
|
|
|
|
"audience": {
|
|
|
|
|
|
"type": "object",
|
|
|
|
|
|
"required": False,
|
2025-12-09 21:16:44 +08:00
|
|
|
|
"desc": "受众信息对象 {id, name}",
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
|
|
|
|
# 热点信息
|
|
|
|
|
|
"hot_topics": {
|
|
|
|
|
|
"type": "object",
|
|
|
|
|
|
"required": False,
|
|
|
|
|
|
"desc": "热点信息 {events: [], festivals: [], trending: []}",
|
2025-12-08 14:58:35 +08:00
|
|
|
|
},
|
|
|
|
|
|
|
|
|
|
|
|
# 可选: 多选列表
|
|
|
|
|
|
"styles_list": {
|
|
|
|
|
|
"type": "list",
|
|
|
|
|
|
"required": False,
|
|
|
|
|
|
"desc": "可选风格列表 [{id, name}, ...]",
|
|
|
|
|
|
},
|
|
|
|
|
|
"audiences_list": {
|
|
|
|
|
|
"type": "list",
|
|
|
|
|
|
"required": False,
|
|
|
|
|
|
"desc": "可选受众列表 [{id, name}, ...]",
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
|
|
|
|
# Prompt 版本控制
|
|
|
|
|
|
"prompt_version": {
|
|
|
|
|
|
"type": "str",
|
|
|
|
|
|
"required": False,
|
|
|
|
|
|
"default": "latest",
|
|
|
|
|
|
"desc": "使用的 prompt 版本",
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def estimate_duration(self, params: Dict[str, Any]) -> int:
|
|
|
|
|
|
"""预估执行时间"""
|
|
|
|
|
|
num_topics = params.get('num_topics', 5)
|
|
|
|
|
|
return 15 + num_topics * 2
|
|
|
|
|
|
|
|
|
|
|
|
async def execute(self, params: Dict[str, Any], task_id: str = None) -> EngineResult:
|
|
|
|
|
|
"""
|
|
|
|
|
|
执行选题生成
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
params: 包含完整对象的参数
|
|
|
|
|
|
task_id: 任务 ID
|
|
|
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
self.log(f"开始生成选题 (V2)")
|
|
|
|
|
|
self.set_progress(task_id, 10)
|
|
|
|
|
|
|
|
|
|
|
|
# 提取参数
|
|
|
|
|
|
num_topics = params.get('num_topics', 5)
|
|
|
|
|
|
month = params.get('month', '')
|
2025-12-09 21:16:44 +08:00
|
|
|
|
|
|
|
|
|
|
# 主体信息 (支持新旧两种格式)
|
|
|
|
|
|
subject = params.get('subject')
|
2025-12-08 14:58:35 +08:00
|
|
|
|
scenic_spot = params.get('scenic_spot')
|
|
|
|
|
|
product = params.get('product')
|
2025-12-09 21:16:44 +08:00
|
|
|
|
|
|
|
|
|
|
# 兼容处理: 如果没有 subject,从 scenic_spot + product 构建
|
|
|
|
|
|
if not subject and scenic_spot:
|
|
|
|
|
|
subject = {
|
|
|
|
|
|
**scenic_spot,
|
|
|
|
|
|
'type': scenic_spot.get('type', 'scenic_spot'),
|
|
|
|
|
|
'products': [product] if product else []
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-08 14:58:35 +08:00
|
|
|
|
style = params.get('style')
|
|
|
|
|
|
audience = params.get('audience')
|
2025-12-09 21:16:44 +08:00
|
|
|
|
hot_topics = params.get('hot_topics')
|
2025-12-08 14:58:35 +08:00
|
|
|
|
styles_list = params.get('styles_list', [])
|
|
|
|
|
|
audiences_list = params.get('audiences_list', [])
|
|
|
|
|
|
prompt_version = params.get('prompt_version', 'latest')
|
|
|
|
|
|
|
|
|
|
|
|
self.set_progress(task_id, 20)
|
|
|
|
|
|
|
|
|
|
|
|
# 获取 PromptRegistry
|
|
|
|
|
|
prompt_registry = self._get_prompt_registry()
|
|
|
|
|
|
|
2025-12-11 10:57:30 +08:00
|
|
|
|
# 🎯 方案A: 检查是否有 Java 透传的 prompt_context
|
|
|
|
|
|
prompt_context = params.get('prompt_context')
|
|
|
|
|
|
|
|
|
|
|
|
if prompt_context:
|
|
|
|
|
|
self.log("使用 Java 透传的 prompt_context")
|
|
|
|
|
|
context = prompt_context
|
|
|
|
|
|
# 补充必要字段
|
|
|
|
|
|
if 'num_topics' not in context:
|
|
|
|
|
|
context['num_topics'] = num_topics
|
|
|
|
|
|
if 'month' not in context:
|
|
|
|
|
|
context['month'] = month
|
|
|
|
|
|
if 'hot_topics' not in context and hot_topics:
|
|
|
|
|
|
context['hot_topics'] = hot_topics
|
|
|
|
|
|
else:
|
|
|
|
|
|
self.log("使用本地构建的 context (兼容模式)")
|
|
|
|
|
|
# 构建 prompt 上下文
|
|
|
|
|
|
context = {
|
|
|
|
|
|
'num_topics': num_topics,
|
|
|
|
|
|
'month': month,
|
|
|
|
|
|
'subject': subject,
|
|
|
|
|
|
# 兼容旧 prompt 模板
|
|
|
|
|
|
'scenic_spot': subject,
|
|
|
|
|
|
'product': subject.get('products', [{}])[0] if subject and subject.get('products') else product,
|
|
|
|
|
|
'style': style,
|
|
|
|
|
|
'audience': audience,
|
|
|
|
|
|
'hot_topics': hot_topics,
|
|
|
|
|
|
'styles_list': self._format_list(styles_list),
|
|
|
|
|
|
'audiences_list': self._format_list(audiences_list),
|
|
|
|
|
|
}
|
2025-12-08 14:58:35 +08:00
|
|
|
|
|
|
|
|
|
|
# 渲染 prompt
|
|
|
|
|
|
system_prompt, user_prompt = prompt_registry.render(
|
|
|
|
|
|
'topic_generate',
|
|
|
|
|
|
context=context,
|
|
|
|
|
|
version=prompt_version
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
self.set_progress(task_id, 30)
|
|
|
|
|
|
|
|
|
|
|
|
# 获取模型参数
|
|
|
|
|
|
prompt_config = prompt_registry.get('topic_generate', prompt_version)
|
|
|
|
|
|
model_params = prompt_config.get_model_params()
|
|
|
|
|
|
|
|
|
|
|
|
# 调用 LLM
|
|
|
|
|
|
self.log("调用 LLM 生成选题...")
|
|
|
|
|
|
raw_result, input_tokens, output_tokens, time_cost = await self.llm.generate(
|
|
|
|
|
|
system_prompt=system_prompt,
|
|
|
|
|
|
user_prompt=user_prompt,
|
|
|
|
|
|
**model_params
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
self.set_progress(task_id, 70)
|
|
|
|
|
|
|
|
|
|
|
|
# 解析结果
|
|
|
|
|
|
topics = self._parse_topics(raw_result)
|
|
|
|
|
|
|
|
|
|
|
|
if not topics:
|
|
|
|
|
|
return EngineResult(
|
|
|
|
|
|
success=False,
|
|
|
|
|
|
error="选题生成失败,无法解析结果",
|
|
|
|
|
|
error_code="PARSE_ERROR"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
self.set_progress(task_id, 90)
|
|
|
|
|
|
|
|
|
|
|
|
# 增强选题信息 (添加原始对象引用)
|
|
|
|
|
|
enhanced_topics = self._enhance_topics(
|
2025-12-09 21:16:44 +08:00
|
|
|
|
topics, subject, style, audience
|
2025-12-08 14:58:35 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
self.set_progress(task_id, 100)
|
|
|
|
|
|
|
|
|
|
|
|
return EngineResult(
|
|
|
|
|
|
success=True,
|
|
|
|
|
|
data={
|
|
|
|
|
|
"topics": enhanced_topics,
|
|
|
|
|
|
"count": len(enhanced_topics),
|
|
|
|
|
|
},
|
|
|
|
|
|
metadata={
|
|
|
|
|
|
"input_tokens": input_tokens,
|
|
|
|
|
|
"output_tokens": output_tokens,
|
|
|
|
|
|
"time_cost": time_cost,
|
|
|
|
|
|
"prompt_version": prompt_version,
|
|
|
|
|
|
}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
self.log(f"选题生成异常: {e}", level='error')
|
|
|
|
|
|
return EngineResult(
|
|
|
|
|
|
success=False,
|
|
|
|
|
|
error=str(e),
|
|
|
|
|
|
error_code="EXECUTION_ERROR"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def _get_prompt_registry(self):
|
|
|
|
|
|
"""获取 PromptRegistry"""
|
|
|
|
|
|
if self._prompt_registry:
|
|
|
|
|
|
return self._prompt_registry
|
|
|
|
|
|
|
|
|
|
|
|
from domain.prompt import PromptRegistry
|
|
|
|
|
|
self._prompt_registry = PromptRegistry('prompts')
|
|
|
|
|
|
return self._prompt_registry
|
|
|
|
|
|
|
|
|
|
|
|
def _format_list(self, items: List[Dict]) -> str:
|
|
|
|
|
|
"""格式化列表为字符串"""
|
|
|
|
|
|
if not items:
|
|
|
|
|
|
return ""
|
|
|
|
|
|
|
|
|
|
|
|
lines = []
|
|
|
|
|
|
for item in items:
|
|
|
|
|
|
name = item.get('name', '')
|
|
|
|
|
|
desc = item.get('description', '')
|
|
|
|
|
|
if name:
|
|
|
|
|
|
lines.append(f"- {name}: {desc}" if desc else f"- {name}")
|
|
|
|
|
|
|
|
|
|
|
|
return "\n".join(lines)
|
|
|
|
|
|
|
|
|
|
|
|
def _parse_topics(self, raw_result: str) -> List[Dict[str, Any]]:
|
|
|
|
|
|
"""解析 LLM 返回的选题"""
|
|
|
|
|
|
import json
|
|
|
|
|
|
import re
|
|
|
|
|
|
|
|
|
|
|
|
# 尝试提取 JSON
|
|
|
|
|
|
json_match = re.search(r'\[[\s\S]*\]', raw_result)
|
|
|
|
|
|
if json_match:
|
|
|
|
|
|
try:
|
|
|
|
|
|
return json.loads(json_match.group())
|
|
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
# 尝试 json_repair
|
|
|
|
|
|
try:
|
|
|
|
|
|
import json_repair
|
|
|
|
|
|
return json_repair.loads(raw_result)
|
|
|
|
|
|
except:
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
self.log("无法解析选题结果", level='error')
|
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
def _enhance_topics(self, topics: List[Dict],
|
2025-12-09 21:16:44 +08:00
|
|
|
|
subject: Optional[Dict],
|
2025-12-08 14:58:35 +08:00
|
|
|
|
style: Optional[Dict],
|
|
|
|
|
|
audience: Optional[Dict]) -> List[Dict]:
|
|
|
|
|
|
"""增强选题信息"""
|
|
|
|
|
|
enhanced = []
|
|
|
|
|
|
|
|
|
|
|
|
for topic in topics:
|
|
|
|
|
|
enhanced_topic = dict(topic)
|
|
|
|
|
|
|
|
|
|
|
|
# 添加原始对象 ID 引用
|
2025-12-09 21:16:44 +08:00
|
|
|
|
if subject:
|
|
|
|
|
|
enhanced_topic['subject_id'] = subject.get('id')
|
|
|
|
|
|
# 从 products 中提取第一个产品的 ID (如果有)
|
|
|
|
|
|
products = subject.get('products', [])
|
|
|
|
|
|
if products and len(products) > 0:
|
|
|
|
|
|
enhanced_topic['product_id'] = products[0].get('id')
|
2025-12-08 14:58:35 +08:00
|
|
|
|
if style:
|
|
|
|
|
|
enhanced_topic['style_id'] = style.get('id')
|
|
|
|
|
|
if audience:
|
|
|
|
|
|
enhanced_topic['audience_id'] = audience.get('id')
|
|
|
|
|
|
|
|
|
|
|
|
enhanced.append(enhanced_topic)
|
|
|
|
|
|
|
|
|
|
|
|
return enhanced
|