116 lines
3.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
海报API路由
"""
import logging
from fastapi import APIRouter, Depends, HTTPException
from typing import List, Dict, Any
from core.config import ConfigManager
from core.ai import AIAgent
from utils.file_io import OutputManager
from api.services.poster import PosterService
from api.models.poster import (
PosterRequest, PosterResponse,
TemplateListResponse,
PosterTextRequest, PosterTextResponse
)
# 从依赖注入模块导入依赖
from api.dependencies import get_config, get_ai_agent, get_output_manager
logger = logging.getLogger(__name__)
# 创建路由
router = APIRouter()
# 依赖注入函数
def get_poster_service(
config_manager: ConfigManager = Depends(get_config),
ai_agent: AIAgent = Depends(get_ai_agent),
output_manager: OutputManager = Depends(get_output_manager)
) -> PosterService:
"""获取海报服务"""
return PosterService(ai_agent, config_manager, output_manager)
@router.post("/generate", response_model=PosterResponse, summary="生成海报")
async def generate_poster(
request: PosterRequest,
poster_service: PosterService = Depends(get_poster_service)
):
"""
生成海报
- **content**: 内容数据,包含标题、正文等
- **topic_index**: 主题索引,用于文件命名
- **template_name**: 模板名称如果为None则根据配置选择
"""
try:
request_id, topic_index, poster_path, template_name = poster_service.generate_poster(
content=request.content,
topic_index=request.topic_index,
template_name=request.template_name
)
return PosterResponse(
request_id=request_id,
topic_index=topic_index,
poster_path=poster_path,
template_name=template_name
)
except Exception as e:
logger.error(f"生成海报失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"生成海报失败: {str(e)}")
@router.get("/templates", response_model=TemplateListResponse, summary="获取可用模板列表")
async def get_templates(
poster_service: PosterService = Depends(get_poster_service)
):
"""获取可用的海报模板列表"""
try:
templates, default_template = poster_service.get_available_templates()
return TemplateListResponse(
templates=templates,
default_template=default_template
)
except Exception as e:
logger.error(f"获取模板列表失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"获取模板列表失败: {str(e)}")
@router.post("/text", response_model=PosterTextResponse, summary="生成海报文案")
async def generate_poster_text(
request: PosterTextRequest,
poster_service: PosterService = Depends(get_poster_service)
):
"""
生成海报文案
- **system_prompt**: 系统提示词
- **user_prompt**: 用户提示词
- **context_data**: 上下文数据,用于填充提示词中的占位符
- **temperature**: 生成温度参数
- **top_p**: top_p参数
"""
try:
request_id, text_content = await poster_service.generate_poster_text(
system_prompt=request.system_prompt,
user_prompt=request.user_prompt,
context_data=request.context_data,
temperature=request.temperature,
top_p=request.top_p
)
return PosterTextResponse(
request_id=request_id,
text_content=text_content
)
except Exception as e:
logger.error(f"生成海报文案失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"生成海报文案失败: {str(e)}")