TravelContentCreator/domain/aigc/engines/topic_generate_v2.py

280 lines
8.9 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
选题生成引擎 V2
- 不访问数据库,接收完整数据
- 使用 PromptRegistry 管理 prompt
- 统一依赖注入
"""
import logging
from typing import Dict, Any, Optional, List
from .base import BaseAIGCEngine, EngineResult
logger = logging.getLogger(__name__)
class TopicGenerateEngineV2(BaseAIGCEngine):
"""
选题生成引擎 V2
改进:
1. 不访问数据库,所有数据由调用方传入
2. 使用 PromptRegistry 管理 prompt
3. 接收完整对象而非 ID
"""
engine_id = "topic_generate_v2"
engine_name = "选题生成 V2"
version = "2.0.0"
description = "根据景区、产品、风格等信息生成营销选题(新版本,无数据库依赖)"
def __init__(self):
super().__init__()
self._prompt_registry = None
def get_param_schema(self) -> Dict[str, Any]:
"""
定义参数结构
V2 改进: 接收完整对象而非 ID
"""
return {
# 基础参数
"num_topics": {
"type": "int",
"required": False,
"default": 5,
"desc": "生成选题数量",
},
"month": {
"type": "str",
"required": True,
"desc": "目标日期/月份 (如 '2024-12''12月5日')",
},
# 完整对象 (由 Java 端传入,无需 Python 查数据库)
"scenic_spot": {
"type": "object",
"required": False,
"desc": "景区信息对象 {id, name, description, location, ...}",
},
"product": {
"type": "object",
"required": False,
"desc": "产品信息对象 {id, name, price, description, ...}",
},
"style": {
"type": "object",
"required": False,
"desc": "风格信息对象 {id, name, description}",
},
"audience": {
"type": "object",
"required": False,
"desc": "受众信息对象 {id, name, description}",
},
# 可选: 多选列表
"styles_list": {
"type": "list",
"required": False,
"desc": "可选风格列表 [{id, name}, ...]",
},
"audiences_list": {
"type": "list",
"required": False,
"desc": "可选受众列表 [{id, name}, ...]",
},
# Prompt 版本控制
"prompt_version": {
"type": "str",
"required": False,
"default": "latest",
"desc": "使用的 prompt 版本",
},
}
def estimate_duration(self, params: Dict[str, Any]) -> int:
"""预估执行时间"""
num_topics = params.get('num_topics', 5)
return 15 + num_topics * 2
async def execute(self, params: Dict[str, Any], task_id: str = None) -> EngineResult:
"""
执行选题生成
Args:
params: 包含完整对象的参数
task_id: 任务 ID
"""
try:
self.log(f"开始生成选题 (V2)")
self.set_progress(task_id, 10)
# 提取参数
num_topics = params.get('num_topics', 5)
month = params.get('month', '')
scenic_spot = params.get('scenic_spot')
product = params.get('product')
style = params.get('style')
audience = params.get('audience')
styles_list = params.get('styles_list', [])
audiences_list = params.get('audiences_list', [])
prompt_version = params.get('prompt_version', 'latest')
self.set_progress(task_id, 20)
# 获取 PromptRegistry
prompt_registry = self._get_prompt_registry()
# 构建 prompt 上下文
context = {
'num_topics': num_topics,
'month': month,
'scenic_spot': scenic_spot,
'product': product,
'style': style,
'audience': audience,
'styles_list': self._format_list(styles_list),
'audiences_list': self._format_list(audiences_list),
}
# 渲染 prompt
system_prompt, user_prompt = prompt_registry.render(
'topic_generate',
context=context,
version=prompt_version
)
self.set_progress(task_id, 30)
# 获取模型参数
prompt_config = prompt_registry.get('topic_generate', prompt_version)
model_params = prompt_config.get_model_params()
# 调用 LLM
self.log("调用 LLM 生成选题...")
raw_result, input_tokens, output_tokens, time_cost = await self.llm.generate(
system_prompt=system_prompt,
user_prompt=user_prompt,
**model_params
)
self.set_progress(task_id, 70)
# 解析结果
topics = self._parse_topics(raw_result)
if not topics:
return EngineResult(
success=False,
error="选题生成失败,无法解析结果",
error_code="PARSE_ERROR"
)
self.set_progress(task_id, 90)
# 增强选题信息 (添加原始对象引用)
enhanced_topics = self._enhance_topics(
topics, scenic_spot, product, style, audience
)
self.set_progress(task_id, 100)
return EngineResult(
success=True,
data={
"topics": enhanced_topics,
"count": len(enhanced_topics),
},
metadata={
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"time_cost": time_cost,
"prompt_version": prompt_version,
}
)
except Exception as e:
self.log(f"选题生成异常: {e}", level='error')
return EngineResult(
success=False,
error=str(e),
error_code="EXECUTION_ERROR"
)
def _get_prompt_registry(self):
"""获取 PromptRegistry"""
if self._prompt_registry:
return self._prompt_registry
from domain.prompt import PromptRegistry
self._prompt_registry = PromptRegistry('prompts')
return self._prompt_registry
def _format_list(self, items: List[Dict]) -> str:
"""格式化列表为字符串"""
if not items:
return ""
lines = []
for item in items:
name = item.get('name', '')
desc = item.get('description', '')
if name:
lines.append(f"- {name}: {desc}" if desc else f"- {name}")
return "\n".join(lines)
def _parse_topics(self, raw_result: str) -> List[Dict[str, Any]]:
"""解析 LLM 返回的选题"""
import json
import re
# 尝试提取 JSON
json_match = re.search(r'\[[\s\S]*\]', raw_result)
if json_match:
try:
return json.loads(json_match.group())
except json.JSONDecodeError:
pass
# 尝试 json_repair
try:
import json_repair
return json_repair.loads(raw_result)
except:
pass
self.log("无法解析选题结果", level='error')
return []
def _enhance_topics(self, topics: List[Dict],
scenic_spot: Optional[Dict],
product: Optional[Dict],
style: Optional[Dict],
audience: Optional[Dict]) -> List[Dict]:
"""增强选题信息"""
enhanced = []
for topic in topics:
enhanced_topic = dict(topic)
# 添加原始对象 ID 引用
if scenic_spot:
enhanced_topic['scenic_spot_id'] = scenic_spot.get('id')
if product:
enhanced_topic['product_id'] = product.get('id')
if style:
enhanced_topic['style_id'] = style.get('id')
if audience:
enhanced_topic['audience_id'] = audience.get('id')
enhanced.append(enhanced_topic)
return enhanced