116 lines
3.7 KiB
Python
116 lines
3.7 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
海报API路由
|
||
"""
|
||
|
||
import logging
|
||
from fastapi import APIRouter, Depends, HTTPException
|
||
from typing import List, Dict, Any
|
||
|
||
from core.config import ConfigManager
|
||
from core.ai import AIAgent
|
||
from utils.file_io import OutputManager
|
||
from api.services.poster import PosterService
|
||
from api.models.poster import (
|
||
PosterRequest, PosterResponse,
|
||
TemplateListResponse,
|
||
PosterTextRequest, PosterTextResponse
|
||
)
|
||
|
||
# 从依赖注入模块导入依赖
|
||
from api.dependencies import get_config, get_ai_agent, get_output_manager
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 创建路由
|
||
router = APIRouter()
|
||
|
||
# 依赖注入函数
|
||
def get_poster_service(
|
||
config_manager: ConfigManager = Depends(get_config),
|
||
ai_agent: AIAgent = Depends(get_ai_agent),
|
||
output_manager: OutputManager = Depends(get_output_manager)
|
||
) -> PosterService:
|
||
"""获取海报服务"""
|
||
return PosterService(ai_agent, config_manager, output_manager)
|
||
|
||
|
||
@router.post("/generate", response_model=PosterResponse, summary="生成海报")
|
||
async def generate_poster(
|
||
request: PosterRequest,
|
||
poster_service: PosterService = Depends(get_poster_service)
|
||
):
|
||
"""
|
||
生成海报
|
||
|
||
- **content**: 内容数据,包含标题、正文等
|
||
- **topic_index**: 主题索引,用于文件命名
|
||
- **template_name**: 模板名称,如果为None则根据配置选择
|
||
"""
|
||
try:
|
||
request_id, topic_index, poster_path, template_name = poster_service.generate_poster(
|
||
content=request.content,
|
||
topic_index=request.topic_index,
|
||
template_name=request.template_name
|
||
)
|
||
|
||
return PosterResponse(
|
||
request_id=request_id,
|
||
topic_index=topic_index,
|
||
poster_path=poster_path,
|
||
template_name=template_name
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"生成海报失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"生成海报失败: {str(e)}")
|
||
|
||
|
||
@router.get("/templates", response_model=TemplateListResponse, summary="获取可用模板列表")
|
||
async def get_templates(
|
||
poster_service: PosterService = Depends(get_poster_service)
|
||
):
|
||
"""获取可用的海报模板列表"""
|
||
try:
|
||
templates, default_template = poster_service.get_available_templates()
|
||
|
||
return TemplateListResponse(
|
||
templates=templates,
|
||
default_template=default_template
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"获取模板列表失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"获取模板列表失败: {str(e)}")
|
||
|
||
|
||
@router.post("/text", response_model=PosterTextResponse, summary="生成海报文案")
|
||
async def generate_poster_text(
|
||
request: PosterTextRequest,
|
||
poster_service: PosterService = Depends(get_poster_service)
|
||
):
|
||
"""
|
||
生成海报文案
|
||
|
||
- **system_prompt**: 系统提示词
|
||
- **user_prompt**: 用户提示词
|
||
- **context_data**: 上下文数据,用于填充提示词中的占位符
|
||
- **temperature**: 生成温度参数
|
||
- **top_p**: top_p参数
|
||
"""
|
||
try:
|
||
request_id, text_content = await poster_service.generate_poster_text(
|
||
system_prompt=request.system_prompt,
|
||
user_prompt=request.user_prompt,
|
||
context_data=request.context_data,
|
||
temperature=request.temperature,
|
||
top_p=request.top_p
|
||
)
|
||
|
||
return PosterTextResponse(
|
||
request_id=request_id,
|
||
text_content=text_content
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"生成海报文案失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"生成海报文案失败: {str(e)}") |