352 lines
9.6 KiB
Python
352 lines
9.6 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
AIGC 统一 API 路由
|
||
提供所有 AIGC 功能的统一入口
|
||
"""
|
||
|
||
import logging
|
||
from typing import Dict, Any, Optional, List
|
||
from fastapi import APIRouter, HTTPException, Depends
|
||
from pydantic import BaseModel, Field
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
router = APIRouter(prefix="/aigc", tags=["AIGC"])
|
||
|
||
|
||
# ========== 请求/响应模型 ==========
|
||
|
||
class ExecuteRequest(BaseModel):
|
||
"""执行请求"""
|
||
engine: str = Field(..., description="引擎 ID")
|
||
params: Dict[str, Any] = Field(default_factory=dict, description="引擎参数")
|
||
async_mode: bool = Field(default=False, description="是否异步执行")
|
||
task_id: Optional[str] = Field(default=None, description="任务 ID (可选)")
|
||
callback_url: Optional[str] = Field(default=None, description="回调 URL (异步模式)")
|
||
|
||
|
||
class ExecuteResponse(BaseModel):
|
||
"""执行响应"""
|
||
success: bool
|
||
task_id: Optional[str] = None
|
||
status: Optional[str] = None
|
||
data: Optional[Dict[str, Any]] = None
|
||
error: Optional[str] = None
|
||
error_code: Optional[str] = None
|
||
estimated_duration: Optional[int] = None
|
||
|
||
|
||
class TaskStatusResponse(BaseModel):
|
||
"""任务状态响应"""
|
||
task_id: str
|
||
engine_id: str
|
||
status: str
|
||
progress: int
|
||
result: Optional[Dict[str, Any]] = None
|
||
error: Optional[str] = None
|
||
created_at: Optional[str] = None
|
||
completed_at: Optional[str] = None
|
||
|
||
|
||
class EngineInfo(BaseModel):
|
||
"""引擎信息"""
|
||
id: str
|
||
name: str
|
||
version: str
|
||
description: str
|
||
param_schema: Optional[Dict[str, Any]] = None
|
||
|
||
|
||
class EngineListResponse(BaseModel):
|
||
"""引擎列表响应"""
|
||
engines: List[EngineInfo]
|
||
count: int
|
||
|
||
|
||
# ========== 依赖注入 ==========
|
||
|
||
def get_engine_registry():
|
||
"""获取引擎注册表"""
|
||
from domain.aigc import EngineRegistry
|
||
return EngineRegistry()
|
||
|
||
|
||
def get_engine_executor():
|
||
"""获取引擎执行器"""
|
||
from domain.aigc import EngineExecutor
|
||
return EngineExecutor()
|
||
|
||
|
||
def ensure_initialized():
|
||
"""确保引擎系统已初始化"""
|
||
from domain.aigc import EngineRegistry
|
||
from domain.aigc.shared import ComponentFactory
|
||
|
||
registry = EngineRegistry()
|
||
|
||
# 如果还没有引擎,进行初始化
|
||
if not registry.list_engines():
|
||
logger.info("初始化 AIGC 引擎系统...")
|
||
|
||
# 创建共享组件
|
||
factory = ComponentFactory()
|
||
|
||
try:
|
||
from api.dependencies import get_ai_agent, get_config, get_database_service
|
||
components = factory.create_components(
|
||
ai_agent=get_ai_agent(),
|
||
db_service=get_database_service(),
|
||
config_manager=get_config()
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"获取依赖失败,使用空组件: {e}")
|
||
components = factory.create_components()
|
||
|
||
# 设置共享组件
|
||
registry.set_shared_components(components)
|
||
|
||
# 自动发现引擎
|
||
registry.auto_discover()
|
||
|
||
logger.info(f"AIGC 引擎系统初始化完成,共 {len(registry.list_engines())} 个引擎")
|
||
|
||
|
||
# ========== API 端点 ==========
|
||
|
||
@router.post("/execute", response_model=ExecuteResponse)
|
||
async def execute_engine(request: ExecuteRequest):
|
||
"""
|
||
执行 AIGC 引擎
|
||
|
||
支持同步和异步两种模式:
|
||
- 同步模式:等待执行完成,返回结果
|
||
- 异步模式:立即返回 task_id,通过 /task/{task_id}/status 查询状态
|
||
"""
|
||
ensure_initialized()
|
||
executor = get_engine_executor()
|
||
|
||
try:
|
||
result = await executor.execute(
|
||
engine_id=request.engine,
|
||
params=request.params,
|
||
async_mode=request.async_mode,
|
||
task_id=request.task_id
|
||
)
|
||
|
||
if request.async_mode:
|
||
return ExecuteResponse(
|
||
success=result.get('success', False),
|
||
task_id=result.get('task_id'),
|
||
status=result.get('status'),
|
||
estimated_duration=result.get('estimated_duration'),
|
||
error=result.get('error'),
|
||
error_code=result.get('error_code')
|
||
)
|
||
else:
|
||
return ExecuteResponse(
|
||
success=result.get('success', False),
|
||
data=result.get('data'),
|
||
error=result.get('error'),
|
||
error_code=result.get('error_code')
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"执行引擎失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.get("/task/{task_id}/status", response_model=TaskStatusResponse)
|
||
async def get_task_status(task_id: str):
|
||
"""
|
||
获取任务状态
|
||
"""
|
||
ensure_initialized()
|
||
executor = get_engine_executor()
|
||
|
||
status = executor.get_task_status(task_id)
|
||
if not status:
|
||
raise HTTPException(status_code=404, detail="任务不存在")
|
||
|
||
return TaskStatusResponse(**status)
|
||
|
||
|
||
@router.post("/task/{task_id}/cancel")
|
||
async def cancel_task(task_id: str):
|
||
"""
|
||
取消任务
|
||
"""
|
||
ensure_initialized()
|
||
executor = get_engine_executor()
|
||
|
||
success = await executor.cancel_task(task_id)
|
||
|
||
return {
|
||
"success": success,
|
||
"task_id": task_id,
|
||
"message": "任务已取消" if success else "取消失败"
|
||
}
|
||
|
||
|
||
@router.get("/engines", response_model=EngineListResponse)
|
||
async def list_engines():
|
||
"""
|
||
列出所有可用引擎
|
||
"""
|
||
ensure_initialized()
|
||
registry = get_engine_registry()
|
||
|
||
engines = registry.list_engines()
|
||
|
||
return EngineListResponse(
|
||
engines=[EngineInfo(**e) for e in engines],
|
||
count=len(engines)
|
||
)
|
||
|
||
|
||
@router.get("/engines/{engine_id}", response_model=EngineInfo)
|
||
async def get_engine_info(engine_id: str):
|
||
"""
|
||
获取引擎详细信息
|
||
"""
|
||
ensure_initialized()
|
||
registry = get_engine_registry()
|
||
|
||
info = registry.get_engine_info(engine_id)
|
||
if not info:
|
||
raise HTTPException(status_code=404, detail="引擎不存在")
|
||
|
||
return EngineInfo(**info)
|
||
|
||
|
||
@router.get("/tasks")
|
||
async def list_tasks(
|
||
status: Optional[str] = None,
|
||
engine_id: Optional[str] = None,
|
||
limit: int = 100
|
||
):
|
||
"""
|
||
列出任务
|
||
"""
|
||
ensure_initialized()
|
||
executor = get_engine_executor()
|
||
|
||
tasks = executor.list_tasks(status=status, engine_id=engine_id, limit=limit)
|
||
|
||
return {
|
||
"tasks": tasks,
|
||
"count": len(tasks)
|
||
}
|
||
|
||
|
||
@router.post("/tasks/cleanup")
|
||
async def cleanup_tasks(max_age_seconds: int = 3600):
|
||
"""
|
||
清理已完成的任务
|
||
"""
|
||
ensure_initialized()
|
||
executor = get_engine_executor()
|
||
|
||
executor.clear_completed_tasks(max_age_seconds)
|
||
|
||
return {"success": True, "message": "清理完成"}
|
||
|
||
|
||
# ========== 配置查询 API (供 Java 端调用) ==========
|
||
|
||
def get_prompt_registry():
|
||
"""获取 PromptRegistry"""
|
||
from domain.prompt import PromptRegistry
|
||
return PromptRegistry('prompts')
|
||
|
||
|
||
@router.get("/config/styles")
|
||
async def get_styles():
|
||
"""
|
||
获取所有风格配置
|
||
|
||
供 Java 端调用,用于前端下拉框展示
|
||
"""
|
||
registry = get_prompt_registry()
|
||
|
||
styles = []
|
||
all_prompts = registry.list_prompts()
|
||
|
||
for prompt_name in all_prompts:
|
||
if prompt_name.startswith('style/'):
|
||
try:
|
||
config = registry.get(prompt_name)
|
||
style_id = prompt_name.replace('style/', '')
|
||
styles.append({
|
||
"id": style_id,
|
||
"name": config.meta.get('style_name', style_id),
|
||
"description": config.meta.get('style_description', config.description),
|
||
"icon": config.meta.get('style_icon', ''),
|
||
"order": config.meta.get('style_order', 99)
|
||
})
|
||
except Exception as e:
|
||
logger.warning(f"加载风格 {prompt_name} 失败: {e}")
|
||
|
||
# 按 order 排序
|
||
styles.sort(key=lambda x: x['order'])
|
||
|
||
return {
|
||
"styles": styles,
|
||
"count": len(styles)
|
||
}
|
||
|
||
|
||
@router.get("/config/audiences")
|
||
async def get_audiences():
|
||
"""
|
||
获取所有人群配置
|
||
|
||
供 Java 端调用,用于前端下拉框展示
|
||
"""
|
||
registry = get_prompt_registry()
|
||
|
||
audiences = []
|
||
all_prompts = registry.list_prompts()
|
||
|
||
for prompt_name in all_prompts:
|
||
if prompt_name.startswith('audience/'):
|
||
try:
|
||
config = registry.get(prompt_name)
|
||
audience_id = prompt_name.replace('audience/', '')
|
||
audiences.append({
|
||
"id": audience_id,
|
||
"name": config.meta.get('audience_name', audience_id),
|
||
"description": config.meta.get('audience_description', config.description),
|
||
"icon": config.meta.get('audience_icon', ''),
|
||
"order": config.meta.get('audience_order', 99)
|
||
})
|
||
except Exception as e:
|
||
logger.warning(f"加载人群 {prompt_name} 失败: {e}")
|
||
|
||
# 按 order 排序
|
||
audiences.sort(key=lambda x: x['order'])
|
||
|
||
return {
|
||
"audiences": audiences,
|
||
"count": len(audiences)
|
||
}
|
||
|
||
|
||
@router.get("/config/all")
|
||
async def get_all_config():
|
||
"""
|
||
获取所有配置 (风格 + 人群)
|
||
|
||
一次性获取,减少请求次数
|
||
"""
|
||
styles_response = await get_styles()
|
||
audiences_response = await get_audiences()
|
||
|
||
return {
|
||
"styles": styles_response["styles"],
|
||
"audiences": audiences_response["audiences"],
|
||
"styles_count": styles_response["count"],
|
||
"audiences_count": audiences_response["count"]
|
||
}
|