161 lines
5.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
配置模型
定义了项目中所有配置的数据类模型
"""
from dataclasses import dataclass, field, fields, asdict
from typing import Dict, Any, List, Optional
from pydantic import BaseModel, Field
# 基础配置类,提供通用方法
class BaseConfig:
"""可从字典更新的配置基类"""
def update(self, new_data: Dict[str, Any]):
"""从字典递归更新配置"""
for key, value in new_data.items():
if hasattr(self, key):
current_attr = getattr(self, key)
# 检查当前属性是否为配置基类实例,并且新值是字典
if isinstance(current_attr, BaseConfig) and isinstance(value, dict):
# 递归更新嵌套的配置对象
current_attr.update(value)
# 新增处理ReferItem列表的特殊情况
elif key == 'refer_list' and isinstance(value, list):
# 显式地从字典创建ReferItem对象
setattr(self, key, [ReferItem(**item) for item in value])
else:
# 否则,直接设置值
setattr(self, key, value)
def to_dict(self) -> Dict[str, Any]:
"""将配置转换为字典"""
return asdict(self)
@dataclass
class GenerateContentConfig(BaseConfig):
"""内容生成配置"""
content_system_prompt: str = "resource/prompt/generateContent/contentSystem.txt"
content_user_prompt: str = "resource/prompt/generateContent/user.txt"
judger_system_prompt: str = "resource/prompt/judgeContent/system.txt"
judger_user_prompt: str = "resource/prompt/judgeContent/user.txt"
enable_content_judge: bool = True
refer_sampling_rate: float = 1.0
@dataclass
class AIModelConfig(BaseConfig):
"""AI模型配置"""
model: str = "qwq-plus"
api_url: str = ""
api_key: str = ""
temperature: float = 0.7
top_p: float = 0.5
presence_penalty: float = 1.2
timeout: int = 60
max_retries: int = 3
topic_system_prompt: str = "resource/prompt/generateTopics/system.txt"
topic_user_prompt: str = "resource/prompt/generateTopics/user.txt"
refer_sampling_rate: float = Field(0.5, ge=0.0, le=1.0)
class Config:
pass
@dataclass
class PosterConfig(BaseConfig):
"""海报生成配置"""
target_size: List[int] = field(default_factory=lambda: [900, 1200])
additional_images_enabled: bool = True
template_selection: str = "random" # random, business, vibrant, original
available_templates: List[str] = field(default_factory=lambda: ["original", "business", "vibrant"])
@dataclass
class ContentConfig(BaseConfig):
"""内容生成配置"""
enable_content_judge: bool = True
num: int = 5
variants_per_topic: int = 1
max_title_length: int = 30
max_content_length: int = 500
@dataclass
class PathConfig(BaseConfig):
"""单个路径配置"""
path: str = ""
@dataclass
class PathListConfig(BaseConfig):
"""路径列表配置"""
paths: List[str] = field(default_factory=list)
@dataclass
class SamplingPathListConfig(BaseConfig):
"""带采样率的路径列表配置"""
sampling_rate: float = 1.0
paths: List[str] = field(default_factory=list)
@dataclass
class ReferItem:
"""单个Refer项"""
path: str = ""
sampling_rate: float = 1.0
step: str = "" # 可选值: "topic", "content", "judge",表示在哪个阶段使用
@dataclass
class ReferConfig(BaseConfig):
"""Refer配置现在是一个列表"""
refer_list: List[ReferItem] = field(default_factory=list)
@dataclass
class OutputConfig(BaseConfig):
"""输出配置"""
base_dir: str = "result"
image_dir: str = "images"
topic_dir: str = "topics"
content_dir: str = "contents"
@dataclass
class ResourceConfig(BaseConfig):
"""资源配置"""
resource_dirs: List[str] = field(default_factory=list)
style: PathListConfig = field(default_factory=PathListConfig)
demand: PathListConfig = field(default_factory=PathListConfig)
object: PathListConfig = field(default_factory=PathListConfig)
product: PathListConfig = field(default_factory=PathListConfig)
refer: ReferConfig = field(default_factory=ReferConfig)
image: PathListConfig = field(default_factory=PathListConfig)
output_dir: OutputConfig = field(default_factory=OutputConfig)
@dataclass
class SystemConfig(BaseConfig):
"""系统配置"""
debug: bool = False
log_level: str = "INFO"
parallel_processing: bool = True
max_workers: int = 4
@dataclass
class TopicConfig(BaseConfig):
"""选题配置"""
date: str = ""
num: int = 5
variants: int = 1
@dataclass
class GenerateTopicConfig(BaseConfig):
"""主题生成配置"""
topic_system_prompt: str = "resource/prompt/generateTopics/system.txt"
topic_user_prompt: str = "resource/prompt/generateTopics/user.txt"
model: Dict[str, Any] = field(default_factory=dict)
topic: TopicConfig = field(default_factory=TopicConfig)