2025-07-31 15:35:23 +08:00

221 lines
7.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Poster Generation Router
海报生成路由 - API v2
"""
import logging
from typing import Dict, Any
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import JSONResponse
from ..models import (
PosterGenerationRequest,
PosterGenerationResponse,
ApiResponse
)
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/generate", response_model=PosterGenerationResponse, summary="生成海报")
async def generate_poster(
request: PosterGenerationRequest,
pipeline: Dict[str, Any] = Depends(__import__('api_v2.main', fromlist=['get_poster_pipeline']).get_poster_pipeline)
):
"""
生成海报
- **template_name**: 模板名称 (vibrant, business, collage等)
- **content**: 海报内容可选为空时将AI生成
- **images**: 图片路径列表(可选)
- **style_options**: 样式选项(可选)
- **scenic_info**: 景区信息用于AI生成内容
- **product_info**: 产品信息用于AI生成内容
- **tweet_info**: 推文信息用于AI生成内容
"""
try:
logger.info(f"开始生成海报,模板: {request.template_name}")
# 获取海报生成器
poster_generator = pipeline["poster_generator"]
# 准备参数
generation_params = {
"template_name": request.template_name,
"content": request.content,
"images": request.images,
"style_options": request.style_options
}
# 如果没有提供content添加AI生成参数
if not request.content:
generation_params.update({
"scenic_info": request.scenic_info,
"product_info": request.product_info,
"tweet_info": request.tweet_info
})
# 生成海报
request_id, poster_data = await poster_generator.generate_poster(**generation_params)
logger.info(f"海报生成完成请求ID: {request_id}")
return PosterGenerationResponse(
success=True,
message="海报生成成功",
data=poster_data,
request_id=request_id
)
except Exception as e:
error_msg = f"海报生成失败: {str(e)}"
logger.error(error_msg, exc_info=True)
return JSONResponse(
status_code=500,
content=PosterGenerationResponse(
success=False,
message="海报生成失败",
error=error_msg
).dict()
)
@router.post("/vibrant", response_model=PosterGenerationResponse, summary="生成Vibrant模板海报")
async def generate_vibrant_poster(
scenic_info: str,
product_info: str,
tweet_info: str = "",
transparent_background: bool = True,
pipeline: Dict[str, Any] = Depends(__import__('api_v2.main', fromlist=['get_poster_pipeline']).get_poster_pipeline)
):
"""
专门生成Vibrant模板海报
- **scenic_info**: 景区信息
- **product_info**: 产品信息
- **tweet_info**: 推文信息
- **transparent_background**: 是否使用透明背景
"""
try:
logger.info("开始生成Vibrant模板海报")
# 获取文本生成器
text_generator = pipeline["text_generator"]
poster_generator = pipeline["poster_generator"]
# 生成Vibrant专用内容
vibrant_content = await text_generator.generate_vibrant_content(
scenic_info=scenic_info,
product_info=product_info,
tweet_info=tweet_info
)
# 设置样式选项
style_options = {
"transparent_background": transparent_background,
"output_format": "fabric_json" # 输出Fabric.js格式
}
# 生成海报
request_id, poster_data = await poster_generator.generate_poster(
template_name="vibrant",
content=vibrant_content,
style_options=style_options
)
logger.info(f"Vibrant海报生成完成请求ID: {request_id}")
return PosterGenerationResponse(
success=True,
message="Vibrant海报生成成功",
data=poster_data,
request_id=request_id
)
except Exception as e:
error_msg = f"Vibrant海报生成失败: {str(e)}"
logger.error(error_msg, exc_info=True)
return JSONResponse(
status_code=500,
content=PosterGenerationResponse(
success=False,
message="Vibrant海报生成失败",
error=error_msg
).dict()
)
@router.get("/templates", response_model=ApiResponse, summary="获取可用模板")
async def get_available_templates(
pipeline: Dict[str, Any] = Depends(__import__('api_v2.main', fromlist=['get_poster_pipeline']).get_poster_pipeline)
):
"""获取所有可用的海报模板"""
try:
poster_generator = pipeline["poster_generator"]
templates = poster_generator.get_available_templates()
template_info = {}
for template_name in templates:
template_info[template_name] = poster_generator.get_template_info(template_name)
return ApiResponse(
success=True,
message="获取模板列表成功",
data={
"templates": templates,
"template_info": template_info
}
)
except Exception as e:
error_msg = f"获取模板列表失败: {str(e)}"
logger.error(error_msg, exc_info=True)
return JSONResponse(
status_code=500,
content=ApiResponse(
success=False,
message="获取模板列表失败",
error=error_msg
).dict()
)
@router.get("/pipeline/stats", response_model=ApiResponse, summary="获取流水线统计")
async def get_pipeline_stats(
pipeline: Dict[str, Any] = Depends(__import__('api_v2.main', fromlist=['get_poster_pipeline']).get_poster_pipeline)
):
"""获取海报生成流水线的统计信息"""
try:
stats = {
"poster_generator": pipeline["poster_generator"].get_generation_stats(),
"text_generator": pipeline["text_generator"].get_generator_stats(),
"template_manager": pipeline["template_manager"].get_manager_info(),
"config": pipeline["config"].poster_generation.dict()
}
return ApiResponse(
success=True,
message="统计信息获取成功",
data=stats
)
except Exception as e:
error_msg = f"获取统计信息失败: {str(e)}"
logger.error(error_msg, exc_info=True)
return JSONResponse(
status_code=500,
content=ApiResponse(
success=False,
message="获取统计信息失败",
error=error_msg
).dict()
)