#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 统一配置管理器 """ import json import os import logging from pathlib import Path from typing import Dict, Type, TypeVar, Optional, Any, cast, List, Set from core.config.models import ( BaseConfig, AIModelConfig, SystemConfig, GenerateTopicConfig, ResourceConfig, GenerateContentConfig, PosterConfig, ContentConfig ) logger = logging.getLogger(__name__) T = TypeVar('T', bound=BaseConfig) class ConfigManager: """ 统一配置管理器 负责加载、管理和访问所有配置 """ # 服务端必要的全局配置 SERVER_CONFIGS = {'system', 'ai_model', 'database'} # 单次生成任务的配置 TASK_CONFIGS = {'topic_gen', 'content_gen', 'poster_gen', 'resource'} def __init__(self): self._configs: Dict[str, BaseConfig] = {} self._raw_configs: Dict[str, Dict[str, Any]] = {} # 存储原始配置数据 self.config_dir: Optional[Path] = None self.config_objects = { 'ai_model': AIModelConfig(), 'system': SystemConfig(), 'resource': ResourceConfig() } self._loaded_configs: Set[str] = set() def load_from_directory(self, config_dir: str, server_mode: bool = False): """ 从目录加载配置 Args: config_dir: 配置文件目录 server_mode: 是否为服务器模式,如果是则只加载必要的全局配置 """ self.config_dir = Path(config_dir) if not self.config_dir.is_dir(): logger.error(f"配置目录不存在: {config_dir}") raise FileNotFoundError(f"配置目录不存在: {config_dir}") # 注册所有已知的配置类型 self._register_configs() # 动态加载目录中的所有.json文件 self._load_all_configs_from_dir(server_mode) def _register_configs(self): """注册所有配置""" self.register_config('ai_model', AIModelConfig) self.register_config('system', SystemConfig) self.register_config('resource', ResourceConfig) # 这些配置在服务器模式下不会自动加载,但仍然需要注册类型 self.register_config('poster', PosterConfig) self.register_config('content', ContentConfig) self.register_config('topic_gen', GenerateTopicConfig) self.register_config('content_gen', GenerateContentConfig) def register_config(self, name: str, config_class: Type[T]) -> None: """ 注册一个配置类 Args: name: 配置名称 config_class: 配置类 (必须是 BaseConfig 的子类) """ if not issubclass(config_class, BaseConfig): raise TypeError("config_class must be a subclass of BaseConfig") if name not in self._configs: self._configs[name] = config_class() def get_config(self, name: str, config_class: Type[T]) -> T: """ 获取配置实例 Args: name: 配置名称 config_class: 配置类 (用于类型提示) Returns: 配置实例 """ config = self._configs.get(name) if config is None: # 如果配置不存在,先注册一个默认实例 self.register_config(name, config_class) config = self._configs.get(name) # 确保配置是正确的类型 if not isinstance(config, config_class): # 尝试转换配置 try: if isinstance(config, BaseConfig): # 将现有配置转换为请求的类型 new_config = config_class(**config.model_dump()) self._configs[name] = new_config config = new_config else: raise TypeError(f"Configuration '{name}' is not of type '{config_class.__name__}'") except Exception as e: logger.error(f"转换配置 '{name}' 到类型 '{config_class.__name__}' 失败: {e}") raise TypeError(f"Configuration '{name}' is not of type '{config_class.__name__}'") from e return cast(T, config) def get_raw_config(self, name: str) -> Dict[str, Any]: """ 获取原始配置数据 Args: name: 配置名称 Returns: 原始配置数据字典 """ if name in self._raw_configs: return self._raw_configs[name] # 如果没有原始配置,但有对象配置,则转换为字典 if name in self._configs: return self._configs[name].to_dict() # 尝试从文件加载 if self.config_dir: config_path = self.config_dir / f"{name}.json" if config_path.exists(): try: with open(config_path, 'r', encoding='utf-8') as f: raw_config = json.load(f) self._raw_configs[name] = raw_config return raw_config except Exception as e: logger.error(f"加载原始配置 '{name}' 失败: {e}") # 返回空字典 return {} def _load_all_configs_from_dir(self, server_mode: bool = False): """ 动态加载目录中的所有.json文件 Args: server_mode: 是否为服务器模式,如果是则只加载必要的全局配置 """ try: # 遍历并加载目录中所有其他的 .json 文件 for config_path in self.config_dir.glob('*.json'): config_name = config_path.stem # 'topic_gen.json' -> 'topic_gen' # 服务器模式下,只加载必要的全局配置 if server_mode and config_name not in self.SERVER_CONFIGS: logger.info(f"服务器模式下跳过非全局配置: {config_name}") continue # 加载原始配置 with open(config_path, 'r', encoding='utf-8') as f: config_data = json.load(f) self._raw_configs[config_name] = config_data # 更新对象配置 if config_name in self._configs: logger.info(f"加载配置文件 '{config_name}': {config_path}") self._configs[config_name].update(config_data) self._loaded_configs.add(config_name) else: logger.info(f"加载原始配置 '{config_name}': {config_path}") # 最后应用环境变量覆盖 self._apply_env_overrides() except Exception as e: logger.error(f"从目录 '{self.config_dir}' 加载配置失败: {e}", exc_info=True) raise def load_task_config(self, config_name: str) -> bool: """ 按需加载任务配置 Args: config_name: 配置名称 Returns: 是否成功加载 """ if config_name in self._loaded_configs: return True if self.config_dir: config_path = self.config_dir / f"{config_name}.json" if config_path.exists(): try: with open(config_path, 'r', encoding='utf-8') as f: config_data = json.load(f) self._raw_configs[config_name] = config_data if config_name in self._configs: self._configs[config_name].update(config_data) self._loaded_configs.add(config_name) logger.info(f"按需加载任务配置 '{config_name}': {config_path}") return True except Exception as e: logger.error(f"加载任务配置 '{config_name}' 失败: {e}") logger.warning(f"未找到任务配置: {config_name}") return False def _apply_env_overrides(self): """应用环境变量覆盖""" logger.info("应用环境变量覆盖...") # 示例: AI模型配置环境变量覆盖 ai_model_config = self.get_config('ai_model', AIModelConfig) if not ai_model_config: return # 如果没有AI配置则跳过 env_mapping = { 'AI_MODEL': 'model', 'API_URL': 'api_url', 'API_KEY': 'api_key' } update_data = {} for env_var, config_key in env_mapping.items(): if os.getenv(env_var): update_data[config_key] = os.getenv(env_var) if update_data: ai_model_config.update(update_data) # 更新原始配置 if 'ai_model' in self._raw_configs: for key, value in update_data.items(): self._raw_configs['ai_model'][key] = value logger.info(f"通过环境变量更新了AI模型配置: {list(update_data.keys())}") def save_config(self, name: str): """ 保存指定的配置到文件 Args: name: 要保存的配置名称 """ if not self.config_dir: raise ValueError("配置目录未设置,无法保存文件") path = self.config_dir / f"{name}.json" config = self.get_config(name, BaseConfig) config_data = config.to_dict() # 更新原始配置 self._raw_configs[name] = config_data try: with open(path, 'w', encoding='utf-8') as f: json.dump(config_data, f, indent=4, ensure_ascii=False) logger.info(f"配置 '{name}' 已保存到 {path}") except Exception as e: logger.error(f"保存配置 '{name}' 到 {path} 失败: {e}", exc_info=True) raise # 全局配置管理器实例 config_manager = ConfigManager() def get_config_manager() -> ConfigManager: return config_manager