迁移到了pydantic

This commit is contained in:
jinye_huang 2025-07-10 16:31:32 +08:00
parent 157d3348a6
commit c310c5069f
5 changed files with 278 additions and 107 deletions

View File

@ -9,7 +9,7 @@ import json
import os
import logging
from pathlib import Path
from typing import Dict, Type, TypeVar, Optional
from typing import Dict, Type, TypeVar, Optional, Any, cast
from core.config.models import (
BaseConfig, AIModelConfig, SystemConfig, GenerateTopicConfig, ResourceConfig,
@ -66,7 +66,7 @@ class ConfigManager:
self.register_config('topic_gen', GenerateTopicConfig)
self.register_config('content_gen', GenerateContentConfig)
def register_config(self, name: str, config_class: Type[T]):
def register_config(self, name: str, config_class: Type[T]) -> None:
"""
注册一个配置类
@ -91,14 +91,27 @@ class ConfigManager:
配置实例
"""
config = self._configs.get(name)
if not isinstance(config, config_class):
if config is None:
# 如果配置不存在,先注册一个默认实例
if config is None:
self.register_config(name, config_class)
config = self._configs.get(name)
else:
raise TypeError(f"Configuration '{name}' is not of type '{config_class.__name__}'")
return config
self.register_config(name, config_class)
config = self._configs.get(name)
# 确保配置是正确的类型
if not isinstance(config, config_class):
# 尝试转换配置
try:
if isinstance(config, BaseConfig):
# 将现有配置转换为请求的类型
new_config = config_class(**config.model_dump())
self._configs[name] = new_config
config = new_config
else:
raise TypeError(f"Configuration '{name}' is not of type '{config_class.__name__}'")
except Exception as e:
logger.error(f"转换配置 '{name}' 到类型 '{config_class.__name__}' 失败: {e}")
raise TypeError(f"Configuration '{name}' is not of type '{config_class.__name__}'") from e
return cast(T, config)
def _load_all_configs_from_dir(self):
"""动态加载目录中的所有.json文件"""
@ -176,7 +189,8 @@ class ConfigManager:
raise ValueError("配置目录未设置,无法保存文件")
path = self.config_dir / f"{name}.json"
config_data = self.get_config(name, BaseConfig).to_dict()
config = self.get_config(name, BaseConfig)
config_data = config.to_dict()
try:
with open(path, 'w', encoding='utf-8') as f:

View File

@ -4,18 +4,23 @@
"""
配置模型
定义了项目中所有配置的数据类模型
使用pydantic进行数据验证和序列化
"""
from dataclasses import dataclass, field, fields, asdict
from typing import Dict, Any, List, Optional
from pydantic import BaseModel, Field
from typing import Dict, Any, List, Optional, Type, TypeVar, Union, ClassVar
from pydantic import BaseModel, Field, model_validator, ConfigDict
T = TypeVar('T', bound='BaseConfig')
# 基础配置类,提供通用方法
class BaseConfig:
class BaseConfig(BaseModel):
"""可从字典更新的配置基类"""
def update(self, new_data: Dict[str, Any]):
model_config = ConfigDict(
extra="allow", # 允许额外字段
arbitrary_types_allowed=True # 允许任意类型
)
def update(self, new_data: Dict[str, Any]) -> None:
"""从字典递归更新配置"""
for key, value in new_data.items():
if hasattr(self, key):
@ -24,20 +29,105 @@ class BaseConfig:
if isinstance(current_attr, BaseConfig) and isinstance(value, dict):
# 递归更新嵌套的配置对象
current_attr.update(value)
# 新增处理ReferItem列表的特殊情况
elif key == 'refer_list' and isinstance(value, list):
# 显式地从字典创建ReferItem对象
setattr(self, key, [ReferItem(**item) for item in value])
else:
# 否则,直接设置值
setattr(self, key, value)
def to_dict(self) -> Dict[str, Any]:
"""将配置转换为字典"""
return asdict(self)
return self.model_dump()
class ReferItem(BaseConfig):
"""单个Refer项"""
path: str = ""
sampling_rate: float = 1.0
step: str = "" # 可选值: "topic", "content", "judge",表示在哪个阶段使用
class ReferConfig(BaseConfig):
"""Refer配置现在是一个列表"""
refer_list: List[ReferItem] = Field(default_factory=list)
def update(self, new_data: Dict[str, Any]) -> None:
"""特殊处理refer_list的更新"""
if 'refer_list' in new_data and isinstance(new_data['refer_list'], list):
# 将字典列表转换为ReferItem对象列表
new_list = []
for item in new_data['refer_list']:
if isinstance(item, dict):
new_list.append(ReferItem(**item))
elif isinstance(item, ReferItem):
new_list.append(item)
self.refer_list = new_list
# 移除已处理的refer_list避免重复处理
new_data_copy = new_data.copy()
new_data_copy.pop('refer_list')
super().update(new_data_copy)
else:
super().update(new_data)
class PathConfig(BaseConfig):
"""单个路径配置"""
path: str = ""
class PathListConfig(BaseConfig):
"""路径列表配置"""
paths: List[str] = Field(default_factory=list)
class SamplingPathListConfig(BaseConfig):
"""带采样率的路径列表配置"""
sampling_rate: float = 1.0
paths: List[str] = Field(default_factory=list)
class OutputConfig(BaseConfig):
"""输出配置"""
base_dir: str = "result"
image_dir: str = "images"
topic_dir: str = "topics"
content_dir: str = "contents"
class ResourceConfig(BaseConfig):
"""资源配置"""
resource_dirs: List[str] = Field(default_factory=list)
style: PathListConfig = Field(default_factory=PathListConfig)
demand: PathListConfig = Field(default_factory=PathListConfig)
object: PathListConfig = Field(default_factory=PathListConfig)
product: PathListConfig = Field(default_factory=PathListConfig)
refer: ReferConfig = Field(default_factory=ReferConfig)
image: PathListConfig = Field(default_factory=PathListConfig)
output_dir: OutputConfig = Field(default_factory=OutputConfig)
class SystemConfig(BaseConfig):
"""系统配置"""
debug: bool = False
log_level: str = "INFO"
parallel_processing: bool = True
max_workers: int = 4
class TopicConfig(BaseConfig):
"""选题配置"""
date: str = ""
num: int = 5
variants: int = 1
class GenerateTopicConfig(BaseConfig):
"""主题生成配置"""
topic_system_prompt: str = "resource/prompt/generateTopics/system.txt"
topic_user_prompt: str = "resource/prompt/generateTopics/user.txt"
model: Dict[str, Any] = Field(default_factory=dict)
topic: TopicConfig = Field(default_factory=TopicConfig)
@dataclass
class GenerateContentConfig(BaseConfig):
"""内容生成配置"""
content_system_prompt: str = "resource/prompt/generateContent/contentSystem.txt"
@ -46,9 +136,10 @@ class GenerateContentConfig(BaseConfig):
judger_user_prompt: str = "resource/prompt/judgeContent/user.txt"
enable_content_judge: bool = True
refer_sampling_rate: float = 1.0
model: Dict[str, Any] = Field(default_factory=dict)
judger_model: Dict[str, Any] = Field(default_factory=dict)
@dataclass
class AIModelConfig(BaseConfig):
"""AI模型配置"""
model: str = "qwq-plus"
@ -63,99 +154,19 @@ class AIModelConfig(BaseConfig):
topic_user_prompt: str = "resource/prompt/generateTopics/user.txt"
refer_sampling_rate: float = Field(0.5, ge=0.0, le=1.0)
class Config:
pass
@dataclass
class PosterConfig(BaseConfig):
"""海报生成配置"""
target_size: List[int] = field(default_factory=lambda: [900, 1200])
target_size: List[int] = Field(default_factory=lambda: [900, 1200])
additional_images_enabled: bool = True
template_selection: str = "random" # random, business, vibrant, original
available_templates: List[str] = field(default_factory=lambda: ["original", "business", "vibrant"])
available_templates: List[str] = Field(default_factory=lambda: ["original", "business", "vibrant"])
@dataclass
class ContentConfig(BaseConfig):
"""内容生成配置"""
enable_content_judge: bool = True
num: int = 5
variants_per_topic: int = 1
max_title_length: int = 30
max_content_length: int = 500
@dataclass
class PathConfig(BaseConfig):
"""单个路径配置"""
path: str = ""
@dataclass
class PathListConfig(BaseConfig):
"""路径列表配置"""
paths: List[str] = field(default_factory=list)
@dataclass
class SamplingPathListConfig(BaseConfig):
"""带采样率的路径列表配置"""
sampling_rate: float = 1.0
paths: List[str] = field(default_factory=list)
@dataclass
class ReferItem:
"""单个Refer项"""
path: str = ""
sampling_rate: float = 1.0
step: str = "" # 可选值: "topic", "content", "judge",表示在哪个阶段使用
@dataclass
class ReferConfig(BaseConfig):
"""Refer配置现在是一个列表"""
refer_list: List[ReferItem] = field(default_factory=list)
@dataclass
class OutputConfig(BaseConfig):
"""输出配置"""
base_dir: str = "result"
image_dir: str = "images"
topic_dir: str = "topics"
content_dir: str = "contents"
@dataclass
class ResourceConfig(BaseConfig):
"""资源配置"""
resource_dirs: List[str] = field(default_factory=list)
style: PathListConfig = field(default_factory=PathListConfig)
demand: PathListConfig = field(default_factory=PathListConfig)
object: PathListConfig = field(default_factory=PathListConfig)
product: PathListConfig = field(default_factory=PathListConfig)
refer: ReferConfig = field(default_factory=ReferConfig)
image: PathListConfig = field(default_factory=PathListConfig)
output_dir: OutputConfig = field(default_factory=OutputConfig)
@dataclass
class SystemConfig(BaseConfig):
"""系统配置"""
debug: bool = False
log_level: str = "INFO"
parallel_processing: bool = True
max_workers: int = 4
@dataclass
class TopicConfig(BaseConfig):
"""选题配置"""
date: str = ""
num: int = 5
variants: int = 1
@dataclass
class GenerateTopicConfig(BaseConfig):
"""主题生成配置"""
topic_system_prompt: str = "resource/prompt/generateTopics/system.txt"
topic_user_prompt: str = "resource/prompt/generateTopics/user.txt"
model: Dict[str, Any] = field(default_factory=dict)
topic: TopicConfig = field(default_factory=TopicConfig)
max_content_length: int = 500

146
tests/test_config.py Normal file
View File

@ -0,0 +1,146 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
测试配置管理器和pydantic模型的兼容性
"""
import os
import json
import tempfile
from pathlib import Path
from core.config.manager import ConfigManager
from core.config.models import (
BaseConfig, AIModelConfig, SystemConfig, GenerateTopicConfig, ResourceConfig,
GenerateContentConfig, PosterConfig, ContentConfig, ReferItem, ReferConfig
)
def test_config_creation():
"""测试创建配置实例"""
print("测试创建配置实例...")
# 创建各种配置实例
ai_config = AIModelConfig(model="test-model", api_key="test-key")
print(f"AI配置: {ai_config.model}, {ai_config.api_key}")
# 测试嵌套配置
resource_config = ResourceConfig(
resource_dirs=["resource"],
refer=ReferConfig(
refer_list=[
ReferItem(path="path1", sampling_rate=0.5, step="topic"),
ReferItem(path="path2", sampling_rate=0.8, step="content")
]
)
)
# 验证嵌套配置
print(f"资源目录: {resource_config.resource_dirs}")
print(f"Refer项数量: {len(resource_config.refer.refer_list)}")
print(f"第一个Refer项: {resource_config.refer.refer_list[0].path}, {resource_config.refer.refer_list[0].sampling_rate}")
# 测试序列化
config_dict = resource_config.to_dict()
print(f"序列化后的字典: {json.dumps(config_dict, indent=2)}")
return True
def test_config_manager():
"""测试配置管理器"""
print("\n测试配置管理器...")
# 创建临时目录
with tempfile.TemporaryDirectory() as temp_dir:
# 创建测试配置文件
ai_config_path = Path(temp_dir) / "ai_model.json"
ai_config = {
"model": "test-model",
"api_key": "test-key",
"temperature": 0.8
}
with open(ai_config_path, "w") as f:
json.dump(ai_config, f)
# 创建嵌套配置文件
resource_config_path = Path(temp_dir) / "resource.json"
resource_config = {
"resource_dirs": ["test_resource"],
"refer": {
"refer_list": [
{"path": "test_path", "sampling_rate": 0.7, "step": "judge"}
]
}
}
with open(resource_config_path, "w") as f:
json.dump(resource_config, f)
# 初始化配置管理器
config_manager = ConfigManager()
config_manager.load_from_directory(temp_dir)
# 获取并验证AI配置
ai_model_config = config_manager.get_config("ai_model", AIModelConfig)
print(f"加载的AI配置: {ai_model_config.model}, {ai_model_config.api_key}, {ai_model_config.temperature}")
assert ai_model_config.model == "test-model"
assert ai_model_config.api_key == "test-key"
assert ai_model_config.temperature == 0.8
# 获取并验证资源配置
resource_config = config_manager.get_config("resource", ResourceConfig)
print(f"加载的资源配置: {resource_config.resource_dirs}")
print(f"加载的Refer项: {resource_config.refer.refer_list[0].path}, {resource_config.refer.refer_list[0].sampling_rate}")
assert resource_config.resource_dirs == ["test_resource"]
assert len(resource_config.refer.refer_list) == 1
assert resource_config.refer.refer_list[0].path == "test_path"
assert resource_config.refer.refer_list[0].sampling_rate == 0.7
# 测试更新配置
ai_model_config.update({"model": "updated-model"})
print(f"更新后的AI配置: {ai_model_config.model}")
assert ai_model_config.model == "updated-model"
# 测试保存配置
config_manager.save_config("ai_model")
# 重新加载并验证
with open(ai_config_path, "r") as f:
saved_config = json.load(f)
print(f"保存的配置: {saved_config['model']}")
assert saved_config["model"] == "updated-model"
return True
def test_config_type_conversion():
"""测试配置类型转换"""
print("\n测试配置类型转换...")
# 创建配置管理器
config_manager = ConfigManager()
# 注册一个SystemConfig
config_manager.register_config("test_config", SystemConfig)
# 尝试以不同类型获取
system_config = config_manager.get_config("test_config", SystemConfig)
print(f"获取为SystemConfig: {type(system_config).__name__}")
# 尝试转换类型
try:
content_config = config_manager.get_config("test_config", ContentConfig)
print(f"成功转换为ContentConfig: {type(content_config).__name__}")
except TypeError as e:
print(f"类型转换失败,符合预期: {e}")
return True
if __name__ == "__main__":
print("开始测试pydantic配置模型和ConfigManager...")
test_config_creation()
test_config_manager()
test_config_type_conversion()
print("\n所有测试完成!")