TravelContentCreator/domain/aigc/engines/topic_generate_v2.py

311 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
选题生成引擎 V2
- 不访问数据库,接收完整数据
- 使用 PromptRegistry 管理 prompt
- 统一依赖注入
"""
import logging
from typing import Dict, Any, Optional, List
from .base import BaseAIGCEngine, EngineResult
logger = logging.getLogger(__name__)
class TopicGenerateEngineV2(BaseAIGCEngine):
"""
选题生成引擎 V2
改进:
1. 不访问数据库,所有数据由调用方传入
2. 使用 PromptRegistry 管理 prompt
3. 接收完整对象而非 ID
"""
engine_id = "topic_generate"
engine_name = "选题生成"
version = "2.0.0"
description = "根据景区、产品、风格等信息生成营销选题"
def __init__(self):
super().__init__()
self._prompt_registry = None
def get_param_schema(self) -> Dict[str, Any]:
"""
定义参数结构
V2.2: 合并 scenic_spot 和 product 为 subject
"""
return {
# 基础参数
"num_topics": {
"type": "int",
"required": False,
"default": 5,
"desc": "生成选题数量",
},
"month": {
"type": "str",
"required": True,
"desc": "目标日期/月份 (如 '2024-12''12月5日')",
},
# 主体信息 (景区+产品合并)
"subject": {
"type": "object",
"required": False,
"desc": "主体信息 {id, name, type, description, location, products: [...]}",
},
# 兼容旧字段
"scenic_spot": {
"type": "object",
"required": False,
"desc": "[兼容] 景区信息,建议使用 subject",
},
"product": {
"type": "object",
"required": False,
"desc": "[兼容] 产品信息,建议使用 subject.products",
},
"style": {
"type": "object",
"required": False,
"desc": "风格信息对象 {id, name}",
},
"audience": {
"type": "object",
"required": False,
"desc": "受众信息对象 {id, name}",
},
# 热点信息
"hot_topics": {
"type": "object",
"required": False,
"desc": "热点信息 {events: [], festivals: [], trending: []}",
},
# 可选: 多选列表
"styles_list": {
"type": "list",
"required": False,
"desc": "可选风格列表 [{id, name}, ...]",
},
"audiences_list": {
"type": "list",
"required": False,
"desc": "可选受众列表 [{id, name}, ...]",
},
# Prompt 版本控制
"prompt_version": {
"type": "str",
"required": False,
"default": "latest",
"desc": "使用的 prompt 版本",
},
}
def estimate_duration(self, params: Dict[str, Any]) -> int:
"""预估执行时间"""
num_topics = params.get('num_topics', 5)
return 15 + num_topics * 2
async def execute(self, params: Dict[str, Any], task_id: str = None) -> EngineResult:
"""
执行选题生成
Args:
params: 包含完整对象的参数
task_id: 任务 ID
"""
try:
self.log(f"开始生成选题 (V2)")
self.set_progress(task_id, 10)
# 提取参数
num_topics = params.get('num_topics', 5)
month = params.get('month', '')
# 主体信息 (支持新旧两种格式)
subject = params.get('subject')
scenic_spot = params.get('scenic_spot')
product = params.get('product')
# 兼容处理: 如果没有 subject从 scenic_spot + product 构建
if not subject and scenic_spot:
subject = {
**scenic_spot,
'type': scenic_spot.get('type', 'scenic_spot'),
'products': [product] if product else []
}
style = params.get('style')
audience = params.get('audience')
hot_topics = params.get('hot_topics')
styles_list = params.get('styles_list', [])
audiences_list = params.get('audiences_list', [])
prompt_version = params.get('prompt_version', 'latest')
self.set_progress(task_id, 20)
# 获取 PromptRegistry
prompt_registry = self._get_prompt_registry()
# 构建 prompt 上下文
context = {
'num_topics': num_topics,
'month': month,
'subject': subject,
# 兼容旧 prompt 模板
'scenic_spot': subject,
'product': subject.get('products', [{}])[0] if subject and subject.get('products') else product,
'style': style,
'audience': audience,
'hot_topics': hot_topics,
'styles_list': self._format_list(styles_list),
'audiences_list': self._format_list(audiences_list),
}
# 渲染 prompt
system_prompt, user_prompt = prompt_registry.render(
'topic_generate',
context=context,
version=prompt_version
)
self.set_progress(task_id, 30)
# 获取模型参数
prompt_config = prompt_registry.get('topic_generate', prompt_version)
model_params = prompt_config.get_model_params()
# 调用 LLM
self.log("调用 LLM 生成选题...")
raw_result, input_tokens, output_tokens, time_cost = await self.llm.generate(
system_prompt=system_prompt,
user_prompt=user_prompt,
**model_params
)
self.set_progress(task_id, 70)
# 解析结果
topics = self._parse_topics(raw_result)
if not topics:
return EngineResult(
success=False,
error="选题生成失败,无法解析结果",
error_code="PARSE_ERROR"
)
self.set_progress(task_id, 90)
# 增强选题信息 (添加原始对象引用)
enhanced_topics = self._enhance_topics(
topics, subject, style, audience
)
self.set_progress(task_id, 100)
return EngineResult(
success=True,
data={
"topics": enhanced_topics,
"count": len(enhanced_topics),
},
metadata={
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"time_cost": time_cost,
"prompt_version": prompt_version,
}
)
except Exception as e:
self.log(f"选题生成异常: {e}", level='error')
return EngineResult(
success=False,
error=str(e),
error_code="EXECUTION_ERROR"
)
def _get_prompt_registry(self):
"""获取 PromptRegistry"""
if self._prompt_registry:
return self._prompt_registry
from domain.prompt import PromptRegistry
self._prompt_registry = PromptRegistry('prompts')
return self._prompt_registry
def _format_list(self, items: List[Dict]) -> str:
"""格式化列表为字符串"""
if not items:
return ""
lines = []
for item in items:
name = item.get('name', '')
desc = item.get('description', '')
if name:
lines.append(f"- {name}: {desc}" if desc else f"- {name}")
return "\n".join(lines)
def _parse_topics(self, raw_result: str) -> List[Dict[str, Any]]:
"""解析 LLM 返回的选题"""
import json
import re
# 尝试提取 JSON
json_match = re.search(r'\[[\s\S]*\]', raw_result)
if json_match:
try:
return json.loads(json_match.group())
except json.JSONDecodeError:
pass
# 尝试 json_repair
try:
import json_repair
return json_repair.loads(raw_result)
except:
pass
self.log("无法解析选题结果", level='error')
return []
def _enhance_topics(self, topics: List[Dict],
subject: Optional[Dict],
style: Optional[Dict],
audience: Optional[Dict]) -> List[Dict]:
"""增强选题信息"""
enhanced = []
for topic in topics:
enhanced_topic = dict(topic)
# 添加原始对象 ID 引用
if subject:
enhanced_topic['subject_id'] = subject.get('id')
# 从 products 中提取第一个产品的 ID (如果有)
products = subject.get('products', [])
if products and len(products) > 0:
enhanced_topic['product_id'] = products[0].get('id')
if style:
enhanced_topic['style_id'] = style.get('id')
if audience:
enhanced_topic['audience_id'] = audience.get('id')
enhanced.append(enhanced_topic)
return enhanced