迁移到了pydantic
This commit is contained in:
parent
157d3348a6
commit
c310c5069f
Binary file not shown.
Binary file not shown.
@ -9,7 +9,7 @@ import json
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Type, TypeVar, Optional
|
||||
from typing import Dict, Type, TypeVar, Optional, Any, cast
|
||||
|
||||
from core.config.models import (
|
||||
BaseConfig, AIModelConfig, SystemConfig, GenerateTopicConfig, ResourceConfig,
|
||||
@ -66,7 +66,7 @@ class ConfigManager:
|
||||
self.register_config('topic_gen', GenerateTopicConfig)
|
||||
self.register_config('content_gen', GenerateContentConfig)
|
||||
|
||||
def register_config(self, name: str, config_class: Type[T]):
|
||||
def register_config(self, name: str, config_class: Type[T]) -> None:
|
||||
"""
|
||||
注册一个配置类
|
||||
|
||||
@ -91,14 +91,27 @@ class ConfigManager:
|
||||
配置实例
|
||||
"""
|
||||
config = self._configs.get(name)
|
||||
if not isinstance(config, config_class):
|
||||
if config is None:
|
||||
# 如果配置不存在,先注册一个默认实例
|
||||
if config is None:
|
||||
self.register_config(name, config_class)
|
||||
config = self._configs.get(name)
|
||||
else:
|
||||
raise TypeError(f"Configuration '{name}' is not of type '{config_class.__name__}'")
|
||||
return config
|
||||
self.register_config(name, config_class)
|
||||
config = self._configs.get(name)
|
||||
|
||||
# 确保配置是正确的类型
|
||||
if not isinstance(config, config_class):
|
||||
# 尝试转换配置
|
||||
try:
|
||||
if isinstance(config, BaseConfig):
|
||||
# 将现有配置转换为请求的类型
|
||||
new_config = config_class(**config.model_dump())
|
||||
self._configs[name] = new_config
|
||||
config = new_config
|
||||
else:
|
||||
raise TypeError(f"Configuration '{name}' is not of type '{config_class.__name__}'")
|
||||
except Exception as e:
|
||||
logger.error(f"转换配置 '{name}' 到类型 '{config_class.__name__}' 失败: {e}")
|
||||
raise TypeError(f"Configuration '{name}' is not of type '{config_class.__name__}'") from e
|
||||
|
||||
return cast(T, config)
|
||||
|
||||
def _load_all_configs_from_dir(self):
|
||||
"""动态加载目录中的所有.json文件"""
|
||||
@ -176,7 +189,8 @@ class ConfigManager:
|
||||
raise ValueError("配置目录未设置,无法保存文件")
|
||||
|
||||
path = self.config_dir / f"{name}.json"
|
||||
config_data = self.get_config(name, BaseConfig).to_dict()
|
||||
config = self.get_config(name, BaseConfig)
|
||||
config_data = config.to_dict()
|
||||
|
||||
try:
|
||||
with open(path, 'w', encoding='utf-8') as f:
|
||||
|
||||
@ -4,18 +4,23 @@
|
||||
"""
|
||||
配置模型
|
||||
定义了项目中所有配置的数据类模型
|
||||
使用pydantic进行数据验证和序列化
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field, fields, asdict
|
||||
from typing import Dict, Any, List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, Any, List, Optional, Type, TypeVar, Union, ClassVar
|
||||
from pydantic import BaseModel, Field, model_validator, ConfigDict
|
||||
|
||||
T = TypeVar('T', bound='BaseConfig')
|
||||
|
||||
# 基础配置类,提供通用方法
|
||||
class BaseConfig:
|
||||
class BaseConfig(BaseModel):
|
||||
"""可从字典更新的配置基类"""
|
||||
|
||||
def update(self, new_data: Dict[str, Any]):
|
||||
model_config = ConfigDict(
|
||||
extra="allow", # 允许额外字段
|
||||
arbitrary_types_allowed=True # 允许任意类型
|
||||
)
|
||||
|
||||
def update(self, new_data: Dict[str, Any]) -> None:
|
||||
"""从字典递归更新配置"""
|
||||
for key, value in new_data.items():
|
||||
if hasattr(self, key):
|
||||
@ -24,20 +29,105 @@ class BaseConfig:
|
||||
if isinstance(current_attr, BaseConfig) and isinstance(value, dict):
|
||||
# 递归更新嵌套的配置对象
|
||||
current_attr.update(value)
|
||||
# 新增:处理ReferItem列表的特殊情况
|
||||
elif key == 'refer_list' and isinstance(value, list):
|
||||
# 显式地从字典创建ReferItem对象
|
||||
setattr(self, key, [ReferItem(**item) for item in value])
|
||||
else:
|
||||
# 否则,直接设置值
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""将配置转换为字典"""
|
||||
return asdict(self)
|
||||
return self.model_dump()
|
||||
|
||||
|
||||
class ReferItem(BaseConfig):
|
||||
"""单个Refer项"""
|
||||
path: str = ""
|
||||
sampling_rate: float = 1.0
|
||||
step: str = "" # 可选值: "topic", "content", "judge",表示在哪个阶段使用
|
||||
|
||||
|
||||
class ReferConfig(BaseConfig):
|
||||
"""Refer配置,现在是一个列表"""
|
||||
refer_list: List[ReferItem] = Field(default_factory=list)
|
||||
|
||||
def update(self, new_data: Dict[str, Any]) -> None:
|
||||
"""特殊处理refer_list的更新"""
|
||||
if 'refer_list' in new_data and isinstance(new_data['refer_list'], list):
|
||||
# 将字典列表转换为ReferItem对象列表
|
||||
new_list = []
|
||||
for item in new_data['refer_list']:
|
||||
if isinstance(item, dict):
|
||||
new_list.append(ReferItem(**item))
|
||||
elif isinstance(item, ReferItem):
|
||||
new_list.append(item)
|
||||
self.refer_list = new_list
|
||||
|
||||
# 移除已处理的refer_list,避免重复处理
|
||||
new_data_copy = new_data.copy()
|
||||
new_data_copy.pop('refer_list')
|
||||
super().update(new_data_copy)
|
||||
else:
|
||||
super().update(new_data)
|
||||
|
||||
|
||||
class PathConfig(BaseConfig):
|
||||
"""单个路径配置"""
|
||||
path: str = ""
|
||||
|
||||
|
||||
class PathListConfig(BaseConfig):
|
||||
"""路径列表配置"""
|
||||
paths: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class SamplingPathListConfig(BaseConfig):
|
||||
"""带采样率的路径列表配置"""
|
||||
sampling_rate: float = 1.0
|
||||
paths: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class OutputConfig(BaseConfig):
|
||||
"""输出配置"""
|
||||
base_dir: str = "result"
|
||||
image_dir: str = "images"
|
||||
topic_dir: str = "topics"
|
||||
content_dir: str = "contents"
|
||||
|
||||
|
||||
class ResourceConfig(BaseConfig):
|
||||
"""资源配置"""
|
||||
resource_dirs: List[str] = Field(default_factory=list)
|
||||
style: PathListConfig = Field(default_factory=PathListConfig)
|
||||
demand: PathListConfig = Field(default_factory=PathListConfig)
|
||||
object: PathListConfig = Field(default_factory=PathListConfig)
|
||||
product: PathListConfig = Field(default_factory=PathListConfig)
|
||||
refer: ReferConfig = Field(default_factory=ReferConfig)
|
||||
image: PathListConfig = Field(default_factory=PathListConfig)
|
||||
output_dir: OutputConfig = Field(default_factory=OutputConfig)
|
||||
|
||||
|
||||
class SystemConfig(BaseConfig):
|
||||
"""系统配置"""
|
||||
debug: bool = False
|
||||
log_level: str = "INFO"
|
||||
parallel_processing: bool = True
|
||||
max_workers: int = 4
|
||||
|
||||
|
||||
class TopicConfig(BaseConfig):
|
||||
"""选题配置"""
|
||||
date: str = ""
|
||||
num: int = 5
|
||||
variants: int = 1
|
||||
|
||||
|
||||
class GenerateTopicConfig(BaseConfig):
|
||||
"""主题生成配置"""
|
||||
topic_system_prompt: str = "resource/prompt/generateTopics/system.txt"
|
||||
topic_user_prompt: str = "resource/prompt/generateTopics/user.txt"
|
||||
model: Dict[str, Any] = Field(default_factory=dict)
|
||||
topic: TopicConfig = Field(default_factory=TopicConfig)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerateContentConfig(BaseConfig):
|
||||
"""内容生成配置"""
|
||||
content_system_prompt: str = "resource/prompt/generateContent/contentSystem.txt"
|
||||
@ -46,9 +136,10 @@ class GenerateContentConfig(BaseConfig):
|
||||
judger_user_prompt: str = "resource/prompt/judgeContent/user.txt"
|
||||
enable_content_judge: bool = True
|
||||
refer_sampling_rate: float = 1.0
|
||||
model: Dict[str, Any] = Field(default_factory=dict)
|
||||
judger_model: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AIModelConfig(BaseConfig):
|
||||
"""AI模型配置"""
|
||||
model: str = "qwq-plus"
|
||||
@ -63,99 +154,19 @@ class AIModelConfig(BaseConfig):
|
||||
topic_user_prompt: str = "resource/prompt/generateTopics/user.txt"
|
||||
refer_sampling_rate: float = Field(0.5, ge=0.0, le=1.0)
|
||||
|
||||
class Config:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class PosterConfig(BaseConfig):
|
||||
"""海报生成配置"""
|
||||
target_size: List[int] = field(default_factory=lambda: [900, 1200])
|
||||
target_size: List[int] = Field(default_factory=lambda: [900, 1200])
|
||||
additional_images_enabled: bool = True
|
||||
template_selection: str = "random" # random, business, vibrant, original
|
||||
available_templates: List[str] = field(default_factory=lambda: ["original", "business", "vibrant"])
|
||||
available_templates: List[str] = Field(default_factory=lambda: ["original", "business", "vibrant"])
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentConfig(BaseConfig):
|
||||
"""内容生成配置"""
|
||||
enable_content_judge: bool = True
|
||||
num: int = 5
|
||||
variants_per_topic: int = 1
|
||||
max_title_length: int = 30
|
||||
max_content_length: int = 500
|
||||
|
||||
|
||||
@dataclass
|
||||
class PathConfig(BaseConfig):
|
||||
"""单个路径配置"""
|
||||
path: str = ""
|
||||
|
||||
@dataclass
|
||||
class PathListConfig(BaseConfig):
|
||||
"""路径列表配置"""
|
||||
paths: List[str] = field(default_factory=list)
|
||||
|
||||
@dataclass
|
||||
class SamplingPathListConfig(BaseConfig):
|
||||
"""带采样率的路径列表配置"""
|
||||
sampling_rate: float = 1.0
|
||||
paths: List[str] = field(default_factory=list)
|
||||
|
||||
@dataclass
|
||||
class ReferItem:
|
||||
"""单个Refer项"""
|
||||
path: str = ""
|
||||
sampling_rate: float = 1.0
|
||||
step: str = "" # 可选值: "topic", "content", "judge",表示在哪个阶段使用
|
||||
|
||||
@dataclass
|
||||
class ReferConfig(BaseConfig):
|
||||
"""Refer配置,现在是一个列表"""
|
||||
refer_list: List[ReferItem] = field(default_factory=list)
|
||||
|
||||
@dataclass
|
||||
class OutputConfig(BaseConfig):
|
||||
"""输出配置"""
|
||||
base_dir: str = "result"
|
||||
image_dir: str = "images"
|
||||
topic_dir: str = "topics"
|
||||
content_dir: str = "contents"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResourceConfig(BaseConfig):
|
||||
"""资源配置"""
|
||||
resource_dirs: List[str] = field(default_factory=list)
|
||||
style: PathListConfig = field(default_factory=PathListConfig)
|
||||
demand: PathListConfig = field(default_factory=PathListConfig)
|
||||
object: PathListConfig = field(default_factory=PathListConfig)
|
||||
product: PathListConfig = field(default_factory=PathListConfig)
|
||||
refer: ReferConfig = field(default_factory=ReferConfig)
|
||||
image: PathListConfig = field(default_factory=PathListConfig)
|
||||
output_dir: OutputConfig = field(default_factory=OutputConfig)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemConfig(BaseConfig):
|
||||
"""系统配置"""
|
||||
debug: bool = False
|
||||
log_level: str = "INFO"
|
||||
parallel_processing: bool = True
|
||||
max_workers: int = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
class TopicConfig(BaseConfig):
|
||||
"""选题配置"""
|
||||
date: str = ""
|
||||
num: int = 5
|
||||
variants: int = 1
|
||||
|
||||
@dataclass
|
||||
class GenerateTopicConfig(BaseConfig):
|
||||
"""主题生成配置"""
|
||||
topic_system_prompt: str = "resource/prompt/generateTopics/system.txt"
|
||||
topic_user_prompt: str = "resource/prompt/generateTopics/user.txt"
|
||||
model: Dict[str, Any] = field(default_factory=dict)
|
||||
topic: TopicConfig = field(default_factory=TopicConfig)
|
||||
max_content_length: int = 500
|
||||
146
tests/test_config.py
Normal file
146
tests/test_config.py
Normal file
@ -0,0 +1,146 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
测试配置管理器和pydantic模型的兼容性
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from core.config.manager import ConfigManager
|
||||
from core.config.models import (
|
||||
BaseConfig, AIModelConfig, SystemConfig, GenerateTopicConfig, ResourceConfig,
|
||||
GenerateContentConfig, PosterConfig, ContentConfig, ReferItem, ReferConfig
|
||||
)
|
||||
|
||||
def test_config_creation():
|
||||
"""测试创建配置实例"""
|
||||
print("测试创建配置实例...")
|
||||
|
||||
# 创建各种配置实例
|
||||
ai_config = AIModelConfig(model="test-model", api_key="test-key")
|
||||
print(f"AI配置: {ai_config.model}, {ai_config.api_key}")
|
||||
|
||||
# 测试嵌套配置
|
||||
resource_config = ResourceConfig(
|
||||
resource_dirs=["resource"],
|
||||
refer=ReferConfig(
|
||||
refer_list=[
|
||||
ReferItem(path="path1", sampling_rate=0.5, step="topic"),
|
||||
ReferItem(path="path2", sampling_rate=0.8, step="content")
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# 验证嵌套配置
|
||||
print(f"资源目录: {resource_config.resource_dirs}")
|
||||
print(f"Refer项数量: {len(resource_config.refer.refer_list)}")
|
||||
print(f"第一个Refer项: {resource_config.refer.refer_list[0].path}, {resource_config.refer.refer_list[0].sampling_rate}")
|
||||
|
||||
# 测试序列化
|
||||
config_dict = resource_config.to_dict()
|
||||
print(f"序列化后的字典: {json.dumps(config_dict, indent=2)}")
|
||||
|
||||
return True
|
||||
|
||||
def test_config_manager():
|
||||
"""测试配置管理器"""
|
||||
print("\n测试配置管理器...")
|
||||
|
||||
# 创建临时目录
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# 创建测试配置文件
|
||||
ai_config_path = Path(temp_dir) / "ai_model.json"
|
||||
ai_config = {
|
||||
"model": "test-model",
|
||||
"api_key": "test-key",
|
||||
"temperature": 0.8
|
||||
}
|
||||
|
||||
with open(ai_config_path, "w") as f:
|
||||
json.dump(ai_config, f)
|
||||
|
||||
# 创建嵌套配置文件
|
||||
resource_config_path = Path(temp_dir) / "resource.json"
|
||||
resource_config = {
|
||||
"resource_dirs": ["test_resource"],
|
||||
"refer": {
|
||||
"refer_list": [
|
||||
{"path": "test_path", "sampling_rate": 0.7, "step": "judge"}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
with open(resource_config_path, "w") as f:
|
||||
json.dump(resource_config, f)
|
||||
|
||||
# 初始化配置管理器
|
||||
config_manager = ConfigManager()
|
||||
config_manager.load_from_directory(temp_dir)
|
||||
|
||||
# 获取并验证AI配置
|
||||
ai_model_config = config_manager.get_config("ai_model", AIModelConfig)
|
||||
print(f"加载的AI配置: {ai_model_config.model}, {ai_model_config.api_key}, {ai_model_config.temperature}")
|
||||
assert ai_model_config.model == "test-model"
|
||||
assert ai_model_config.api_key == "test-key"
|
||||
assert ai_model_config.temperature == 0.8
|
||||
|
||||
# 获取并验证资源配置
|
||||
resource_config = config_manager.get_config("resource", ResourceConfig)
|
||||
print(f"加载的资源配置: {resource_config.resource_dirs}")
|
||||
print(f"加载的Refer项: {resource_config.refer.refer_list[0].path}, {resource_config.refer.refer_list[0].sampling_rate}")
|
||||
assert resource_config.resource_dirs == ["test_resource"]
|
||||
assert len(resource_config.refer.refer_list) == 1
|
||||
assert resource_config.refer.refer_list[0].path == "test_path"
|
||||
assert resource_config.refer.refer_list[0].sampling_rate == 0.7
|
||||
|
||||
# 测试更新配置
|
||||
ai_model_config.update({"model": "updated-model"})
|
||||
print(f"更新后的AI配置: {ai_model_config.model}")
|
||||
assert ai_model_config.model == "updated-model"
|
||||
|
||||
# 测试保存配置
|
||||
config_manager.save_config("ai_model")
|
||||
|
||||
# 重新加载并验证
|
||||
with open(ai_config_path, "r") as f:
|
||||
saved_config = json.load(f)
|
||||
print(f"保存的配置: {saved_config['model']}")
|
||||
assert saved_config["model"] == "updated-model"
|
||||
|
||||
return True
|
||||
|
||||
def test_config_type_conversion():
|
||||
"""测试配置类型转换"""
|
||||
print("\n测试配置类型转换...")
|
||||
|
||||
# 创建配置管理器
|
||||
config_manager = ConfigManager()
|
||||
|
||||
# 注册一个SystemConfig
|
||||
config_manager.register_config("test_config", SystemConfig)
|
||||
|
||||
# 尝试以不同类型获取
|
||||
system_config = config_manager.get_config("test_config", SystemConfig)
|
||||
print(f"获取为SystemConfig: {type(system_config).__name__}")
|
||||
|
||||
# 尝试转换类型
|
||||
try:
|
||||
content_config = config_manager.get_config("test_config", ContentConfig)
|
||||
print(f"成功转换为ContentConfig: {type(content_config).__name__}")
|
||||
except TypeError as e:
|
||||
print(f"类型转换失败,符合预期: {e}")
|
||||
|
||||
return True
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("开始测试pydantic配置模型和ConfigManager...")
|
||||
|
||||
test_config_creation()
|
||||
test_config_manager()
|
||||
test_config_type_conversion()
|
||||
|
||||
print("\n所有测试完成!")
|
||||
Loading…
x
Reference in New Issue
Block a user