174 lines
5.8 KiB
Python
174 lines
5.8 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
配置模型
|
||
定义了项目中所有配置的数据类模型
|
||
使用pydantic进行数据验证和序列化
|
||
"""
|
||
|
||
from typing import Dict, Any, List, Optional, Type, TypeVar, Union, ClassVar
|
||
from pydantic import BaseModel, Field, model_validator, ConfigDict
|
||
|
||
T = TypeVar('T', bound='BaseConfig')
|
||
|
||
class BaseConfig(BaseModel):
|
||
"""可从字典更新的配置基类"""
|
||
|
||
model_config = ConfigDict(
|
||
extra="allow", # 允许额外字段
|
||
arbitrary_types_allowed=True # 允许任意类型
|
||
)
|
||
|
||
def update(self, new_data: Dict[str, Any]) -> None:
|
||
"""从字典递归更新配置"""
|
||
for key, value in new_data.items():
|
||
if hasattr(self, key):
|
||
current_attr = getattr(self, key)
|
||
# 检查当前属性是否为配置基类实例,并且新值是字典
|
||
if isinstance(current_attr, BaseConfig) and isinstance(value, dict):
|
||
# 递归更新嵌套的配置对象
|
||
current_attr.update(value)
|
||
else:
|
||
# 否则,直接设置值
|
||
setattr(self, key, value)
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
"""将配置转换为字典"""
|
||
return self.model_dump()
|
||
|
||
|
||
class ReferItem(BaseConfig):
|
||
"""单个Refer项"""
|
||
path: str = ""
|
||
sampling_rate: float = 1.0
|
||
step: str = "" # 可选值: "topic", "content", "judge",表示在哪个阶段使用
|
||
|
||
|
||
class ReferConfig(BaseConfig):
|
||
"""Refer配置,现在是一个列表"""
|
||
refer_list: List[ReferItem] = Field(default_factory=list)
|
||
|
||
def update(self, new_data: Dict[str, Any]) -> None:
|
||
"""特殊处理refer_list的更新"""
|
||
if 'refer_list' in new_data and isinstance(new_data['refer_list'], list):
|
||
# 将字典列表转换为ReferItem对象列表
|
||
new_list = []
|
||
for item in new_data['refer_list']:
|
||
if isinstance(item, dict):
|
||
new_list.append(ReferItem(**item))
|
||
elif isinstance(item, ReferItem):
|
||
new_list.append(item)
|
||
self.refer_list = new_list
|
||
|
||
# 移除已处理的refer_list,避免重复处理
|
||
new_data_copy = new_data.copy()
|
||
new_data_copy.pop('refer_list')
|
||
super().update(new_data_copy)
|
||
else:
|
||
super().update(new_data)
|
||
|
||
|
||
class PathConfig(BaseConfig):
|
||
"""单个路径配置"""
|
||
path: str = ""
|
||
|
||
|
||
class PathListConfig(BaseConfig):
|
||
"""路径列表配置"""
|
||
paths: List[str] = Field(default_factory=list)
|
||
|
||
|
||
class SamplingPathListConfig(BaseConfig):
|
||
"""带采样率的路径列表配置"""
|
||
sampling_rate: float = 1.0
|
||
paths: List[str] = Field(default_factory=list)
|
||
|
||
|
||
class OutputConfig(BaseConfig):
|
||
"""输出配置"""
|
||
base_dir: str = "result"
|
||
image_dir: str = "images"
|
||
topic_dir: str = "topics"
|
||
content_dir: str = "contents"
|
||
|
||
|
||
class ResourceConfig(BaseConfig):
|
||
"""资源配置"""
|
||
resource_dirs: List[str] = Field(default_factory=list)
|
||
style: PathListConfig = Field(default_factory=PathListConfig)
|
||
demand: PathListConfig = Field(default_factory=PathListConfig)
|
||
object: PathListConfig = Field(default_factory=PathListConfig)
|
||
product: PathListConfig = Field(default_factory=PathListConfig)
|
||
refer: ReferConfig = Field(default_factory=ReferConfig)
|
||
image: PathListConfig = Field(default_factory=PathListConfig)
|
||
output_dir: OutputConfig = Field(default_factory=OutputConfig)
|
||
|
||
|
||
class SystemConfig(BaseConfig):
|
||
"""系统配置"""
|
||
debug: bool = False
|
||
log_level: str = "INFO"
|
||
parallel_processing: bool = True
|
||
max_workers: int = 4
|
||
|
||
|
||
class TopicConfig(BaseConfig):
|
||
"""选题配置"""
|
||
date: str = ""
|
||
num: int = 5
|
||
variants: int = 1
|
||
|
||
|
||
class GenerateTopicConfig(BaseConfig):
|
||
"""主题生成配置"""
|
||
topic_system_prompt: str = "resource/prompt/generateTopics/system.txt"
|
||
topic_user_prompt: str = "resource/prompt/generateTopics/user.txt"
|
||
model: Dict[str, Any] = Field(default_factory=dict)
|
||
topic: TopicConfig = Field(default_factory=TopicConfig)
|
||
|
||
|
||
class GenerateContentConfig(BaseConfig):
|
||
"""内容生成配置"""
|
||
content_system_prompt: str = "resource/prompt/generateContent/contentSystem.txt"
|
||
content_user_prompt: str = "resource/prompt/generateContent/user.txt"
|
||
judger_system_prompt: str = "resource/prompt/judgeContent/system.txt"
|
||
judger_user_prompt: str = "resource/prompt/judgeContent/user.txt"
|
||
enable_content_judge: bool = True
|
||
refer_sampling_rate: float = 1.0
|
||
model: Dict[str, Any] = Field(default_factory=dict)
|
||
judger_model: Dict[str, Any] = Field(default_factory=dict)
|
||
|
||
|
||
class AIModelConfig(BaseConfig):
|
||
"""AI模型配置"""
|
||
model: str = "qwq-plus"
|
||
api_url: str = ""
|
||
api_key: str = ""
|
||
temperature: float = 0.7
|
||
top_p: float = 0.5
|
||
presence_penalty: float = 1.2
|
||
timeout: int = 60
|
||
max_retries: int = 3
|
||
topic_system_prompt: str = "resource/prompt/generateTopics/system.txt"
|
||
topic_user_prompt: str = "resource/prompt/generateTopics/user.txt"
|
||
refer_sampling_rate: float = Field(0.5, ge=0.0, le=1.0)
|
||
|
||
|
||
class PosterConfig(BaseConfig):
|
||
"""海报生成配置"""
|
||
target_size: List[int] = Field(default_factory=lambda: [900, 1200])
|
||
additional_images_enabled: bool = True
|
||
template_selection: str = "random" # random, business, vibrant, original
|
||
available_templates: List[str] = Field(default_factory=lambda: ["original", "business", "vibrant"])
|
||
poster_system_prompt: str = "resource/prompt/generatePoster/system.txt"
|
||
poster_user_prompt: str = "resource/prompt/generatePoster/user.txt"
|
||
|
||
|
||
class ContentConfig(BaseConfig):
|
||
"""内容生成配置"""
|
||
enable_content_judge: bool = True
|
||
num: int = 5
|
||
variants_per_topic: int = 1
|
||
max_title_length: int = 30
|
||
max_content_length: int = 500 |