167 lines
5.8 KiB
Python
167 lines
5.8 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
海报API路由 - 统一生成接口
|
||
"""
|
||
|
||
import logging
|
||
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
||
from typing import List, Dict, Any
|
||
|
||
from core.config import ConfigManager
|
||
from core.ai import AIAgent
|
||
from utils.file_io import OutputManager
|
||
from api.services.poster import PosterService
|
||
from api.models.poster import (
|
||
PosterGenerateRequest, PosterGenerateResponse,
|
||
ImageUsageRequest, ImageUsageResponse,
|
||
TemplateListResponse
|
||
)
|
||
|
||
# 从依赖注入模块导入依赖
|
||
from api.dependencies import get_config, get_ai_agent, get_output_manager
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 创建路由
|
||
router = APIRouter()
|
||
|
||
# 依赖注入函数
|
||
def get_poster_service(
|
||
config_manager: ConfigManager = Depends(get_config),
|
||
ai_agent: AIAgent = Depends(get_ai_agent),
|
||
output_manager: OutputManager = Depends(get_output_manager)
|
||
) -> PosterService:
|
||
"""获取海报服务"""
|
||
return PosterService(ai_agent, config_manager, output_manager)
|
||
|
||
|
||
@router.get("/test/templates", summary="测试模板配置")
|
||
async def test_templates(
|
||
poster_service: PosterService = Depends(get_poster_service)
|
||
):
|
||
"""
|
||
测试模板配置是否正确加载
|
||
"""
|
||
try:
|
||
# 获取所有模板信息
|
||
templates = poster_service._templates
|
||
|
||
# 检查每个模板的配置
|
||
result = {}
|
||
for template_id, template_info in templates.items():
|
||
result[template_id] = {
|
||
"basic_info": template_info,
|
||
"has_system_prompt": bool(template_info.get('system_prompt')),
|
||
"has_user_prompt_template": bool(template_info.get('user_prompt_template')),
|
||
"prompt_lengths": {
|
||
"system_prompt": len(template_info.get('system_prompt', '')),
|
||
"user_prompt_template": len(template_info.get('user_prompt_template', ''))
|
||
}
|
||
}
|
||
|
||
return {
|
||
"message": "模板配置检查完成",
|
||
"template_count": len(templates),
|
||
"templates": result
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"测试模板配置失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"测试失败: {str(e)}")
|
||
|
||
|
||
@router.get("/templates", response_model=TemplateListResponse, summary="获取海报模板列表")
|
||
async def get_templates(
|
||
poster_service: PosterService = Depends(get_poster_service)
|
||
):
|
||
"""
|
||
获取所有可用的海报模板列表
|
||
"""
|
||
try:
|
||
templates = poster_service.get_available_templates()
|
||
return TemplateListResponse(
|
||
templates=templates,
|
||
totalCount=len(templates)
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"获取模板列表失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"获取模板列表失败: {str(e)}")
|
||
|
||
|
||
@router.get("/templates/{template_id}", summary="获取指定模板信息")
|
||
async def get_template_info(
|
||
template_id: str,
|
||
poster_service: PosterService = Depends(get_poster_service)
|
||
):
|
||
"""
|
||
获取指定模板的详细信息
|
||
"""
|
||
try:
|
||
template_info = await poster_service.get_template_info(template_id)
|
||
if not template_info:
|
||
raise HTTPException(status_code=404, detail=f"模板 {template_id} 不存在")
|
||
|
||
return template_info
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"获取模板信息失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"获取模板信息失败: {str(e)}")
|
||
|
||
|
||
@router.post("/generate", response_model=PosterGenerateResponse, summary="生成海报")
|
||
async def generate_poster(
|
||
request: PosterGenerateRequest,
|
||
poster_service: PosterService = Depends(get_poster_service)
|
||
):
|
||
"""
|
||
生成海报
|
||
|
||
- **content_id**: 内容ID(可选)
|
||
- **product_id**: 产品ID(可选)
|
||
- **scenic_spot_id**: 景区ID(可选)
|
||
- **images_base64**: 图像base64编码(可选)
|
||
- **template_id**: 模板ID(默认为vibrant)
|
||
- **poster_content**: 用户提供的海报内容(可选)
|
||
- **force_llm_generation**: 是否强制使用LLM生成内容(可选)
|
||
- **generate_psd**: 是否生成PSD分层文件(可选)
|
||
- **psd_output_path**: PSD文件输出路径(可选)
|
||
"""
|
||
try:
|
||
result = await poster_service.generate_poster(
|
||
template_id=request.templateId,
|
||
poster_content=request.posterContent,
|
||
content_id=request.contentId,
|
||
product_id=request.productId,
|
||
scenic_spot_id=request.scenicSpotId,
|
||
images_base64=request.imagesBase64,
|
||
num_variations=request.numVariations,
|
||
force_llm_generation=request.forceLlmGeneration,
|
||
generate_psd=request.generatePsd,
|
||
psd_output_path=request.psdOutputPath
|
||
)
|
||
|
||
return PosterGenerateResponse(**result)
|
||
except ValueError as e:
|
||
logger.error(f"参数错误: {e}")
|
||
raise HTTPException(status_code=400, detail=str(e))
|
||
except Exception as e:
|
||
logger.error(f"生成海报失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"生成海报失败: {str(e)}")
|
||
|
||
|
||
@router.post("/image-usage", response_model=ImageUsageResponse, summary="查询图像使用情况")
|
||
async def get_image_usage(
|
||
request: ImageUsageRequest,
|
||
poster_service: PosterService = Depends(get_poster_service)
|
||
):
|
||
"""
|
||
查询图像使用情况
|
||
"""
|
||
try:
|
||
result = poster_service.get_image_usage_info(request.image_ids)
|
||
return ImageUsageResponse(**result)
|
||
except Exception as e:
|
||
logger.error(f"查询图像使用情况失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"查询图像使用情况失败: {str(e)}") |