迁移到了pydantic

This commit is contained in:
jinye_huang 2025-07-10 16:31:32 +08:00
parent 157d3348a6
commit c310c5069f
5 changed files with 278 additions and 107 deletions

View File

@ -9,7 +9,7 @@ import json
import os import os
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Dict, Type, TypeVar, Optional from typing import Dict, Type, TypeVar, Optional, Any, cast
from core.config.models import ( from core.config.models import (
BaseConfig, AIModelConfig, SystemConfig, GenerateTopicConfig, ResourceConfig, BaseConfig, AIModelConfig, SystemConfig, GenerateTopicConfig, ResourceConfig,
@ -66,7 +66,7 @@ class ConfigManager:
self.register_config('topic_gen', GenerateTopicConfig) self.register_config('topic_gen', GenerateTopicConfig)
self.register_config('content_gen', GenerateContentConfig) self.register_config('content_gen', GenerateContentConfig)
def register_config(self, name: str, config_class: Type[T]): def register_config(self, name: str, config_class: Type[T]) -> None:
""" """
注册一个配置类 注册一个配置类
@ -91,14 +91,27 @@ class ConfigManager:
配置实例 配置实例
""" """
config = self._configs.get(name) config = self._configs.get(name)
if not isinstance(config, config_class): if config is None:
# 如果配置不存在,先注册一个默认实例 # 如果配置不存在,先注册一个默认实例
if config is None: self.register_config(name, config_class)
self.register_config(name, config_class) config = self._configs.get(name)
config = self._configs.get(name)
else: # 确保配置是正确的类型
raise TypeError(f"Configuration '{name}' is not of type '{config_class.__name__}'") if not isinstance(config, config_class):
return config # 尝试转换配置
try:
if isinstance(config, BaseConfig):
# 将现有配置转换为请求的类型
new_config = config_class(**config.model_dump())
self._configs[name] = new_config
config = new_config
else:
raise TypeError(f"Configuration '{name}' is not of type '{config_class.__name__}'")
except Exception as e:
logger.error(f"转换配置 '{name}' 到类型 '{config_class.__name__}' 失败: {e}")
raise TypeError(f"Configuration '{name}' is not of type '{config_class.__name__}'") from e
return cast(T, config)
def _load_all_configs_from_dir(self): def _load_all_configs_from_dir(self):
"""动态加载目录中的所有.json文件""" """动态加载目录中的所有.json文件"""
@ -176,7 +189,8 @@ class ConfigManager:
raise ValueError("配置目录未设置,无法保存文件") raise ValueError("配置目录未设置,无法保存文件")
path = self.config_dir / f"{name}.json" path = self.config_dir / f"{name}.json"
config_data = self.get_config(name, BaseConfig).to_dict() config = self.get_config(name, BaseConfig)
config_data = config.to_dict()
try: try:
with open(path, 'w', encoding='utf-8') as f: with open(path, 'w', encoding='utf-8') as f:

View File

@ -4,18 +4,23 @@
""" """
配置模型 配置模型
定义了项目中所有配置的数据类模型 定义了项目中所有配置的数据类模型
使用pydantic进行数据验证和序列化
""" """
from dataclasses import dataclass, field, fields, asdict from typing import Dict, Any, List, Optional, Type, TypeVar, Union, ClassVar
from typing import Dict, Any, List, Optional from pydantic import BaseModel, Field, model_validator, ConfigDict
from pydantic import BaseModel, Field
T = TypeVar('T', bound='BaseConfig')
# 基础配置类,提供通用方法 class BaseConfig(BaseModel):
class BaseConfig:
"""可从字典更新的配置基类""" """可从字典更新的配置基类"""
def update(self, new_data: Dict[str, Any]): model_config = ConfigDict(
extra="allow", # 允许额外字段
arbitrary_types_allowed=True # 允许任意类型
)
def update(self, new_data: Dict[str, Any]) -> None:
"""从字典递归更新配置""" """从字典递归更新配置"""
for key, value in new_data.items(): for key, value in new_data.items():
if hasattr(self, key): if hasattr(self, key):
@ -24,20 +29,105 @@ class BaseConfig:
if isinstance(current_attr, BaseConfig) and isinstance(value, dict): if isinstance(current_attr, BaseConfig) and isinstance(value, dict):
# 递归更新嵌套的配置对象 # 递归更新嵌套的配置对象
current_attr.update(value) current_attr.update(value)
# 新增处理ReferItem列表的特殊情况
elif key == 'refer_list' and isinstance(value, list):
# 显式地从字典创建ReferItem对象
setattr(self, key, [ReferItem(**item) for item in value])
else: else:
# 否则,直接设置值 # 否则,直接设置值
setattr(self, key, value) setattr(self, key, value)
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
"""将配置转换为字典""" """将配置转换为字典"""
return asdict(self) return self.model_dump()
class ReferItem(BaseConfig):
"""单个Refer项"""
path: str = ""
sampling_rate: float = 1.0
step: str = "" # 可选值: "topic", "content", "judge",表示在哪个阶段使用
class ReferConfig(BaseConfig):
"""Refer配置现在是一个列表"""
refer_list: List[ReferItem] = Field(default_factory=list)
def update(self, new_data: Dict[str, Any]) -> None:
"""特殊处理refer_list的更新"""
if 'refer_list' in new_data and isinstance(new_data['refer_list'], list):
# 将字典列表转换为ReferItem对象列表
new_list = []
for item in new_data['refer_list']:
if isinstance(item, dict):
new_list.append(ReferItem(**item))
elif isinstance(item, ReferItem):
new_list.append(item)
self.refer_list = new_list
# 移除已处理的refer_list避免重复处理
new_data_copy = new_data.copy()
new_data_copy.pop('refer_list')
super().update(new_data_copy)
else:
super().update(new_data)
class PathConfig(BaseConfig):
"""单个路径配置"""
path: str = ""
class PathListConfig(BaseConfig):
"""路径列表配置"""
paths: List[str] = Field(default_factory=list)
class SamplingPathListConfig(BaseConfig):
"""带采样率的路径列表配置"""
sampling_rate: float = 1.0
paths: List[str] = Field(default_factory=list)
class OutputConfig(BaseConfig):
"""输出配置"""
base_dir: str = "result"
image_dir: str = "images"
topic_dir: str = "topics"
content_dir: str = "contents"
class ResourceConfig(BaseConfig):
"""资源配置"""
resource_dirs: List[str] = Field(default_factory=list)
style: PathListConfig = Field(default_factory=PathListConfig)
demand: PathListConfig = Field(default_factory=PathListConfig)
object: PathListConfig = Field(default_factory=PathListConfig)
product: PathListConfig = Field(default_factory=PathListConfig)
refer: ReferConfig = Field(default_factory=ReferConfig)
image: PathListConfig = Field(default_factory=PathListConfig)
output_dir: OutputConfig = Field(default_factory=OutputConfig)
class SystemConfig(BaseConfig):
"""系统配置"""
debug: bool = False
log_level: str = "INFO"
parallel_processing: bool = True
max_workers: int = 4
class TopicConfig(BaseConfig):
"""选题配置"""
date: str = ""
num: int = 5
variants: int = 1
class GenerateTopicConfig(BaseConfig):
"""主题生成配置"""
topic_system_prompt: str = "resource/prompt/generateTopics/system.txt"
topic_user_prompt: str = "resource/prompt/generateTopics/user.txt"
model: Dict[str, Any] = Field(default_factory=dict)
topic: TopicConfig = Field(default_factory=TopicConfig)
@dataclass
class GenerateContentConfig(BaseConfig): class GenerateContentConfig(BaseConfig):
"""内容生成配置""" """内容生成配置"""
content_system_prompt: str = "resource/prompt/generateContent/contentSystem.txt" content_system_prompt: str = "resource/prompt/generateContent/contentSystem.txt"
@ -46,9 +136,10 @@ class GenerateContentConfig(BaseConfig):
judger_user_prompt: str = "resource/prompt/judgeContent/user.txt" judger_user_prompt: str = "resource/prompt/judgeContent/user.txt"
enable_content_judge: bool = True enable_content_judge: bool = True
refer_sampling_rate: float = 1.0 refer_sampling_rate: float = 1.0
model: Dict[str, Any] = Field(default_factory=dict)
judger_model: Dict[str, Any] = Field(default_factory=dict)
@dataclass
class AIModelConfig(BaseConfig): class AIModelConfig(BaseConfig):
"""AI模型配置""" """AI模型配置"""
model: str = "qwq-plus" model: str = "qwq-plus"
@ -63,99 +154,19 @@ class AIModelConfig(BaseConfig):
topic_user_prompt: str = "resource/prompt/generateTopics/user.txt" topic_user_prompt: str = "resource/prompt/generateTopics/user.txt"
refer_sampling_rate: float = Field(0.5, ge=0.0, le=1.0) refer_sampling_rate: float = Field(0.5, ge=0.0, le=1.0)
class Config:
pass
@dataclass
class PosterConfig(BaseConfig): class PosterConfig(BaseConfig):
"""海报生成配置""" """海报生成配置"""
target_size: List[int] = field(default_factory=lambda: [900, 1200]) target_size: List[int] = Field(default_factory=lambda: [900, 1200])
additional_images_enabled: bool = True additional_images_enabled: bool = True
template_selection: str = "random" # random, business, vibrant, original template_selection: str = "random" # random, business, vibrant, original
available_templates: List[str] = field(default_factory=lambda: ["original", "business", "vibrant"]) available_templates: List[str] = Field(default_factory=lambda: ["original", "business", "vibrant"])
@dataclass
class ContentConfig(BaseConfig): class ContentConfig(BaseConfig):
"""内容生成配置""" """内容生成配置"""
enable_content_judge: bool = True enable_content_judge: bool = True
num: int = 5 num: int = 5
variants_per_topic: int = 1 variants_per_topic: int = 1
max_title_length: int = 30 max_title_length: int = 30
max_content_length: int = 500 max_content_length: int = 500
@dataclass
class PathConfig(BaseConfig):
"""单个路径配置"""
path: str = ""
@dataclass
class PathListConfig(BaseConfig):
"""路径列表配置"""
paths: List[str] = field(default_factory=list)
@dataclass
class SamplingPathListConfig(BaseConfig):
"""带采样率的路径列表配置"""
sampling_rate: float = 1.0
paths: List[str] = field(default_factory=list)
@dataclass
class ReferItem:
"""单个Refer项"""
path: str = ""
sampling_rate: float = 1.0
step: str = "" # 可选值: "topic", "content", "judge",表示在哪个阶段使用
@dataclass
class ReferConfig(BaseConfig):
"""Refer配置现在是一个列表"""
refer_list: List[ReferItem] = field(default_factory=list)
@dataclass
class OutputConfig(BaseConfig):
"""输出配置"""
base_dir: str = "result"
image_dir: str = "images"
topic_dir: str = "topics"
content_dir: str = "contents"
@dataclass
class ResourceConfig(BaseConfig):
"""资源配置"""
resource_dirs: List[str] = field(default_factory=list)
style: PathListConfig = field(default_factory=PathListConfig)
demand: PathListConfig = field(default_factory=PathListConfig)
object: PathListConfig = field(default_factory=PathListConfig)
product: PathListConfig = field(default_factory=PathListConfig)
refer: ReferConfig = field(default_factory=ReferConfig)
image: PathListConfig = field(default_factory=PathListConfig)
output_dir: OutputConfig = field(default_factory=OutputConfig)
@dataclass
class SystemConfig(BaseConfig):
"""系统配置"""
debug: bool = False
log_level: str = "INFO"
parallel_processing: bool = True
max_workers: int = 4
@dataclass
class TopicConfig(BaseConfig):
"""选题配置"""
date: str = ""
num: int = 5
variants: int = 1
@dataclass
class GenerateTopicConfig(BaseConfig):
"""主题生成配置"""
topic_system_prompt: str = "resource/prompt/generateTopics/system.txt"
topic_user_prompt: str = "resource/prompt/generateTopics/user.txt"
model: Dict[str, Any] = field(default_factory=dict)
topic: TopicConfig = field(default_factory=TopicConfig)

146
tests/test_config.py Normal file
View File

@ -0,0 +1,146 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
测试配置管理器和pydantic模型的兼容性
"""
import os
import json
import tempfile
from pathlib import Path
from core.config.manager import ConfigManager
from core.config.models import (
BaseConfig, AIModelConfig, SystemConfig, GenerateTopicConfig, ResourceConfig,
GenerateContentConfig, PosterConfig, ContentConfig, ReferItem, ReferConfig
)
def test_config_creation():
"""测试创建配置实例"""
print("测试创建配置实例...")
# 创建各种配置实例
ai_config = AIModelConfig(model="test-model", api_key="test-key")
print(f"AI配置: {ai_config.model}, {ai_config.api_key}")
# 测试嵌套配置
resource_config = ResourceConfig(
resource_dirs=["resource"],
refer=ReferConfig(
refer_list=[
ReferItem(path="path1", sampling_rate=0.5, step="topic"),
ReferItem(path="path2", sampling_rate=0.8, step="content")
]
)
)
# 验证嵌套配置
print(f"资源目录: {resource_config.resource_dirs}")
print(f"Refer项数量: {len(resource_config.refer.refer_list)}")
print(f"第一个Refer项: {resource_config.refer.refer_list[0].path}, {resource_config.refer.refer_list[0].sampling_rate}")
# 测试序列化
config_dict = resource_config.to_dict()
print(f"序列化后的字典: {json.dumps(config_dict, indent=2)}")
return True
def test_config_manager():
"""测试配置管理器"""
print("\n测试配置管理器...")
# 创建临时目录
with tempfile.TemporaryDirectory() as temp_dir:
# 创建测试配置文件
ai_config_path = Path(temp_dir) / "ai_model.json"
ai_config = {
"model": "test-model",
"api_key": "test-key",
"temperature": 0.8
}
with open(ai_config_path, "w") as f:
json.dump(ai_config, f)
# 创建嵌套配置文件
resource_config_path = Path(temp_dir) / "resource.json"
resource_config = {
"resource_dirs": ["test_resource"],
"refer": {
"refer_list": [
{"path": "test_path", "sampling_rate": 0.7, "step": "judge"}
]
}
}
with open(resource_config_path, "w") as f:
json.dump(resource_config, f)
# 初始化配置管理器
config_manager = ConfigManager()
config_manager.load_from_directory(temp_dir)
# 获取并验证AI配置
ai_model_config = config_manager.get_config("ai_model", AIModelConfig)
print(f"加载的AI配置: {ai_model_config.model}, {ai_model_config.api_key}, {ai_model_config.temperature}")
assert ai_model_config.model == "test-model"
assert ai_model_config.api_key == "test-key"
assert ai_model_config.temperature == 0.8
# 获取并验证资源配置
resource_config = config_manager.get_config("resource", ResourceConfig)
print(f"加载的资源配置: {resource_config.resource_dirs}")
print(f"加载的Refer项: {resource_config.refer.refer_list[0].path}, {resource_config.refer.refer_list[0].sampling_rate}")
assert resource_config.resource_dirs == ["test_resource"]
assert len(resource_config.refer.refer_list) == 1
assert resource_config.refer.refer_list[0].path == "test_path"
assert resource_config.refer.refer_list[0].sampling_rate == 0.7
# 测试更新配置
ai_model_config.update({"model": "updated-model"})
print(f"更新后的AI配置: {ai_model_config.model}")
assert ai_model_config.model == "updated-model"
# 测试保存配置
config_manager.save_config("ai_model")
# 重新加载并验证
with open(ai_config_path, "r") as f:
saved_config = json.load(f)
print(f"保存的配置: {saved_config['model']}")
assert saved_config["model"] == "updated-model"
return True
def test_config_type_conversion():
"""测试配置类型转换"""
print("\n测试配置类型转换...")
# 创建配置管理器
config_manager = ConfigManager()
# 注册一个SystemConfig
config_manager.register_config("test_config", SystemConfig)
# 尝试以不同类型获取
system_config = config_manager.get_config("test_config", SystemConfig)
print(f"获取为SystemConfig: {type(system_config).__name__}")
# 尝试转换类型
try:
content_config = config_manager.get_config("test_config", ContentConfig)
print(f"成功转换为ContentConfig: {type(content_config).__name__}")
except TypeError as e:
print(f"类型转换失败,符合预期: {e}")
return True
if __name__ == "__main__":
print("开始测试pydantic配置模型和ConfigManager...")
test_config_creation()
test_config_manager()
test_config_type_conversion()
print("\n所有测试完成!")