176 lines
5.6 KiB
Python
Raw Normal View History

2025-07-08 17:45:40 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
统一配置管理器
"""
import json
import os
import logging
from pathlib import Path
from typing import Dict, Type, TypeVar, Optional
from core.config.models import BaseConfig, AIModelConfig, PosterConfig, ContentConfig, ResourceConfig, SystemConfig, GenerateTopicConfig
logger = logging.getLogger(__name__)
T = TypeVar('T', bound=BaseConfig)
class ConfigManager:
"""
统一配置管理器
负责加载管理和访问所有配置
"""
def __init__(self):
self._configs: Dict[str, BaseConfig] = {}
self._config_files: Dict[str, Path] = {}
self.config_dir: Optional[Path] = None
def load_from_directory(self, config_dir: str):
"""
从目录加载配置
Args:
config_dir: 配置文件目录
"""
self.config_dir = Path(config_dir)
# 定义配置文件
self._config_files = {
'main': self.config_dir / 'poster_gen_config.json',
'topic_gen': self.config_dir / 'topic_gen_config.json',
}
# 注册配置
self._register_configs()
# 加载所有配置
self._load_all_configs()
def _register_configs(self):
"""注册所有配置"""
self.register_config('ai_model', AIModelConfig)
self.register_config('poster', PosterConfig)
self.register_config('content', ContentConfig)
self.register_config('resource', ResourceConfig)
self.register_config('system', SystemConfig)
self.register_config('topic_gen', GenerateTopicConfig)
def register_config(self, name: str, config_class: Type[T]):
"""
注册一个配置类
Args:
name: 配置名称
config_class: 配置类 (必须是 BaseConfig 的子类)
"""
if not issubclass(config_class, BaseConfig):
raise TypeError("config_class must be a subclass of BaseConfig")
self._configs[name] = config_class()
def get_config(self, name: str, config_class: Type[T]) -> T:
"""
获取配置实例
Args:
name: 配置名称
config_class: 配置类 (用于类型提示)
Returns:
配置实例
"""
config = self._configs.get(name)
if not isinstance(config, config_class):
raise TypeError(f"Configuration '{name}' is not of type '{config_class.__name__}'")
return config
def _load_all_configs(self):
"""加载所有配置文件"""
try:
# 加载主配置文件
if self._config_files['main'].exists():
self._load_main_config(self._config_files['main'])
else:
logger.warning(f"主配置文件不存在: {self._config_files['main']}")
# 加载其他配置文件
for name, path in self._config_files.items():
if name != 'main' and path.exists():
self._load_specific_config(name, path)
elif name != 'main':
logger.warning(f"配置文件不存在: {path}")
# 应用环境变量覆盖
self._apply_env_overrides()
except Exception as e:
logger.error(f"配置加载失败: {e}", exc_info=True)
raise
def _load_main_config(self, path: Path):
"""加载主配置文件,并分发到各个配置对象"""
logger.info(f"加载主配置文件: {path}")
with open(path, 'r', encoding='utf-8') as f:
config_data = json.load(f)
for name, config_obj in self._configs.items():
if name in config_data:
config_obj.update(config_data[name])
def _load_specific_config(self, name: str, path: Path):
"""加载特定的配置文件"""
if name in self._configs:
logger.info(f"加载配置文件 '{name}': {path}")
with open(path, 'r', encoding='utf-8') as f:
config_data = json.load(f)
self._configs[name].update(config_data)
def _apply_env_overrides(self):
"""应用环境变量覆盖"""
logger.info("应用环境变量覆盖...")
# 示例: AI模型配置环境变量覆盖
ai_model_config = self.get_config('ai_model', AIModelConfig)
env_mapping = {
'AI_MODEL': 'model',
'API_URL': 'api_url',
'API_KEY': 'api_key'
}
update_data = {}
for env_var, config_key in env_mapping.items():
if os.getenv(env_var):
update_data[config_key] = os.getenv(env_var)
if update_data:
ai_model_config.update(update_data)
logger.info(f"通过环境变量更新了AI模型配置: {update_data}")
def save_config(self, name: str):
"""
保存指定的配置到文件
Args:
name: 要保存的配置名称
"""
if name not in self._config_files:
raise ValueError(f"没有为 '{name}' 定义配置文件路径")
path = self._config_files[name]
config_data = self.get_config(name, BaseConfig).to_dict()
try:
with open(path, 'w', encoding='utf-8') as f:
json.dump(config_data, f, indent=4, ensure_ascii=False)
logger.info(f"配置 '{name}' 已保存到 {path}")
except Exception as e:
logger.error(f"保存配置 '{name}'{path} 失败: {e}", exc_info=True)
raise
# 全局配置管理器实例
config_manager = ConfigManager()
def get_config_manager() -> ConfigManager:
return config_manager