254 lines
6.7 KiB
Python
254 lines
6.7 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
AIGC 统一 API 路由
|
||
提供所有 AIGC 功能的统一入口
|
||
"""
|
||
|
||
import logging
|
||
from typing import Dict, Any, Optional, List
|
||
from fastapi import APIRouter, HTTPException, Depends
|
||
from pydantic import BaseModel, Field
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
router = APIRouter(prefix="/aigc", tags=["AIGC"])
|
||
|
||
|
||
# ========== 请求/响应模型 ==========
|
||
|
||
class ExecuteRequest(BaseModel):
|
||
"""执行请求"""
|
||
engine: str = Field(..., description="引擎 ID")
|
||
params: Dict[str, Any] = Field(default_factory=dict, description="引擎参数")
|
||
async_mode: bool = Field(default=False, description="是否异步执行")
|
||
task_id: Optional[str] = Field(default=None, description="任务 ID (可选)")
|
||
callback_url: Optional[str] = Field(default=None, description="回调 URL (异步模式)")
|
||
|
||
|
||
class ExecuteResponse(BaseModel):
|
||
"""执行响应"""
|
||
success: bool
|
||
task_id: Optional[str] = None
|
||
status: Optional[str] = None
|
||
data: Optional[Dict[str, Any]] = None
|
||
error: Optional[str] = None
|
||
error_code: Optional[str] = None
|
||
estimated_duration: Optional[int] = None
|
||
|
||
|
||
class TaskStatusResponse(BaseModel):
|
||
"""任务状态响应"""
|
||
task_id: str
|
||
engine_id: str
|
||
status: str
|
||
progress: int
|
||
result: Optional[Dict[str, Any]] = None
|
||
error: Optional[str] = None
|
||
created_at: Optional[str] = None
|
||
completed_at: Optional[str] = None
|
||
|
||
|
||
class EngineInfo(BaseModel):
|
||
"""引擎信息"""
|
||
id: str
|
||
name: str
|
||
version: str
|
||
description: str
|
||
param_schema: Optional[Dict[str, Any]] = None
|
||
|
||
|
||
class EngineListResponse(BaseModel):
|
||
"""引擎列表响应"""
|
||
engines: List[EngineInfo]
|
||
count: int
|
||
|
||
|
||
# ========== 依赖注入 ==========
|
||
|
||
def get_engine_registry():
|
||
"""获取引擎注册表"""
|
||
from domain.aigc import EngineRegistry
|
||
return EngineRegistry()
|
||
|
||
|
||
def get_engine_executor():
|
||
"""获取引擎执行器"""
|
||
from domain.aigc import EngineExecutor
|
||
return EngineExecutor()
|
||
|
||
|
||
def ensure_initialized():
|
||
"""确保引擎系统已初始化"""
|
||
from domain.aigc import EngineRegistry
|
||
from domain.aigc.shared import ComponentFactory
|
||
|
||
registry = EngineRegistry()
|
||
|
||
# 如果还没有引擎,进行初始化
|
||
if not registry.list_engines():
|
||
logger.info("初始化 AIGC 引擎系统...")
|
||
|
||
# 创建共享组件
|
||
factory = ComponentFactory()
|
||
|
||
try:
|
||
from api.dependencies import get_ai_agent, get_config, get_database_service
|
||
components = factory.create_components(
|
||
ai_agent=get_ai_agent(),
|
||
db_service=get_database_service(),
|
||
config_manager=get_config()
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"获取依赖失败,使用空组件: {e}")
|
||
components = factory.create_components()
|
||
|
||
# 设置共享组件
|
||
registry.set_shared_components(components)
|
||
|
||
# 自动发现引擎
|
||
registry.auto_discover()
|
||
|
||
logger.info(f"AIGC 引擎系统初始化完成,共 {len(registry.list_engines())} 个引擎")
|
||
|
||
|
||
# ========== API 端点 ==========
|
||
|
||
@router.post("/execute", response_model=ExecuteResponse)
|
||
async def execute_engine(request: ExecuteRequest):
|
||
"""
|
||
执行 AIGC 引擎
|
||
|
||
支持同步和异步两种模式:
|
||
- 同步模式:等待执行完成,返回结果
|
||
- 异步模式:立即返回 task_id,通过 /task/{task_id}/status 查询状态
|
||
"""
|
||
ensure_initialized()
|
||
executor = get_engine_executor()
|
||
|
||
try:
|
||
result = await executor.execute(
|
||
engine_id=request.engine,
|
||
params=request.params,
|
||
async_mode=request.async_mode,
|
||
task_id=request.task_id
|
||
)
|
||
|
||
if request.async_mode:
|
||
return ExecuteResponse(
|
||
success=result.get('success', False),
|
||
task_id=result.get('task_id'),
|
||
status=result.get('status'),
|
||
estimated_duration=result.get('estimated_duration'),
|
||
error=result.get('error'),
|
||
error_code=result.get('error_code')
|
||
)
|
||
else:
|
||
return ExecuteResponse(
|
||
success=result.get('success', False),
|
||
data=result.get('data'),
|
||
error=result.get('error'),
|
||
error_code=result.get('error_code')
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"执行引擎失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.get("/task/{task_id}/status", response_model=TaskStatusResponse)
|
||
async def get_task_status(task_id: str):
|
||
"""
|
||
获取任务状态
|
||
"""
|
||
ensure_initialized()
|
||
executor = get_engine_executor()
|
||
|
||
status = executor.get_task_status(task_id)
|
||
if not status:
|
||
raise HTTPException(status_code=404, detail="任务不存在")
|
||
|
||
return TaskStatusResponse(**status)
|
||
|
||
|
||
@router.post("/task/{task_id}/cancel")
|
||
async def cancel_task(task_id: str):
|
||
"""
|
||
取消任务
|
||
"""
|
||
ensure_initialized()
|
||
executor = get_engine_executor()
|
||
|
||
success = await executor.cancel_task(task_id)
|
||
|
||
return {
|
||
"success": success,
|
||
"task_id": task_id,
|
||
"message": "任务已取消" if success else "取消失败"
|
||
}
|
||
|
||
|
||
@router.get("/engines", response_model=EngineListResponse)
|
||
async def list_engines():
|
||
"""
|
||
列出所有可用引擎
|
||
"""
|
||
ensure_initialized()
|
||
registry = get_engine_registry()
|
||
|
||
engines = registry.list_engines()
|
||
|
||
return EngineListResponse(
|
||
engines=[EngineInfo(**e) for e in engines],
|
||
count=len(engines)
|
||
)
|
||
|
||
|
||
@router.get("/engines/{engine_id}", response_model=EngineInfo)
|
||
async def get_engine_info(engine_id: str):
|
||
"""
|
||
获取引擎详细信息
|
||
"""
|
||
ensure_initialized()
|
||
registry = get_engine_registry()
|
||
|
||
info = registry.get_engine_info(engine_id)
|
||
if not info:
|
||
raise HTTPException(status_code=404, detail="引擎不存在")
|
||
|
||
return EngineInfo(**info)
|
||
|
||
|
||
@router.get("/tasks")
|
||
async def list_tasks(
|
||
status: Optional[str] = None,
|
||
engine_id: Optional[str] = None,
|
||
limit: int = 100
|
||
):
|
||
"""
|
||
列出任务
|
||
"""
|
||
ensure_initialized()
|
||
executor = get_engine_executor()
|
||
|
||
tasks = executor.list_tasks(status=status, engine_id=engine_id, limit=limit)
|
||
|
||
return {
|
||
"tasks": tasks,
|
||
"count": len(tasks)
|
||
}
|
||
|
||
|
||
@router.post("/tasks/cleanup")
|
||
async def cleanup_tasks(max_age_seconds: int = 3600):
|
||
"""
|
||
清理已完成的任务
|
||
"""
|
||
ensure_initialized()
|
||
executor = get_engine_executor()
|
||
|
||
executor.clear_completed_tasks(max_age_seconds)
|
||
|
||
return {"success": True, "message": "清理完成"}
|