TravelContentCreator/domain/aigc/engines/poster_generate_v3.py

209 lines
7.2 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
海报生成引擎 V3
使用 poster_v2 模块支持5种布局和5种主题
"""
import logging
from typing import Dict, Any, Optional, List
import aiohttp
import base64
from .base import BaseAIGCEngine, EngineResult
logger = logging.getLogger(__name__)
class PosterGenerateEngineV3(BaseAIGCEngine):
"""
海报生成引擎 V3
使用 poster_v2 模块:
- 5种布局: hero_bottom, overlay_center, overlay_bottom, split_vertical, card_float
- 5种主题: ocean, sunset, peach, mint, latte
- 支持真实图片
"""
engine_id = "poster_generate_v3"
engine_name = "海报生成 V3"
version = "3.0.0"
description = "基于小红书风格的海报生成支持5种布局和5种主题"
def __init__(self):
super().__init__()
self._service = None
def get_param_schema(self) -> Dict[str, Any]:
"""定义参数结构"""
return {
"layout": {
"type": "str",
"required": False,
"default": "hero_bottom",
"desc": "布局: hero_bottom, overlay_center, overlay_bottom, split_vertical, card_float",
},
"theme": {
"type": "str",
"required": False,
"default": "ocean",
"desc": "主题: ocean, sunset, peach, mint, latte",
},
"content": {
"type": "object",
"required": True,
"desc": "内容 {title, subtitle, price, tags, details, ...}",
},
"image_url": {
"type": "str",
"required": False,
"desc": "背景图片 URL",
},
"image_path": {
"type": "str",
"required": False,
"desc": "背景图片本地路径",
},
"image_base64": {
"type": "str",
"required": False,
"desc": "背景图片 Base64",
},
"auto_layout": {
"type": "bool",
"required": False,
"default": False,
"desc": "是否自动推荐布局",
},
"auto_theme": {
"type": "bool",
"required": False,
"default": False,
"desc": "是否自动推荐主题",
},
"category": {
"type": "str",
"required": False,
"desc": "内容分类 (用于自动推荐主题)",
},
}
def estimate_duration(self, params: Dict[str, Any]) -> int:
"""预估执行时间"""
has_image = bool(params.get('image_url') or params.get('image_path') or params.get('image_base64'))
return 10 if has_image else 5
async def execute(self, params: Dict[str, Any], task_id: str = None) -> EngineResult:
"""执行海报生成"""
try:
self.log("开始生成海报 (V3)")
self.set_progress(task_id, 10)
# 提取参数
layout = params.get('layout', 'hero_bottom')
theme = params.get('theme', 'ocean')
content = params.get('content', {})
image_url = params.get('image_url')
image_path = params.get('image_path')
image_base64 = params.get('image_base64')
auto_layout = params.get('auto_layout', False)
auto_theme = params.get('auto_theme', False)
category = params.get('category')
if not content.get('title'):
return EngineResult(
success=False,
error="内容缺少标题",
error_code="MISSING_TITLE"
)
self.set_progress(task_id, 20)
# 获取服务
service = self._get_service()
# 自动推荐布局/主题
if auto_layout:
layout = service.suggest_layout(content)
self.log(f"自动推荐布局: {layout}")
if auto_theme and category:
theme = service.suggest_theme(category)
self.log(f"自动推荐主题: {theme}")
self.set_progress(task_id, 30)
# 下载图片 (如果是 URL)
final_image_base64 = image_base64
if image_url and not final_image_base64:
self.log(f"下载图片: {image_url[:50]}...")
final_image_base64 = await self._download_image(image_url)
self.set_progress(task_id, 50)
# 生成海报
result = service.generate(
layout=layout,
theme=theme,
content=content,
image_base64=final_image_base64,
image_path=image_path,
)
self.set_progress(task_id, 90)
if result.get('success'):
self.set_progress(task_id, 100)
return EngineResult(
success=True,
data={
"image_base64": result['image_base64'],
"images_base64": [result['image_base64']],
"layout": result['layout'],
"theme": result['theme'],
},
metadata={
"layout": result['layout'],
"theme": result['theme'],
"has_image": bool(final_image_base64 or image_path),
}
)
else:
return EngineResult(
success=False,
error=result.get('error', '海报生成失败'),
error_code="GENERATION_FAILED"
)
except Exception as e:
self.log(f"海报生成异常: {e}", level='error')
return EngineResult(
success=False,
error=str(e),
error_code="EXECUTION_ERROR"
)
def _get_service(self):
"""获取海报服务"""
if self._service:
return self._service
from poster_v2 import get_poster_service_v2
self._service = get_poster_service_v2()
return self._service
async def _download_image(self, url: str) -> Optional[str]:
"""从 URL 下载图片并转为 Base64"""
try:
async with aiohttp.ClientSession() as session:
async with session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as response:
if response.status == 200:
image_bytes = await response.read()
return base64.b64encode(image_bytes).decode('utf-8')
else:
self.log(f"下载图片失败: {url}, 状态码: {response.status}", level='warning')
except Exception as e:
self.log(f"下载图片异常: {url}, 错误: {e}", level='warning')
return None