216 lines
6.4 KiB
Python
216 lines
6.4 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
引擎注册表
|
||
负责引擎的自动发现、注册和管理
|
||
"""
|
||
|
||
import logging
|
||
import importlib
|
||
import pkgutil
|
||
from typing import Dict, Type, Optional, List, Any
|
||
from pathlib import Path
|
||
|
||
from .engines.base import BaseAIGCEngine
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class EngineRegistry:
|
||
"""
|
||
引擎注册表
|
||
|
||
支持:
|
||
- 自动发现 engines 目录下的所有引擎
|
||
- 手动注册引擎
|
||
- 按 ID 获取引擎实例
|
||
- 列出所有可用引擎
|
||
"""
|
||
|
||
_instance: Optional['EngineRegistry'] = None
|
||
|
||
def __new__(cls):
|
||
"""单例模式"""
|
||
if cls._instance is None:
|
||
cls._instance = super().__new__(cls)
|
||
cls._instance._initialized = False
|
||
return cls._instance
|
||
|
||
def __init__(self):
|
||
if self._initialized:
|
||
return
|
||
|
||
self._engines: Dict[str, Type[BaseAIGCEngine]] = {}
|
||
self._instances: Dict[str, BaseAIGCEngine] = {}
|
||
self._shared_components: Optional[Dict[str, Any]] = None
|
||
self._initialized = True
|
||
|
||
logger.info("引擎注册表初始化")
|
||
|
||
def set_shared_components(self, components: Dict[str, Any]):
|
||
"""
|
||
设置共享组件
|
||
|
||
Args:
|
||
components: 共享组件字典 (db, llm, prompt, image, storage, paths, config)
|
||
"""
|
||
self._shared_components = components
|
||
|
||
# 为已实例化的引擎注入组件
|
||
for engine in self._instances.values():
|
||
engine.initialize(components)
|
||
|
||
logger.info("共享组件已设置")
|
||
|
||
def auto_discover(self, package_path: str = "domain.aigc.engines"):
|
||
"""
|
||
自动发现并注册引擎
|
||
|
||
Args:
|
||
package_path: 引擎包路径
|
||
"""
|
||
logger.info(f"开始自动发现引擎: {package_path}")
|
||
|
||
try:
|
||
package = importlib.import_module(package_path)
|
||
package_dir = Path(package.__file__).parent
|
||
|
||
# 遍历包中的所有模块
|
||
for _, module_name, is_pkg in pkgutil.iter_modules([str(package_dir)]):
|
||
# 跳过 base 模块和私有模块
|
||
if module_name.startswith('_') or module_name == 'base':
|
||
continue
|
||
|
||
# 跳过子包(如 _future)
|
||
if is_pkg:
|
||
continue
|
||
|
||
try:
|
||
module = importlib.import_module(f"{package_path}.{module_name}")
|
||
|
||
# 查找模块中的引擎类
|
||
for attr_name in dir(module):
|
||
attr = getattr(module, attr_name)
|
||
|
||
# 检查是否是 BaseAIGCEngine 的子类
|
||
if (isinstance(attr, type) and
|
||
issubclass(attr, BaseAIGCEngine) and
|
||
attr is not BaseAIGCEngine and
|
||
hasattr(attr, 'engine_id') and
|
||
attr.engine_id):
|
||
|
||
self.register(attr)
|
||
|
||
except Exception as e:
|
||
logger.warning(f"加载模块 {module_name} 失败: {e}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"自动发现引擎失败: {e}")
|
||
|
||
logger.info(f"自动发现完成,共注册 {len(self._engines)} 个引擎")
|
||
|
||
def register(self, engine_class: Type[BaseAIGCEngine]):
|
||
"""
|
||
注册引擎
|
||
|
||
Args:
|
||
engine_class: 引擎类
|
||
"""
|
||
engine_id = engine_class.engine_id
|
||
|
||
if not engine_id:
|
||
logger.warning(f"引擎类 {engine_class.__name__} 未定义 engine_id,跳过注册")
|
||
return
|
||
|
||
if engine_id in self._engines:
|
||
logger.warning(f"引擎 {engine_id} 已存在,将被覆盖")
|
||
|
||
self._engines[engine_id] = engine_class
|
||
logger.info(f"注册引擎: {engine_id} ({engine_class.engine_name})")
|
||
|
||
def get(self, engine_id: str) -> Optional[BaseAIGCEngine]:
|
||
"""
|
||
获取引擎实例
|
||
|
||
Args:
|
||
engine_id: 引擎 ID
|
||
|
||
Returns:
|
||
引擎实例,不存在返回 None
|
||
"""
|
||
# 检查是否已实例化
|
||
if engine_id in self._instances:
|
||
return self._instances[engine_id]
|
||
|
||
# 检查是否已注册
|
||
if engine_id not in self._engines:
|
||
logger.warning(f"引擎 {engine_id} 未注册")
|
||
return None
|
||
|
||
# 创建实例
|
||
try:
|
||
engine_class = self._engines[engine_id]
|
||
engine = engine_class()
|
||
|
||
# 注入共享组件
|
||
if self._shared_components:
|
||
engine.initialize(self._shared_components)
|
||
|
||
self._instances[engine_id] = engine
|
||
return engine
|
||
|
||
except Exception as e:
|
||
logger.error(f"创建引擎 {engine_id} 实例失败: {e}")
|
||
return None
|
||
|
||
def has(self, engine_id: str) -> bool:
|
||
"""
|
||
检查引擎是否存在
|
||
|
||
Args:
|
||
engine_id: 引擎 ID
|
||
|
||
Returns:
|
||
是否存在
|
||
"""
|
||
return engine_id in self._engines
|
||
|
||
def list_engines(self) -> List[Dict[str, Any]]:
|
||
"""
|
||
列出所有可用引擎
|
||
|
||
Returns:
|
||
引擎信息列表
|
||
"""
|
||
engines = []
|
||
for engine_id, engine_class in self._engines.items():
|
||
engines.append({
|
||
"id": engine_id,
|
||
"name": engine_class.engine_name,
|
||
"version": engine_class.version,
|
||
"description": engine_class.description,
|
||
})
|
||
return engines
|
||
|
||
def get_engine_info(self, engine_id: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
获取引擎详细信息
|
||
|
||
Args:
|
||
engine_id: 引擎 ID
|
||
|
||
Returns:
|
||
引擎信息,包含参数 Schema
|
||
"""
|
||
engine = self.get(engine_id)
|
||
if engine:
|
||
return engine.get_info()
|
||
return None
|
||
|
||
def clear(self):
|
||
"""清空注册表(主要用于测试)"""
|
||
self._engines.clear()
|
||
self._instances.clear()
|
||
logger.info("引擎注册表已清空")
|