TravelContentCreator/domain/aigc/shared/component_factory.py

167 lines
4.9 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
组件工厂
负责创建和初始化所有共享组件
"""
import logging
import json
from typing import Dict, Any, Optional
from pathlib import Path
from .llm_client import LLMClient
from .image_processor import ImageProcessor
from .database_accessor import DatabaseAccessor
from .file_storage import FileStorage
from domain.prompt import PromptRegistry
logger = logging.getLogger(__name__)
class ComponentFactory:
"""
组件工厂
负责:
- 加载配置
- 创建共享组件
- 注入依赖
"""
def __init__(self, project_root: Optional[str] = None):
"""
初始化组件工厂
Args:
project_root: 项目根目录,不传则自动检测
"""
self._project_root = self._detect_project_root(project_root)
self._paths_config: Dict[str, Any] = {}
self._components: Dict[str, Any] = {}
# 加载路径配置
self._load_paths_config()
def _detect_project_root(self, project_root: Optional[str]) -> Path:
"""检测项目根目录"""
if project_root:
return Path(project_root)
# 尝试从当前文件位置推断
current_file = Path(__file__)
# domain/aigc/shared/component_factory.py -> 项目根目录
return current_file.parent.parent.parent.parent
def _load_paths_config(self):
"""加载路径配置"""
paths_file = self._project_root / "config" / "paths.json"
if paths_file.exists():
try:
with open(paths_file, 'r', encoding='utf-8') as f:
self._paths_config = json.load(f)
logger.info(f"已加载路径配置: {paths_file}")
except Exception as e:
logger.warning(f"加载路径配置失败: {e}")
# 确保 project_root 正确
self._paths_config['project_root'] = str(self._project_root)
def create_components(
self,
ai_agent=None,
db_service=None,
config_manager=None
) -> Dict[str, Any]:
"""
创建所有共享组件
Args:
ai_agent: AI Agent 实例(可选,用于 LLM 调用)
db_service: DatabaseService 实例(可选,用于数据库访问)
config_manager: ConfigManager 实例(可选,用于配置访问)
Returns:
共享组件字典
"""
project_root = str(self._project_root)
# 创建 LLM 客户端
llm = LLMClient(ai_agent)
# 创建 PromptRegistry
prompts_path = str(self._project_root / "prompts")
prompt = PromptRegistry(prompts_path)
# 创建图片处理器
image = ImageProcessor()
# 创建数据库访问器
db = DatabaseAccessor(db_service)
# 创建文件存储
output_base = self._paths_config.get('output', {}).get('data', 'data')
storage = FileStorage(output_base)
storage.set_project_root(project_root)
# 组装组件
self._components = {
'llm': llm,
'prompt': prompt,
'image': image,
'db': db,
'storage': storage,
'paths': self._paths_config,
'config': self._get_config_dict(config_manager),
'project_root': project_root,
}
logger.info("共享组件创建完成")
return self._components
def _get_config_dict(self, config_manager) -> Dict[str, Any]:
"""从 ConfigManager 提取配置字典"""
if not config_manager:
return {}
try:
# 尝试获取常用配置
config = {}
# AI 模型配置
if hasattr(config_manager, 'get_config'):
from core.config import AIModelConfig
try:
ai_config = config_manager.get_config('ai_model', AIModelConfig)
config['ai_model'] = {
'model': getattr(ai_config, 'model', ''),
'temperature': getattr(ai_config, 'temperature', 0.7),
}
except:
pass
return config
except Exception as e:
logger.warning(f"获取配置失败: {e}")
return {}
def get_components(self) -> Dict[str, Any]:
"""获取已创建的组件"""
return self._components
def update_component(self, name: str, component: Any):
"""更新单个组件"""
self._components[name] = component
@property
def project_root(self) -> Path:
"""获取项目根目录"""
return self._project_root
@property
def paths(self) -> Dict[str, Any]:
"""获取路径配置"""
return self._paths_config