TravelContentCreator/domain/aigc/engines/poster_generate_v2.py

265 lines
9.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
海报生成引擎 V2
- 不访问数据库,接收完整数据
- 图片使用 URL 而非 Base64
- 统一依赖注入
"""
import logging
from typing import Dict, Any, Optional, List
import aiohttp
import base64
from io import BytesIO
from .base import BaseAIGCEngine, EngineResult
logger = logging.getLogger(__name__)
class PosterGenerateEngineV2(BaseAIGCEngine):
"""
海报生成引擎 V2
改进:
1. 不访问数据库,所有数据由调用方传入
2. 图片使用 URL 下载,而非 Base64 传输
3. 接收完整对象而非 ID
"""
engine_id = "poster_generate_v2"
engine_name = "海报生成 V2"
version = "2.0.0"
description = "根据内容和模板生成营销海报新版本无数据库依赖URL 图片)"
def __init__(self):
super().__init__()
self._poster_service = None
def get_param_schema(self) -> Dict[str, Any]:
"""定义参数结构"""
return {
# 模板
"template_id": {
"type": "str",
"required": True,
"desc": "海报模板 ID",
},
# 内容
"content": {
"type": "object",
"required": True,
"desc": "海报内容 {title, content, tag}",
},
# 完整对象 (由 Java 端传入)
"scenic_spot": {
"type": "object",
"required": False,
"desc": "景区信息对象 {id, name, description, ...}",
},
"product": {
"type": "object",
"required": False,
"desc": "产品信息对象 {id, name, price, ...}",
},
# 图片 - 使用 URL 而非 Base64
"image_urls": {
"type": "list",
"required": False,
"desc": "图片 URL 列表 (优先使用)",
},
"image_paths": {
"type": "list",
"required": False,
"desc": "图片本地路径列表 (共享存储场景)",
},
"images_base64": {
"type": "list",
"required": False,
"desc": "图片 Base64 列表 (兼容旧接口,不推荐)",
},
# 生成选项
"force_llm": {
"type": "bool",
"required": False,
"default": False,
"desc": "是否强制使用 LLM 生成内容",
},
"generate_psd": {
"type": "bool",
"required": False,
"default": False,
"desc": "是否生成 PSD 文件",
},
"generate_fabric_json": {
"type": "bool",
"required": False,
"default": False,
"desc": "是否生成 Fabric.js JSON",
},
}
def estimate_duration(self, params: Dict[str, Any]) -> int:
"""预估执行时间"""
generate_psd = params.get('generate_psd', False)
return 45 if generate_psd else 30
async def execute(self, params: Dict[str, Any], task_id: str = None) -> EngineResult:
"""执行海报生成"""
try:
self.log("开始生成海报 (V2)")
self.set_progress(task_id, 10)
# 提取参数
template_id = params.get('template_id')
content = params.get('content', {})
scenic_spot = params.get('scenic_spot')
product = params.get('product')
image_urls = params.get('image_urls', [])
image_paths = params.get('image_paths', [])
images_base64 = params.get('images_base64', [])
force_llm = params.get('force_llm', False)
generate_psd = params.get('generate_psd', False)
generate_fabric_json = params.get('generate_fabric_json', False)
if not template_id:
return EngineResult(
success=False,
error="缺少模板 ID",
error_code="MISSING_TEMPLATE"
)
self.set_progress(task_id, 20)
# 获取图片 (优先级: URL > 路径 > Base64)
final_images_base64 = []
if image_urls:
self.log(f"从 URL 下载 {len(image_urls)} 张图片...")
final_images_base64 = await self._download_images(image_urls)
elif image_paths:
self.log(f"从路径读取 {len(image_paths)} 张图片...")
final_images_base64 = self._read_images_from_paths(image_paths)
elif images_base64:
self.log(f"使用传入的 {len(images_base64)} 张 Base64 图片")
final_images_base64 = images_base64
self.set_progress(task_id, 40)
# 调用海报服务
poster_service = self._get_poster_service()
result = await poster_service.generate_poster(
template_id=template_id,
poster_content=content,
content_id=None,
product_id=str(product.get('id')) if product else None,
scenic_spot_id=str(scenic_spot.get('id')) if scenic_spot else None,
images_base64=final_images_base64,
force_llm_generation=force_llm,
generate_psd=generate_psd,
generate_fabric_json=generate_fabric_json,
)
self.set_progress(task_id, 80)
# 处理结果
images_data = result.get('resultImagesBase64', [])
fabric_jsons = result.get('fabricJsons', [])
psd_files = result.get('psdFiles', [])
# 提取 base64 字符串
def extract_base64(item):
if isinstance(item, dict):
return item.get('image', '')
return item
output_images = [extract_base64(img) for img in images_data]
if output_images:
self.set_progress(task_id, 100)
return EngineResult(
success=True,
data={
"image_base64": output_images[0] if output_images else None,
"images_base64": output_images,
"fabric_json": fabric_jsons[0] if fabric_jsons else None,
"fabric_jsons": fabric_jsons,
"psd_base64": psd_files[0] if psd_files else None,
"psd_files": psd_files,
"template_id": template_id,
},
metadata={
"template_id": template_id,
"has_psd": bool(psd_files),
"has_fabric_json": bool(fabric_jsons),
"num_images": len(output_images),
"image_source": "url" if image_urls else ("path" if image_paths else "base64"),
}
)
else:
return EngineResult(
success=False,
error="海报生成失败",
error_code="GENERATION_FAILED"
)
except Exception as e:
self.log(f"海报生成异常: {e}", level='error')
return EngineResult(
success=False,
error=str(e),
error_code="EXECUTION_ERROR"
)
def _get_poster_service(self):
"""获取海报服务"""
if self._poster_service:
return self._poster_service
from api.services.poster import PosterService
self._poster_service = PosterService()
return self._poster_service
async def _download_images(self, urls: List[str]) -> List[str]:
"""从 URL 下载图片并转为 Base64"""
images = []
async with aiohttp.ClientSession() as session:
for url in urls:
try:
async with session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as response:
if response.status == 200:
image_bytes = await response.read()
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
images.append(image_base64)
self.log(f"下载图片成功: {url[:50]}...")
else:
self.log(f"下载图片失败: {url}, 状态码: {response.status}", level='warning')
except Exception as e:
self.log(f"下载图片异常: {url}, 错误: {e}", level='warning')
return images
def _read_images_from_paths(self, paths: List[str]) -> List[str]:
"""从本地路径读取图片并转为 Base64"""
images = []
for path in paths:
try:
with open(path, 'rb') as f:
image_bytes = f.read()
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
images.append(image_base64)
self.log(f"读取图片成功: {path}")
except Exception as e:
self.log(f"读取图片失败: {path}, 错误: {e}", level='warning')
return images