265 lines
9.6 KiB
Python
265 lines
9.6 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
海报生成引擎 V2
|
||
- 不访问数据库,接收完整数据
|
||
- 图片使用 URL 而非 Base64
|
||
- 统一依赖注入
|
||
"""
|
||
|
||
import logging
|
||
from typing import Dict, Any, Optional, List
|
||
import aiohttp
|
||
import base64
|
||
from io import BytesIO
|
||
|
||
from .base import BaseAIGCEngine, EngineResult
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class PosterGenerateEngineV2(BaseAIGCEngine):
|
||
"""
|
||
海报生成引擎 V2
|
||
|
||
改进:
|
||
1. 不访问数据库,所有数据由调用方传入
|
||
2. 图片使用 URL 下载,而非 Base64 传输
|
||
3. 接收完整对象而非 ID
|
||
"""
|
||
|
||
engine_id = "poster_generate_v2"
|
||
engine_name = "海报生成 V2"
|
||
version = "2.0.0"
|
||
description = "根据内容和模板生成营销海报(新版本,无数据库依赖,URL 图片)"
|
||
|
||
def __init__(self):
|
||
super().__init__()
|
||
self._poster_service = None
|
||
|
||
def get_param_schema(self) -> Dict[str, Any]:
|
||
"""定义参数结构"""
|
||
return {
|
||
# 模板
|
||
"template_id": {
|
||
"type": "str",
|
||
"required": True,
|
||
"desc": "海报模板 ID",
|
||
},
|
||
|
||
# 内容
|
||
"content": {
|
||
"type": "object",
|
||
"required": True,
|
||
"desc": "海报内容 {title, content, tag}",
|
||
},
|
||
|
||
# 完整对象 (由 Java 端传入)
|
||
"scenic_spot": {
|
||
"type": "object",
|
||
"required": False,
|
||
"desc": "景区信息对象 {id, name, description, ...}",
|
||
},
|
||
"product": {
|
||
"type": "object",
|
||
"required": False,
|
||
"desc": "产品信息对象 {id, name, price, ...}",
|
||
},
|
||
|
||
# 图片 - 使用 URL 而非 Base64
|
||
"image_urls": {
|
||
"type": "list",
|
||
"required": False,
|
||
"desc": "图片 URL 列表 (优先使用)",
|
||
},
|
||
"image_paths": {
|
||
"type": "list",
|
||
"required": False,
|
||
"desc": "图片本地路径列表 (共享存储场景)",
|
||
},
|
||
"images_base64": {
|
||
"type": "list",
|
||
"required": False,
|
||
"desc": "图片 Base64 列表 (兼容旧接口,不推荐)",
|
||
},
|
||
|
||
# 生成选项
|
||
"force_llm": {
|
||
"type": "bool",
|
||
"required": False,
|
||
"default": False,
|
||
"desc": "是否强制使用 LLM 生成内容",
|
||
},
|
||
"generate_psd": {
|
||
"type": "bool",
|
||
"required": False,
|
||
"default": False,
|
||
"desc": "是否生成 PSD 文件",
|
||
},
|
||
"generate_fabric_json": {
|
||
"type": "bool",
|
||
"required": False,
|
||
"default": False,
|
||
"desc": "是否生成 Fabric.js JSON",
|
||
},
|
||
}
|
||
|
||
def estimate_duration(self, params: Dict[str, Any]) -> int:
|
||
"""预估执行时间"""
|
||
generate_psd = params.get('generate_psd', False)
|
||
return 45 if generate_psd else 30
|
||
|
||
async def execute(self, params: Dict[str, Any], task_id: str = None) -> EngineResult:
|
||
"""执行海报生成"""
|
||
try:
|
||
self.log("开始生成海报 (V2)")
|
||
self.set_progress(task_id, 10)
|
||
|
||
# 提取参数
|
||
template_id = params.get('template_id')
|
||
content = params.get('content', {})
|
||
scenic_spot = params.get('scenic_spot')
|
||
product = params.get('product')
|
||
image_urls = params.get('image_urls', [])
|
||
image_paths = params.get('image_paths', [])
|
||
images_base64 = params.get('images_base64', [])
|
||
force_llm = params.get('force_llm', False)
|
||
generate_psd = params.get('generate_psd', False)
|
||
generate_fabric_json = params.get('generate_fabric_json', False)
|
||
|
||
if not template_id:
|
||
return EngineResult(
|
||
success=False,
|
||
error="缺少模板 ID",
|
||
error_code="MISSING_TEMPLATE"
|
||
)
|
||
|
||
self.set_progress(task_id, 20)
|
||
|
||
# 获取图片 (优先级: URL > 路径 > Base64)
|
||
final_images_base64 = []
|
||
|
||
if image_urls:
|
||
self.log(f"从 URL 下载 {len(image_urls)} 张图片...")
|
||
final_images_base64 = await self._download_images(image_urls)
|
||
elif image_paths:
|
||
self.log(f"从路径读取 {len(image_paths)} 张图片...")
|
||
final_images_base64 = self._read_images_from_paths(image_paths)
|
||
elif images_base64:
|
||
self.log(f"使用传入的 {len(images_base64)} 张 Base64 图片")
|
||
final_images_base64 = images_base64
|
||
|
||
self.set_progress(task_id, 40)
|
||
|
||
# 调用海报服务
|
||
poster_service = self._get_poster_service()
|
||
|
||
result = await poster_service.generate_poster(
|
||
template_id=template_id,
|
||
poster_content=content,
|
||
content_id=None,
|
||
product_id=str(product.get('id')) if product else None,
|
||
scenic_spot_id=str(scenic_spot.get('id')) if scenic_spot else None,
|
||
images_base64=final_images_base64,
|
||
force_llm_generation=force_llm,
|
||
generate_psd=generate_psd,
|
||
generate_fabric_json=generate_fabric_json,
|
||
)
|
||
|
||
self.set_progress(task_id, 80)
|
||
|
||
# 处理结果
|
||
images_data = result.get('resultImagesBase64', [])
|
||
fabric_jsons = result.get('fabricJsons', [])
|
||
psd_files = result.get('psdFiles', [])
|
||
|
||
# 提取 base64 字符串
|
||
def extract_base64(item):
|
||
if isinstance(item, dict):
|
||
return item.get('image', '')
|
||
return item
|
||
|
||
output_images = [extract_base64(img) for img in images_data]
|
||
|
||
if output_images:
|
||
self.set_progress(task_id, 100)
|
||
return EngineResult(
|
||
success=True,
|
||
data={
|
||
"image_base64": output_images[0] if output_images else None,
|
||
"images_base64": output_images,
|
||
"fabric_json": fabric_jsons[0] if fabric_jsons else None,
|
||
"fabric_jsons": fabric_jsons,
|
||
"psd_base64": psd_files[0] if psd_files else None,
|
||
"psd_files": psd_files,
|
||
"template_id": template_id,
|
||
},
|
||
metadata={
|
||
"template_id": template_id,
|
||
"has_psd": bool(psd_files),
|
||
"has_fabric_json": bool(fabric_jsons),
|
||
"num_images": len(output_images),
|
||
"image_source": "url" if image_urls else ("path" if image_paths else "base64"),
|
||
}
|
||
)
|
||
else:
|
||
return EngineResult(
|
||
success=False,
|
||
error="海报生成失败",
|
||
error_code="GENERATION_FAILED"
|
||
)
|
||
|
||
except Exception as e:
|
||
self.log(f"海报生成异常: {e}", level='error')
|
||
return EngineResult(
|
||
success=False,
|
||
error=str(e),
|
||
error_code="EXECUTION_ERROR"
|
||
)
|
||
|
||
def _get_poster_service(self):
|
||
"""获取海报服务"""
|
||
if self._poster_service:
|
||
return self._poster_service
|
||
|
||
from api.services.poster import PosterService
|
||
self._poster_service = PosterService()
|
||
return self._poster_service
|
||
|
||
async def _download_images(self, urls: List[str]) -> List[str]:
|
||
"""从 URL 下载图片并转为 Base64"""
|
||
images = []
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
for url in urls:
|
||
try:
|
||
async with session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as response:
|
||
if response.status == 200:
|
||
image_bytes = await response.read()
|
||
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
|
||
images.append(image_base64)
|
||
self.log(f"下载图片成功: {url[:50]}...")
|
||
else:
|
||
self.log(f"下载图片失败: {url}, 状态码: {response.status}", level='warning')
|
||
except Exception as e:
|
||
self.log(f"下载图片异常: {url}, 错误: {e}", level='warning')
|
||
|
||
return images
|
||
|
||
def _read_images_from_paths(self, paths: List[str]) -> List[str]:
|
||
"""从本地路径读取图片并转为 Base64"""
|
||
images = []
|
||
|
||
for path in paths:
|
||
try:
|
||
with open(path, 'rb') as f:
|
||
image_bytes = f.read()
|
||
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
|
||
images.append(image_base64)
|
||
self.log(f"读取图片成功: {path}")
|
||
except Exception as e:
|
||
self.log(f"读取图片失败: {path}, 错误: {e}", level='warning')
|
||
|
||
return images
|