209 lines
7.2 KiB
Python
209 lines
7.2 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
|
|||
|
|
"""
|
|||
|
|
海报生成引擎 V3
|
|||
|
|
使用 poster_v2 模块,支持5种布局和5种主题
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import logging
|
|||
|
|
from typing import Dict, Any, Optional, List
|
|||
|
|
import aiohttp
|
|||
|
|
import base64
|
|||
|
|
|
|||
|
|
from .base import BaseAIGCEngine, EngineResult
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class PosterGenerateEngineV3(BaseAIGCEngine):
|
|||
|
|
"""
|
|||
|
|
海报生成引擎 V3
|
|||
|
|
|
|||
|
|
使用 poster_v2 模块:
|
|||
|
|
- 5种布局: hero_bottom, overlay_center, overlay_bottom, split_vertical, card_float
|
|||
|
|
- 5种主题: ocean, sunset, peach, mint, latte
|
|||
|
|
- 支持真实图片
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
engine_id = "poster_generate_v3"
|
|||
|
|
engine_name = "海报生成 V3"
|
|||
|
|
version = "3.0.0"
|
|||
|
|
description = "基于小红书风格的海报生成,支持5种布局和5种主题"
|
|||
|
|
|
|||
|
|
def __init__(self):
|
|||
|
|
super().__init__()
|
|||
|
|
self._service = None
|
|||
|
|
|
|||
|
|
def get_param_schema(self) -> Dict[str, Any]:
|
|||
|
|
"""定义参数结构"""
|
|||
|
|
return {
|
|||
|
|
"layout": {
|
|||
|
|
"type": "str",
|
|||
|
|
"required": False,
|
|||
|
|
"default": "hero_bottom",
|
|||
|
|
"desc": "布局: hero_bottom, overlay_center, overlay_bottom, split_vertical, card_float",
|
|||
|
|
},
|
|||
|
|
"theme": {
|
|||
|
|
"type": "str",
|
|||
|
|
"required": False,
|
|||
|
|
"default": "ocean",
|
|||
|
|
"desc": "主题: ocean, sunset, peach, mint, latte",
|
|||
|
|
},
|
|||
|
|
"content": {
|
|||
|
|
"type": "object",
|
|||
|
|
"required": True,
|
|||
|
|
"desc": "内容 {title, subtitle, price, tags, details, ...}",
|
|||
|
|
},
|
|||
|
|
"image_url": {
|
|||
|
|
"type": "str",
|
|||
|
|
"required": False,
|
|||
|
|
"desc": "背景图片 URL",
|
|||
|
|
},
|
|||
|
|
"image_path": {
|
|||
|
|
"type": "str",
|
|||
|
|
"required": False,
|
|||
|
|
"desc": "背景图片本地路径",
|
|||
|
|
},
|
|||
|
|
"image_base64": {
|
|||
|
|
"type": "str",
|
|||
|
|
"required": False,
|
|||
|
|
"desc": "背景图片 Base64",
|
|||
|
|
},
|
|||
|
|
"auto_layout": {
|
|||
|
|
"type": "bool",
|
|||
|
|
"required": False,
|
|||
|
|
"default": False,
|
|||
|
|
"desc": "是否自动推荐布局",
|
|||
|
|
},
|
|||
|
|
"auto_theme": {
|
|||
|
|
"type": "bool",
|
|||
|
|
"required": False,
|
|||
|
|
"default": False,
|
|||
|
|
"desc": "是否自动推荐主题",
|
|||
|
|
},
|
|||
|
|
"category": {
|
|||
|
|
"type": "str",
|
|||
|
|
"required": False,
|
|||
|
|
"desc": "内容分类 (用于自动推荐主题)",
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def estimate_duration(self, params: Dict[str, Any]) -> int:
|
|||
|
|
"""预估执行时间"""
|
|||
|
|
has_image = bool(params.get('image_url') or params.get('image_path') or params.get('image_base64'))
|
|||
|
|
return 10 if has_image else 5
|
|||
|
|
|
|||
|
|
async def execute(self, params: Dict[str, Any], task_id: str = None) -> EngineResult:
|
|||
|
|
"""执行海报生成"""
|
|||
|
|
try:
|
|||
|
|
self.log("开始生成海报 (V3)")
|
|||
|
|
self.set_progress(task_id, 10)
|
|||
|
|
|
|||
|
|
# 提取参数
|
|||
|
|
layout = params.get('layout', 'hero_bottom')
|
|||
|
|
theme = params.get('theme', 'ocean')
|
|||
|
|
content = params.get('content', {})
|
|||
|
|
image_url = params.get('image_url')
|
|||
|
|
image_path = params.get('image_path')
|
|||
|
|
image_base64 = params.get('image_base64')
|
|||
|
|
auto_layout = params.get('auto_layout', False)
|
|||
|
|
auto_theme = params.get('auto_theme', False)
|
|||
|
|
category = params.get('category')
|
|||
|
|
|
|||
|
|
if not content.get('title'):
|
|||
|
|
return EngineResult(
|
|||
|
|
success=False,
|
|||
|
|
error="内容缺少标题",
|
|||
|
|
error_code="MISSING_TITLE"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
self.set_progress(task_id, 20)
|
|||
|
|
|
|||
|
|
# 获取服务
|
|||
|
|
service = self._get_service()
|
|||
|
|
|
|||
|
|
# 自动推荐布局/主题
|
|||
|
|
if auto_layout:
|
|||
|
|
layout = service.suggest_layout(content)
|
|||
|
|
self.log(f"自动推荐布局: {layout}")
|
|||
|
|
|
|||
|
|
if auto_theme and category:
|
|||
|
|
theme = service.suggest_theme(category)
|
|||
|
|
self.log(f"自动推荐主题: {theme}")
|
|||
|
|
|
|||
|
|
self.set_progress(task_id, 30)
|
|||
|
|
|
|||
|
|
# 下载图片 (如果是 URL)
|
|||
|
|
final_image_base64 = image_base64
|
|||
|
|
if image_url and not final_image_base64:
|
|||
|
|
self.log(f"下载图片: {image_url[:50]}...")
|
|||
|
|
final_image_base64 = await self._download_image(image_url)
|
|||
|
|
|
|||
|
|
self.set_progress(task_id, 50)
|
|||
|
|
|
|||
|
|
# 生成海报
|
|||
|
|
result = service.generate(
|
|||
|
|
layout=layout,
|
|||
|
|
theme=theme,
|
|||
|
|
content=content,
|
|||
|
|
image_base64=final_image_base64,
|
|||
|
|
image_path=image_path,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
self.set_progress(task_id, 90)
|
|||
|
|
|
|||
|
|
if result.get('success'):
|
|||
|
|
self.set_progress(task_id, 100)
|
|||
|
|
return EngineResult(
|
|||
|
|
success=True,
|
|||
|
|
data={
|
|||
|
|
"image_base64": result['image_base64'],
|
|||
|
|
"images_base64": [result['image_base64']],
|
|||
|
|
"layout": result['layout'],
|
|||
|
|
"theme": result['theme'],
|
|||
|
|
},
|
|||
|
|
metadata={
|
|||
|
|
"layout": result['layout'],
|
|||
|
|
"theme": result['theme'],
|
|||
|
|
"has_image": bool(final_image_base64 or image_path),
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
return EngineResult(
|
|||
|
|
success=False,
|
|||
|
|
error=result.get('error', '海报生成失败'),
|
|||
|
|
error_code="GENERATION_FAILED"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
self.log(f"海报生成异常: {e}", level='error')
|
|||
|
|
return EngineResult(
|
|||
|
|
success=False,
|
|||
|
|
error=str(e),
|
|||
|
|
error_code="EXECUTION_ERROR"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def _get_service(self):
|
|||
|
|
"""获取海报服务"""
|
|||
|
|
if self._service:
|
|||
|
|
return self._service
|
|||
|
|
|
|||
|
|
from poster_v2 import get_poster_service_v2
|
|||
|
|
self._service = get_poster_service_v2()
|
|||
|
|
return self._service
|
|||
|
|
|
|||
|
|
async def _download_image(self, url: str) -> Optional[str]:
|
|||
|
|
"""从 URL 下载图片并转为 Base64"""
|
|||
|
|
try:
|
|||
|
|
async with aiohttp.ClientSession() as session:
|
|||
|
|
async with session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as response:
|
|||
|
|
if response.status == 200:
|
|||
|
|
image_bytes = await response.read()
|
|||
|
|
return base64.b64encode(image_bytes).decode('utf-8')
|
|||
|
|
else:
|
|||
|
|
self.log(f"下载图片失败: {url}, 状态码: {response.status}", level='warning')
|
|||
|
|
except Exception as e:
|
|||
|
|
self.log(f"下载图片异常: {url}, 错误: {e}", level='warning')
|
|||
|
|
return None
|