#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 配置模型 定义了项目中所有配置的数据类模型 使用pydantic进行数据验证和序列化 """ from typing import Dict, Any, List, Optional, Type, TypeVar, Union, ClassVar from pydantic import BaseModel, Field, model_validator, ConfigDict T = TypeVar('T', bound='BaseConfig') class BaseConfig(BaseModel): """可从字典更新的配置基类""" model_config = ConfigDict( extra="allow", # 允许额外字段 arbitrary_types_allowed=True # 允许任意类型 ) def update(self, new_data: Dict[str, Any]) -> None: """从字典递归更新配置""" for key, value in new_data.items(): if hasattr(self, key): current_attr = getattr(self, key) # 检查当前属性是否为配置基类实例,并且新值是字典 if isinstance(current_attr, BaseConfig) and isinstance(value, dict): # 递归更新嵌套的配置对象 current_attr.update(value) else: # 否则,直接设置值 setattr(self, key, value) def to_dict(self) -> Dict[str, Any]: """将配置转换为字典""" return self.model_dump() class ReferItem(BaseConfig): """单个Refer项""" path: str = "" sampling_rate: float = 1.0 step: str = "" # 可选值: "topic", "content", "judge",表示在哪个阶段使用 class ReferConfig(BaseConfig): """Refer配置,现在是一个列表""" refer_list: List[ReferItem] = Field(default_factory=list) def update(self, new_data: Dict[str, Any]) -> None: """特殊处理refer_list的更新""" if 'refer_list' in new_data and isinstance(new_data['refer_list'], list): # 将字典列表转换为ReferItem对象列表 new_list = [] for item in new_data['refer_list']: if isinstance(item, dict): new_list.append(ReferItem(**item)) elif isinstance(item, ReferItem): new_list.append(item) self.refer_list = new_list # 移除已处理的refer_list,避免重复处理 new_data_copy = new_data.copy() new_data_copy.pop('refer_list') super().update(new_data_copy) else: super().update(new_data) class PathConfig(BaseConfig): """单个路径配置""" path: str = "" class PathListConfig(BaseConfig): """路径列表配置""" paths: List[str] = Field(default_factory=list) class SamplingPathListConfig(BaseConfig): """带采样率的路径列表配置""" sampling_rate: float = 1.0 paths: List[str] = Field(default_factory=list) class OutputConfig(BaseConfig): """输出配置""" base_dir: str = "result" image_dir: str = "images" topic_dir: str = "topics" content_dir: str = "contents" class ResourceConfig(BaseConfig): """资源配置""" resource_dirs: List[str] = Field(default_factory=list) style: PathListConfig = Field(default_factory=PathListConfig) demand: PathListConfig = Field(default_factory=PathListConfig) object: PathListConfig = Field(default_factory=PathListConfig) product: PathListConfig = Field(default_factory=PathListConfig) refer: ReferConfig = Field(default_factory=ReferConfig) image: PathListConfig = Field(default_factory=PathListConfig) output_dir: OutputConfig = Field(default_factory=OutputConfig) class SystemConfig(BaseConfig): """系统配置""" debug: bool = False log_level: str = "INFO" parallel_processing: bool = True max_workers: int = 4 class TopicConfig(BaseConfig): """选题配置""" date: str = "" num: int = 5 variants: int = 1 class GenerateTopicConfig(BaseConfig): """主题生成配置""" topic_system_prompt: str = "resource/prompt/generateTopics/system.txt" topic_user_prompt: str = "resource/prompt/generateTopics/user.txt" model: Dict[str, Any] = Field(default_factory=dict) topic: TopicConfig = Field(default_factory=TopicConfig) class GenerateContentConfig(BaseConfig): """内容生成配置""" content_system_prompt: str = "resource/prompt/generateContent/system.txt" content_user_prompt: str = "resource/prompt/generateContent/user.txt" judger_system_prompt: str = "resource/prompt/judgeContent/system.txt" judger_user_prompt: str = "resource/prompt/judgeContent/user.txt" enable_content_judge: bool = True refer_sampling_rate: float = 1.0 model: Dict[str, Any] = Field(default_factory=dict) judger_model: Dict[str, Any] = Field(default_factory=dict) class AIModelConfig(BaseConfig): """AI模型配置""" model: str = "qwq-plus" api_url: str = "" api_key: str = "" temperature: float = 0.7 top_p: float = 0.5 presence_penalty: float = 1.2 timeout: int = 60 max_retries: int = 3 topic_system_prompt: str = "resource/prompt/generateTopics/system.txt" topic_user_prompt: str = "resource/prompt/generateTopics/user.txt" refer_sampling_rate: float = Field(0.5, ge=0.0, le=1.0) class PosterConfig(BaseConfig): """海报生成配置""" target_size: List[int] = Field(default_factory=lambda: [900, 1200]) additional_images_enabled: bool = True template_selection: str = "random" # random, business, vibrant, original available_templates: List[str] = Field(default_factory=lambda: ["original", "business", "vibrant"]) poster_system_prompt: str = "resource/prompt/generatePoster/system.txt" poster_user_prompt: str = "resource/prompt/generatePoster/user.txt" class ContentConfig(BaseConfig): """内容生成配置""" enable_content_judge: bool = True num: int = 5 variants_per_topic: int = 1 max_title_length: int = 30 max_content_length: int = 500