495 lines
18 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
智能海报生成引擎
AI 生成文案 + poster_v2 生成海报
"""
import os
import json
import logging
from typing import Dict, Any, Optional, List
import aiohttp
import base64
from .base import BaseAIGCEngine, EngineResult
logger = logging.getLogger(__name__)
class PosterSmartEngine(BaseAIGCEngine):
"""
智能海报生成引擎
流程:
1. 接收产品/景点原始信息
2. AI 生成小红书风格的海报文案
3. 使用 poster_v2 生成最终海报
特点:
- 自动生成吸引人的标题副标题亮点
- 智能推荐布局和主题
- 支持多种内容类型 (景点/美食/酒店/活动等)
"""
engine_id = "poster_smart"
engine_name = "智能海报生成"
version = "1.0.0"
description = "AI生成文案+自动生成小红书风格海报"
def __init__(self):
super().__init__()
self._poster_service = None
self._ai_agent = None
def get_param_schema(self) -> Dict[str, Any]:
"""定义参数结构"""
return {
# === 必填: 产品信息 ===
"category": {
"type": "str",
"required": True,
"desc": "内容类型: 景点/美食/酒店/活动/攻略/民宿",
},
"name": {
"type": "str",
"required": True,
"desc": "产品/景点名称",
},
"description": {
"type": "str",
"required": False,
"desc": "详细描述",
},
# === 可选: 补充信息 ===
"price": {
"type": "str",
"required": False,
"desc": "价格信息 (如 199元/人)",
},
"location": {
"type": "str",
"required": False,
"desc": "位置/地点",
},
"features": {
"type": "str",
"required": False,
"desc": "特色/卖点 (逗号分隔或描述文本)",
},
"target_audience": {
"type": "str",
"required": False,
"desc": "目标人群 (如 亲子家庭、情侣、闺蜜)",
},
"style_hint": {
"type": "str",
"required": False,
"desc": "风格提示 (如 清新、活力、高级感)",
},
# === 图片 ===
"image_url": {
"type": "str",
"required": False,
"desc": "背景图片 URL",
},
"image_path": {
"type": "str",
"required": False,
"desc": "背景图片本地路径",
},
"image_base64": {
"type": "str",
"required": False,
"desc": "背景图片 Base64",
},
# === 覆盖选项 ===
"override_layout": {
"type": "str",
"required": False,
"desc": "强制指定布局 (覆盖AI推荐)",
},
"override_theme": {
"type": "str",
"required": False,
"desc": "强制指定主题 (覆盖AI推荐)",
},
# === 高级选项 ===
"skip_ai": {
"type": "bool",
"required": False,
"default": False,
"desc": "跳过AI生成直接使用提供的内容",
},
"raw_content": {
"type": "object",
"required": False,
"desc": "直接使用的海报内容 (skip_ai=true 时)",
},
}
def estimate_duration(self, params: Dict[str, Any]) -> int:
"""预估执行时间"""
skip_ai = params.get('skip_ai', False)
has_image = bool(params.get('image_url'))
if skip_ai:
return 10 if has_image else 5
else:
return 25 if has_image else 20 # AI 生成需要更多时间
async def execute(self, params: Dict[str, Any], task_id: str = None) -> EngineResult:
"""执行智能海报生成"""
try:
self.log("开始智能海报生成")
self.set_progress(task_id, 5)
# 提取参数
category = params.get('category', '景点')
name = params.get('name')
description = params.get('description', '')
price = params.get('price')
location = params.get('location')
features = params.get('features')
target_audience = params.get('target_audience')
style_hint = params.get('style_hint')
image_url = params.get('image_url')
image_path = params.get('image_path')
image_base64 = params.get('image_base64')
override_layout = params.get('override_layout')
override_theme = params.get('override_theme')
skip_ai = params.get('skip_ai', False)
raw_content = params.get('raw_content', {})
if not name:
return EngineResult(
success=False,
error="缺少产品/景点名称",
error_code="MISSING_NAME"
)
self.set_progress(task_id, 10)
# === Step 1: AI 生成文案 ===
if skip_ai and raw_content:
self.log("跳过AI生成使用提供的内容")
poster_content = raw_content
suggested_layout = raw_content.get('suggested_layout', 'hero_bottom')
suggested_theme = raw_content.get('suggested_theme', 'ocean')
else:
self.log("AI 生成海报文案...")
ai_result = await self._generate_copywriting(
category=category,
name=name,
description=description,
price=price,
location=location,
features=features,
target_audience=target_audience,
style_hint=style_hint,
)
if not ai_result.get('success'):
return EngineResult(
success=False,
error=ai_result.get('error', 'AI文案生成失败'),
error_code="AI_GENERATION_FAILED"
)
poster_content = ai_result['content']
suggested_layout = poster_content.pop('suggested_layout', 'hero_bottom')
suggested_theme = poster_content.pop('suggested_theme', 'ocean')
self.log(f"AI 推荐: 布局={suggested_layout}, 主题={suggested_theme}")
self.set_progress(task_id, 50)
# === Step 2: 确定布局和主题 ===
final_layout = override_layout or suggested_layout
final_theme = override_theme or suggested_theme
self.log(f"最终配置: 布局={final_layout}, 主题={final_theme}")
# === Step 3: 下载图片 ===
final_image_base64 = image_base64
if image_url and not final_image_base64:
self.log(f"下载图片: {image_url[:50]}...")
final_image_base64 = await self._download_image(image_url)
self.set_progress(task_id, 70)
# === Step 4: 生成海报 ===
self.log("生成海报图片...")
poster_service = self._get_poster_service()
result = poster_service.generate(
layout=final_layout,
theme=final_theme,
content=poster_content,
image_base64=final_image_base64,
image_path=image_path,
)
self.set_progress(task_id, 95)
if result.get('success'):
self.set_progress(task_id, 100)
return EngineResult(
success=True,
data={
"image_base64": result['image_base64'],
"images_base64": [result['image_base64']],
"layout": final_layout,
"theme": final_theme,
"generated_content": poster_content, # 返回AI生成的文案
},
metadata={
"layout": final_layout,
"theme": final_theme,
"ai_generated": not skip_ai,
"category": category,
"name": name,
}
)
else:
return EngineResult(
success=False,
error=result.get('error', '海报生成失败'),
error_code="POSTER_GENERATION_FAILED"
)
except Exception as e:
self.log(f"智能海报生成异常: {e}", level='error')
return EngineResult(
success=False,
error=str(e),
error_code="EXECUTION_ERROR"
)
async def _generate_copywriting(
self,
category: str,
name: str,
description: str = "",
price: str = None,
location: str = None,
features: str = None,
target_audience: str = None,
style_hint: str = None,
) -> Dict[str, Any]:
"""使用 AI 生成海报文案"""
try:
ai_agent = self._get_ai_agent()
if not ai_agent:
# 没有 AI Agent使用模板生成
return self._fallback_generate(
category, name, description, price, location, features
)
# 构建 prompt
from domain.prompt import PromptRegistry
registry = PromptRegistry('prompts')
prompt_config = registry.get('poster_copywriting')
# 填充模板 (简单替换,因为 user 字段可能包含 Jinja 语法)
user_prompt = prompt_config.user
user_prompt = user_prompt.replace("{category}", category)
user_prompt = user_prompt.replace("{name}", name)
user_prompt = user_prompt.replace("{description}", description or "无详细描述")
user_prompt = user_prompt.replace("{price}", price or "")
user_prompt = user_prompt.replace("{location}", location or "")
user_prompt = user_prompt.replace("{features}", features or "")
user_prompt = user_prompt.replace("{target_audience}", target_audience or "")
user_prompt = user_prompt.replace("{style_hint}", style_hint or "")
# 调用 AI (generate_text 返回 tuple)
content_text, _, _, _ = await ai_agent.generate_text(
system_prompt=prompt_config.system,
user_prompt=user_prompt,
temperature=0.7,
use_stream=False, # qwen-plus 不需要 stream
)
# 提取 JSON
json_content = self._extract_json(content_text)
if json_content:
return {
"success": True,
"content": json_content
}
else:
self.log("AI 返回格式异常,使用备用生成", level='warning')
return self._fallback_generate(
category, name, description, price, location, features
)
except Exception as e:
self.log(f"AI 文案生成失败: {e}, 使用备用生成", level='warning')
return self._fallback_generate(
category, name, description, price, location, features
)
def _fallback_generate(
self,
category: str,
name: str,
description: str = "",
price: str = None,
location: str = None,
features: str = None,
) -> Dict[str, Any]:
"""备用文案生成 (无AI时使用)"""
# 根据类型选择布局和主题
layout_map = {
"景点": ("hero_bottom", "ocean"),
"美食": ("overlay_bottom", "peach"),
"酒店": ("card_float", "mint"),
"民宿": ("split_vertical", "latte"),
"活动": ("overlay_center", "sunset"),
"攻略": ("overlay_center", "ocean"),
}
suggested_layout, suggested_theme = layout_map.get(category, ("hero_bottom", "ocean"))
# 解析特色
feature_list = []
if features:
feature_list = [f.strip() for f in features.split(',') if f.strip()][:4]
# 处理价格和后缀
price_display = None
price_suffix = None
if price:
# 提取后缀 (/人, /份, /晚, /位 等)
import re
suffix_match = re.search(r'[/每](人|份|晚|位|间|套|次)', price)
if suffix_match:
price_suffix = f"/{suffix_match.group(1)}"
# 清理价格,只保留数字部分
price_clean = re.sub(r'[^\d.]', '', price)
if price_clean:
price_display = f"¥{price_clean}"
# 特殊处理免费
if '免费' in price or '0' == price_clean:
price_display = "免费"
price_suffix = None
# 处理副标题 - 控制长度 (15字以内)
subtitle_text = description[:15] if description and len(description) > 15 else description
if not subtitle_text:
subtitle_text = f"发现{category}好去处"
content = {
"title": name,
"subtitle": subtitle_text,
"highlights": feature_list[:3] if feature_list else None,
"details": feature_list if feature_list else None,
"price": price_display,
"price_suffix": price_suffix,
"tags": [category, location.split()[0] if location else "推荐"],
"suggested_layout": suggested_layout,
"suggested_theme": suggested_theme,
}
# 清理 None 值
content = {k: v for k, v in content.items() if v is not None}
return {
"success": True,
"content": content
}
def _extract_json(self, text: str) -> Optional[Dict]:
"""从文本中提取 JSON"""
import re
# 尝试直接解析
try:
return json.loads(text)
except:
pass
# 尝试提取 ```json ... ``` 块
json_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', text)
if json_match:
try:
return json.loads(json_match.group(1))
except:
pass
# 尝试提取 {...} 块
brace_match = re.search(r'\{[\s\S]*\}', text)
if brace_match:
try:
return json.loads(brace_match.group())
except:
pass
return None
def _get_poster_service(self):
"""获取海报服务"""
if self._poster_service:
return self._poster_service
from poster_v2 import get_poster_service_v2
self._poster_service = get_poster_service_v2()
return self._poster_service
def _get_ai_agent(self):
"""获取 AI Agent"""
if self._ai_agent is not None:
return self._ai_agent
try:
# 优先从共享组件获取
if hasattr(self, 'shared_components') and self.shared_components:
if hasattr(self.shared_components, 'ai_agent') and self.shared_components.ai_agent:
self._ai_agent = self.shared_components.ai_agent
return self._ai_agent
# 尝试创建新的 AI Agent
from core.ai.ai_agent import AIAgent
from core.config_loader import get_config
from core.config import AIModelConfig
config = get_config()
# 使用 qwen-plus 模型 (无思维链,更快更省 token)
ai_config = AIModelConfig(
model="qwen-plus", # 不用 qwq-plus 的思维链
api_url=config.get("ai_model.api_url"),
api_key=config.get("ai_model.api_key") or os.environ.get("AI_API_KEY", ""),
temperature=0.7,
timeout=30000, # 30秒足够
)
self._ai_agent = AIAgent(ai_config)
except Exception as e:
self.log(f"获取 AI Agent 失败: {e}", level='warning')
self._ai_agent = None
return self._ai_agent
async def _download_image(self, url: str) -> Optional[str]:
"""从 URL 下载图片并转为 Base64"""
try:
async with aiohttp.ClientSession() as session:
async with session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as response:
if response.status == 200:
image_bytes = await response.read()
return base64.b64encode(image_bytes).decode('utf-8')
else:
self.log(f"下载图片失败: {url}, 状态码: {response.status}", level='warning')
except Exception as e:
self.log(f"下载图片异常: {url}, 错误: {e}", level='warning')
return None