#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Configuration Models 配置数据模型 - 定义所有算法组件的配置结构 """ from typing import Dict, List, Any, Optional, Union from pydantic import BaseModel, Field, validator from pathlib import Path import os class BaseConfig(BaseModel): """基础配置类""" class Config: extra = "allow" # 允许额外字段 validate_assignment = True class TaskModelConfig(BaseModel): """任务级别的模型配置""" temperature: Optional[float] = Field(None, ge=0.0, le=2.0) top_p: Optional[float] = Field(None, ge=0.0, le=1.0) presence_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0) frequency_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0) max_tokens: Optional[int] = Field(None, gt=0) enable_stream: Optional[bool] = None def merge_with_default(self, default_config: "AIModelConfig") -> Dict[str, Any]: """与默认配置合并,任务配置优先""" result = { "temperature": self.temperature if self.temperature is not None else default_config.temperature, "top_p": self.top_p if self.top_p is not None else default_config.top_p, "presence_penalty": self.presence_penalty if self.presence_penalty is not None else default_config.presence_penalty, "frequency_penalty": self.frequency_penalty if self.frequency_penalty is not None else default_config.frequency_penalty, "max_tokens": self.max_tokens if self.max_tokens is not None else default_config.max_tokens, "enable_stream": self.enable_stream if self.enable_stream is not None else default_config.enable_stream, } return {k: v for k, v in result.items() if v is not None} class AIModelConfig(BaseConfig): """AI模型配置 - 改进版,支持任务级别参数""" # 基础配置 model: str = "qwq-plus" api_url: str = "" api_key: str = "" timeout: int = Field(60, gt=0) max_retries: int = Field(3, ge=0) # 默认模型参数 temperature: float = Field(0.7, ge=0.0, le=2.0) top_p: float = Field(0.5, ge=0.0, le=1.0) presence_penalty: float = Field(1.2, ge=-2.0, le=2.0) frequency_penalty: float = Field(0.0, ge=-2.0, le=2.0) max_tokens: Optional[int] = Field(None, gt=0) # 流式输出配置 enable_stream: bool = True stream_chunk_size: int = 1024 # 任务级别的模型参数配置 task_configs: Dict[str, TaskModelConfig] = Field(default_factory=dict) def get_task_config(self, task_name: str) -> Dict[str, Any]: """ 获取特定任务的模型配置 Args: task_name: 任务名称 (topic_generation, content_generation, content_judging, poster_generation) Returns: 该任务的模型参数字典 """ if task_name in self.task_configs: return self.task_configs[task_name].merge_with_default(self) else: # 返回默认配置 return { "temperature": self.temperature, "top_p": self.top_p, "presence_penalty": self.presence_penalty, "frequency_penalty": self.frequency_penalty, "max_tokens": self.max_tokens, "enable_stream": self.enable_stream, } def set_task_config(self, task_name: str, **params) -> None: """ 设置特定任务的模型参数 Args: task_name: 任务名称 **params: 模型参数 """ self.task_configs[task_name] = TaskModelConfig(**params) class PromptConfig(BaseConfig): """提示词配置 - 解决硬编码的提示词路径问题""" # 内容生成提示词 topic_generation: Dict[str, str] = Field(default_factory=dict) content_generation: Dict[str, str] = Field(default_factory=dict) content_judging: Dict[str, str] = Field(default_factory=dict) poster_generation: Dict[str, str] = Field(default_factory=dict) # 提示词模板路径 (支持自定义) template_directory: Optional[str] = None @validator("*", pre=True) def validate_prompts(cls, v): """验证提示词配置""" if isinstance(v, dict): # 确保包含system和user两个键 if "system" not in v or "user" not in v: return { "system": v.get("system", ""), "user": v.get("user", "") } return v class ContentGenerationConfig(BaseConfig): """内容生成配置""" # 主题生成配置 topic_count: int = Field(5, ge=1, le=20) topic_date_range: Optional[str] = None topic_style: str = "default" # 内容生成配置 content_length: str = Field("medium", regex="^(short|medium|long)$") content_style: str = "default" enable_auto_judge: bool = True # 审核配置 judge_enabled: bool = True judge_threshold: float = Field(0.7, ge=0.0, le=1.0) # 参考资料配置 refer_sampling_rate: float = Field(0.5, ge=0.0, le=1.0) enable_refer_content: bool = True class PosterGenerationConfig(BaseConfig): """海报生成配置 - 解决硬编码的模板和资源问题""" # 尺寸配置 default_size: List[int] = Field(default_factory=lambda: [900, 1200]) supported_sizes: List[List[int]] = Field( default_factory=lambda: [[900, 1200], [1080, 1080], [750, 1334]] ) # 模板配置 default_template: str = "vibrant" available_templates: List[str] = Field( default_factory=lambda: ["vibrant", "business", "collage"] ) template_selection_mode: str = Field("manual", regex="^(manual|random|auto)$") # 字体配置 - 解决硬编码字体路径 font_configs: Dict[str, Dict[str, Any]] = Field(default_factory=dict) default_font_family: str = "default" # 颜色主题配置 - 从硬编码中提取 color_themes: Dict[str, Dict[str, List[List[int]]]] = Field(default_factory=dict) default_color_theme: str = "ocean_deep" # 效果配置 glass_effect: Dict[str, Any] = Field(default_factory=dict) text_effects: Dict[str, Any] = Field(default_factory=dict) # 输出配置 output_format: str = Field("PNG", regex="^(PNG|JPEG|WEBP)$") output_quality: int = Field(95, ge=1, le=100) generate_thumbnail: bool = True thumbnail_size: List[int] = Field(default_factory=lambda: [300, 400]) class DocumentProcessingConfig(BaseConfig): """文档处理配置""" # 支持的文件类型 supported_extensions: List[str] = Field( default_factory=lambda: [".txt", ".md", ".pdf", ".docx", ".doc", ".xlsx", ".xls"] ) # 提取配置 max_file_size: int = Field(50 * 1024 * 1024) # 50MB encoding: str = "utf-8" extract_images: bool = True extract_tables: bool = True # 处理配置 enable_content_cleaning: bool = True remove_extra_whitespace: bool = True normalize_text: bool = True # 爬虫相关配置 enable_web_scraping: bool = True scraping_config: Dict[str, Any] = Field(default_factory=dict) class OutputConfig(BaseConfig): """输出配置 - 解决硬编码的文件名和路径问题""" # 基础路径配置 base_output_directory: str = "result" enable_timestamped_folders: bool = True folder_name_pattern: str = "run_{timestamp}" # 文件命名配置 file_naming: Dict[str, str] = Field(default_factory=lambda: { "topics": "topics.json", "content": "article.json", "poster": "poster.png", "raw_response": "{stage}_raw_response.txt", "system_prompt": "{stage}_system_prompt.txt", "user_prompt": "{stage}_user_prompt.txt" }) # 保存格式配置 save_raw_responses: bool = True save_prompts: bool = True save_metadata: bool = True # 压缩和备份 enable_compression: bool = False backup_count: int = 5 class ResourceConfig(BaseConfig): """资源配置 - 管理外部资源路径,包括static文件""" # 资源基础路径 resource_base_directory: Optional[str] = None # 字体资源 font_directory: Optional[str] = None font_files: Dict[str, str] = Field(default_factory=dict) # 提示词模板目录 prompt_template_directory: Optional[str] = None # 图片资源 image_assets_directory: Optional[str] = None default_background_images: List[str] = Field(default_factory=list) # 静态文件配置 - 解决爬虫模块依赖 static_files_directory: Optional[str] = None required_js_files: Dict[str, str] = Field(default_factory=lambda: { "xhs_xs_xsc_56": "xhs_xs_xsc_56.js", "xhs_xray": "xhs_xray.js", "xhs_creator_xs": "xhs_creator_xs.js", "xhs_xray_pack1": "xhs_xray_pack1.js", "xhs_xray_pack2": "xhs_xray_pack2.js" }) def get_js_file_path(self, js_name: str) -> Optional[str]: """ 解析JavaScript文件路径 Args: js_name: JS文件标识符 Returns: 完整的文件路径,如果找不到返回None """ if js_name not in self.required_js_files: return None filename = self.required_js_files[js_name] if self.static_files_directory: js_path = Path(self.resolve_path(self.static_files_directory)) / filename if js_path.exists(): return str(js_path) # 搜索可能的路径 search_paths = [ Path("static") / filename, Path("../static") / filename, Path("../../static") / filename, Path(f"../{self.static_files_directory or 'static'}") / filename if self.static_files_directory else None ] for path in search_paths: if path and path.exists(): return str(path.resolve()) return None def resolve_path(self, relative_path: str) -> str: """ 解析相对路径为绝对路径 Args: relative_path: 相对路径 Returns: 解析后的路径 """ if not relative_path: return "" path = Path(relative_path) # 如果已经是绝对路径,直接返回 if path.is_absolute(): return str(path) # 如果有resource_base_directory,基于它解析 if self.resource_base_directory: base_path = Path(self.resource_base_directory) resolved = base_path / path return str(resolved) # 否则基于当前目录解析 return str(path.resolve()) def find_resource_directory(self, resource_type: str) -> Optional[Path]: """ 查找资源目录 Args: resource_type: 资源类型 (fonts, prompts, images, static) Returns: 找到的目录路径,如果找不到返回None """ # 根据资源类型确定目录名 dir_mapping = { "fonts": self.font_directory, "prompts": self.prompt_template_directory, "images": self.image_assets_directory, "static": self.static_files_directory } relative_dir = dir_mapping.get(resource_type) if not relative_dir: return None # 构建搜索路径列表 search_paths = [] # 如果有基础目录,优先搜索 if self.resource_base_directory: base_path = Path(self.resource_base_directory) search_paths.append(base_path / relative_dir) # 添加相对路径搜索 search_paths.extend([ Path(relative_dir), Path("..") / relative_dir, Path("../..") / relative_dir, Path("../../..") / relative_dir ]) # 特殊搜索:针对外部项目的resource目录 if resource_type == "prompts": search_paths.extend([ Path("resource/prompt"), Path("../resource/prompt"), Path("../../resource/prompt"), Path("../../../resource/prompt") ]) elif resource_type == "static": search_paths.extend([ Path("static"), Path("../static"), Path("../../static"), Path("../../../static") ]) # 搜索存在的目录 for path in search_paths: if path.exists() and path.is_dir(): return path.resolve() return None def validate_js_files(self) -> Dict[str, bool]: """ 验证所有必需的JS文件是否存在 Returns: 文件名->是否存在的字典 """ result = {} for js_name in self.required_js_files: result[js_name] = self.get_js_file_path(js_name) is not None return result class WebCrawlingConfig(BaseModel): """网页爬虫配置""" # 请求配置 request_interval: float = Field(default=1.0, ge=0.1, description="请求间隔(秒)") max_requests: int = Field(default=100, ge=0, description="最大请求数,0表示无限制") request_timeout: float = Field(default=30.0, ge=1.0, description="请求超时时间") # 浏览器配置 user_agent: str = Field( default="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", description="用户代理字符串" ) # 代理配置 enable_proxy: bool = Field(default=False, description="是否启用代理") proxy_config: Optional[Dict[str, str]] = Field(default=None, description="代理配置") # 错误处理 max_retries: int = Field(default=3, ge=0, description="最大重试次数") retry_delay: float = Field(default=2.0, ge=0.1, description="重试延迟") # 内容过滤 min_content_length: int = Field(default=50, ge=0, description="最小内容长度") max_content_length: int = Field(default=10000, ge=100, description="最大内容长度") # 小红书特定配置 xhs_search_config: Dict[str, Any] = Field( default_factory=lambda: { "default_sort_type": 0, # 默认排序类型 "default_note_type": 0, # 默认笔记类型 "enable_image_download": True, "enable_video_download": False, "search_result_limit": 50 }, description="小红书搜索配置" ) class KeywordAnalysisConfig(BaseModel): """关键词分析配置""" # 基础配置 max_keywords: int = Field(default=20, ge=1, le=100, description="最大关键词数量") min_keyword_length: int = Field(default=2, ge=1, description="最小关键词长度") max_keyword_length: int = Field(default=10, ge=2, description="最大关键词长度") # AI分析配置 enable_ai_analysis: bool = Field(default=True, description="是否启用AI分析") ai_analysis_threshold: int = Field(default=100, ge=0, description="启用AI分析的最小内容长度") # 评分配置 frequency_weight: float = Field(default=0.4, ge=0.0, le=1.0, description="频率权重") position_weight: float = Field(default=0.3, ge=0.0, le=1.0, description="位置权重") type_weight: float = Field(default=0.3, ge=0.0, le=1.0, description="类型权重") # 分类配置 enable_categorization: bool = Field(default=True, description="是否启用关键词分类") category_threshold: int = Field(default=2, ge=1, description="分类阈值") # 搜索建议配置 max_suggestions: int = Field(default=20, ge=1, description="最大搜索建议数量") enable_combination_suggestions: bool = Field(default=True, description="是否启用组合建议") class ContentAnalysisConfig(BaseModel): """内容分析配置""" # 质量评分配置 content_length_weight: float = Field(default=0.25, ge=0.0, le=1.0, description="内容长度权重") title_quality_weight: float = Field(default=0.20, ge=0.0, le=1.0, description="标题质量权重") media_richness_weight: float = Field(default=0.20, ge=0.0, le=1.0, description="媒体丰富度权重") structure_weight: float = Field(default=0.20, ge=0.0, le=1.0, description="结构权重") tag_completeness_weight: float = Field(default=0.15, ge=0.0, le=1.0, description="标签完整性权重") # 互动评分配置 likes_weight: float = Field(default=0.40, ge=0.0, le=1.0, description="点赞权重") comments_weight: float = Field(default=0.35, ge=0.0, le=1.0, description="评论权重") shares_weight: float = Field(default=0.25, ge=0.0, le=1.0, description="分享权重") # 分析配置 enable_sentiment_analysis: bool = Field(default=True, description="是否启用情感分析") enable_theme_extraction: bool = Field(default=True, description="是否启用主题提取") enable_readability_analysis: bool = Field(default=True, description="是否启用可读性分析") # 主题提取配置 max_themes: int = Field(default=10, ge=1, description="最大主题数量") theme_frequency_threshold: int = Field(default=2, ge=1, description="主题频率阈值")