TravelContentCreator/core/unified_config.py

175 lines
4.9 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
统一配置加载器
合并分散的配置文件为统一接口
"""
import os
import json
import logging
from pathlib import Path
from typing import Any, Dict, Optional
from functools import lru_cache
logger = logging.getLogger(__name__)
class UnifiedConfig:
"""
统一配置管理器
合并多个配置文件,提供统一的访问接口
"""
_instance: Optional['UnifiedConfig'] = None
def __new__(cls, config_dir: str = "config"):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self, config_dir: str = "config"):
if self._initialized:
return
self._config_dir = Path(config_dir)
self._config: Dict[str, Any] = {}
self._load_all_configs()
self._initialized = True
def _load_all_configs(self):
"""加载所有配置文件"""
config_files = [
("ai_model", "ai_model.json"),
("database", "database.json"),
("paths", "paths.json"),
("resource", "resource.json"),
("system", "system.json"),
("cookies", "cookies.json"),
("engines", "engines.json"),
("poster_gen", "poster_gen.json"),
]
for key, filename in config_files:
filepath = self._config_dir / filename
if filepath.exists():
try:
with open(filepath, 'r', encoding='utf-8') as f:
self._config[key] = json.load(f)
logger.debug(f"已加载配置: {filename}")
except Exception as e:
logger.warning(f"加载配置失败 {filename}: {e}")
self._config[key] = {}
else:
self._config[key] = {}
# 加载环境变量覆盖
self._apply_env_overrides()
logger.info(f"配置加载完成,共 {len(self._config)} 个配置组")
def _apply_env_overrides(self):
"""应用环境变量覆盖"""
env_mappings = {
"DATABASE_HOST": ("database", "host"),
"DATABASE_PORT": ("database", "port"),
"DATABASE_USER": ("database", "user"),
"DATABASE_PASSWORD": ("database", "password"),
"DATABASE_NAME": ("database", "database"),
"AI_MODEL_API_KEY": ("ai_model", "api_key"),
"AI_MODEL_BASE_URL": ("ai_model", "base_url"),
}
for env_key, (config_group, config_key) in env_mappings.items():
value = os.environ.get(env_key)
if value:
if config_group not in self._config:
self._config[config_group] = {}
self._config[config_group][config_key] = value
logger.debug(f"环境变量覆盖: {env_key}")
def get(self, key: str, default: Any = None) -> Any:
"""
获取配置值
支持点号分隔的路径,如 "database.host"
Args:
key: 配置键
default: 默认值
Returns:
配置值
"""
parts = key.split(".")
value = self._config
for part in parts:
if isinstance(value, dict) and part in value:
value = value[part]
else:
return default
return value
def get_group(self, group: str) -> Dict[str, Any]:
"""
获取配置组
Args:
group: 配置组名称
Returns:
配置组字典
"""
return self._config.get(group, {})
@property
def database(self) -> Dict[str, Any]:
"""数据库配置"""
return self.get_group("database")
@property
def ai_model(self) -> Dict[str, Any]:
"""AI 模型配置"""
return self.get_group("ai_model")
@property
def paths(self) -> Dict[str, Any]:
"""路径配置"""
return self.get_group("paths")
@property
def poster_gen(self) -> Dict[str, Any]:
"""海报生成配置"""
return self.get_group("poster_gen")
@property
def engines(self) -> Dict[str, Any]:
"""引擎配置"""
return self.get_group("engines")
@property
def system(self) -> Dict[str, Any]:
"""系统配置"""
return self.get_group("system")
def reload(self):
"""重新加载配置"""
self._config.clear()
self._load_all_configs()
logger.info("配置已重新加载")
def to_dict(self) -> Dict[str, Any]:
"""导出所有配置"""
return self._config.copy()
@lru_cache(maxsize=1)
def get_unified_config() -> UnifiedConfig:
"""获取统一配置实例"""
return UnifiedConfig()