#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 引擎注册表 负责引擎的自动发现、注册和管理 """ import logging import importlib import pkgutil from typing import Dict, Type, Optional, List, Any from pathlib import Path from .engines.base import BaseAIGCEngine logger = logging.getLogger(__name__) class EngineRegistry: """ 引擎注册表 支持: - 自动发现 engines 目录下的所有引擎 - 手动注册引擎 - 按 ID 获取引擎实例 - 列出所有可用引擎 """ _instance: Optional['EngineRegistry'] = None def __new__(cls): """单例模式""" if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._initialized = False return cls._instance def __init__(self): if self._initialized: return self._engines: Dict[str, Type[BaseAIGCEngine]] = {} self._instances: Dict[str, BaseAIGCEngine] = {} self._shared_components: Optional[Dict[str, Any]] = None self._initialized = True logger.info("引擎注册表初始化") def set_shared_components(self, components: Dict[str, Any]): """ 设置共享组件 Args: components: 共享组件字典 (db, llm, prompt, image, storage, paths, config) """ self._shared_components = components # 为已实例化的引擎注入组件 for engine in self._instances.values(): engine.initialize(components) logger.info("共享组件已设置") def auto_discover(self, package_path: str = "domain.aigc.engines"): """ 自动发现并注册引擎 Args: package_path: 引擎包路径 """ logger.info(f"开始自动发现引擎: {package_path}") try: package = importlib.import_module(package_path) package_dir = Path(package.__file__).parent # 遍历包中的所有模块 for _, module_name, is_pkg in pkgutil.iter_modules([str(package_dir)]): # 跳过 base 模块和私有模块 if module_name.startswith('_') or module_name == 'base': continue # 跳过子包(如 _future) if is_pkg: continue try: module = importlib.import_module(f"{package_path}.{module_name}") # 查找模块中的引擎类 for attr_name in dir(module): attr = getattr(module, attr_name) # 检查是否是 BaseAIGCEngine 的子类 if (isinstance(attr, type) and issubclass(attr, BaseAIGCEngine) and attr is not BaseAIGCEngine and hasattr(attr, 'engine_id') and attr.engine_id): self.register(attr) except Exception as e: logger.warning(f"加载模块 {module_name} 失败: {e}") except Exception as e: logger.error(f"自动发现引擎失败: {e}") logger.info(f"自动发现完成,共注册 {len(self._engines)} 个引擎") def register(self, engine_class: Type[BaseAIGCEngine]): """ 注册引擎 Args: engine_class: 引擎类 """ engine_id = engine_class.engine_id if not engine_id: logger.warning(f"引擎类 {engine_class.__name__} 未定义 engine_id,跳过注册") return if engine_id in self._engines: logger.warning(f"引擎 {engine_id} 已存在,将被覆盖") self._engines[engine_id] = engine_class logger.info(f"注册引擎: {engine_id} ({engine_class.engine_name})") def get(self, engine_id: str) -> Optional[BaseAIGCEngine]: """ 获取引擎实例 Args: engine_id: 引擎 ID Returns: 引擎实例,不存在返回 None """ # 检查是否已实例化 if engine_id in self._instances: return self._instances[engine_id] # 检查是否已注册 if engine_id not in self._engines: logger.warning(f"引擎 {engine_id} 未注册") return None # 创建实例 try: engine_class = self._engines[engine_id] engine = engine_class() # 注入共享组件 if self._shared_components: engine.initialize(self._shared_components) self._instances[engine_id] = engine return engine except Exception as e: logger.error(f"创建引擎 {engine_id} 实例失败: {e}") return None def has(self, engine_id: str) -> bool: """ 检查引擎是否存在 Args: engine_id: 引擎 ID Returns: 是否存在 """ return engine_id in self._engines def list_engines(self) -> List[Dict[str, Any]]: """ 列出所有可用引擎 Returns: 引擎信息列表 """ engines = [] for engine_id, engine_class in self._engines.items(): engines.append({ "id": engine_id, "name": engine_class.engine_name, "version": engine_class.version, "description": engine_class.description, }) return engines def get_engine_info(self, engine_id: str) -> Optional[Dict[str, Any]]: """ 获取引擎详细信息 Args: engine_id: 引擎 ID Returns: 引擎信息,包含参数 Schema """ engine = self.get(engine_id) if engine: return engine.get_info() return None def clear(self): """清空注册表(主要用于测试)""" self._engines.clear() self._instances.clear() logger.info("引擎注册表已清空")