174 lines
5.8 KiB
Python
Raw Normal View History

2025-07-08 17:45:40 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
配置模型
定义了项目中所有配置的数据类模型
2025-07-10 16:31:32 +08:00
使用pydantic进行数据验证和序列化
2025-07-08 17:45:40 +08:00
"""
2025-07-10 16:31:32 +08:00
from typing import Dict, Any, List, Optional, Type, TypeVar, Union, ClassVar
from pydantic import BaseModel, Field, model_validator, ConfigDict
2025-07-08 17:45:40 +08:00
2025-07-10 16:31:32 +08:00
T = TypeVar('T', bound='BaseConfig')
2025-07-08 17:45:40 +08:00
2025-07-10 16:31:32 +08:00
class BaseConfig(BaseModel):
2025-07-08 17:45:40 +08:00
"""可从字典更新的配置基类"""
2025-07-10 16:31:32 +08:00
model_config = ConfigDict(
extra="allow", # 允许额外字段
arbitrary_types_allowed=True # 允许任意类型
)
def update(self, new_data: Dict[str, Any]) -> None:
"""从字典递归更新配置"""
for key, value in new_data.items():
if hasattr(self, key):
current_attr = getattr(self, key)
# 检查当前属性是否为配置基类实例,并且新值是字典
2025-07-08 17:45:40 +08:00
if isinstance(current_attr, BaseConfig) and isinstance(value, dict):
# 递归更新嵌套的配置对象
2025-07-08 17:45:40 +08:00
current_attr.update(value)
else:
# 否则,直接设置值
setattr(self, key, value)
2025-07-10 16:31:32 +08:00
2025-07-08 17:45:40 +08:00
def to_dict(self) -> Dict[str, Any]:
"""将配置转换为字典"""
2025-07-10 16:31:32 +08:00
return self.model_dump()
2025-07-08 17:45:40 +08:00
2025-07-10 16:31:32 +08:00
class ReferItem(BaseConfig):
"""单个Refer项"""
path: str = ""
sampling_rate: float = 1.0
step: str = "" # 可选值: "topic", "content", "judge",表示在哪个阶段使用
2025-07-08 17:45:40 +08:00
2025-07-10 16:31:32 +08:00
class ReferConfig(BaseConfig):
"""Refer配置现在是一个列表"""
refer_list: List[ReferItem] = Field(default_factory=list)
def update(self, new_data: Dict[str, Any]) -> None:
"""特殊处理refer_list的更新"""
if 'refer_list' in new_data and isinstance(new_data['refer_list'], list):
# 将字典列表转换为ReferItem对象列表
new_list = []
for item in new_data['refer_list']:
if isinstance(item, dict):
new_list.append(ReferItem(**item))
elif isinstance(item, ReferItem):
new_list.append(item)
self.refer_list = new_list
# 移除已处理的refer_list避免重复处理
new_data_copy = new_data.copy()
new_data_copy.pop('refer_list')
super().update(new_data_copy)
else:
super().update(new_data)
2025-07-08 17:45:40 +08:00
class PathConfig(BaseConfig):
"""单个路径配置"""
path: str = ""
2025-07-10 16:31:32 +08:00
class PathListConfig(BaseConfig):
"""路径列表配置"""
2025-07-10 16:31:32 +08:00
paths: List[str] = Field(default_factory=list)
class SamplingPathListConfig(BaseConfig):
"""带采样率的路径列表配置"""
sampling_rate: float = 1.0
2025-07-10 16:31:32 +08:00
paths: List[str] = Field(default_factory=list)
class OutputConfig(BaseConfig):
"""输出配置"""
base_dir: str = "result"
image_dir: str = "images"
topic_dir: str = "topics"
content_dir: str = "contents"
2025-07-08 17:45:40 +08:00
class ResourceConfig(BaseConfig):
"""资源配置"""
2025-07-10 16:31:32 +08:00
resource_dirs: List[str] = Field(default_factory=list)
style: PathListConfig = Field(default_factory=PathListConfig)
demand: PathListConfig = Field(default_factory=PathListConfig)
object: PathListConfig = Field(default_factory=PathListConfig)
product: PathListConfig = Field(default_factory=PathListConfig)
refer: ReferConfig = Field(default_factory=ReferConfig)
image: PathListConfig = Field(default_factory=PathListConfig)
output_dir: OutputConfig = Field(default_factory=OutputConfig)
2025-07-08 17:45:40 +08:00
class SystemConfig(BaseConfig):
"""系统配置"""
debug: bool = False
log_level: str = "INFO"
parallel_processing: bool = True
max_workers: int = 4
class TopicConfig(BaseConfig):
"""选题配置"""
date: str = ""
num: int = 5
variants: int = 1
2025-07-10 16:31:32 +08:00
2025-07-08 17:45:40 +08:00
class GenerateTopicConfig(BaseConfig):
"""主题生成配置"""
topic_system_prompt: str = "resource/prompt/generateTopics/system.txt"
topic_user_prompt: str = "resource/prompt/generateTopics/user.txt"
2025-07-10 16:31:32 +08:00
model: Dict[str, Any] = Field(default_factory=dict)
topic: TopicConfig = Field(default_factory=TopicConfig)
class GenerateContentConfig(BaseConfig):
"""内容生成配置"""
2025-07-11 17:39:51 +08:00
content_system_prompt: str = "resource/prompt/generateContent/system.txt"
2025-07-10 16:31:32 +08:00
content_user_prompt: str = "resource/prompt/generateContent/user.txt"
judger_system_prompt: str = "resource/prompt/judgeContent/system.txt"
judger_user_prompt: str = "resource/prompt/judgeContent/user.txt"
enable_content_judge: bool = True
refer_sampling_rate: float = 1.0
model: Dict[str, Any] = Field(default_factory=dict)
judger_model: Dict[str, Any] = Field(default_factory=dict)
class AIModelConfig(BaseConfig):
"""AI模型配置"""
model: str = "qwq-plus"
api_url: str = ""
api_key: str = ""
temperature: float = 0.7
top_p: float = 0.5
presence_penalty: float = 1.2
timeout: int = 60
max_retries: int = 3
topic_system_prompt: str = "resource/prompt/generateTopics/system.txt"
topic_user_prompt: str = "resource/prompt/generateTopics/user.txt"
refer_sampling_rate: float = Field(0.5, ge=0.0, le=1.0)
class PosterConfig(BaseConfig):
"""海报生成配置"""
target_size: List[int] = Field(default_factory=lambda: [900, 1200])
additional_images_enabled: bool = True
template_selection: str = "random" # random, business, vibrant, original
available_templates: List[str] = Field(default_factory=lambda: ["original", "business", "vibrant"])
poster_system_prompt: str = "resource/prompt/generatePoster/system.txt"
poster_user_prompt: str = "resource/prompt/generatePoster/user.txt"
2025-07-10 16:31:32 +08:00
class ContentConfig(BaseConfig):
"""内容生成配置"""
enable_content_judge: bool = True
num: int = 5
variants_per_topic: int = 1
max_title_length: int = 30
max_content_length: int = 500