221 lines
7.1 KiB
Python
221 lines
7.1 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
Poster Generation Router
|
||
海报生成路由 - API v2
|
||
"""
|
||
|
||
import logging
|
||
from typing import Dict, Any
|
||
from fastapi import APIRouter, Depends, HTTPException
|
||
from fastapi.responses import JSONResponse
|
||
|
||
from ..models import (
|
||
PosterGenerationRequest,
|
||
PosterGenerationResponse,
|
||
ApiResponse
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
router = APIRouter()
|
||
|
||
|
||
@router.post("/generate", response_model=PosterGenerationResponse, summary="生成海报")
|
||
async def generate_poster(
|
||
request: PosterGenerationRequest,
|
||
pipeline: Dict[str, Any] = Depends(__import__('api_v2.main', fromlist=['get_poster_pipeline']).get_poster_pipeline)
|
||
):
|
||
"""
|
||
生成海报
|
||
|
||
- **template_name**: 模板名称 (vibrant, business, collage等)
|
||
- **content**: 海报内容(可选,为空时将AI生成)
|
||
- **images**: 图片路径列表(可选)
|
||
- **style_options**: 样式选项(可选)
|
||
- **scenic_info**: 景区信息(用于AI生成内容)
|
||
- **product_info**: 产品信息(用于AI生成内容)
|
||
- **tweet_info**: 推文信息(用于AI生成内容)
|
||
"""
|
||
try:
|
||
logger.info(f"开始生成海报,模板: {request.template_name}")
|
||
|
||
# 获取海报生成器
|
||
poster_generator = pipeline["poster_generator"]
|
||
|
||
# 准备参数
|
||
generation_params = {
|
||
"template_name": request.template_name,
|
||
"content": request.content,
|
||
"images": request.images,
|
||
"style_options": request.style_options
|
||
}
|
||
|
||
# 如果没有提供content,添加AI生成参数
|
||
if not request.content:
|
||
generation_params.update({
|
||
"scenic_info": request.scenic_info,
|
||
"product_info": request.product_info,
|
||
"tweet_info": request.tweet_info
|
||
})
|
||
|
||
# 生成海报
|
||
request_id, poster_data = await poster_generator.generate_poster(**generation_params)
|
||
|
||
logger.info(f"海报生成完成,请求ID: {request_id}")
|
||
|
||
return PosterGenerationResponse(
|
||
success=True,
|
||
message="海报生成成功",
|
||
data=poster_data,
|
||
request_id=request_id
|
||
)
|
||
|
||
except Exception as e:
|
||
error_msg = f"海报生成失败: {str(e)}"
|
||
logger.error(error_msg, exc_info=True)
|
||
|
||
return JSONResponse(
|
||
status_code=500,
|
||
content=PosterGenerationResponse(
|
||
success=False,
|
||
message="海报生成失败",
|
||
error=error_msg
|
||
).dict()
|
||
)
|
||
|
||
|
||
@router.post("/vibrant", response_model=PosterGenerationResponse, summary="生成Vibrant模板海报")
|
||
async def generate_vibrant_poster(
|
||
scenic_info: str,
|
||
product_info: str,
|
||
tweet_info: str = "",
|
||
transparent_background: bool = True,
|
||
pipeline: Dict[str, Any] = Depends(__import__('api_v2.main', fromlist=['get_poster_pipeline']).get_poster_pipeline)
|
||
):
|
||
"""
|
||
专门生成Vibrant模板海报
|
||
|
||
- **scenic_info**: 景区信息
|
||
- **product_info**: 产品信息
|
||
- **tweet_info**: 推文信息
|
||
- **transparent_background**: 是否使用透明背景
|
||
"""
|
||
try:
|
||
logger.info("开始生成Vibrant模板海报")
|
||
|
||
# 获取文本生成器
|
||
text_generator = pipeline["text_generator"]
|
||
poster_generator = pipeline["poster_generator"]
|
||
|
||
# 生成Vibrant专用内容
|
||
vibrant_content = await text_generator.generate_vibrant_content(
|
||
scenic_info=scenic_info,
|
||
product_info=product_info,
|
||
tweet_info=tweet_info
|
||
)
|
||
|
||
# 设置样式选项
|
||
style_options = {
|
||
"transparent_background": transparent_background,
|
||
"output_format": "fabric_json" # 输出Fabric.js格式
|
||
}
|
||
|
||
# 生成海报
|
||
request_id, poster_data = await poster_generator.generate_poster(
|
||
template_name="vibrant",
|
||
content=vibrant_content,
|
||
style_options=style_options
|
||
)
|
||
|
||
logger.info(f"Vibrant海报生成完成,请求ID: {request_id}")
|
||
|
||
return PosterGenerationResponse(
|
||
success=True,
|
||
message="Vibrant海报生成成功",
|
||
data=poster_data,
|
||
request_id=request_id
|
||
)
|
||
|
||
except Exception as e:
|
||
error_msg = f"Vibrant海报生成失败: {str(e)}"
|
||
logger.error(error_msg, exc_info=True)
|
||
|
||
return JSONResponse(
|
||
status_code=500,
|
||
content=PosterGenerationResponse(
|
||
success=False,
|
||
message="Vibrant海报生成失败",
|
||
error=error_msg
|
||
).dict()
|
||
)
|
||
|
||
|
||
@router.get("/templates", response_model=ApiResponse, summary="获取可用模板")
|
||
async def get_available_templates(
|
||
pipeline: Dict[str, Any] = Depends(__import__('api_v2.main', fromlist=['get_poster_pipeline']).get_poster_pipeline)
|
||
):
|
||
"""获取所有可用的海报模板"""
|
||
try:
|
||
poster_generator = pipeline["poster_generator"]
|
||
templates = poster_generator.get_available_templates()
|
||
|
||
template_info = {}
|
||
for template_name in templates:
|
||
template_info[template_name] = poster_generator.get_template_info(template_name)
|
||
|
||
return ApiResponse(
|
||
success=True,
|
||
message="获取模板列表成功",
|
||
data={
|
||
"templates": templates,
|
||
"template_info": template_info
|
||
}
|
||
)
|
||
|
||
except Exception as e:
|
||
error_msg = f"获取模板列表失败: {str(e)}"
|
||
logger.error(error_msg, exc_info=True)
|
||
|
||
return JSONResponse(
|
||
status_code=500,
|
||
content=ApiResponse(
|
||
success=False,
|
||
message="获取模板列表失败",
|
||
error=error_msg
|
||
).dict()
|
||
)
|
||
|
||
|
||
@router.get("/pipeline/stats", response_model=ApiResponse, summary="获取流水线统计")
|
||
async def get_pipeline_stats(
|
||
pipeline: Dict[str, Any] = Depends(__import__('api_v2.main', fromlist=['get_poster_pipeline']).get_poster_pipeline)
|
||
):
|
||
"""获取海报生成流水线的统计信息"""
|
||
try:
|
||
stats = {
|
||
"poster_generator": pipeline["poster_generator"].get_generation_stats(),
|
||
"text_generator": pipeline["text_generator"].get_generator_stats(),
|
||
"template_manager": pipeline["template_manager"].get_manager_info(),
|
||
"config": pipeline["config"].poster_generation.dict()
|
||
}
|
||
|
||
return ApiResponse(
|
||
success=True,
|
||
message="统计信息获取成功",
|
||
data=stats
|
||
)
|
||
|
||
except Exception as e:
|
||
error_msg = f"获取统计信息失败: {str(e)}"
|
||
logger.error(error_msg, exc_info=True)
|
||
|
||
return JSONResponse(
|
||
status_code=500,
|
||
content=ApiResponse(
|
||
success=False,
|
||
message="获取统计信息失败",
|
||
error=error_msg
|
||
).dict()
|
||
) |