TravelContentCreator/domain/aigc/engine_registry.py

216 lines
6.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
引擎注册表
负责引擎的自动发现、注册和管理
"""
import logging
import importlib
import pkgutil
from typing import Dict, Type, Optional, List, Any
from pathlib import Path
from .engines.base import BaseAIGCEngine
logger = logging.getLogger(__name__)
class EngineRegistry:
"""
引擎注册表
支持:
- 自动发现 engines 目录下的所有引擎
- 手动注册引擎
- 按 ID 获取引擎实例
- 列出所有可用引擎
"""
_instance: Optional['EngineRegistry'] = None
def __new__(cls):
"""单例模式"""
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
self._engines: Dict[str, Type[BaseAIGCEngine]] = {}
self._instances: Dict[str, BaseAIGCEngine] = {}
self._shared_components: Optional[Dict[str, Any]] = None
self._initialized = True
logger.info("引擎注册表初始化")
def set_shared_components(self, components: Dict[str, Any]):
"""
设置共享组件
Args:
components: 共享组件字典 (db, llm, prompt, image, storage, paths, config)
"""
self._shared_components = components
# 为已实例化的引擎注入组件
for engine in self._instances.values():
engine.initialize(components)
logger.info("共享组件已设置")
def auto_discover(self, package_path: str = "domain.aigc.engines"):
"""
自动发现并注册引擎
Args:
package_path: 引擎包路径
"""
logger.info(f"开始自动发现引擎: {package_path}")
try:
package = importlib.import_module(package_path)
package_dir = Path(package.__file__).parent
# 遍历包中的所有模块
for _, module_name, is_pkg in pkgutil.iter_modules([str(package_dir)]):
# 跳过 base 模块和私有模块
if module_name.startswith('_') or module_name == 'base':
continue
# 跳过子包(如 _future
if is_pkg:
continue
try:
module = importlib.import_module(f"{package_path}.{module_name}")
# 查找模块中的引擎类
for attr_name in dir(module):
attr = getattr(module, attr_name)
# 检查是否是 BaseAIGCEngine 的子类
if (isinstance(attr, type) and
issubclass(attr, BaseAIGCEngine) and
attr is not BaseAIGCEngine and
hasattr(attr, 'engine_id') and
attr.engine_id):
self.register(attr)
except Exception as e:
logger.warning(f"加载模块 {module_name} 失败: {e}")
except Exception as e:
logger.error(f"自动发现引擎失败: {e}")
logger.info(f"自动发现完成,共注册 {len(self._engines)} 个引擎")
def register(self, engine_class: Type[BaseAIGCEngine]):
"""
注册引擎
Args:
engine_class: 引擎类
"""
engine_id = engine_class.engine_id
if not engine_id:
logger.warning(f"引擎类 {engine_class.__name__} 未定义 engine_id跳过注册")
return
if engine_id in self._engines:
logger.warning(f"引擎 {engine_id} 已存在,将被覆盖")
self._engines[engine_id] = engine_class
logger.info(f"注册引擎: {engine_id} ({engine_class.engine_name})")
def get(self, engine_id: str) -> Optional[BaseAIGCEngine]:
"""
获取引擎实例
Args:
engine_id: 引擎 ID
Returns:
引擎实例,不存在返回 None
"""
# 检查是否已实例化
if engine_id in self._instances:
return self._instances[engine_id]
# 检查是否已注册
if engine_id not in self._engines:
logger.warning(f"引擎 {engine_id} 未注册")
return None
# 创建实例
try:
engine_class = self._engines[engine_id]
engine = engine_class()
# 注入共享组件
if self._shared_components:
engine.initialize(self._shared_components)
self._instances[engine_id] = engine
return engine
except Exception as e:
logger.error(f"创建引擎 {engine_id} 实例失败: {e}")
return None
def has(self, engine_id: str) -> bool:
"""
检查引擎是否存在
Args:
engine_id: 引擎 ID
Returns:
是否存在
"""
return engine_id in self._engines
def list_engines(self) -> List[Dict[str, Any]]:
"""
列出所有可用引擎
Returns:
引擎信息列表
"""
engines = []
for engine_id, engine_class in self._engines.items():
engines.append({
"id": engine_id,
"name": engine_class.engine_name,
"version": engine_class.version,
"description": engine_class.description,
})
return engines
def get_engine_info(self, engine_id: str) -> Optional[Dict[str, Any]]:
"""
获取引擎详细信息
Args:
engine_id: 引擎 ID
Returns:
引擎信息,包含参数 Schema
"""
engine = self.get(engine_id)
if engine:
return engine.get_info()
return None
def clear(self):
"""清空注册表(主要用于测试)"""
self._engines.clear()
self._instances.clear()
logger.info("引擎注册表已清空")