TravelContentCreator/domain/aigc/engines/content_generate_v2.py

352 lines
12 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
内容生成引擎 V2
- 不访问数据库,接收完整数据
- 使用 PromptRegistry 管理 prompt
- 统一依赖注入
"""
import logging
from typing import Dict, Any, Optional
from .base import BaseAIGCEngine, EngineResult
logger = logging.getLogger(__name__)
class ContentGenerateEngineV2(BaseAIGCEngine):
"""
内容生成引擎 V2
改进:
1. 不访问数据库,所有数据由调用方传入
2. 使用 PromptRegistry 管理 prompt
3. 接收完整对象而非 ID
"""
engine_id = "content_generate_v2"
engine_name = "内容生成 V2"
version = "2.0.0"
description = "根据选题信息生成小红书风格的营销文案(新版本,无数据库依赖)"
def __init__(self):
super().__init__()
self._prompt_registry = None
def get_param_schema(self) -> Dict[str, Any]:
"""定义参数结构"""
return {
# 选题信息
"topic": {
"type": "object",
"required": True,
"desc": "选题信息 {index, date, object, style, targetAudience, ...}",
},
# 完整对象 (由 Java 端传入)
"scenic_spot": {
"type": "object",
"required": False,
"desc": "景区信息对象 {id, name, description, ...}",
},
"product": {
"type": "object",
"required": False,
"desc": "产品信息对象 {id, name, price, description, ...}",
},
"style": {
"type": "object",
"required": False,
"desc": "风格信息对象 {id, name, description}",
},
"audience": {
"type": "object",
"required": False,
"desc": "受众信息对象 {id, name, description}",
},
# 参考内容
"refer_content": {
"type": "str",
"required": False,
"desc": "参考范文内容",
},
# 审核选项
"enable_judge": {
"type": "bool",
"required": False,
"default": True,
"desc": "是否启用内容审核",
},
# Prompt 版本控制
"prompt_version": {
"type": "str",
"required": False,
"default": "latest",
"desc": "使用的 prompt 版本",
},
}
def estimate_duration(self, params: Dict[str, Any]) -> int:
"""预估执行时间"""
enable_judge = params.get('enable_judge', True)
return 30 if enable_judge else 20
async def execute(self, params: Dict[str, Any], task_id: str = None) -> EngineResult:
"""执行内容生成"""
try:
self.log("开始生成内容 (V2)")
self.set_progress(task_id, 10)
# 提取参数
topic = params.get('topic', {})
scenic_spot = params.get('scenic_spot')
product = params.get('product')
style = params.get('style')
audience = params.get('audience')
refer_content = params.get('refer_content', '')
enable_judge = params.get('enable_judge', True)
prompt_version = params.get('prompt_version', 'latest')
self.set_progress(task_id, 20)
# 获取 PromptRegistry
prompt_registry = self._get_prompt_registry()
# 构建 prompt 上下文
context = {
'style_content': self._format_style(style, topic),
'demand_content': self._format_audience(audience, topic),
'object_content': self._format_scenic_spot(scenic_spot, topic),
'product_content': self._format_product(product),
'refer_content': refer_content,
}
# 渲染 prompt
system_prompt, user_prompt = prompt_registry.render(
'content_generate',
context=context,
version=prompt_version
)
self.set_progress(task_id, 30)
# 获取模型参数
prompt_config = prompt_registry.get('content_generate', prompt_version)
model_params = prompt_config.get_model_params()
# 调用 LLM 生成内容
self.log("调用 LLM 生成内容...")
raw_result, input_tokens, output_tokens, time_cost = await self.llm.generate(
system_prompt=system_prompt,
user_prompt=user_prompt,
**model_params
)
self.set_progress(task_id, 60)
# 解析结果
content = self._parse_content(raw_result)
if not content:
return EngineResult(
success=False,
error="内容生成失败,无法解析结果",
error_code="PARSE_ERROR"
)
# 内容审核
final_content = content
judge_result = None
if enable_judge:
self.set_progress(task_id, 70)
self.log("执行内容审核...")
judge_result = await self._judge_content(
content,
scenic_spot,
product,
prompt_registry,
prompt_version
)
if judge_result and judge_result.get('success'):
final_content = {
'title': judge_result.get('title', content.get('title')),
'content': judge_result.get('content', content.get('content')),
'tag': content.get('tag', ''),
}
self.set_progress(task_id, 100)
return EngineResult(
success=True,
data={
"content": final_content,
"original_content": content,
"topic": topic,
"judged": enable_judge and judge_result is not None,
"judge_analysis": judge_result.get('analysis') if judge_result else None,
},
metadata={
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"time_cost": time_cost,
"prompt_version": prompt_version,
}
)
except Exception as e:
self.log(f"内容生成异常: {e}", level='error')
return EngineResult(
success=False,
error=str(e),
error_code="EXECUTION_ERROR"
)
def _get_prompt_registry(self):
"""获取 PromptRegistry"""
if self._prompt_registry:
return self._prompt_registry
from domain.prompt import PromptRegistry
self._prompt_registry = PromptRegistry('prompts')
return self._prompt_registry
def _format_style(self, style: Optional[Dict], topic: Dict) -> str:
"""格式化风格信息"""
if style:
name = style.get('name', style.get('styleName', ''))
desc = style.get('description', '')
return f"{name}\n{desc}" if desc else name
# 从 topic 中获取
return topic.get('style', '')
def _format_audience(self, audience: Optional[Dict], topic: Dict) -> str:
"""格式化受众信息"""
if audience:
name = audience.get('name', audience.get('audienceName', ''))
desc = audience.get('description', '')
return f"{name}\n{desc}" if desc else name
return topic.get('targetAudience', '')
def _format_scenic_spot(self, scenic_spot: Optional[Dict], topic: Dict) -> str:
"""格式化景区信息"""
if scenic_spot:
name = scenic_spot.get('name', '')
desc = scenic_spot.get('description', '')
location = scenic_spot.get('location', '')
parts = [name]
if location:
parts.append(f"位置: {location}")
if desc:
parts.append(desc)
return "\n".join(parts)
return topic.get('object', '')
def _format_product(self, product: Optional[Dict]) -> str:
"""格式化产品信息"""
if not product:
return ""
name = product.get('name', product.get('productName', ''))
price = product.get('price', '')
desc = product.get('description', product.get('detailedDescription', ''))
parts = [name]
if price:
parts.append(f"价格: {price}")
if desc:
parts.append(desc)
return "\n".join(parts)
def _parse_content(self, raw_result: str) -> Optional[Dict[str, Any]]:
"""解析 LLM 返回的内容"""
import json
import re
# 尝试提取 JSON
json_match = re.search(r'\{[\s\S]*\}', raw_result)
if json_match:
try:
return json.loads(json_match.group())
except json.JSONDecodeError:
pass
# 尝试 json_repair
try:
import json_repair
return json_repair.loads(raw_result)
except:
pass
self.log("无法解析内容结果", level='error')
return None
async def _judge_content(self, content: Dict,
scenic_spot: Optional[Dict],
product: Optional[Dict],
prompt_registry,
prompt_version: str) -> Optional[Dict]:
"""执行内容审核"""
try:
# 构建产品资料
product_info_parts = []
if scenic_spot:
product_info_parts.append(f"景区: {scenic_spot.get('name', '')}")
if scenic_spot.get('description'):
product_info_parts.append(scenic_spot['description'])
if product:
product_info_parts.append(f"产品: {product.get('name', '')}")
if product.get('price'):
product_info_parts.append(f"价格: {product['price']}")
if product.get('description'):
product_info_parts.append(product['description'])
product_info = "\n".join(product_info_parts)
# 渲染审核 prompt
context = {
'product_info': product_info,
'title_to_judge': content.get('title', ''),
'content_to_judge': content.get('content', ''),
}
system_prompt, user_prompt = prompt_registry.render(
'content_judge',
context=context,
version=prompt_version
)
# 获取模型参数
judge_config = prompt_registry.get('content_judge', prompt_version)
model_params = judge_config.get_model_params()
# 调用 LLM
raw_result, _, _, _ = await self.llm.generate(
system_prompt=system_prompt,
user_prompt=user_prompt,
**model_params
)
# 解析结果
result = self._parse_content(raw_result)
if result:
result['success'] = True
return result
except Exception as e:
self.log(f"内容审核失败: {e}", level='warning')
return None