487 lines
18 KiB
Python
Raw Normal View History

2025-07-31 15:35:23 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Configuration Models
配置数据模型 - 定义所有算法组件的配置结构
"""
from typing import Dict, List, Any, Optional, Union
from pydantic import BaseModel, Field, validator
from pathlib import Path
import os
class BaseConfig(BaseModel):
"""基础配置类"""
class Config:
extra = "allow" # 允许额外字段
validate_assignment = True
class TaskModelConfig(BaseModel):
"""任务级别的模型配置"""
temperature: Optional[float] = Field(None, ge=0.0, le=2.0)
top_p: Optional[float] = Field(None, ge=0.0, le=1.0)
presence_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0)
frequency_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0)
max_tokens: Optional[int] = Field(None, gt=0)
enable_stream: Optional[bool] = None
def merge_with_default(self, default_config: "AIModelConfig") -> Dict[str, Any]:
"""与默认配置合并,任务配置优先"""
result = {
"temperature": self.temperature if self.temperature is not None else default_config.temperature,
"top_p": self.top_p if self.top_p is not None else default_config.top_p,
"presence_penalty": self.presence_penalty if self.presence_penalty is not None else default_config.presence_penalty,
"frequency_penalty": self.frequency_penalty if self.frequency_penalty is not None else default_config.frequency_penalty,
"max_tokens": self.max_tokens if self.max_tokens is not None else default_config.max_tokens,
"enable_stream": self.enable_stream if self.enable_stream is not None else default_config.enable_stream,
}
return {k: v for k, v in result.items() if v is not None}
class AIModelConfig(BaseConfig):
"""AI模型配置 - 改进版,支持任务级别参数"""
# 基础配置
model: str = "qwq-plus"
api_url: str = ""
api_key: str = ""
timeout: int = Field(60, gt=0)
max_retries: int = Field(3, ge=0)
# 默认模型参数
temperature: float = Field(0.7, ge=0.0, le=2.0)
top_p: float = Field(0.5, ge=0.0, le=1.0)
presence_penalty: float = Field(1.2, ge=-2.0, le=2.0)
frequency_penalty: float = Field(0.0, ge=-2.0, le=2.0)
max_tokens: Optional[int] = Field(None, gt=0)
# 流式输出配置
enable_stream: bool = True
stream_chunk_size: int = 1024
# 任务级别的模型参数配置
task_configs: Dict[str, TaskModelConfig] = Field(default_factory=dict)
def get_task_config(self, task_name: str) -> Dict[str, Any]:
"""
获取特定任务的模型配置
Args:
task_name: 任务名称 (topic_generation, content_generation, content_judging, poster_generation)
Returns:
该任务的模型参数字典
"""
if task_name in self.task_configs:
return self.task_configs[task_name].merge_with_default(self)
else:
# 返回默认配置
return {
"temperature": self.temperature,
"top_p": self.top_p,
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
"max_tokens": self.max_tokens,
"enable_stream": self.enable_stream,
}
def set_task_config(self, task_name: str, **params) -> None:
"""
设置特定任务的模型参数
Args:
task_name: 任务名称
**params: 模型参数
"""
self.task_configs[task_name] = TaskModelConfig(**params)
class PromptConfig(BaseConfig):
"""提示词配置 - 解决硬编码的提示词路径问题"""
# 内容生成提示词
topic_generation: Dict[str, str] = Field(default_factory=dict)
content_generation: Dict[str, str] = Field(default_factory=dict)
content_judging: Dict[str, str] = Field(default_factory=dict)
poster_generation: Dict[str, str] = Field(default_factory=dict)
# 提示词模板路径 (支持自定义)
template_directory: Optional[str] = None
@validator("*", pre=True)
def validate_prompts(cls, v):
"""验证提示词配置"""
if isinstance(v, dict):
# 确保包含system和user两个键
if "system" not in v or "user" not in v:
return {
"system": v.get("system", ""),
"user": v.get("user", "")
}
return v
class ContentGenerationConfig(BaseConfig):
"""内容生成配置"""
# 主题生成配置
topic_count: int = Field(5, ge=1, le=20)
topic_date_range: Optional[str] = None
topic_style: str = "default"
# 内容生成配置
content_length: str = Field("medium", regex="^(short|medium|long)$")
content_style: str = "default"
enable_auto_judge: bool = True
# 审核配置
judge_enabled: bool = True
judge_threshold: float = Field(0.7, ge=0.0, le=1.0)
# 参考资料配置
refer_sampling_rate: float = Field(0.5, ge=0.0, le=1.0)
enable_refer_content: bool = True
class PosterGenerationConfig(BaseConfig):
"""海报生成配置 - 解决硬编码的模板和资源问题"""
# 尺寸配置
default_size: List[int] = Field(default_factory=lambda: [900, 1200])
supported_sizes: List[List[int]] = Field(
default_factory=lambda: [[900, 1200], [1080, 1080], [750, 1334]]
)
# 模板配置
default_template: str = "vibrant"
available_templates: List[str] = Field(
default_factory=lambda: ["vibrant", "business", "collage"]
)
template_selection_mode: str = Field("manual", regex="^(manual|random|auto)$")
# 字体配置 - 解决硬编码字体路径
font_configs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
default_font_family: str = "default"
# 颜色主题配置 - 从硬编码中提取
color_themes: Dict[str, Dict[str, List[List[int]]]] = Field(default_factory=dict)
default_color_theme: str = "ocean_deep"
# 效果配置
glass_effect: Dict[str, Any] = Field(default_factory=dict)
text_effects: Dict[str, Any] = Field(default_factory=dict)
# 输出配置
output_format: str = Field("PNG", regex="^(PNG|JPEG|WEBP)$")
output_quality: int = Field(95, ge=1, le=100)
generate_thumbnail: bool = True
thumbnail_size: List[int] = Field(default_factory=lambda: [300, 400])
class DocumentProcessingConfig(BaseConfig):
"""文档处理配置"""
# 支持的文件类型
supported_extensions: List[str] = Field(
default_factory=lambda: [".txt", ".md", ".pdf", ".docx", ".doc", ".xlsx", ".xls"]
)
# 提取配置
max_file_size: int = Field(50 * 1024 * 1024) # 50MB
encoding: str = "utf-8"
extract_images: bool = True
extract_tables: bool = True
# 处理配置
enable_content_cleaning: bool = True
remove_extra_whitespace: bool = True
normalize_text: bool = True
# 爬虫相关配置
enable_web_scraping: bool = True
scraping_config: Dict[str, Any] = Field(default_factory=dict)
class OutputConfig(BaseConfig):
"""输出配置 - 解决硬编码的文件名和路径问题"""
# 基础路径配置
base_output_directory: str = "result"
enable_timestamped_folders: bool = True
folder_name_pattern: str = "run_{timestamp}"
# 文件命名配置
file_naming: Dict[str, str] = Field(default_factory=lambda: {
"topics": "topics.json",
"content": "article.json",
"poster": "poster.png",
"raw_response": "{stage}_raw_response.txt",
"system_prompt": "{stage}_system_prompt.txt",
"user_prompt": "{stage}_user_prompt.txt"
})
# 保存格式配置
save_raw_responses: bool = True
save_prompts: bool = True
save_metadata: bool = True
# 压缩和备份
enable_compression: bool = False
backup_count: int = 5
class ResourceConfig(BaseConfig):
"""资源配置 - 管理外部资源路径包括static文件"""
# 资源基础路径
resource_base_directory: Optional[str] = None
# 字体资源
font_directory: Optional[str] = None
font_files: Dict[str, str] = Field(default_factory=dict)
# 提示词模板目录
prompt_template_directory: Optional[str] = None
# 图片资源
image_assets_directory: Optional[str] = None
default_background_images: List[str] = Field(default_factory=list)
# 静态文件配置 - 解决爬虫模块依赖
static_files_directory: Optional[str] = None
required_js_files: Dict[str, str] = Field(default_factory=lambda: {
"xhs_xs_xsc_56": "xhs_xs_xsc_56.js",
"xhs_xray": "xhs_xray.js",
"xhs_creator_xs": "xhs_creator_xs.js",
"xhs_xray_pack1": "xhs_xray_pack1.js",
"xhs_xray_pack2": "xhs_xray_pack2.js"
})
def get_js_file_path(self, js_name: str) -> Optional[str]:
"""
解析JavaScript文件路径
Args:
js_name: JS文件标识符
Returns:
完整的文件路径如果找不到返回None
"""
if js_name not in self.required_js_files:
return None
filename = self.required_js_files[js_name]
if self.static_files_directory:
js_path = Path(self.resolve_path(self.static_files_directory)) / filename
if js_path.exists():
return str(js_path)
# 搜索可能的路径
search_paths = [
Path("static") / filename,
Path("../static") / filename,
Path("../../static") / filename,
Path(f"../{self.static_files_directory or 'static'}") / filename if self.static_files_directory else None
]
for path in search_paths:
if path and path.exists():
return str(path.resolve())
return None
def resolve_path(self, relative_path: str) -> str:
"""
解析相对路径为绝对路径
Args:
relative_path: 相对路径
Returns:
解析后的路径
"""
if not relative_path:
return ""
path = Path(relative_path)
# 如果已经是绝对路径,直接返回
if path.is_absolute():
return str(path)
# 如果有resource_base_directory基于它解析
if self.resource_base_directory:
base_path = Path(self.resource_base_directory)
resolved = base_path / path
return str(resolved)
# 否则基于当前目录解析
return str(path.resolve())
def find_resource_directory(self, resource_type: str) -> Optional[Path]:
"""
查找资源目录
Args:
resource_type: 资源类型 (fonts, prompts, images, static)
Returns:
找到的目录路径如果找不到返回None
"""
# 根据资源类型确定目录名
dir_mapping = {
"fonts": self.font_directory,
"prompts": self.prompt_template_directory,
"images": self.image_assets_directory,
"static": self.static_files_directory
}
relative_dir = dir_mapping.get(resource_type)
if not relative_dir:
return None
# 构建搜索路径列表
search_paths = []
# 如果有基础目录,优先搜索
if self.resource_base_directory:
base_path = Path(self.resource_base_directory)
search_paths.append(base_path / relative_dir)
# 添加相对路径搜索
search_paths.extend([
Path(relative_dir),
Path("..") / relative_dir,
Path("../..") / relative_dir,
Path("../../..") / relative_dir
])
# 特殊搜索针对外部项目的resource目录
if resource_type == "prompts":
search_paths.extend([
Path("resource/prompt"),
Path("../resource/prompt"),
Path("../../resource/prompt"),
Path("../../../resource/prompt")
])
elif resource_type == "static":
search_paths.extend([
Path("static"),
Path("../static"),
Path("../../static"),
Path("../../../static")
])
# 搜索存在的目录
for path in search_paths:
if path.exists() and path.is_dir():
return path.resolve()
return None
def validate_js_files(self) -> Dict[str, bool]:
"""
验证所有必需的JS文件是否存在
Returns:
文件名->是否存在的字典
"""
result = {}
for js_name in self.required_js_files:
result[js_name] = self.get_js_file_path(js_name) is not None
return result
class WebCrawlingConfig(BaseModel):
"""网页爬虫配置"""
# 请求配置
request_interval: float = Field(default=1.0, ge=0.1, description="请求间隔(秒)")
max_requests: int = Field(default=100, ge=0, description="最大请求数0表示无限制")
request_timeout: float = Field(default=30.0, ge=1.0, description="请求超时时间")
# 浏览器配置
user_agent: str = Field(
default="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
description="用户代理字符串"
)
# 代理配置
enable_proxy: bool = Field(default=False, description="是否启用代理")
proxy_config: Optional[Dict[str, str]] = Field(default=None, description="代理配置")
# 错误处理
max_retries: int = Field(default=3, ge=0, description="最大重试次数")
retry_delay: float = Field(default=2.0, ge=0.1, description="重试延迟")
# 内容过滤
min_content_length: int = Field(default=50, ge=0, description="最小内容长度")
max_content_length: int = Field(default=10000, ge=100, description="最大内容长度")
# 小红书特定配置
xhs_search_config: Dict[str, Any] = Field(
default_factory=lambda: {
"default_sort_type": 0, # 默认排序类型
"default_note_type": 0, # 默认笔记类型
"enable_image_download": True,
"enable_video_download": False,
"search_result_limit": 50
},
description="小红书搜索配置"
)
class KeywordAnalysisConfig(BaseModel):
"""关键词分析配置"""
# 基础配置
max_keywords: int = Field(default=20, ge=1, le=100, description="最大关键词数量")
min_keyword_length: int = Field(default=2, ge=1, description="最小关键词长度")
max_keyword_length: int = Field(default=10, ge=2, description="最大关键词长度")
# AI分析配置
enable_ai_analysis: bool = Field(default=True, description="是否启用AI分析")
ai_analysis_threshold: int = Field(default=100, ge=0, description="启用AI分析的最小内容长度")
# 评分配置
frequency_weight: float = Field(default=0.4, ge=0.0, le=1.0, description="频率权重")
position_weight: float = Field(default=0.3, ge=0.0, le=1.0, description="位置权重")
type_weight: float = Field(default=0.3, ge=0.0, le=1.0, description="类型权重")
# 分类配置
enable_categorization: bool = Field(default=True, description="是否启用关键词分类")
category_threshold: int = Field(default=2, ge=1, description="分类阈值")
# 搜索建议配置
max_suggestions: int = Field(default=20, ge=1, description="最大搜索建议数量")
enable_combination_suggestions: bool = Field(default=True, description="是否启用组合建议")
class ContentAnalysisConfig(BaseModel):
"""内容分析配置"""
# 质量评分配置
content_length_weight: float = Field(default=0.25, ge=0.0, le=1.0, description="内容长度权重")
title_quality_weight: float = Field(default=0.20, ge=0.0, le=1.0, description="标题质量权重")
media_richness_weight: float = Field(default=0.20, ge=0.0, le=1.0, description="媒体丰富度权重")
structure_weight: float = Field(default=0.20, ge=0.0, le=1.0, description="结构权重")
tag_completeness_weight: float = Field(default=0.15, ge=0.0, le=1.0, description="标签完整性权重")
# 互动评分配置
likes_weight: float = Field(default=0.40, ge=0.0, le=1.0, description="点赞权重")
comments_weight: float = Field(default=0.35, ge=0.0, le=1.0, description="评论权重")
shares_weight: float = Field(default=0.25, ge=0.0, le=1.0, description="分享权重")
# 分析配置
enable_sentiment_analysis: bool = Field(default=True, description="是否启用情感分析")
enable_theme_extraction: bool = Field(default=True, description="是否启用主题提取")
enable_readability_analysis: bool = Field(default=True, description="是否启用可读性分析")
# 主题提取配置
max_themes: int = Field(default=10, ge=1, description="最大主题数量")
theme_frequency_threshold: int = Field(default=2, ge=1, description="主题频率阈值")