174 lines
5.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
配置模型
定义了项目中所有配置的数据类模型
使用pydantic进行数据验证和序列化
"""
from typing import Dict, Any, List, Optional, Type, TypeVar, Union, ClassVar
from pydantic import BaseModel, Field, model_validator, ConfigDict
T = TypeVar('T', bound='BaseConfig')
class BaseConfig(BaseModel):
"""可从字典更新的配置基类"""
model_config = ConfigDict(
extra="allow", # 允许额外字段
arbitrary_types_allowed=True # 允许任意类型
)
def update(self, new_data: Dict[str, Any]) -> None:
"""从字典递归更新配置"""
for key, value in new_data.items():
if hasattr(self, key):
current_attr = getattr(self, key)
# 检查当前属性是否为配置基类实例,并且新值是字典
if isinstance(current_attr, BaseConfig) and isinstance(value, dict):
# 递归更新嵌套的配置对象
current_attr.update(value)
else:
# 否则,直接设置值
setattr(self, key, value)
def to_dict(self) -> Dict[str, Any]:
"""将配置转换为字典"""
return self.model_dump()
class ReferItem(BaseConfig):
"""单个Refer项"""
path: str = ""
sampling_rate: float = 1.0
step: str = "" # 可选值: "topic", "content", "judge",表示在哪个阶段使用
class ReferConfig(BaseConfig):
"""Refer配置现在是一个列表"""
refer_list: List[ReferItem] = Field(default_factory=list)
def update(self, new_data: Dict[str, Any]) -> None:
"""特殊处理refer_list的更新"""
if 'refer_list' in new_data and isinstance(new_data['refer_list'], list):
# 将字典列表转换为ReferItem对象列表
new_list = []
for item in new_data['refer_list']:
if isinstance(item, dict):
new_list.append(ReferItem(**item))
elif isinstance(item, ReferItem):
new_list.append(item)
self.refer_list = new_list
# 移除已处理的refer_list避免重复处理
new_data_copy = new_data.copy()
new_data_copy.pop('refer_list')
super().update(new_data_copy)
else:
super().update(new_data)
class PathConfig(BaseConfig):
"""单个路径配置"""
path: str = ""
class PathListConfig(BaseConfig):
"""路径列表配置"""
paths: List[str] = Field(default_factory=list)
class SamplingPathListConfig(BaseConfig):
"""带采样率的路径列表配置"""
sampling_rate: float = 1.0
paths: List[str] = Field(default_factory=list)
class OutputConfig(BaseConfig):
"""输出配置"""
base_dir: str = "result"
image_dir: str = "images"
topic_dir: str = "topics"
content_dir: str = "contents"
class ResourceConfig(BaseConfig):
"""资源配置"""
resource_dirs: List[str] = Field(default_factory=list)
style: PathListConfig = Field(default_factory=PathListConfig)
demand: PathListConfig = Field(default_factory=PathListConfig)
object: PathListConfig = Field(default_factory=PathListConfig)
product: PathListConfig = Field(default_factory=PathListConfig)
refer: ReferConfig = Field(default_factory=ReferConfig)
image: PathListConfig = Field(default_factory=PathListConfig)
output_dir: OutputConfig = Field(default_factory=OutputConfig)
class SystemConfig(BaseConfig):
"""系统配置"""
debug: bool = False
log_level: str = "INFO"
parallel_processing: bool = True
max_workers: int = 4
class TopicConfig(BaseConfig):
"""选题配置"""
date: str = ""
num: int = 5
variants: int = 1
class GenerateTopicConfig(BaseConfig):
"""主题生成配置"""
topic_system_prompt: str = "resource/prompt/generateTopics/system.txt"
topic_user_prompt: str = "resource/prompt/generateTopics/user.txt"
model: Dict[str, Any] = Field(default_factory=dict)
topic: TopicConfig = Field(default_factory=TopicConfig)
class GenerateContentConfig(BaseConfig):
"""内容生成配置"""
content_system_prompt: str = "resource/prompt/generateContent/system.txt"
content_user_prompt: str = "resource/prompt/generateContent/user.txt"
judger_system_prompt: str = "resource/prompt/judgeContent/system.txt"
judger_user_prompt: str = "resource/prompt/judgeContent/user.txt"
enable_content_judge: bool = True
refer_sampling_rate: float = 1.0
model: Dict[str, Any] = Field(default_factory=dict)
judger_model: Dict[str, Any] = Field(default_factory=dict)
class AIModelConfig(BaseConfig):
"""AI模型配置"""
model: str = "qwq-plus"
api_url: str = ""
api_key: str = ""
temperature: float = 0.7
top_p: float = 0.5
presence_penalty: float = 1.2
timeout: int = 60
max_retries: int = 3
topic_system_prompt: str = "resource/prompt/generateTopics/system.txt"
topic_user_prompt: str = "resource/prompt/generateTopics/user.txt"
refer_sampling_rate: float = Field(0.5, ge=0.0, le=1.0)
class PosterConfig(BaseConfig):
"""海报生成配置"""
target_size: List[int] = Field(default_factory=lambda: [900, 1200])
additional_images_enabled: bool = True
template_selection: str = "random" # random, business, vibrant, original
available_templates: List[str] = Field(default_factory=lambda: ["original", "business", "vibrant"])
poster_system_prompt: str = "resource/prompt/generatePoster/system.txt"
poster_user_prompt: str = "resource/prompt/generatePoster/user.txt"
class ContentConfig(BaseConfig):
"""内容生成配置"""
enable_content_judge: bool = True
num: int = 5
variants_per_topic: int = 1
max_title_length: int = 30
max_content_length: int = 500