161 lines
5.0 KiB
Python
161 lines
5.0 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
配置模型
|
||
定义了项目中所有配置的数据类模型
|
||
"""
|
||
|
||
from dataclasses import dataclass, field, fields, asdict
|
||
from typing import Dict, Any, List, Optional
|
||
from pydantic import BaseModel, Field
|
||
|
||
|
||
# 基础配置类,提供通用方法
|
||
class BaseConfig:
|
||
"""可从字典更新的配置基类"""
|
||
|
||
def update(self, new_data: Dict[str, Any]):
|
||
"""从字典递归更新配置"""
|
||
for key, value in new_data.items():
|
||
if hasattr(self, key):
|
||
current_attr = getattr(self, key)
|
||
# 检查当前属性是否为配置基类实例,并且新值是字典
|
||
if isinstance(current_attr, BaseConfig) and isinstance(value, dict):
|
||
# 递归更新嵌套的配置对象
|
||
current_attr.update(value)
|
||
# 新增:处理ReferItem列表的特殊情况
|
||
elif key == 'refer_list' and isinstance(value, list):
|
||
# 显式地从字典创建ReferItem对象
|
||
setattr(self, key, [ReferItem(**item) for item in value])
|
||
else:
|
||
# 否则,直接设置值
|
||
setattr(self, key, value)
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
"""将配置转换为字典"""
|
||
return asdict(self)
|
||
|
||
|
||
@dataclass
|
||
class GenerateContentConfig(BaseConfig):
|
||
"""内容生成配置"""
|
||
content_system_prompt: str = "resource/prompt/generateContent/contentSystem.txt"
|
||
content_user_prompt: str = "resource/prompt/generateContent/user.txt"
|
||
judger_system_prompt: str = "resource/prompt/judgeContent/system.txt"
|
||
judger_user_prompt: str = "resource/prompt/judgeContent/user.txt"
|
||
enable_content_judge: bool = True
|
||
refer_sampling_rate: float = 1.0
|
||
|
||
|
||
@dataclass
|
||
class AIModelConfig(BaseConfig):
|
||
"""AI模型配置"""
|
||
model: str = "qwq-plus"
|
||
api_url: str = ""
|
||
api_key: str = ""
|
||
temperature: float = 0.7
|
||
top_p: float = 0.5
|
||
presence_penalty: float = 1.2
|
||
timeout: int = 60
|
||
max_retries: int = 3
|
||
topic_system_prompt: str = "resource/prompt/generateTopics/system.txt"
|
||
topic_user_prompt: str = "resource/prompt/generateTopics/user.txt"
|
||
refer_sampling_rate: float = Field(0.5, ge=0.0, le=1.0)
|
||
|
||
class Config:
|
||
pass
|
||
|
||
|
||
@dataclass
|
||
class PosterConfig(BaseConfig):
|
||
"""海报生成配置"""
|
||
target_size: List[int] = field(default_factory=lambda: [900, 1200])
|
||
additional_images_enabled: bool = True
|
||
template_selection: str = "random" # random, business, vibrant, original
|
||
available_templates: List[str] = field(default_factory=lambda: ["original", "business", "vibrant"])
|
||
|
||
|
||
@dataclass
|
||
class ContentConfig(BaseConfig):
|
||
"""内容生成配置"""
|
||
enable_content_judge: bool = True
|
||
num: int = 5
|
||
variants_per_topic: int = 1
|
||
max_title_length: int = 30
|
||
max_content_length: int = 500
|
||
|
||
|
||
@dataclass
|
||
class PathConfig(BaseConfig):
|
||
"""单个路径配置"""
|
||
path: str = ""
|
||
|
||
@dataclass
|
||
class PathListConfig(BaseConfig):
|
||
"""路径列表配置"""
|
||
paths: List[str] = field(default_factory=list)
|
||
|
||
@dataclass
|
||
class SamplingPathListConfig(BaseConfig):
|
||
"""带采样率的路径列表配置"""
|
||
sampling_rate: float = 1.0
|
||
paths: List[str] = field(default_factory=list)
|
||
|
||
@dataclass
|
||
class ReferItem:
|
||
"""单个Refer项"""
|
||
path: str = ""
|
||
sampling_rate: float = 1.0
|
||
step: str = "" # 可选值: "topic", "content", "judge",表示在哪个阶段使用
|
||
|
||
@dataclass
|
||
class ReferConfig(BaseConfig):
|
||
"""Refer配置,现在是一个列表"""
|
||
refer_list: List[ReferItem] = field(default_factory=list)
|
||
|
||
@dataclass
|
||
class OutputConfig(BaseConfig):
|
||
"""输出配置"""
|
||
base_dir: str = "result"
|
||
image_dir: str = "images"
|
||
topic_dir: str = "topics"
|
||
content_dir: str = "contents"
|
||
|
||
|
||
@dataclass
|
||
class ResourceConfig(BaseConfig):
|
||
"""资源配置"""
|
||
resource_dirs: List[str] = field(default_factory=list)
|
||
style: PathListConfig = field(default_factory=PathListConfig)
|
||
demand: PathListConfig = field(default_factory=PathListConfig)
|
||
object: PathListConfig = field(default_factory=PathListConfig)
|
||
product: PathListConfig = field(default_factory=PathListConfig)
|
||
refer: ReferConfig = field(default_factory=ReferConfig)
|
||
image: PathListConfig = field(default_factory=PathListConfig)
|
||
output_dir: OutputConfig = field(default_factory=OutputConfig)
|
||
|
||
|
||
@dataclass
|
||
class SystemConfig(BaseConfig):
|
||
"""系统配置"""
|
||
debug: bool = False
|
||
log_level: str = "INFO"
|
||
parallel_processing: bool = True
|
||
max_workers: int = 4
|
||
|
||
|
||
@dataclass
|
||
class TopicConfig(BaseConfig):
|
||
"""选题配置"""
|
||
date: str = ""
|
||
num: int = 5
|
||
variants: int = 1
|
||
|
||
@dataclass
|
||
class GenerateTopicConfig(BaseConfig):
|
||
"""主题生成配置"""
|
||
topic_system_prompt: str = "resource/prompt/generateTopics/system.txt"
|
||
topic_user_prompt: str = "resource/prompt/generateTopics/user.txt"
|
||
model: Dict[str, Any] = field(default_factory=dict)
|
||
topic: TopicConfig = field(default_factory=TopicConfig) |