TravelContentCreator/domain/aigc/engines/poster_generate_v2.py

253 lines
8.9 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
海报生成引擎 V2
- 不访问数据库,接收完整数据
- 图片使用 URL 而非 Base64
- 使用轻量 PosterServiceV2
"""
import logging
from typing import Dict, Any, Optional, List
import aiohttp
import base64
from io import BytesIO
from .base import BaseAIGCEngine, EngineResult
logger = logging.getLogger(__name__)
class PosterGenerateEngineV2(BaseAIGCEngine):
"""
海报生成引擎 V2
改进:
1. 不访问数据库,所有数据由调用方传入
2. 图片使用 URL 下载,而非 Base64 传输
3. 使用轻量 PosterServiceV2 (无数据库依赖)
"""
engine_id = "poster_generate"
engine_name = "海报生成"
version = "2.1.0"
description = "根据内容和模板生成营销海报"
def __init__(self):
super().__init__()
self._poster_service = None
def get_param_schema(self) -> Dict[str, Any]:
"""定义参数结构"""
return {
# 模板
"template_id": {
"type": "str",
"required": True,
"desc": "海报模板 ID",
},
# 内容
"content": {
"type": "object",
"required": True,
"desc": "海报内容 {title, content, tag}",
},
# 完整对象 (由 Java 端传入)
"scenic_spot": {
"type": "object",
"required": False,
"desc": "景区信息对象 {id, name, description, ...}",
},
"product": {
"type": "object",
"required": False,
"desc": "产品信息对象 {id, name, price, ...}",
},
# 图片 - 使用 URL 而非 Base64
"image_urls": {
"type": "list",
"required": False,
"desc": "图片 URL 列表 (优先使用)",
},
"image_paths": {
"type": "list",
"required": False,
"desc": "图片本地路径列表 (共享存储场景)",
},
"images_base64": {
"type": "list",
"required": False,
"desc": "图片 Base64 列表 (兼容旧接口,不推荐)",
},
# 生成选项
"force_llm": {
"type": "bool",
"required": False,
"default": False,
"desc": "是否强制使用 LLM 生成内容",
},
"generate_psd": {
"type": "bool",
"required": False,
"default": False,
"desc": "是否生成 PSD 文件",
},
"generate_fabric_json": {
"type": "bool",
"required": False,
"default": False,
"desc": "是否生成 Fabric.js JSON",
},
}
def estimate_duration(self, params: Dict[str, Any]) -> int:
"""预估执行时间"""
generate_psd = params.get('generate_psd', False)
return 45 if generate_psd else 30
async def execute(self, params: Dict[str, Any], task_id: str = None) -> EngineResult:
"""执行海报生成"""
try:
self.log("开始生成海报 (V2)")
self.set_progress(task_id, 10)
# 提取参数
template_id = params.get('template_id')
content = params.get('content', {})
scenic_spot = params.get('scenic_spot')
product = params.get('product')
image_urls = params.get('image_urls', [])
image_paths = params.get('image_paths', [])
images_base64 = params.get('images_base64', [])
force_llm = params.get('force_llm', False)
generate_psd = params.get('generate_psd', False)
generate_fabric_json = params.get('generate_fabric_json', False)
if not template_id:
return EngineResult(
success=False,
error="缺少模板 ID",
error_code="MISSING_TEMPLATE"
)
self.set_progress(task_id, 20)
# 获取图片 (优先级: URL > 路径 > Base64)
final_images_base64 = []
if image_urls:
self.log(f"从 URL 下载 {len(image_urls)} 张图片...")
final_images_base64 = await self._download_images(image_urls)
elif image_paths:
self.log(f"从路径读取 {len(image_paths)} 张图片...")
final_images_base64 = self._read_images_from_paths(image_paths)
elif images_base64:
self.log(f"使用传入的 {len(images_base64)} 张 Base64 图片")
final_images_base64 = images_base64
self.set_progress(task_id, 40)
# 调用海报服务 (使用轻量 V2 服务)
poster_service = self._get_poster_service()
result = poster_service.generate_poster(
template_id=template_id,
content=content,
images_base64=final_images_base64,
generate_fabric_json=generate_fabric_json,
)
self.set_progress(task_id, 80)
# 处理结果
if not result.get('success'):
return EngineResult(
success=False,
error=result.get('error', '海报生成失败'),
error_code="GENERATION_FAILED"
)
output_image = result.get('image_base64')
fabric_json = result.get('fabric_json')
if output_image:
self.set_progress(task_id, 100)
return EngineResult(
success=True,
data={
"image_base64": output_image,
"images_base64": [output_image],
"fabric_json": fabric_json,
"template_id": template_id,
},
metadata={
"template_id": template_id,
"has_fabric_json": bool(fabric_json),
"image_source": "url" if image_urls else ("path" if image_paths else "base64"),
}
)
else:
return EngineResult(
success=False,
error="海报生成失败: 无输出图片",
error_code="GENERATION_FAILED"
)
except Exception as e:
self.log(f"海报生成异常: {e}", level='error')
return EngineResult(
success=False,
error=str(e),
error_code="EXECUTION_ERROR"
)
def _get_poster_service(self):
"""获取海报服务 (使用轻量 V2 服务)"""
if self._poster_service:
return self._poster_service
from domain.poster.poster_service import get_poster_service
self._poster_service = get_poster_service()
return self._poster_service
async def _download_images(self, urls: List[str]) -> List[str]:
"""从 URL 下载图片并转为 Base64"""
images = []
async with aiohttp.ClientSession() as session:
for url in urls:
try:
async with session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as response:
if response.status == 200:
image_bytes = await response.read()
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
images.append(image_base64)
self.log(f"下载图片成功: {url[:50]}...")
else:
self.log(f"下载图片失败: {url}, 状态码: {response.status}", level='warning')
except Exception as e:
self.log(f"下载图片异常: {url}, 错误: {e}", level='warning')
return images
def _read_images_from_paths(self, paths: List[str]) -> List[str]:
"""从本地路径读取图片并转为 Base64"""
images = []
for path in paths:
try:
with open(path, 'rb') as f:
image_bytes = f.read()
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
images.append(image_base64)
self.log(f"读取图片成功: {path}")
except Exception as e:
self.log(f"读取图片失败: {path}, 错误: {e}", level='warning')
return images