265 lines
9.6 KiB
Python
265 lines
9.6 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
海报生成引擎 V2
|
|
- 不访问数据库,接收完整数据
|
|
- 图片使用 URL 而非 Base64
|
|
- 使用轻量 PosterServiceV2
|
|
"""
|
|
|
|
import logging
|
|
from typing import Dict, Any, Optional, List
|
|
import aiohttp
|
|
import base64
|
|
from io import BytesIO
|
|
|
|
from .base import BaseAIGCEngine, EngineResult
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class PosterGenerateEngineV2(BaseAIGCEngine):
|
|
"""
|
|
海报生成引擎 V2
|
|
|
|
改进:
|
|
1. 不访问数据库,所有数据由调用方传入
|
|
2. 图片使用 URL 下载,而非 Base64 传输
|
|
3. 使用轻量 PosterServiceV2 (无数据库依赖)
|
|
"""
|
|
|
|
engine_id = "poster_generate"
|
|
engine_name = "海报生成"
|
|
version = "2.1.0"
|
|
description = "根据内容和模板生成营销海报"
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self._poster_service = None
|
|
|
|
def get_param_schema(self) -> Dict[str, Any]:
|
|
"""定义参数结构"""
|
|
return {
|
|
# 模板
|
|
"template_id": {
|
|
"type": "str",
|
|
"required": True,
|
|
"desc": "海报模板 ID",
|
|
},
|
|
|
|
# 内容
|
|
"content": {
|
|
"type": "object",
|
|
"required": True,
|
|
"desc": "海报内容 {title, content, tag}",
|
|
},
|
|
|
|
# 完整对象 (由 Java 端传入)
|
|
"scenic_spot": {
|
|
"type": "object",
|
|
"required": False,
|
|
"desc": "景区信息对象 {id, name, description, ...}",
|
|
},
|
|
"product": {
|
|
"type": "object",
|
|
"required": False,
|
|
"desc": "产品信息对象 {id, name, price, ...}",
|
|
},
|
|
|
|
# 图片 - 使用 URL 而非 Base64
|
|
"image_urls": {
|
|
"type": "list",
|
|
"required": False,
|
|
"desc": "图片 URL 列表 (优先使用)",
|
|
},
|
|
"image_paths": {
|
|
"type": "list",
|
|
"required": False,
|
|
"desc": "图片本地路径列表 (共享存储场景)",
|
|
},
|
|
"images_base64": {
|
|
"type": "list",
|
|
"required": False,
|
|
"desc": "图片 Base64 列表 (兼容旧接口,不推荐)",
|
|
},
|
|
|
|
# 生成选项
|
|
"force_llm": {
|
|
"type": "bool",
|
|
"required": False,
|
|
"default": False,
|
|
"desc": "是否强制使用 LLM 生成内容",
|
|
},
|
|
"generate_psd": {
|
|
"type": "bool",
|
|
"required": False,
|
|
"default": False,
|
|
"desc": "是否生成 PSD 文件",
|
|
},
|
|
"generate_fabric_json": {
|
|
"type": "bool",
|
|
"required": False,
|
|
"default": False,
|
|
"desc": "是否生成 Fabric.js JSON",
|
|
},
|
|
}
|
|
|
|
def estimate_duration(self, params: Dict[str, Any]) -> int:
|
|
"""预估执行时间"""
|
|
generate_psd = params.get('generate_psd', False)
|
|
return 45 if generate_psd else 30
|
|
|
|
async def execute(self, params: Dict[str, Any], task_id: str = None) -> EngineResult:
|
|
"""执行海报生成"""
|
|
try:
|
|
self.log("开始生成海报 (V2)")
|
|
self.set_progress(task_id, 10)
|
|
|
|
# 提取参数
|
|
template_id = params.get('template_id')
|
|
content = params.get('content', {})
|
|
scenic_spot = params.get('scenic_spot')
|
|
product = params.get('product')
|
|
image_urls = params.get('image_urls', [])
|
|
image_paths = params.get('image_paths', [])
|
|
images_base64 = params.get('images_base64', [])
|
|
force_llm = params.get('force_llm', False)
|
|
generate_psd = params.get('generate_psd', False)
|
|
generate_fabric_json = params.get('generate_fabric_json', False)
|
|
|
|
if not template_id:
|
|
return EngineResult(
|
|
success=False,
|
|
error="缺少模板 ID",
|
|
error_code="MISSING_TEMPLATE"
|
|
)
|
|
|
|
self.set_progress(task_id, 20)
|
|
|
|
# 获取图片 (优先级: URL > 路径 > Base64)
|
|
final_images_base64 = []
|
|
|
|
if image_urls:
|
|
self.log(f"从 URL 下载 {len(image_urls)} 张图片...")
|
|
final_images_base64 = await self._download_images(image_urls)
|
|
elif image_paths:
|
|
self.log(f"从路径读取 {len(image_paths)} 张图片...")
|
|
final_images_base64 = self._read_images_from_paths(image_paths)
|
|
elif images_base64:
|
|
self.log(f"使用传入的 {len(images_base64)} 张 Base64 图片")
|
|
final_images_base64 = images_base64
|
|
|
|
self.set_progress(task_id, 40)
|
|
|
|
# 调用海报服务
|
|
poster_service = self._get_poster_service()
|
|
|
|
result = await poster_service.generate_poster(
|
|
template_id=template_id,
|
|
poster_content=content,
|
|
content_id=None,
|
|
product_id=str(product.get('id')) if product else None,
|
|
scenic_spot_id=str(scenic_spot.get('id')) if scenic_spot else None,
|
|
images_base64=final_images_base64,
|
|
force_llm_generation=force_llm,
|
|
generate_psd=generate_psd,
|
|
generate_fabric_json=generate_fabric_json,
|
|
)
|
|
|
|
self.set_progress(task_id, 80)
|
|
|
|
# 处理结果
|
|
images_data = result.get('resultImagesBase64', [])
|
|
fabric_jsons = result.get('fabricJsons', [])
|
|
psd_files = result.get('psdFiles', [])
|
|
|
|
# 提取 base64 字符串
|
|
def extract_base64(item):
|
|
if isinstance(item, dict):
|
|
return item.get('image', '')
|
|
return item
|
|
|
|
output_images = [extract_base64(img) for img in images_data]
|
|
|
|
if output_images:
|
|
self.set_progress(task_id, 100)
|
|
return EngineResult(
|
|
success=True,
|
|
data={
|
|
"image_base64": output_images[0] if output_images else None,
|
|
"images_base64": output_images,
|
|
"fabric_json": fabric_jsons[0] if fabric_jsons else None,
|
|
"fabric_jsons": fabric_jsons,
|
|
"psd_base64": psd_files[0] if psd_files else None,
|
|
"psd_files": psd_files,
|
|
"template_id": template_id,
|
|
},
|
|
metadata={
|
|
"template_id": template_id,
|
|
"has_psd": bool(psd_files),
|
|
"has_fabric_json": bool(fabric_jsons),
|
|
"num_images": len(output_images),
|
|
"image_source": "url" if image_urls else ("path" if image_paths else "base64"),
|
|
}
|
|
)
|
|
else:
|
|
return EngineResult(
|
|
success=False,
|
|
error="海报生成失败",
|
|
error_code="GENERATION_FAILED"
|
|
)
|
|
|
|
except Exception as e:
|
|
self.log(f"海报生成异常: {e}", level='error')
|
|
return EngineResult(
|
|
success=False,
|
|
error=str(e),
|
|
error_code="EXECUTION_ERROR"
|
|
)
|
|
|
|
def _get_poster_service(self):
|
|
"""获取海报服务"""
|
|
if self._poster_service:
|
|
return self._poster_service
|
|
|
|
from api.services.poster import PosterService
|
|
self._poster_service = PosterService()
|
|
return self._poster_service
|
|
|
|
async def _download_images(self, urls: List[str]) -> List[str]:
|
|
"""从 URL 下载图片并转为 Base64"""
|
|
images = []
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
for url in urls:
|
|
try:
|
|
async with session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as response:
|
|
if response.status == 200:
|
|
image_bytes = await response.read()
|
|
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
|
|
images.append(image_base64)
|
|
self.log(f"下载图片成功: {url[:50]}...")
|
|
else:
|
|
self.log(f"下载图片失败: {url}, 状态码: {response.status}", level='warning')
|
|
except Exception as e:
|
|
self.log(f"下载图片异常: {url}, 错误: {e}", level='warning')
|
|
|
|
return images
|
|
|
|
def _read_images_from_paths(self, paths: List[str]) -> List[str]:
|
|
"""从本地路径读取图片并转为 Base64"""
|
|
images = []
|
|
|
|
for path in paths:
|
|
try:
|
|
with open(path, 'rb') as f:
|
|
image_bytes = f.read()
|
|
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
|
|
images.append(image_base64)
|
|
self.log(f"读取图片成功: {path}")
|
|
except Exception as e:
|
|
self.log(f"读取图片失败: {path}, 错误: {e}", level='warning')
|
|
|
|
return images
|