167 lines
4.9 KiB
Python
167 lines
4.9 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
组件工厂
|
|
负责创建和初始化所有共享组件
|
|
"""
|
|
|
|
import logging
|
|
import json
|
|
from typing import Dict, Any, Optional
|
|
from pathlib import Path
|
|
|
|
from .llm_client import LLMClient
|
|
from .image_processor import ImageProcessor
|
|
from .database_accessor import DatabaseAccessor
|
|
from .file_storage import FileStorage
|
|
from domain.prompt import PromptRegistry
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ComponentFactory:
|
|
"""
|
|
组件工厂
|
|
|
|
负责:
|
|
- 加载配置
|
|
- 创建共享组件
|
|
- 注入依赖
|
|
"""
|
|
|
|
def __init__(self, project_root: Optional[str] = None):
|
|
"""
|
|
初始化组件工厂
|
|
|
|
Args:
|
|
project_root: 项目根目录,不传则自动检测
|
|
"""
|
|
self._project_root = self._detect_project_root(project_root)
|
|
self._paths_config: Dict[str, Any] = {}
|
|
self._components: Dict[str, Any] = {}
|
|
|
|
# 加载路径配置
|
|
self._load_paths_config()
|
|
|
|
def _detect_project_root(self, project_root: Optional[str]) -> Path:
|
|
"""检测项目根目录"""
|
|
if project_root:
|
|
return Path(project_root)
|
|
|
|
# 尝试从当前文件位置推断
|
|
current_file = Path(__file__)
|
|
# domain/aigc/shared/component_factory.py -> 项目根目录
|
|
return current_file.parent.parent.parent.parent
|
|
|
|
def _load_paths_config(self):
|
|
"""加载路径配置"""
|
|
paths_file = self._project_root / "config" / "paths.json"
|
|
|
|
if paths_file.exists():
|
|
try:
|
|
with open(paths_file, 'r', encoding='utf-8') as f:
|
|
self._paths_config = json.load(f)
|
|
logger.info(f"已加载路径配置: {paths_file}")
|
|
except Exception as e:
|
|
logger.warning(f"加载路径配置失败: {e}")
|
|
|
|
# 确保 project_root 正确
|
|
self._paths_config['project_root'] = str(self._project_root)
|
|
|
|
def create_components(
|
|
self,
|
|
ai_agent=None,
|
|
db_service=None,
|
|
config_manager=None
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
创建所有共享组件
|
|
|
|
Args:
|
|
ai_agent: AI Agent 实例(可选,用于 LLM 调用)
|
|
db_service: DatabaseService 实例(可选,用于数据库访问)
|
|
config_manager: ConfigManager 实例(可选,用于配置访问)
|
|
|
|
Returns:
|
|
共享组件字典
|
|
"""
|
|
project_root = str(self._project_root)
|
|
|
|
# 创建 LLM 客户端
|
|
llm = LLMClient(ai_agent)
|
|
|
|
# 创建 PromptRegistry
|
|
prompts_path = str(self._project_root / "prompts")
|
|
prompt = PromptRegistry(prompts_path)
|
|
|
|
# 创建图片处理器
|
|
image = ImageProcessor()
|
|
|
|
# 创建数据库访问器
|
|
db = DatabaseAccessor(db_service)
|
|
|
|
# 创建文件存储
|
|
output_base = self._paths_config.get('output', {}).get('data', 'data')
|
|
storage = FileStorage(output_base)
|
|
storage.set_project_root(project_root)
|
|
|
|
# 组装组件
|
|
self._components = {
|
|
'llm': llm,
|
|
'prompt': prompt,
|
|
'image': image,
|
|
'db': db,
|
|
'storage': storage,
|
|
'paths': self._paths_config,
|
|
'config': self._get_config_dict(config_manager),
|
|
'project_root': project_root,
|
|
}
|
|
|
|
logger.info("共享组件创建完成")
|
|
return self._components
|
|
|
|
def _get_config_dict(self, config_manager) -> Dict[str, Any]:
|
|
"""从 ConfigManager 提取配置字典"""
|
|
if not config_manager:
|
|
return {}
|
|
|
|
try:
|
|
# 尝试获取常用配置
|
|
config = {}
|
|
|
|
# AI 模型配置
|
|
if hasattr(config_manager, 'get_config'):
|
|
from core.config import AIModelConfig
|
|
try:
|
|
ai_config = config_manager.get_config('ai_model', AIModelConfig)
|
|
config['ai_model'] = {
|
|
'model': getattr(ai_config, 'model', ''),
|
|
'temperature': getattr(ai_config, 'temperature', 0.7),
|
|
}
|
|
except:
|
|
pass
|
|
|
|
return config
|
|
except Exception as e:
|
|
logger.warning(f"获取配置失败: {e}")
|
|
return {}
|
|
|
|
def get_components(self) -> Dict[str, Any]:
|
|
"""获取已创建的组件"""
|
|
return self._components
|
|
|
|
def update_component(self, name: str, component: Any):
|
|
"""更新单个组件"""
|
|
self._components[name] = component
|
|
|
|
@property
|
|
def project_root(self) -> Path:
|
|
"""获取项目根目录"""
|
|
return self._project_root
|
|
|
|
@property
|
|
def paths(self) -> Dict[str, Any]:
|
|
"""获取路径配置"""
|
|
return self._paths_config
|