bangbang-aigc-server/core/config/manager.py.backup
2025-07-31 15:35:23 +08:00

281 lines
10 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
统一配置管理器
"""
import json
import os
import logging
from pathlib import Path
from typing import Dict, Type, TypeVar, Optional, Any, cast, List, Set
from core.config.models import (
BaseConfig, AIModelConfig, SystemConfig, GenerateTopicConfig, ResourceConfig,
GenerateContentConfig, PosterConfig, ContentConfig
)
logger = logging.getLogger(__name__)
T = TypeVar('T', bound=BaseConfig)
class ConfigManager:
"""
统一配置管理器
负责加载、管理和访问所有配置
"""
# 服务端必要的全局配置
SERVER_CONFIGS = {'system', 'ai_model', 'database'}
# 单次生成任务的配置
TASK_CONFIGS = {'topic_gen', 'content_gen', 'poster_gen', 'resource'}
def __init__(self):
self._configs: Dict[str, BaseConfig] = {}
self._raw_configs: Dict[str, Dict[str, Any]] = {} # 存储原始配置数据
self.config_dir: Optional[Path] = None
self.config_objects = {
'ai_model': AIModelConfig(),
'system': SystemConfig(),
'resource': ResourceConfig()
}
self._loaded_configs: Set[str] = set()
def load_from_directory(self, config_dir: str, server_mode: bool = False):
"""
从目录加载配置
Args:
config_dir: 配置文件目录
server_mode: 是否为服务器模式,如果是则只加载必要的全局配置
"""
self.config_dir = Path(config_dir)
if not self.config_dir.is_dir():
logger.error(f"配置目录不存在: {config_dir}")
raise FileNotFoundError(f"配置目录不存在: {config_dir}")
# 注册所有已知的配置类型
self._register_configs()
# 动态加载目录中的所有.json文件
self._load_all_configs_from_dir(server_mode)
def _register_configs(self):
"""注册所有配置"""
self.register_config('ai_model', AIModelConfig)
self.register_config('system', SystemConfig)
self.register_config('resource', ResourceConfig)
# 这些配置在服务器模式下不会自动加载,但仍然需要注册类型
self.register_config('poster', PosterConfig)
self.register_config('content', ContentConfig)
self.register_config('topic_gen', GenerateTopicConfig)
self.register_config('content_gen', GenerateContentConfig)
def register_config(self, name: str, config_class: Type[T]) -> None:
"""
注册一个配置类
Args:
name: 配置名称
config_class: 配置类 (必须是 BaseConfig 的子类)
"""
if not issubclass(config_class, BaseConfig):
raise TypeError("config_class must be a subclass of BaseConfig")
if name not in self._configs:
self._configs[name] = config_class()
def get_config(self, name: str, config_class: Type[T]) -> T:
"""
获取配置实例
Args:
name: 配置名称
config_class: 配置类 (用于类型提示)
Returns:
配置实例
"""
config = self._configs.get(name)
if config is None:
# 如果配置不存在,先注册一个默认实例
self.register_config(name, config_class)
config = self._configs.get(name)
# 确保配置是正确的类型
if not isinstance(config, config_class):
# 尝试转换配置
try:
if isinstance(config, BaseConfig):
# 将现有配置转换为请求的类型
new_config = config_class(**config.model_dump())
self._configs[name] = new_config
config = new_config
else:
raise TypeError(f"Configuration '{name}' is not of type '{config_class.__name__}'")
except Exception as e:
logger.error(f"转换配置 '{name}' 到类型 '{config_class.__name__}' 失败: {e}")
raise TypeError(f"Configuration '{name}' is not of type '{config_class.__name__}'") from e
return cast(T, config)
def get_raw_config(self, name: str) -> Dict[str, Any]:
"""
获取原始配置数据
Args:
name: 配置名称
Returns:
原始配置数据字典
"""
if name in self._raw_configs:
return self._raw_configs[name]
# 如果没有原始配置,但有对象配置,则转换为字典
if name in self._configs:
return self._configs[name].to_dict()
# 尝试从文件加载
if self.config_dir:
config_path = self.config_dir / f"{name}.json"
if config_path.exists():
try:
with open(config_path, 'r', encoding='utf-8') as f:
raw_config = json.load(f)
self._raw_configs[name] = raw_config
return raw_config
except Exception as e:
logger.error(f"加载原始配置 '{name}' 失败: {e}")
# 返回空字典
return {}
def _load_all_configs_from_dir(self, server_mode: bool = False):
"""
动态加载目录中的所有.json文件
Args:
server_mode: 是否为服务器模式,如果是则只加载必要的全局配置
"""
try:
# 遍历并加载目录中所有其他的 .json 文件
for config_path in self.config_dir.glob('*.json'):
config_name = config_path.stem # 'topic_gen.json' -> 'topic_gen'
# 服务器模式下,只加载必要的全局配置
if server_mode and config_name not in self.SERVER_CONFIGS:
logger.info(f"服务器模式下跳过非全局配置: {config_name}")
continue
# 加载原始配置
with open(config_path, 'r', encoding='utf-8') as f:
config_data = json.load(f)
self._raw_configs[config_name] = config_data
# 更新对象配置
if config_name in self._configs:
logger.info(f"加载配置文件 '{config_name}': {config_path}")
self._configs[config_name].update(config_data)
self._loaded_configs.add(config_name)
else:
logger.info(f"加载原始配置 '{config_name}': {config_path}")
# 最后应用环境变量覆盖
self._apply_env_overrides()
except Exception as e:
logger.error(f"从目录 '{self.config_dir}' 加载配置失败: {e}", exc_info=True)
raise
def load_task_config(self, config_name: str) -> bool:
"""
按需加载任务配置
Args:
config_name: 配置名称
Returns:
是否成功加载
"""
if config_name in self._loaded_configs:
return True
if self.config_dir:
config_path = self.config_dir / f"{config_name}.json"
if config_path.exists():
try:
with open(config_path, 'r', encoding='utf-8') as f:
config_data = json.load(f)
self._raw_configs[config_name] = config_data
if config_name in self._configs:
self._configs[config_name].update(config_data)
self._loaded_configs.add(config_name)
logger.info(f"按需加载任务配置 '{config_name}': {config_path}")
return True
except Exception as e:
logger.error(f"加载任务配置 '{config_name}' 失败: {e}")
logger.warning(f"未找到任务配置: {config_name}")
return False
def _apply_env_overrides(self):
"""应用环境变量覆盖"""
logger.info("应用环境变量覆盖...")
# 示例: AI模型配置环境变量覆盖
ai_model_config = self.get_config('ai_model', AIModelConfig)
if not ai_model_config: return # 如果没有AI配置则跳过
env_mapping = {
'AI_MODEL': 'model',
'API_URL': 'api_url',
'API_KEY': 'api_key'
}
update_data = {}
for env_var, config_key in env_mapping.items():
if os.getenv(env_var):
update_data[config_key] = os.getenv(env_var)
if update_data:
ai_model_config.update(update_data)
# 更新原始配置
if 'ai_model' in self._raw_configs:
for key, value in update_data.items():
self._raw_configs['ai_model'][key] = value
logger.info(f"通过环境变量更新了AI模型配置: {list(update_data.keys())}")
def save_config(self, name: str):
"""
保存指定的配置到文件
Args:
name: 要保存的配置名称
"""
if not self.config_dir:
raise ValueError("配置目录未设置,无法保存文件")
path = self.config_dir / f"{name}.json"
config = self.get_config(name, BaseConfig)
config_data = config.to_dict()
# 更新原始配置
self._raw_configs[name] = config_data
try:
with open(path, 'w', encoding='utf-8') as f:
json.dump(config_data, f, indent=4, ensure_ascii=False)
logger.info(f"配置 '{name}' 已保存到 {path}")
except Exception as e:
logger.error(f"保存配置 '{name}'{path} 失败: {e}", exc_info=True)
raise
# 全局配置管理器实例
config_manager = ConfigManager()
def get_config_manager() -> ConfigManager:
return config_manager