352 lines
9.6 KiB
Python
Raw Normal View History

2025-12-08 14:58:35 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
AIGC 统一 API 路由
提供所有 AIGC 功能的统一入口
"""
import logging
from typing import Dict, Any, Optional, List
from fastapi import APIRouter, HTTPException, Depends
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/aigc", tags=["AIGC"])
# ========== 请求/响应模型 ==========
class ExecuteRequest(BaseModel):
"""执行请求"""
engine: str = Field(..., description="引擎 ID")
params: Dict[str, Any] = Field(default_factory=dict, description="引擎参数")
async_mode: bool = Field(default=False, description="是否异步执行")
task_id: Optional[str] = Field(default=None, description="任务 ID (可选)")
callback_url: Optional[str] = Field(default=None, description="回调 URL (异步模式)")
class ExecuteResponse(BaseModel):
"""执行响应"""
success: bool
task_id: Optional[str] = None
status: Optional[str] = None
data: Optional[Dict[str, Any]] = None
error: Optional[str] = None
error_code: Optional[str] = None
estimated_duration: Optional[int] = None
class TaskStatusResponse(BaseModel):
"""任务状态响应"""
task_id: str
engine_id: str
status: str
progress: int
result: Optional[Dict[str, Any]] = None
error: Optional[str] = None
created_at: Optional[str] = None
completed_at: Optional[str] = None
class EngineInfo(BaseModel):
"""引擎信息"""
id: str
name: str
version: str
description: str
param_schema: Optional[Dict[str, Any]] = None
class EngineListResponse(BaseModel):
"""引擎列表响应"""
engines: List[EngineInfo]
count: int
# ========== 依赖注入 ==========
def get_engine_registry():
"""获取引擎注册表"""
from domain.aigc import EngineRegistry
return EngineRegistry()
def get_engine_executor():
"""获取引擎执行器"""
from domain.aigc import EngineExecutor
return EngineExecutor()
def ensure_initialized():
"""确保引擎系统已初始化"""
from domain.aigc import EngineRegistry
from domain.aigc.shared import ComponentFactory
registry = EngineRegistry()
# 如果还没有引擎,进行初始化
if not registry.list_engines():
logger.info("初始化 AIGC 引擎系统...")
# 创建共享组件
factory = ComponentFactory()
try:
from api.dependencies import get_ai_agent, get_config, get_database_service
components = factory.create_components(
ai_agent=get_ai_agent(),
db_service=get_database_service(),
config_manager=get_config()
)
except Exception as e:
logger.warning(f"获取依赖失败,使用空组件: {e}")
components = factory.create_components()
# 设置共享组件
registry.set_shared_components(components)
# 自动发现引擎
registry.auto_discover()
logger.info(f"AIGC 引擎系统初始化完成,共 {len(registry.list_engines())} 个引擎")
# ========== API 端点 ==========
@router.post("/execute", response_model=ExecuteResponse)
async def execute_engine(request: ExecuteRequest):
"""
执行 AIGC 引擎
支持同步和异步两种模式
- 同步模式等待执行完成返回结果
- 异步模式立即返回 task_id通过 /task/{task_id}/status 查询状态
"""
ensure_initialized()
executor = get_engine_executor()
try:
result = await executor.execute(
engine_id=request.engine,
params=request.params,
async_mode=request.async_mode,
task_id=request.task_id
)
if request.async_mode:
return ExecuteResponse(
success=result.get('success', False),
task_id=result.get('task_id'),
status=result.get('status'),
estimated_duration=result.get('estimated_duration'),
error=result.get('error'),
error_code=result.get('error_code')
)
else:
return ExecuteResponse(
success=result.get('success', False),
data=result.get('data'),
error=result.get('error'),
error_code=result.get('error_code')
)
except Exception as e:
logger.error(f"执行引擎失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/task/{task_id}/status", response_model=TaskStatusResponse)
async def get_task_status(task_id: str):
"""
获取任务状态
"""
ensure_initialized()
executor = get_engine_executor()
status = executor.get_task_status(task_id)
if not status:
raise HTTPException(status_code=404, detail="任务不存在")
return TaskStatusResponse(**status)
@router.post("/task/{task_id}/cancel")
async def cancel_task(task_id: str):
"""
取消任务
"""
ensure_initialized()
executor = get_engine_executor()
success = await executor.cancel_task(task_id)
return {
"success": success,
"task_id": task_id,
"message": "任务已取消" if success else "取消失败"
}
@router.get("/engines", response_model=EngineListResponse)
async def list_engines():
"""
列出所有可用引擎
"""
ensure_initialized()
registry = get_engine_registry()
engines = registry.list_engines()
return EngineListResponse(
engines=[EngineInfo(**e) for e in engines],
count=len(engines)
)
@router.get("/engines/{engine_id}", response_model=EngineInfo)
async def get_engine_info(engine_id: str):
"""
获取引擎详细信息
"""
ensure_initialized()
registry = get_engine_registry()
info = registry.get_engine_info(engine_id)
if not info:
raise HTTPException(status_code=404, detail="引擎不存在")
return EngineInfo(**info)
@router.get("/tasks")
async def list_tasks(
status: Optional[str] = None,
engine_id: Optional[str] = None,
limit: int = 100
):
"""
列出任务
"""
ensure_initialized()
executor = get_engine_executor()
tasks = executor.list_tasks(status=status, engine_id=engine_id, limit=limit)
return {
"tasks": tasks,
"count": len(tasks)
}
@router.post("/tasks/cleanup")
async def cleanup_tasks(max_age_seconds: int = 3600):
"""
清理已完成的任务
"""
ensure_initialized()
executor = get_engine_executor()
executor.clear_completed_tasks(max_age_seconds)
return {"success": True, "message": "清理完成"}
# ========== 配置查询 API (供 Java 端调用) ==========
def get_prompt_registry():
"""获取 PromptRegistry"""
from domain.prompt import PromptRegistry
return PromptRegistry('prompts')
@router.get("/config/styles")
async def get_styles():
"""
获取所有风格配置
Java 端调用用于前端下拉框展示
"""
registry = get_prompt_registry()
styles = []
all_prompts = registry.list_prompts()
for prompt_name in all_prompts:
if prompt_name.startswith('style/'):
try:
config = registry.get(prompt_name)
style_id = prompt_name.replace('style/', '')
styles.append({
"id": style_id,
"name": config.meta.get('style_name', style_id),
"description": config.meta.get('style_description', config.description),
"icon": config.meta.get('style_icon', ''),
"order": config.meta.get('style_order', 99)
})
except Exception as e:
logger.warning(f"加载风格 {prompt_name} 失败: {e}")
# 按 order 排序
styles.sort(key=lambda x: x['order'])
return {
"styles": styles,
"count": len(styles)
}
@router.get("/config/audiences")
async def get_audiences():
"""
获取所有人群配置
Java 端调用用于前端下拉框展示
"""
registry = get_prompt_registry()
audiences = []
all_prompts = registry.list_prompts()
for prompt_name in all_prompts:
if prompt_name.startswith('audience/'):
try:
config = registry.get(prompt_name)
audience_id = prompt_name.replace('audience/', '')
audiences.append({
"id": audience_id,
"name": config.meta.get('audience_name', audience_id),
"description": config.meta.get('audience_description', config.description),
"icon": config.meta.get('audience_icon', ''),
"order": config.meta.get('audience_order', 99)
})
except Exception as e:
logger.warning(f"加载人群 {prompt_name} 失败: {e}")
# 按 order 排序
audiences.sort(key=lambda x: x['order'])
return {
"audiences": audiences,
"count": len(audiences)
}
@router.get("/config/all")
async def get_all_config():
"""
获取所有配置 (风格 + 人群)
一次性获取减少请求次数
"""
styles_response = await get_styles()
audiences_response = await get_audiences()
return {
"styles": styles_response["styles"],
"audiences": audiences_response["audiences"],
"styles_count": styles_response["count"],
"audiences_count": audiences_response["count"]
}