diff --git a/core/config/__pycache__/manager.cpython-312.pyc b/core/config/__pycache__/manager.cpython-312.pyc index 7e59efb..f76d679 100644 Binary files a/core/config/__pycache__/manager.cpython-312.pyc and b/core/config/__pycache__/manager.cpython-312.pyc differ diff --git a/core/config/__pycache__/models.cpython-312.pyc b/core/config/__pycache__/models.cpython-312.pyc index 9dd9014..6e86875 100644 Binary files a/core/config/__pycache__/models.cpython-312.pyc and b/core/config/__pycache__/models.cpython-312.pyc differ diff --git a/core/config/manager.py b/core/config/manager.py index 1b39204..f73a7e0 100644 --- a/core/config/manager.py +++ b/core/config/manager.py @@ -9,7 +9,7 @@ import json import os import logging from pathlib import Path -from typing import Dict, Type, TypeVar, Optional +from typing import Dict, Type, TypeVar, Optional, Any, cast from core.config.models import ( BaseConfig, AIModelConfig, SystemConfig, GenerateTopicConfig, ResourceConfig, @@ -66,7 +66,7 @@ class ConfigManager: self.register_config('topic_gen', GenerateTopicConfig) self.register_config('content_gen', GenerateContentConfig) - def register_config(self, name: str, config_class: Type[T]): + def register_config(self, name: str, config_class: Type[T]) -> None: """ 注册一个配置类 @@ -91,14 +91,27 @@ class ConfigManager: 配置实例 """ config = self._configs.get(name) - if not isinstance(config, config_class): + if config is None: # 如果配置不存在,先注册一个默认实例 - if config is None: - self.register_config(name, config_class) - config = self._configs.get(name) - else: - raise TypeError(f"Configuration '{name}' is not of type '{config_class.__name__}'") - return config + self.register_config(name, config_class) + config = self._configs.get(name) + + # 确保配置是正确的类型 + if not isinstance(config, config_class): + # 尝试转换配置 + try: + if isinstance(config, BaseConfig): + # 将现有配置转换为请求的类型 + new_config = config_class(**config.model_dump()) + self._configs[name] = new_config + config = new_config + else: + raise TypeError(f"Configuration '{name}' is not of type '{config_class.__name__}'") + except Exception as e: + logger.error(f"转换配置 '{name}' 到类型 '{config_class.__name__}' 失败: {e}") + raise TypeError(f"Configuration '{name}' is not of type '{config_class.__name__}'") from e + + return cast(T, config) def _load_all_configs_from_dir(self): """动态加载目录中的所有.json文件""" @@ -176,7 +189,8 @@ class ConfigManager: raise ValueError("配置目录未设置,无法保存文件") path = self.config_dir / f"{name}.json" - config_data = self.get_config(name, BaseConfig).to_dict() + config = self.get_config(name, BaseConfig) + config_data = config.to_dict() try: with open(path, 'w', encoding='utf-8') as f: diff --git a/core/config/models.py b/core/config/models.py index 5afc4f0..7e9c217 100644 --- a/core/config/models.py +++ b/core/config/models.py @@ -4,18 +4,23 @@ """ 配置模型 定义了项目中所有配置的数据类模型 +使用pydantic进行数据验证和序列化 """ -from dataclasses import dataclass, field, fields, asdict -from typing import Dict, Any, List, Optional -from pydantic import BaseModel, Field +from typing import Dict, Any, List, Optional, Type, TypeVar, Union, ClassVar +from pydantic import BaseModel, Field, model_validator, ConfigDict +T = TypeVar('T', bound='BaseConfig') -# 基础配置类,提供通用方法 -class BaseConfig: +class BaseConfig(BaseModel): """可从字典更新的配置基类""" - def update(self, new_data: Dict[str, Any]): + model_config = ConfigDict( + extra="allow", # 允许额外字段 + arbitrary_types_allowed=True # 允许任意类型 + ) + + def update(self, new_data: Dict[str, Any]) -> None: """从字典递归更新配置""" for key, value in new_data.items(): if hasattr(self, key): @@ -24,20 +29,105 @@ class BaseConfig: if isinstance(current_attr, BaseConfig) and isinstance(value, dict): # 递归更新嵌套的配置对象 current_attr.update(value) - # 新增:处理ReferItem列表的特殊情况 - elif key == 'refer_list' and isinstance(value, list): - # 显式地从字典创建ReferItem对象 - setattr(self, key, [ReferItem(**item) for item in value]) else: # 否则,直接设置值 setattr(self, key, value) - + def to_dict(self) -> Dict[str, Any]: """将配置转换为字典""" - return asdict(self) + return self.model_dump() + + +class ReferItem(BaseConfig): + """单个Refer项""" + path: str = "" + sampling_rate: float = 1.0 + step: str = "" # 可选值: "topic", "content", "judge",表示在哪个阶段使用 + + +class ReferConfig(BaseConfig): + """Refer配置,现在是一个列表""" + refer_list: List[ReferItem] = Field(default_factory=list) + + def update(self, new_data: Dict[str, Any]) -> None: + """特殊处理refer_list的更新""" + if 'refer_list' in new_data and isinstance(new_data['refer_list'], list): + # 将字典列表转换为ReferItem对象列表 + new_list = [] + for item in new_data['refer_list']: + if isinstance(item, dict): + new_list.append(ReferItem(**item)) + elif isinstance(item, ReferItem): + new_list.append(item) + self.refer_list = new_list + + # 移除已处理的refer_list,避免重复处理 + new_data_copy = new_data.copy() + new_data_copy.pop('refer_list') + super().update(new_data_copy) + else: + super().update(new_data) + + +class PathConfig(BaseConfig): + """单个路径配置""" + path: str = "" + + +class PathListConfig(BaseConfig): + """路径列表配置""" + paths: List[str] = Field(default_factory=list) + + +class SamplingPathListConfig(BaseConfig): + """带采样率的路径列表配置""" + sampling_rate: float = 1.0 + paths: List[str] = Field(default_factory=list) + + +class OutputConfig(BaseConfig): + """输出配置""" + base_dir: str = "result" + image_dir: str = "images" + topic_dir: str = "topics" + content_dir: str = "contents" + + +class ResourceConfig(BaseConfig): + """资源配置""" + resource_dirs: List[str] = Field(default_factory=list) + style: PathListConfig = Field(default_factory=PathListConfig) + demand: PathListConfig = Field(default_factory=PathListConfig) + object: PathListConfig = Field(default_factory=PathListConfig) + product: PathListConfig = Field(default_factory=PathListConfig) + refer: ReferConfig = Field(default_factory=ReferConfig) + image: PathListConfig = Field(default_factory=PathListConfig) + output_dir: OutputConfig = Field(default_factory=OutputConfig) + + +class SystemConfig(BaseConfig): + """系统配置""" + debug: bool = False + log_level: str = "INFO" + parallel_processing: bool = True + max_workers: int = 4 + + +class TopicConfig(BaseConfig): + """选题配置""" + date: str = "" + num: int = 5 + variants: int = 1 + + +class GenerateTopicConfig(BaseConfig): + """主题生成配置""" + topic_system_prompt: str = "resource/prompt/generateTopics/system.txt" + topic_user_prompt: str = "resource/prompt/generateTopics/user.txt" + model: Dict[str, Any] = Field(default_factory=dict) + topic: TopicConfig = Field(default_factory=TopicConfig) -@dataclass class GenerateContentConfig(BaseConfig): """内容生成配置""" content_system_prompt: str = "resource/prompt/generateContent/contentSystem.txt" @@ -46,9 +136,10 @@ class GenerateContentConfig(BaseConfig): judger_user_prompt: str = "resource/prompt/judgeContent/user.txt" enable_content_judge: bool = True refer_sampling_rate: float = 1.0 + model: Dict[str, Any] = Field(default_factory=dict) + judger_model: Dict[str, Any] = Field(default_factory=dict) -@dataclass class AIModelConfig(BaseConfig): """AI模型配置""" model: str = "qwq-plus" @@ -63,99 +154,19 @@ class AIModelConfig(BaseConfig): topic_user_prompt: str = "resource/prompt/generateTopics/user.txt" refer_sampling_rate: float = Field(0.5, ge=0.0, le=1.0) - class Config: - pass - -@dataclass class PosterConfig(BaseConfig): """海报生成配置""" - target_size: List[int] = field(default_factory=lambda: [900, 1200]) + target_size: List[int] = Field(default_factory=lambda: [900, 1200]) additional_images_enabled: bool = True template_selection: str = "random" # random, business, vibrant, original - available_templates: List[str] = field(default_factory=lambda: ["original", "business", "vibrant"]) + available_templates: List[str] = Field(default_factory=lambda: ["original", "business", "vibrant"]) -@dataclass class ContentConfig(BaseConfig): """内容生成配置""" enable_content_judge: bool = True num: int = 5 variants_per_topic: int = 1 max_title_length: int = 30 - max_content_length: int = 500 - - -@dataclass -class PathConfig(BaseConfig): - """单个路径配置""" - path: str = "" - -@dataclass -class PathListConfig(BaseConfig): - """路径列表配置""" - paths: List[str] = field(default_factory=list) - -@dataclass -class SamplingPathListConfig(BaseConfig): - """带采样率的路径列表配置""" - sampling_rate: float = 1.0 - paths: List[str] = field(default_factory=list) - -@dataclass -class ReferItem: - """单个Refer项""" - path: str = "" - sampling_rate: float = 1.0 - step: str = "" # 可选值: "topic", "content", "judge",表示在哪个阶段使用 - -@dataclass -class ReferConfig(BaseConfig): - """Refer配置,现在是一个列表""" - refer_list: List[ReferItem] = field(default_factory=list) - -@dataclass -class OutputConfig(BaseConfig): - """输出配置""" - base_dir: str = "result" - image_dir: str = "images" - topic_dir: str = "topics" - content_dir: str = "contents" - - -@dataclass -class ResourceConfig(BaseConfig): - """资源配置""" - resource_dirs: List[str] = field(default_factory=list) - style: PathListConfig = field(default_factory=PathListConfig) - demand: PathListConfig = field(default_factory=PathListConfig) - object: PathListConfig = field(default_factory=PathListConfig) - product: PathListConfig = field(default_factory=PathListConfig) - refer: ReferConfig = field(default_factory=ReferConfig) - image: PathListConfig = field(default_factory=PathListConfig) - output_dir: OutputConfig = field(default_factory=OutputConfig) - - -@dataclass -class SystemConfig(BaseConfig): - """系统配置""" - debug: bool = False - log_level: str = "INFO" - parallel_processing: bool = True - max_workers: int = 4 - - -@dataclass -class TopicConfig(BaseConfig): - """选题配置""" - date: str = "" - num: int = 5 - variants: int = 1 - -@dataclass -class GenerateTopicConfig(BaseConfig): - """主题生成配置""" - topic_system_prompt: str = "resource/prompt/generateTopics/system.txt" - topic_user_prompt: str = "resource/prompt/generateTopics/user.txt" - model: Dict[str, Any] = field(default_factory=dict) - topic: TopicConfig = field(default_factory=TopicConfig) \ No newline at end of file + max_content_length: int = 500 \ No newline at end of file diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..79f6c6e --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +测试配置管理器和pydantic模型的兼容性 +""" + +import os +import json +import tempfile +from pathlib import Path + +from core.config.manager import ConfigManager +from core.config.models import ( + BaseConfig, AIModelConfig, SystemConfig, GenerateTopicConfig, ResourceConfig, + GenerateContentConfig, PosterConfig, ContentConfig, ReferItem, ReferConfig +) + +def test_config_creation(): + """测试创建配置实例""" + print("测试创建配置实例...") + + # 创建各种配置实例 + ai_config = AIModelConfig(model="test-model", api_key="test-key") + print(f"AI配置: {ai_config.model}, {ai_config.api_key}") + + # 测试嵌套配置 + resource_config = ResourceConfig( + resource_dirs=["resource"], + refer=ReferConfig( + refer_list=[ + ReferItem(path="path1", sampling_rate=0.5, step="topic"), + ReferItem(path="path2", sampling_rate=0.8, step="content") + ] + ) + ) + + # 验证嵌套配置 + print(f"资源目录: {resource_config.resource_dirs}") + print(f"Refer项数量: {len(resource_config.refer.refer_list)}") + print(f"第一个Refer项: {resource_config.refer.refer_list[0].path}, {resource_config.refer.refer_list[0].sampling_rate}") + + # 测试序列化 + config_dict = resource_config.to_dict() + print(f"序列化后的字典: {json.dumps(config_dict, indent=2)}") + + return True + +def test_config_manager(): + """测试配置管理器""" + print("\n测试配置管理器...") + + # 创建临时目录 + with tempfile.TemporaryDirectory() as temp_dir: + # 创建测试配置文件 + ai_config_path = Path(temp_dir) / "ai_model.json" + ai_config = { + "model": "test-model", + "api_key": "test-key", + "temperature": 0.8 + } + + with open(ai_config_path, "w") as f: + json.dump(ai_config, f) + + # 创建嵌套配置文件 + resource_config_path = Path(temp_dir) / "resource.json" + resource_config = { + "resource_dirs": ["test_resource"], + "refer": { + "refer_list": [ + {"path": "test_path", "sampling_rate": 0.7, "step": "judge"} + ] + } + } + + with open(resource_config_path, "w") as f: + json.dump(resource_config, f) + + # 初始化配置管理器 + config_manager = ConfigManager() + config_manager.load_from_directory(temp_dir) + + # 获取并验证AI配置 + ai_model_config = config_manager.get_config("ai_model", AIModelConfig) + print(f"加载的AI配置: {ai_model_config.model}, {ai_model_config.api_key}, {ai_model_config.temperature}") + assert ai_model_config.model == "test-model" + assert ai_model_config.api_key == "test-key" + assert ai_model_config.temperature == 0.8 + + # 获取并验证资源配置 + resource_config = config_manager.get_config("resource", ResourceConfig) + print(f"加载的资源配置: {resource_config.resource_dirs}") + print(f"加载的Refer项: {resource_config.refer.refer_list[0].path}, {resource_config.refer.refer_list[0].sampling_rate}") + assert resource_config.resource_dirs == ["test_resource"] + assert len(resource_config.refer.refer_list) == 1 + assert resource_config.refer.refer_list[0].path == "test_path" + assert resource_config.refer.refer_list[0].sampling_rate == 0.7 + + # 测试更新配置 + ai_model_config.update({"model": "updated-model"}) + print(f"更新后的AI配置: {ai_model_config.model}") + assert ai_model_config.model == "updated-model" + + # 测试保存配置 + config_manager.save_config("ai_model") + + # 重新加载并验证 + with open(ai_config_path, "r") as f: + saved_config = json.load(f) + print(f"保存的配置: {saved_config['model']}") + assert saved_config["model"] == "updated-model" + + return True + +def test_config_type_conversion(): + """测试配置类型转换""" + print("\n测试配置类型转换...") + + # 创建配置管理器 + config_manager = ConfigManager() + + # 注册一个SystemConfig + config_manager.register_config("test_config", SystemConfig) + + # 尝试以不同类型获取 + system_config = config_manager.get_config("test_config", SystemConfig) + print(f"获取为SystemConfig: {type(system_config).__name__}") + + # 尝试转换类型 + try: + content_config = config_manager.get_config("test_config", ContentConfig) + print(f"成功转换为ContentConfig: {type(content_config).__name__}") + except TypeError as e: + print(f"类型转换失败,符合预期: {e}") + + return True + +if __name__ == "__main__": + print("开始测试pydantic配置模型和ConfigManager...") + + test_config_creation() + test_config_manager() + test_config_type_conversion() + + print("\n所有测试完成!") \ No newline at end of file