#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 选题生成引擎 V2 - 不访问数据库,接收完整数据 - 使用 PromptRegistry 管理 prompt - 统一依赖注入 """ import logging from typing import Dict, Any, Optional, List from .base import BaseAIGCEngine, EngineResult logger = logging.getLogger(__name__) class TopicGenerateEngineV2(BaseAIGCEngine): """ 选题生成引擎 V2 改进: 1. 不访问数据库,所有数据由调用方传入 2. 使用 PromptRegistry 管理 prompt 3. 接收完整对象而非 ID """ engine_id = "topic_generate" engine_name = "选题生成" version = "2.0.0" description = "根据景区、产品、风格等信息生成营销选题" def __init__(self): super().__init__() self._prompt_registry = None def get_param_schema(self) -> Dict[str, Any]: """ 定义参数结构 V2.2: 合并 scenic_spot 和 product 为 subject """ return { # 基础参数 "num_topics": { "type": "int", "required": False, "default": 5, "desc": "生成选题数量", }, "month": { "type": "str", "required": True, "desc": "目标日期/月份 (如 '2024-12' 或 '12月5日')", }, # 主体信息 (景区+产品合并) "subject": { "type": "object", "required": False, "desc": "主体信息 {id, name, type, description, location, products: [...]}", }, # 兼容旧字段 "scenic_spot": { "type": "object", "required": False, "desc": "[兼容] 景区信息,建议使用 subject", }, "product": { "type": "object", "required": False, "desc": "[兼容] 产品信息,建议使用 subject.products", }, "style": { "type": "object", "required": False, "desc": "风格信息对象 {id, name}", }, "audience": { "type": "object", "required": False, "desc": "受众信息对象 {id, name}", }, # 热点信息 "hot_topics": { "type": "object", "required": False, "desc": "热点信息 {events: [], festivals: [], trending: []}", }, # 可选: 多选列表 "styles_list": { "type": "list", "required": False, "desc": "可选风格列表 [{id, name}, ...]", }, "audiences_list": { "type": "list", "required": False, "desc": "可选受众列表 [{id, name}, ...]", }, # Prompt 版本控制 "prompt_version": { "type": "str", "required": False, "default": "latest", "desc": "使用的 prompt 版本", }, } def estimate_duration(self, params: Dict[str, Any]) -> int: """预估执行时间""" num_topics = params.get('num_topics', 5) return 15 + num_topics * 2 async def execute(self, params: Dict[str, Any], task_id: str = None) -> EngineResult: """ 执行选题生成 Args: params: 包含完整对象的参数 task_id: 任务 ID """ try: self.log(f"开始生成选题 (V2)") self.set_progress(task_id, 10) # 提取参数 num_topics = params.get('num_topics', 5) month = params.get('month', '') # 主体信息 (支持新旧两种格式) subject = params.get('subject') scenic_spot = params.get('scenic_spot') product = params.get('product') # 兼容处理: 如果没有 subject,从 scenic_spot + product 构建 if not subject and scenic_spot: subject = { **scenic_spot, 'type': scenic_spot.get('type', 'scenic_spot'), 'products': [product] if product else [] } style = params.get('style') audience = params.get('audience') hot_topics = params.get('hot_topics') styles_list = params.get('styles_list', []) audiences_list = params.get('audiences_list', []) prompt_version = params.get('prompt_version', 'latest') self.set_progress(task_id, 20) # 获取 PromptRegistry prompt_registry = self._get_prompt_registry() # 🎯 方案A: 检查是否有 Java 透传的 prompt_context prompt_context = params.get('prompt_context') if prompt_context: self.log("使用 Java 透传的 prompt_context") context = prompt_context # 补充必要字段 if 'num_topics' not in context: context['num_topics'] = num_topics if 'month' not in context: context['month'] = month if 'hot_topics' not in context and hot_topics: context['hot_topics'] = hot_topics else: self.log("使用本地构建的 context (兼容模式)") # 构建 prompt 上下文 context = { 'num_topics': num_topics, 'month': month, 'subject': subject, # 兼容旧 prompt 模板 'scenic_spot': subject, 'product': subject.get('products', [{}])[0] if subject and subject.get('products') else product, 'style': style, 'audience': audience, 'hot_topics': hot_topics, 'styles_list': self._format_list(styles_list), 'audiences_list': self._format_list(audiences_list), } # 渲染 prompt system_prompt, user_prompt = prompt_registry.render( 'topic_generate', context=context, version=prompt_version ) self.set_progress(task_id, 30) # 获取模型参数 prompt_config = prompt_registry.get('topic_generate', prompt_version) model_params = prompt_config.get_model_params() # 调用 LLM self.log("调用 LLM 生成选题...") raw_result, input_tokens, output_tokens, time_cost = await self.llm.generate( system_prompt=system_prompt, user_prompt=user_prompt, **model_params ) self.set_progress(task_id, 70) # 解析结果 topics = self._parse_topics(raw_result) if not topics: return EngineResult( success=False, error="选题生成失败,无法解析结果", error_code="PARSE_ERROR" ) self.set_progress(task_id, 90) # 增强选题信息 (添加原始对象引用) enhanced_topics = self._enhance_topics( topics, subject, style, audience ) self.set_progress(task_id, 100) return EngineResult( success=True, data={ "topics": enhanced_topics, "count": len(enhanced_topics), }, metadata={ "input_tokens": input_tokens, "output_tokens": output_tokens, "time_cost": time_cost, "prompt_version": prompt_version, } ) except Exception as e: self.log(f"选题生成异常: {e}", level='error') return EngineResult( success=False, error=str(e), error_code="EXECUTION_ERROR" ) def _get_prompt_registry(self): """获取 PromptRegistry""" if self._prompt_registry: return self._prompt_registry from domain.prompt import PromptRegistry self._prompt_registry = PromptRegistry('prompts') return self._prompt_registry def _format_list(self, items: List[Dict]) -> str: """格式化列表为字符串""" if not items: return "" lines = [] for item in items: name = item.get('name', '') desc = item.get('description', '') if name: lines.append(f"- {name}: {desc}" if desc else f"- {name}") return "\n".join(lines) def _parse_topics(self, raw_result: str) -> List[Dict[str, Any]]: """解析 LLM 返回的选题""" import json import re # 尝试提取 JSON json_match = re.search(r'\[[\s\S]*\]', raw_result) if json_match: try: return json.loads(json_match.group()) except json.JSONDecodeError: pass # 尝试 json_repair try: import json_repair return json_repair.loads(raw_result) except: pass self.log("无法解析选题结果", level='error') return [] def _enhance_topics(self, topics: List[Dict], subject: Optional[Dict], style: Optional[Dict], audience: Optional[Dict]) -> List[Dict]: """增强选题信息""" enhanced = [] for topic in topics: enhanced_topic = dict(topic) # 添加原始对象 ID 引用 if subject: enhanced_topic['subject_id'] = subject.get('id') # 从 products 中提取第一个产品的 ID (如果有) products = subject.get('products', []) if products and len(products) > 0: enhanced_topic['product_id'] = products[0].get('id') if style: enhanced_topic['style_id'] = style.get('id') if audience: enhanced_topic['audience_id'] = audience.get('id') enhanced.append(enhanced_topic) return enhanced