171 lines
5.0 KiB
Python
171 lines
5.0 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
依赖注入容器
|
|
统一管理所有服务和组件的生命周期
|
|
"""
|
|
|
|
import logging
|
|
from typing import Dict, Any, Type, TypeVar, Optional, Callable
|
|
from functools import lru_cache
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
T = TypeVar('T')
|
|
|
|
|
|
class Container:
|
|
"""
|
|
依赖注入容器
|
|
|
|
功能:
|
|
1. 单例管理 - 每个服务只创建一次
|
|
2. 延迟加载 - 只在需要时创建
|
|
3. 依赖解析 - 自动解析依赖关系
|
|
|
|
使用示例:
|
|
```python
|
|
# 注册服务
|
|
Container.register(ConfigManager, lambda: ConfigManager())
|
|
Container.register(AIAgent, lambda c: AIAgent(c.get(ConfigManager)))
|
|
|
|
# 获取服务
|
|
config = Container.get(ConfigManager)
|
|
agent = Container.get(AIAgent)
|
|
```
|
|
"""
|
|
|
|
_instances: Dict[Type, Any] = {}
|
|
_factories: Dict[Type, Callable] = {}
|
|
_initialized: bool = False
|
|
|
|
@classmethod
|
|
def register(cls, service_type: Type[T], factory: Callable[['Container'], T]):
|
|
"""
|
|
注册服务工厂
|
|
|
|
Args:
|
|
service_type: 服务类型
|
|
factory: 工厂函数,接收 Container 作为参数
|
|
"""
|
|
cls._factories[service_type] = factory
|
|
logger.debug(f"注册服务: {service_type.__name__}")
|
|
|
|
@classmethod
|
|
def get(cls, service_type: Type[T]) -> T:
|
|
"""
|
|
获取服务实例
|
|
|
|
Args:
|
|
service_type: 服务类型
|
|
|
|
Returns:
|
|
服务实例 (单例)
|
|
"""
|
|
if service_type not in cls._instances:
|
|
if service_type not in cls._factories:
|
|
raise ValueError(f"服务未注册: {service_type.__name__}")
|
|
|
|
factory = cls._factories[service_type]
|
|
cls._instances[service_type] = factory(cls)
|
|
logger.debug(f"创建服务实例: {service_type.__name__}")
|
|
|
|
return cls._instances[service_type]
|
|
|
|
@classmethod
|
|
def has(cls, service_type: Type) -> bool:
|
|
"""检查服务是否已注册"""
|
|
return service_type in cls._factories
|
|
|
|
@classmethod
|
|
def reset(cls):
|
|
"""重置容器 (主要用于测试)"""
|
|
cls._instances.clear()
|
|
cls._factories.clear()
|
|
cls._initialized = False
|
|
logger.info("容器已重置")
|
|
|
|
@classmethod
|
|
def initialize(cls):
|
|
"""
|
|
初始化容器,注册所有默认服务
|
|
"""
|
|
if cls._initialized:
|
|
return
|
|
|
|
logger.info("初始化依赖注入容器...")
|
|
|
|
# 注册配置管理器
|
|
from core.config import ConfigManager
|
|
cls.register(ConfigManager, lambda c: ConfigManager())
|
|
|
|
# 注册 AI Agent
|
|
from core.ai import AIAgent
|
|
from core.config import AIModelConfig
|
|
def create_ai_agent(c: 'Container') -> AIAgent:
|
|
config_manager = c.get(ConfigManager)
|
|
ai_config = config_manager.get_config('ai_model', AIModelConfig)
|
|
return AIAgent(ai_config)
|
|
cls.register(AIAgent, create_ai_agent)
|
|
|
|
# 注册 Prompt Registry
|
|
from domain.prompt import PromptRegistry
|
|
cls.register(PromptRegistry, lambda c: PromptRegistry('prompts'))
|
|
|
|
# 注册 OutputManager (延迟创建)
|
|
from utils.file_io import OutputManager
|
|
import uuid
|
|
from datetime import datetime
|
|
def create_output_manager(c: 'Container') -> OutputManager:
|
|
run_id = f"api_request-{datetime.now().strftime('%Y%m%d-%H%M%S')}-{str(uuid.uuid4())[:8]}"
|
|
return OutputManager('result', run_id)
|
|
cls.register(OutputManager, create_output_manager)
|
|
|
|
cls._initialized = True
|
|
logger.info("依赖注入容器初始化完成")
|
|
|
|
@classmethod
|
|
def get_or_create(cls, service_type: Type[T], factory: Callable[[], T]) -> T:
|
|
"""
|
|
获取服务,如果不存在则使用提供的工厂创建
|
|
|
|
Args:
|
|
service_type: 服务类型
|
|
factory: 备用工厂函数
|
|
"""
|
|
if service_type not in cls._instances:
|
|
if service_type not in cls._factories:
|
|
cls._instances[service_type] = factory()
|
|
else:
|
|
cls._instances[service_type] = cls._factories[service_type](cls)
|
|
|
|
return cls._instances[service_type]
|
|
|
|
|
|
# 便捷函数
|
|
def get_service(service_type: Type[T]) -> T:
|
|
"""获取服务的便捷函数"""
|
|
Container.initialize()
|
|
return Container.get(service_type)
|
|
|
|
|
|
def inject(*service_types: Type):
|
|
"""
|
|
依赖注入装饰器
|
|
|
|
使用示例:
|
|
```python
|
|
@inject(ConfigManager, AIAgent)
|
|
def my_function(config, agent, other_param):
|
|
...
|
|
```
|
|
"""
|
|
def decorator(func):
|
|
def wrapper(*args, **kwargs):
|
|
Container.initialize()
|
|
injected = [Container.get(t) for t in service_types]
|
|
return func(*injected, *args, **kwargs)
|
|
return wrapper
|
|
return decorator
|