352 lines
9.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
AIGC 统一 API 路由
提供所有 AIGC 功能的统一入口
"""
import logging
from typing import Dict, Any, Optional, List
from fastapi import APIRouter, HTTPException, Depends
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/aigc", tags=["AIGC"])
# ========== 请求/响应模型 ==========
class ExecuteRequest(BaseModel):
"""执行请求"""
engine: str = Field(..., description="引擎 ID")
params: Dict[str, Any] = Field(default_factory=dict, description="引擎参数")
async_mode: bool = Field(default=False, description="是否异步执行")
task_id: Optional[str] = Field(default=None, description="任务 ID (可选)")
callback_url: Optional[str] = Field(default=None, description="回调 URL (异步模式)")
class ExecuteResponse(BaseModel):
"""执行响应"""
success: bool
task_id: Optional[str] = None
status: Optional[str] = None
data: Optional[Dict[str, Any]] = None
error: Optional[str] = None
error_code: Optional[str] = None
estimated_duration: Optional[int] = None
class TaskStatusResponse(BaseModel):
"""任务状态响应"""
task_id: str
engine_id: str
status: str
progress: int
result: Optional[Dict[str, Any]] = None
error: Optional[str] = None
created_at: Optional[str] = None
completed_at: Optional[str] = None
class EngineInfo(BaseModel):
"""引擎信息"""
id: str
name: str
version: str
description: str
param_schema: Optional[Dict[str, Any]] = None
class EngineListResponse(BaseModel):
"""引擎列表响应"""
engines: List[EngineInfo]
count: int
# ========== 依赖注入 ==========
def get_engine_registry():
"""获取引擎注册表"""
from domain.aigc import EngineRegistry
return EngineRegistry()
def get_engine_executor():
"""获取引擎执行器"""
from domain.aigc import EngineExecutor
return EngineExecutor()
def ensure_initialized():
"""确保引擎系统已初始化"""
from domain.aigc import EngineRegistry
from domain.aigc.shared import ComponentFactory
registry = EngineRegistry()
# 如果还没有引擎,进行初始化
if not registry.list_engines():
logger.info("初始化 AIGC 引擎系统...")
# 创建共享组件
factory = ComponentFactory()
try:
from api.dependencies import get_ai_agent, get_config, get_database_service
components = factory.create_components(
ai_agent=get_ai_agent(),
db_service=get_database_service(),
config_manager=get_config()
)
except Exception as e:
logger.warning(f"获取依赖失败,使用空组件: {e}")
components = factory.create_components()
# 设置共享组件
registry.set_shared_components(components)
# 自动发现引擎
registry.auto_discover()
logger.info(f"AIGC 引擎系统初始化完成,共 {len(registry.list_engines())} 个引擎")
# ========== API 端点 ==========
@router.post("/execute", response_model=ExecuteResponse)
async def execute_engine(request: ExecuteRequest):
"""
执行 AIGC 引擎
支持同步和异步两种模式:
- 同步模式:等待执行完成,返回结果
- 异步模式:立即返回 task_id通过 /task/{task_id}/status 查询状态
"""
ensure_initialized()
executor = get_engine_executor()
try:
result = await executor.execute(
engine_id=request.engine,
params=request.params,
async_mode=request.async_mode,
task_id=request.task_id
)
if request.async_mode:
return ExecuteResponse(
success=result.get('success', False),
task_id=result.get('task_id'),
status=result.get('status'),
estimated_duration=result.get('estimated_duration'),
error=result.get('error'),
error_code=result.get('error_code')
)
else:
return ExecuteResponse(
success=result.get('success', False),
data=result.get('data'),
error=result.get('error'),
error_code=result.get('error_code')
)
except Exception as e:
logger.error(f"执行引擎失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/task/{task_id}/status", response_model=TaskStatusResponse)
async def get_task_status(task_id: str):
"""
获取任务状态
"""
ensure_initialized()
executor = get_engine_executor()
status = executor.get_task_status(task_id)
if not status:
raise HTTPException(status_code=404, detail="任务不存在")
return TaskStatusResponse(**status)
@router.post("/task/{task_id}/cancel")
async def cancel_task(task_id: str):
"""
取消任务
"""
ensure_initialized()
executor = get_engine_executor()
success = await executor.cancel_task(task_id)
return {
"success": success,
"task_id": task_id,
"message": "任务已取消" if success else "取消失败"
}
@router.get("/engines", response_model=EngineListResponse)
async def list_engines():
"""
列出所有可用引擎
"""
ensure_initialized()
registry = get_engine_registry()
engines = registry.list_engines()
return EngineListResponse(
engines=[EngineInfo(**e) for e in engines],
count=len(engines)
)
@router.get("/engines/{engine_id}", response_model=EngineInfo)
async def get_engine_info(engine_id: str):
"""
获取引擎详细信息
"""
ensure_initialized()
registry = get_engine_registry()
info = registry.get_engine_info(engine_id)
if not info:
raise HTTPException(status_code=404, detail="引擎不存在")
return EngineInfo(**info)
@router.get("/tasks")
async def list_tasks(
status: Optional[str] = None,
engine_id: Optional[str] = None,
limit: int = 100
):
"""
列出任务
"""
ensure_initialized()
executor = get_engine_executor()
tasks = executor.list_tasks(status=status, engine_id=engine_id, limit=limit)
return {
"tasks": tasks,
"count": len(tasks)
}
@router.post("/tasks/cleanup")
async def cleanup_tasks(max_age_seconds: int = 3600):
"""
清理已完成的任务
"""
ensure_initialized()
executor = get_engine_executor()
executor.clear_completed_tasks(max_age_seconds)
return {"success": True, "message": "清理完成"}
# ========== 配置查询 API (供 Java 端调用) ==========
def get_prompt_registry():
"""获取 PromptRegistry"""
from domain.prompt import PromptRegistry
return PromptRegistry('prompts')
@router.get("/config/styles")
async def get_styles():
"""
获取所有风格配置
供 Java 端调用,用于前端下拉框展示
"""
registry = get_prompt_registry()
styles = []
all_prompts = registry.list_prompts()
for prompt_name in all_prompts:
if prompt_name.startswith('style/'):
try:
config = registry.get(prompt_name)
style_id = prompt_name.replace('style/', '')
styles.append({
"id": style_id,
"name": config.meta.get('style_name', style_id),
"description": config.meta.get('style_description', config.description),
"icon": config.meta.get('style_icon', ''),
"order": config.meta.get('style_order', 99)
})
except Exception as e:
logger.warning(f"加载风格 {prompt_name} 失败: {e}")
# 按 order 排序
styles.sort(key=lambda x: x['order'])
return {
"styles": styles,
"count": len(styles)
}
@router.get("/config/audiences")
async def get_audiences():
"""
获取所有人群配置
供 Java 端调用,用于前端下拉框展示
"""
registry = get_prompt_registry()
audiences = []
all_prompts = registry.list_prompts()
for prompt_name in all_prompts:
if prompt_name.startswith('audience/'):
try:
config = registry.get(prompt_name)
audience_id = prompt_name.replace('audience/', '')
audiences.append({
"id": audience_id,
"name": config.meta.get('audience_name', audience_id),
"description": config.meta.get('audience_description', config.description),
"icon": config.meta.get('audience_icon', ''),
"order": config.meta.get('audience_order', 99)
})
except Exception as e:
logger.warning(f"加载人群 {prompt_name} 失败: {e}")
# 按 order 排序
audiences.sort(key=lambda x: x['order'])
return {
"audiences": audiences,
"count": len(audiences)
}
@router.get("/config/all")
async def get_all_config():
"""
获取所有配置 (风格 + 人群)
一次性获取,减少请求次数
"""
styles_response = await get_styles()
audiences_response = await get_audiences()
return {
"styles": styles_response["styles"],
"audiences": audiences_response["audiences"],
"styles_count": styles_response["count"],
"audiences_count": audiences_response["count"]
}