TravelContentCreator/domain/aigc/engine_registry.py

216 lines
6.4 KiB
Python
Raw Normal View History

2025-12-08 14:58:35 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
引擎注册表
负责引擎的自动发现注册和管理
"""
import logging
import importlib
import pkgutil
from typing import Dict, Type, Optional, List, Any
from pathlib import Path
from .engines.base import BaseAIGCEngine
logger = logging.getLogger(__name__)
class EngineRegistry:
"""
引擎注册表
支持
- 自动发现 engines 目录下的所有引擎
- 手动注册引擎
- ID 获取引擎实例
- 列出所有可用引擎
"""
_instance: Optional['EngineRegistry'] = None
def __new__(cls):
"""单例模式"""
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
self._engines: Dict[str, Type[BaseAIGCEngine]] = {}
self._instances: Dict[str, BaseAIGCEngine] = {}
self._shared_components: Optional[Dict[str, Any]] = None
self._initialized = True
logger.info("引擎注册表初始化")
def set_shared_components(self, components: Dict[str, Any]):
"""
设置共享组件
Args:
components: 共享组件字典 (db, llm, prompt, image, storage, paths, config)
"""
self._shared_components = components
# 为已实例化的引擎注入组件
for engine in self._instances.values():
engine.initialize(components)
logger.info("共享组件已设置")
def auto_discover(self, package_path: str = "domain.aigc.engines"):
"""
自动发现并注册引擎
Args:
package_path: 引擎包路径
"""
logger.info(f"开始自动发现引擎: {package_path}")
try:
package = importlib.import_module(package_path)
package_dir = Path(package.__file__).parent
# 遍历包中的所有模块
for _, module_name, is_pkg in pkgutil.iter_modules([str(package_dir)]):
# 跳过 base 模块和私有模块
if module_name.startswith('_') or module_name == 'base':
continue
# 跳过子包(如 _future
if is_pkg:
continue
try:
module = importlib.import_module(f"{package_path}.{module_name}")
# 查找模块中的引擎类
for attr_name in dir(module):
attr = getattr(module, attr_name)
# 检查是否是 BaseAIGCEngine 的子类
if (isinstance(attr, type) and
issubclass(attr, BaseAIGCEngine) and
attr is not BaseAIGCEngine and
hasattr(attr, 'engine_id') and
attr.engine_id):
self.register(attr)
except Exception as e:
logger.warning(f"加载模块 {module_name} 失败: {e}")
except Exception as e:
logger.error(f"自动发现引擎失败: {e}")
logger.info(f"自动发现完成,共注册 {len(self._engines)} 个引擎")
def register(self, engine_class: Type[BaseAIGCEngine]):
"""
注册引擎
Args:
engine_class: 引擎类
"""
engine_id = engine_class.engine_id
if not engine_id:
logger.warning(f"引擎类 {engine_class.__name__} 未定义 engine_id跳过注册")
return
if engine_id in self._engines:
logger.warning(f"引擎 {engine_id} 已存在,将被覆盖")
self._engines[engine_id] = engine_class
logger.info(f"注册引擎: {engine_id} ({engine_class.engine_name})")
def get(self, engine_id: str) -> Optional[BaseAIGCEngine]:
"""
获取引擎实例
Args:
engine_id: 引擎 ID
Returns:
引擎实例不存在返回 None
"""
# 检查是否已实例化
if engine_id in self._instances:
return self._instances[engine_id]
# 检查是否已注册
if engine_id not in self._engines:
logger.warning(f"引擎 {engine_id} 未注册")
return None
# 创建实例
try:
engine_class = self._engines[engine_id]
engine = engine_class()
# 注入共享组件
if self._shared_components:
engine.initialize(self._shared_components)
self._instances[engine_id] = engine
return engine
except Exception as e:
logger.error(f"创建引擎 {engine_id} 实例失败: {e}")
return None
def has(self, engine_id: str) -> bool:
"""
检查引擎是否存在
Args:
engine_id: 引擎 ID
Returns:
是否存在
"""
return engine_id in self._engines
def list_engines(self) -> List[Dict[str, Any]]:
"""
列出所有可用引擎
Returns:
引擎信息列表
"""
engines = []
for engine_id, engine_class in self._engines.items():
engines.append({
"id": engine_id,
"name": engine_class.engine_name,
"version": engine_class.version,
"description": engine_class.description,
})
return engines
def get_engine_info(self, engine_id: str) -> Optional[Dict[str, Any]]:
"""
获取引擎详细信息
Args:
engine_id: 引擎 ID
Returns:
引擎信息包含参数 Schema
"""
engine = self.get(engine_id)
if engine:
return engine.get_info()
return None
def clear(self):
"""清空注册表(主要用于测试)"""
self._engines.clear()
self._instances.clear()
logger.info("引擎注册表已清空")