167 lines
5.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
海报API路由 - 统一生成接口
"""
import logging
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
from typing import List, Dict, Any
from core.config import ConfigManager
from core.ai import AIAgent
from utils.file_io import OutputManager
from api.services.poster import PosterService
from api.models.poster import (
PosterGenerateRequest, PosterGenerateResponse,
ImageUsageRequest, ImageUsageResponse,
TemplateListResponse
)
# 从依赖注入模块导入依赖
from api.dependencies import get_config, get_ai_agent, get_output_manager
logger = logging.getLogger(__name__)
# 创建路由
router = APIRouter()
# 依赖注入函数
def get_poster_service(
config_manager: ConfigManager = Depends(get_config),
ai_agent: AIAgent = Depends(get_ai_agent),
output_manager: OutputManager = Depends(get_output_manager)
) -> PosterService:
"""获取海报服务"""
return PosterService(ai_agent, config_manager, output_manager)
@router.get("/test/templates", summary="测试模板配置")
async def test_templates(
poster_service: PosterService = Depends(get_poster_service)
):
"""
测试模板配置是否正确加载
"""
try:
# 获取所有模板信息
templates = poster_service._templates
# 检查每个模板的配置
result = {}
for template_id, template_info in templates.items():
result[template_id] = {
"basic_info": template_info,
"has_system_prompt": bool(template_info.get('system_prompt')),
"has_user_prompt_template": bool(template_info.get('user_prompt_template')),
"prompt_lengths": {
"system_prompt": len(template_info.get('system_prompt', '')),
"user_prompt_template": len(template_info.get('user_prompt_template', ''))
}
}
return {
"message": "模板配置检查完成",
"template_count": len(templates),
"templates": result
}
except Exception as e:
logger.error(f"测试模板配置失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"测试失败: {str(e)}")
@router.get("/templates", response_model=TemplateListResponse, summary="获取海报模板列表")
async def get_templates(
poster_service: PosterService = Depends(get_poster_service)
):
"""
获取所有可用的海报模板列表
"""
try:
templates = poster_service.get_available_templates()
return TemplateListResponse(
templates=templates,
totalCount=len(templates)
)
except Exception as e:
logger.error(f"获取模板列表失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"获取模板列表失败: {str(e)}")
@router.get("/templates/{template_id}", summary="获取指定模板信息")
async def get_template_info(
template_id: str,
poster_service: PosterService = Depends(get_poster_service)
):
"""
获取指定模板的详细信息
"""
try:
template_info = await poster_service.get_template_info(template_id)
if not template_info:
raise HTTPException(status_code=404, detail=f"模板 {template_id} 不存在")
return template_info
except HTTPException:
raise
except Exception as e:
logger.error(f"获取模板信息失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"获取模板信息失败: {str(e)}")
@router.post("/generate", response_model=PosterGenerateResponse, summary="生成海报")
async def generate_poster(
request: PosterGenerateRequest,
poster_service: PosterService = Depends(get_poster_service)
):
"""
生成海报
- **content_id**: 内容ID可选
- **product_id**: 产品ID可选
- **scenic_spot_id**: 景区ID可选
- **images_base64**: 图像base64编码可选
- **template_id**: 模板ID默认为vibrant
- **poster_content**: 用户提供的海报内容(可选)
- **force_llm_generation**: 是否强制使用LLM生成内容可选
- **generate_psd**: 是否生成PSD分层文件可选
- **psd_output_path**: PSD文件输出路径可选
"""
try:
result = await poster_service.generate_poster(
template_id=request.templateId,
poster_content=request.posterContent,
content_id=request.contentId,
product_id=request.productId,
scenic_spot_id=request.scenicSpotId,
images_base64=request.imagesBase64,
num_variations=request.numVariations,
force_llm_generation=request.forceLlmGeneration,
generate_psd=request.generatePsd,
psd_output_path=request.psdOutputPath
)
return PosterGenerateResponse(**result)
except ValueError as e:
logger.error(f"参数错误: {e}")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"生成海报失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"生成海报失败: {str(e)}")
@router.post("/image-usage", response_model=ImageUsageResponse, summary="查询图像使用情况")
async def get_image_usage(
request: ImageUsageRequest,
poster_service: PosterService = Depends(get_poster_service)
):
"""
查询图像使用情况
"""
try:
result = poster_service.get_image_usage_info(request.image_ids)
return ImageUsageResponse(**result)
except Exception as e:
logger.error(f"查询图像使用情况失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"查询图像使用情况失败: {str(e)}")