314 lines
10 KiB
Python
314 lines
10 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
统一海报API路由
|
||
支持多种模板类型的海报生成,配置化管理
|
||
"""
|
||
|
||
import logging
|
||
import uuid
|
||
from datetime import datetime, timezone
|
||
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
||
from typing import Dict, Any, List
|
||
|
||
from core.ai import AIAgent
|
||
from api.services.poster_service import UnifiedPosterService
|
||
from api.models.vibrant_poster import (
|
||
PosterGenerationRequest, PosterGenerationResponse,
|
||
ContentGenerationRequest, ContentGenerationResponse,
|
||
TemplateListResponse, TemplateInfo,
|
||
BaseAPIResponse
|
||
)
|
||
|
||
# 从依赖注入模块导入依赖
|
||
from api.dependencies import get_ai_agent
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 创建路由
|
||
router = APIRouter()
|
||
|
||
# 依赖注入函数
|
||
def get_unified_poster_service(
|
||
ai_agent: AIAgent = Depends(get_ai_agent)
|
||
) -> UnifiedPosterService:
|
||
"""获取统一海报服务"""
|
||
return UnifiedPosterService(ai_agent)
|
||
|
||
|
||
def create_response(success: bool, message: str, data: Any = None, request_id: str = None) -> Dict[str, Any]:
|
||
"""创建标准响应"""
|
||
if request_id is None:
|
||
request_id = f"req-{datetime.now().strftime('%Y%m%d-%H%M%S')}-{str(uuid.uuid4())[:8]}"
|
||
|
||
return {
|
||
"success": success,
|
||
"message": message,
|
||
"request_id": request_id,
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
"data": data
|
||
}
|
||
|
||
|
||
@router.get("/templates", response_model=TemplateListResponse, summary="获取所有可用模板")
|
||
async def get_templates(
|
||
service: UnifiedPosterService = Depends(get_unified_poster_service)
|
||
):
|
||
"""
|
||
获取所有可用的海报模板列表
|
||
|
||
返回每个模板的详细信息,包括:
|
||
- 模板ID和名称
|
||
- 模板描述
|
||
- 模板尺寸
|
||
- 必填字段和可选字段
|
||
"""
|
||
try:
|
||
templates = service.get_available_templates()
|
||
response_data = create_response(
|
||
success=True,
|
||
message="获取模板列表成功",
|
||
data=templates
|
||
)
|
||
return TemplateListResponse(**response_data)
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取模板列表失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"获取模板列表失败: {str(e)}")
|
||
|
||
|
||
@router.get("/templates/{template_id}", response_model=BaseAPIResponse, summary="获取指定模板信息")
|
||
async def get_template_info(
|
||
template_id: str,
|
||
service: UnifiedPosterService = Depends(get_unified_poster_service)
|
||
):
|
||
"""
|
||
获取指定模板的详细信息
|
||
|
||
参数:
|
||
- **template_id**: 模板ID
|
||
"""
|
||
try:
|
||
template_info = service.get_template_info(template_id)
|
||
if not template_info:
|
||
raise HTTPException(status_code=404, detail=f"模板 {template_id} 不存在")
|
||
|
||
response_data = create_response(
|
||
success=True,
|
||
message="获取模板信息成功",
|
||
data=template_info.dict()
|
||
)
|
||
return BaseAPIResponse(**response_data)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"获取模板信息失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"获取模板信息失败: {str(e)}")
|
||
|
||
|
||
@router.post("/content/generate", response_model=ContentGenerationResponse, summary="生成海报内容")
|
||
async def generate_content(
|
||
request: ContentGenerationRequest,
|
||
service: UnifiedPosterService = Depends(get_unified_poster_service)
|
||
):
|
||
"""
|
||
根据源数据生成海报内容,不生成实际图片
|
||
|
||
用于:
|
||
1. 预览生成的内容
|
||
2. 调试和测试内容生成
|
||
3. 分步骤生成(先生成内容,再生成图片)
|
||
|
||
参数:
|
||
- **template_id**: 模板ID
|
||
- **source_data**: 源数据,用于AI生成内容
|
||
- **temperature**: AI生成温度参数
|
||
"""
|
||
try:
|
||
content = await service.generate_content(
|
||
template_id=request.template_id,
|
||
source_data=request.source_data,
|
||
temperature=request.temperature
|
||
)
|
||
|
||
response_data = create_response(
|
||
success=True,
|
||
message="内容生成成功",
|
||
data={
|
||
"template_id": request.template_id,
|
||
"content": content,
|
||
"metadata": {
|
||
"generation_method": "ai_generated",
|
||
"temperature": request.temperature
|
||
}
|
||
}
|
||
)
|
||
return ContentGenerationResponse(**response_data)
|
||
|
||
except Exception as e:
|
||
logger.error(f"生成内容失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"生成内容失败: {str(e)}")
|
||
|
||
|
||
@router.post("/generate", response_model=PosterGenerationResponse, summary="生成海报")
|
||
async def generate_poster(
|
||
request: PosterGenerationRequest,
|
||
service: UnifiedPosterService = Depends(get_unified_poster_service)
|
||
):
|
||
"""
|
||
生成海报图片
|
||
|
||
支持两种模式:
|
||
1. 直接提供内容(content字段)
|
||
2. 提供源数据让AI生成内容(source_data字段)
|
||
|
||
参数:
|
||
- **template_id**: 模板ID
|
||
- **content**: 直接提供的海报内容(可选)
|
||
- **source_data**: 源数据,用于AI生成内容(可选)
|
||
- **topic_name**: 主题名称,用于文件命名
|
||
- **image_path**: 指定图片路径(可选)
|
||
- **image_dir**: 图片目录(可选)
|
||
- **output_dir**: 输出目录(可选)
|
||
- **temperature**: AI生成温度参数
|
||
"""
|
||
try:
|
||
result = await service.generate_poster(
|
||
template_id=request.template_id,
|
||
content=request.content,
|
||
source_data=request.source_data,
|
||
topic_name=request.topic_name,
|
||
image_path=request.image_path,
|
||
image_dir=request.image_dir,
|
||
output_dir=request.output_dir,
|
||
temperature=request.temperature
|
||
)
|
||
|
||
response_data = create_response(
|
||
success=True,
|
||
message="海报生成成功",
|
||
data=result,
|
||
request_id=result["request_id"]
|
||
)
|
||
return PosterGenerationResponse(**response_data)
|
||
|
||
except Exception as e:
|
||
logger.error(f"生成海报失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"生成海报失败: {str(e)}")
|
||
|
||
|
||
@router.post("/batch", response_model=BaseAPIResponse, summary="批量生成海报")
|
||
async def batch_generate_posters(
|
||
template_id: str,
|
||
base_path: str,
|
||
image_dir: str = None,
|
||
source_files: Dict[str, str] = None,
|
||
output_base: str = "result/posters",
|
||
parallel_count: int = 3,
|
||
temperature: float = 0.7,
|
||
service: UnifiedPosterService = Depends(get_unified_poster_service)
|
||
):
|
||
"""
|
||
批量生成海报
|
||
|
||
自动扫描指定目录下的topic文件夹,为每个topic生成海报。
|
||
|
||
参数:
|
||
- **template_id**: 模板ID
|
||
- **base_path**: 包含多个topic目录的基础路径
|
||
- **image_dir**: 图片目录(可选)
|
||
- **source_files**: 源文件配置字典(可选)
|
||
- **output_base**: 输出基础目录
|
||
- **parallel_count**: 并发处理数量
|
||
- **temperature**: AI生成温度参数
|
||
"""
|
||
try:
|
||
result = await service.batch_generate_posters(
|
||
template_id=template_id,
|
||
base_path=base_path,
|
||
image_dir=image_dir,
|
||
source_files=source_files or {},
|
||
output_base=output_base,
|
||
parallel_count=parallel_count,
|
||
temperature=temperature
|
||
)
|
||
|
||
response_data = create_response(
|
||
success=True,
|
||
message=f"批量生成完成,成功: {result['successful_count']}, 失败: {result['failed_count']}",
|
||
data=result,
|
||
request_id=result["request_id"]
|
||
)
|
||
return BaseAPIResponse(**response_data)
|
||
|
||
except Exception as e:
|
||
logger.error(f"批量生成海报失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"批量生成海报失败: {str(e)}")
|
||
|
||
|
||
@router.post("/config/reload", response_model=BaseAPIResponse, summary="重新加载配置")
|
||
async def reload_config(
|
||
service: UnifiedPosterService = Depends(get_unified_poster_service)
|
||
):
|
||
"""
|
||
重新加载海报配置
|
||
|
||
用于在不重启服务的情况下更新配置,包括:
|
||
- 提示词模板
|
||
- 模板配置
|
||
- 默认参数
|
||
"""
|
||
try:
|
||
service.reload_config()
|
||
response_data = create_response(
|
||
success=True,
|
||
message="配置重新加载成功"
|
||
)
|
||
return BaseAPIResponse(**response_data)
|
||
|
||
except Exception as e:
|
||
logger.error(f"重新加载配置失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"重新加载配置失败: {str(e)}")
|
||
|
||
|
||
@router.get("/health", summary="健康检查")
|
||
async def health_check():
|
||
"""服务健康检查"""
|
||
return create_response(
|
||
success=True,
|
||
message="统一海报服务运行正常",
|
||
data={
|
||
"service": "unified_poster",
|
||
"status": "healthy",
|
||
"version": "2.0.0"
|
||
}
|
||
)
|
||
|
||
|
||
@router.get("/config", summary="获取服务配置")
|
||
async def get_service_config(
|
||
service: UnifiedPosterService = Depends(get_unified_poster_service)
|
||
):
|
||
"""获取服务配置信息"""
|
||
try:
|
||
config_info = {
|
||
"default_image_dir": service.config_manager.get_default_config("image_dir"),
|
||
"default_output_dir": service.config_manager.get_default_config("output_dir"),
|
||
"default_font_dir": service.config_manager.get_default_config("font_dir"),
|
||
"default_template": service.config_manager.get_default_config("template"),
|
||
"supported_image_formats": ["png", "jpg", "jpeg", "webp"],
|
||
"available_templates": len(service.get_available_templates())
|
||
}
|
||
|
||
response_data = create_response(
|
||
success=True,
|
||
message="获取配置成功",
|
||
data=config_info
|
||
)
|
||
return BaseAPIResponse(**response_data)
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取配置失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"获取配置失败: {str(e)}") |