TravelContentCreator/domain/aigc/engines/content_generate_v2.py

421 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
内容生成引擎 V2
- 不访问数据库,接收完整数据
- 使用 PromptRegistry 管理 prompt
- 统一依赖注入
"""
import logging
from typing import Dict, Any, Optional, List
from .base import BaseAIGCEngine, EngineResult
logger = logging.getLogger(__name__)
class ContentGenerateEngineV2(BaseAIGCEngine):
"""
内容生成引擎 V2
改进:
1. 不访问数据库,所有数据由调用方传入
2. 使用 PromptRegistry 管理 prompt
3. 接收完整对象而非 ID
"""
engine_id = "content_generate"
engine_name = "内容生成"
version = "2.1.0"
description = "根据选题信息生成小红书风格的营销文案"
def __init__(self):
super().__init__()
self._prompt_registry = None
self._reference_manager = None
def get_param_schema(self) -> Dict[str, Any]:
"""
定义参数结构
V2.2: 合并 scenic_spot 和 product 为 subject
"""
return {
# 选题信息
"topic": {
"type": "object",
"required": True,
"desc": "选题信息 {index, date, title, subject_name, product_name, style, audience, ...}",
},
# 主体信息 (景区+产品合并)
"subject": {
"type": "object",
"required": False,
"desc": "主体信息 {id, name, type, description, location, products: [...]}",
},
# 兼容旧字段
"scenic_spot": {
"type": "object",
"required": False,
"desc": "[兼容] 景区信息,建议使用 subject",
},
"product": {
"type": "object",
"required": False,
"desc": "[兼容] 产品信息,建议使用 subject.products",
},
"style": {
"type": "object",
"required": False,
"desc": "风格信息对象 {id, name}",
},
"audience": {
"type": "object",
"required": False,
"desc": "受众信息对象 {id, name}",
},
# 热点信息
"hot_topics": {
"type": "object",
"required": False,
"desc": "热点信息 {events: [], festivals: [], trending: []}",
},
# 参考内容
"reference": {
"type": "object",
"required": False,
"desc": "参考内容 {mode: 'none'/'reference'/'rewrite', title, content}. reference=参考风格原创内容, rewrite=保留框架换主体",
},
# 审核选项
"enable_judge": {
"type": "bool",
"required": False,
"default": True,
"desc": "是否启用内容审核",
},
# Prompt 版本控制
"prompt_version": {
"type": "str",
"required": False,
"default": "latest",
"desc": "使用的 prompt 版本",
},
}
def estimate_duration(self, params: Dict[str, Any]) -> int:
"""预估执行时间"""
enable_judge = params.get('enable_judge', True)
return 30 if enable_judge else 20
async def execute(self, params: Dict[str, Any], task_id: str = None) -> EngineResult:
"""执行内容生成"""
try:
self.log("开始生成内容 (V2)")
self.set_progress(task_id, 10)
# 提取参数
topic = params.get('topic', {})
# 主体信息 (支持新旧两种格式)
subject = params.get('subject')
scenic_spot = params.get('scenic_spot')
product = params.get('product')
# 兼容处理: 如果没有 subject从 scenic_spot + product 构建
if not subject and scenic_spot:
subject = {
**scenic_spot,
'type': scenic_spot.get('type', 'scenic_spot'),
'products': [product] if product else []
}
style = params.get('style')
audience = params.get('audience')
hot_topics = params.get('hot_topics')
reference = params.get('reference')
enable_judge = params.get('enable_judge', True)
prompt_version = params.get('prompt_version', 'latest')
self.set_progress(task_id, 20)
# 获取 PromptRegistry
prompt_registry = self._get_prompt_registry()
# 从 subject 提取产品信息
current_product = None
if subject and subject.get('products'):
current_product = subject['products'][0]
# 如果没有用户指定的参考内容,加载内置参考文献库
use_builtin_examples = not reference or reference.get('mode') == 'none'
# 构建 prompt 上下文
context = {
'style_content': self._format_style(style, topic),
'demand_content': self._format_audience(audience, topic),
'object_content': self._format_subject(subject, topic),
'product_content': self._format_product(current_product),
'hot_topics': hot_topics,
'reference': reference,
# 内置参考文献 (仅在无用户指定参考时使用)
'title_examples': self._get_title_examples(20) if use_builtin_examples else None,
'content_examples': self._get_content_examples(3) if use_builtin_examples else None,
}
# 渲染 prompt
system_prompt, user_prompt = prompt_registry.render(
'content_generate',
context=context,
version=prompt_version
)
self.set_progress(task_id, 30)
# 获取模型参数
prompt_config = prompt_registry.get('content_generate', prompt_version)
model_params = prompt_config.get_model_params()
# 调用 LLM 生成内容
self.log("调用 LLM 生成内容...")
raw_result, input_tokens, output_tokens, time_cost = await self.llm.generate(
system_prompt=system_prompt,
user_prompt=user_prompt,
**model_params
)
self.set_progress(task_id, 60)
# 解析结果
content = self._parse_content(raw_result)
if not content:
return EngineResult(
success=False,
error="内容生成失败,无法解析结果",
error_code="PARSE_ERROR"
)
# 内容审核
final_content = content
judge_result = None
if enable_judge:
self.set_progress(task_id, 70)
self.log("执行内容审核...")
judge_result = await self._judge_content(
content,
scenic_spot,
product,
prompt_registry,
prompt_version
)
if judge_result and judge_result.get('success'):
final_content = {
'title': judge_result.get('title', content.get('title')),
'content': judge_result.get('content', content.get('content')),
'tag': content.get('tag', ''),
}
self.set_progress(task_id, 100)
return EngineResult(
success=True,
data={
"content": final_content,
"original_content": content,
"topic": topic,
"judged": enable_judge and judge_result is not None,
"judge_analysis": judge_result.get('analysis') if judge_result else None,
},
metadata={
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"time_cost": time_cost,
"prompt_version": prompt_version,
}
)
except Exception as e:
self.log(f"内容生成异常: {e}", level='error')
return EngineResult(
success=False,
error=str(e),
error_code="EXECUTION_ERROR"
)
def _get_prompt_registry(self):
"""获取 PromptRegistry"""
if self._prompt_registry:
return self._prompt_registry
from domain.prompt import PromptRegistry
self._prompt_registry = PromptRegistry('prompts')
return self._prompt_registry
def _get_reference_manager(self):
"""获取 ReferenceManager"""
if self._reference_manager is None:
from domain.prompt.reference_manager import get_reference_manager
self._reference_manager = get_reference_manager()
return self._reference_manager
def _get_title_examples(self, audience_id: str = None, style_id: str = None, count: int = 20) -> List[str]:
"""获取标题参考格式 (智能匹配 + 随机抽取)"""
manager = self._get_reference_manager()
return manager.get_titles(audience_id=audience_id, style_id=style_id, count=count)
def _get_content_examples(self, audience_id: str = None, style_id: str = None, count: int = 3) -> List[str]:
"""获取正文范文参考 (智能匹配 + 随机抽取)"""
manager = self._get_reference_manager()
return manager.get_contents(audience_id=audience_id, style_id=style_id, count=count)
def _format_style(self, style: Optional[Dict], topic: Dict) -> str:
"""格式化风格信息"""
if style:
name = style.get('name', style.get('styleName', ''))
desc = style.get('description', '')
return f"{name}\n{desc}" if desc else name
# 从 topic 中获取
return topic.get('style', '')
def _format_audience(self, audience: Optional[Dict], topic: Dict) -> str:
"""格式化受众信息"""
if audience:
name = audience.get('name', audience.get('audienceName', ''))
desc = audience.get('description', '')
return f"{name}\n{desc}" if desc else name
return topic.get('targetAudience', '')
def _format_subject(self, subject: Optional[Dict], topic: Dict) -> str:
"""格式化主体信息 (景区/酒店等)"""
if subject:
name = subject.get('name', '')
desc = subject.get('description', '')
location = subject.get('location', '')
highlights = subject.get('highlights', [])
parts = [f"名称: {name}"]
if location:
parts.append(f"位置: {location}")
if highlights:
parts.append(f"亮点: {', '.join(highlights)}")
if desc:
parts.append(f"描述: {desc}")
return "\n".join(parts)
# 兼容旧字段
return topic.get('object', topic.get('subject_name', ''))
def _format_scenic_spot(self, scenic_spot: Optional[Dict], topic: Dict) -> str:
"""[兼容] 格式化景区信息"""
return self._format_subject(scenic_spot, topic)
def _format_product(self, product: Optional[Dict]) -> str:
"""格式化产品信息"""
if not product:
return ""
name = product.get('name', product.get('productName', ''))
price = product.get('price', '')
desc = product.get('description', product.get('detailedDescription', ''))
parts = [name]
if price:
parts.append(f"价格: {price}")
if desc:
parts.append(desc)
return "\n".join(parts)
def _parse_content(self, raw_result: str) -> Optional[Dict[str, Any]]:
"""解析 LLM 返回的内容"""
import json
import re
# 尝试提取 JSON
json_match = re.search(r'\{[\s\S]*\}', raw_result)
if json_match:
try:
return json.loads(json_match.group())
except json.JSONDecodeError:
pass
# 尝试 json_repair
try:
import json_repair
return json_repair.loads(raw_result)
except:
pass
self.log("无法解析内容结果", level='error')
return None
async def _judge_content(self, content: Dict,
scenic_spot: Optional[Dict],
product: Optional[Dict],
prompt_registry,
prompt_version: str) -> Optional[Dict]:
"""执行内容审核"""
try:
# 构建产品资料
product_info_parts = []
if scenic_spot:
product_info_parts.append(f"景区: {scenic_spot.get('name', '')}")
if scenic_spot.get('description'):
product_info_parts.append(scenic_spot['description'])
if product:
product_info_parts.append(f"产品: {product.get('name', '')}")
if product.get('price'):
product_info_parts.append(f"价格: {product['price']}")
if product.get('description'):
product_info_parts.append(product['description'])
product_info = "\n".join(product_info_parts)
# 渲染审核 prompt
context = {
'product_info': product_info,
'title_to_judge': content.get('title', ''),
'content_to_judge': content.get('content', ''),
}
system_prompt, user_prompt = prompt_registry.render(
'content_judge',
context=context,
version=prompt_version
)
# 获取模型参数
judge_config = prompt_registry.get('content_judge', prompt_version)
model_params = judge_config.get_model_params()
# 调用 LLM
raw_result, _, _, _ = await self.llm.generate(
system_prompt=system_prompt,
user_prompt=user_prompt,
**model_params
)
# 解析结果
result = self._parse_content(raw_result)
if result:
result['success'] = True
return result
except Exception as e:
self.log(f"内容审核失败: {e}", level='warning')
return None