#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ AIGC 统一 API 路由 提供所有 AIGC 功能的统一入口 """ import logging from typing import Dict, Any, Optional, List from fastapi import APIRouter, HTTPException, Depends from pydantic import BaseModel, Field logger = logging.getLogger(__name__) router = APIRouter(prefix="/aigc", tags=["AIGC"]) # ========== 请求/响应模型 ========== class ExecuteRequest(BaseModel): """执行请求""" engine: str = Field(..., description="引擎 ID") params: Dict[str, Any] = Field(default_factory=dict, description="引擎参数") async_mode: bool = Field(default=False, description="是否异步执行") task_id: Optional[str] = Field(default=None, description="任务 ID (可选)") callback_url: Optional[str] = Field(default=None, description="回调 URL (异步模式)") class ExecuteResponse(BaseModel): """执行响应""" success: bool task_id: Optional[str] = None status: Optional[str] = None data: Optional[Dict[str, Any]] = None error: Optional[str] = None error_code: Optional[str] = None estimated_duration: Optional[int] = None class TaskStatusResponse(BaseModel): """任务状态响应""" task_id: str engine_id: str status: str progress: int result: Optional[Dict[str, Any]] = None error: Optional[str] = None created_at: Optional[str] = None completed_at: Optional[str] = None class EngineInfo(BaseModel): """引擎信息""" id: str name: str version: str description: str param_schema: Optional[Dict[str, Any]] = None class EngineListResponse(BaseModel): """引擎列表响应""" engines: List[EngineInfo] count: int # ========== 依赖注入 ========== def get_engine_registry(): """获取引擎注册表""" from domain.aigc import EngineRegistry return EngineRegistry() def get_engine_executor(): """获取引擎执行器""" from domain.aigc import EngineExecutor return EngineExecutor() def ensure_initialized(): """确保引擎系统已初始化""" from domain.aigc import EngineRegistry from domain.aigc.shared import ComponentFactory registry = EngineRegistry() # 如果还没有引擎,进行初始化 if not registry.list_engines(): logger.info("初始化 AIGC 引擎系统...") # 创建共享组件 factory = ComponentFactory() try: from api.dependencies import get_ai_agent, get_config, get_database_service components = factory.create_components( ai_agent=get_ai_agent(), db_service=get_database_service(), config_manager=get_config() ) except Exception as e: logger.warning(f"获取依赖失败,使用空组件: {e}") components = factory.create_components() # 设置共享组件 registry.set_shared_components(components) # 自动发现引擎 registry.auto_discover() logger.info(f"AIGC 引擎系统初始化完成,共 {len(registry.list_engines())} 个引擎") # ========== API 端点 ========== @router.post("/execute", response_model=ExecuteResponse) async def execute_engine(request: ExecuteRequest): """ 执行 AIGC 引擎 支持同步和异步两种模式: - 同步模式:等待执行完成,返回结果 - 异步模式:立即返回 task_id,通过 /task/{task_id}/status 查询状态 """ ensure_initialized() executor = get_engine_executor() try: result = await executor.execute( engine_id=request.engine, params=request.params, async_mode=request.async_mode, task_id=request.task_id ) if request.async_mode: return ExecuteResponse( success=result.get('success', False), task_id=result.get('task_id'), status=result.get('status'), estimated_duration=result.get('estimated_duration'), error=result.get('error'), error_code=result.get('error_code') ) else: return ExecuteResponse( success=result.get('success', False), data=result.get('data'), error=result.get('error'), error_code=result.get('error_code') ) except Exception as e: logger.error(f"执行引擎失败: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @router.get("/task/{task_id}/status", response_model=TaskStatusResponse) async def get_task_status(task_id: str): """ 获取任务状态 """ ensure_initialized() executor = get_engine_executor() status = executor.get_task_status(task_id) if not status: raise HTTPException(status_code=404, detail="任务不存在") return TaskStatusResponse(**status) @router.post("/task/{task_id}/cancel") async def cancel_task(task_id: str): """ 取消任务 """ ensure_initialized() executor = get_engine_executor() success = await executor.cancel_task(task_id) return { "success": success, "task_id": task_id, "message": "任务已取消" if success else "取消失败" } @router.get("/engines", response_model=EngineListResponse) async def list_engines(): """ 列出所有可用引擎 """ ensure_initialized() registry = get_engine_registry() engines = registry.list_engines() return EngineListResponse( engines=[EngineInfo(**e) for e in engines], count=len(engines) ) @router.get("/engines/{engine_id}", response_model=EngineInfo) async def get_engine_info(engine_id: str): """ 获取引擎详细信息 """ ensure_initialized() registry = get_engine_registry() info = registry.get_engine_info(engine_id) if not info: raise HTTPException(status_code=404, detail="引擎不存在") return EngineInfo(**info) @router.get("/tasks") async def list_tasks( status: Optional[str] = None, engine_id: Optional[str] = None, limit: int = 100 ): """ 列出任务 """ ensure_initialized() executor = get_engine_executor() tasks = executor.list_tasks(status=status, engine_id=engine_id, limit=limit) return { "tasks": tasks, "count": len(tasks) } @router.post("/tasks/cleanup") async def cleanup_tasks(max_age_seconds: int = 3600): """ 清理已完成的任务 """ ensure_initialized() executor = get_engine_executor() executor.clear_completed_tasks(max_age_seconds) return {"success": True, "message": "清理完成"} # ========== 配置查询 API (供 Java 端调用) ========== def get_prompt_registry(): """获取 PromptRegistry""" from domain.prompt import PromptRegistry return PromptRegistry('prompts') @router.get("/config/styles") async def get_styles(): """ 获取所有风格配置 供 Java 端调用,用于前端下拉框展示 """ registry = get_prompt_registry() styles = [] all_prompts = registry.list_prompts() for prompt_name in all_prompts: if prompt_name.startswith('style/'): try: config = registry.get(prompt_name) style_id = prompt_name.replace('style/', '') styles.append({ "id": style_id, "name": config.meta.get('style_name', style_id), "description": config.meta.get('style_description', config.description), "icon": config.meta.get('style_icon', ''), "order": config.meta.get('style_order', 99) }) except Exception as e: logger.warning(f"加载风格 {prompt_name} 失败: {e}") # 按 order 排序 styles.sort(key=lambda x: x['order']) return { "styles": styles, "count": len(styles) } @router.get("/config/audiences") async def get_audiences(): """ 获取所有人群配置 供 Java 端调用,用于前端下拉框展示 """ registry = get_prompt_registry() audiences = [] all_prompts = registry.list_prompts() for prompt_name in all_prompts: if prompt_name.startswith('audience/'): try: config = registry.get(prompt_name) audience_id = prompt_name.replace('audience/', '') audiences.append({ "id": audience_id, "name": config.meta.get('audience_name', audience_id), "description": config.meta.get('audience_description', config.description), "icon": config.meta.get('audience_icon', ''), "order": config.meta.get('audience_order', 99) }) except Exception as e: logger.warning(f"加载人群 {prompt_name} 失败: {e}") # 按 order 排序 audiences.sort(key=lambda x: x['order']) return { "audiences": audiences, "count": len(audiences) } @router.get("/config/all") async def get_all_config(): """ 获取所有配置 (风格 + 人群) 一次性获取,减少请求次数 """ styles_response = await get_styles() audiences_response = await get_audiences() return { "styles": styles_response["styles"], "audiences": audiences_response["audiences"], "styles_count": styles_response["count"], "audiences_count": audiences_response["count"] }