2025-07-31 15:35:23 +08:00

487 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Configuration Models
配置数据模型 - 定义所有算法组件的配置结构
"""
from typing import Dict, List, Any, Optional, Union
from pydantic import BaseModel, Field, validator
from pathlib import Path
import os
class BaseConfig(BaseModel):
"""基础配置类"""
class Config:
extra = "allow" # 允许额外字段
validate_assignment = True
class TaskModelConfig(BaseModel):
"""任务级别的模型配置"""
temperature: Optional[float] = Field(None, ge=0.0, le=2.0)
top_p: Optional[float] = Field(None, ge=0.0, le=1.0)
presence_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0)
frequency_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0)
max_tokens: Optional[int] = Field(None, gt=0)
enable_stream: Optional[bool] = None
def merge_with_default(self, default_config: "AIModelConfig") -> Dict[str, Any]:
"""与默认配置合并,任务配置优先"""
result = {
"temperature": self.temperature if self.temperature is not None else default_config.temperature,
"top_p": self.top_p if self.top_p is not None else default_config.top_p,
"presence_penalty": self.presence_penalty if self.presence_penalty is not None else default_config.presence_penalty,
"frequency_penalty": self.frequency_penalty if self.frequency_penalty is not None else default_config.frequency_penalty,
"max_tokens": self.max_tokens if self.max_tokens is not None else default_config.max_tokens,
"enable_stream": self.enable_stream if self.enable_stream is not None else default_config.enable_stream,
}
return {k: v for k, v in result.items() if v is not None}
class AIModelConfig(BaseConfig):
"""AI模型配置 - 改进版,支持任务级别参数"""
# 基础配置
model: str = "qwq-plus"
api_url: str = ""
api_key: str = ""
timeout: int = Field(60, gt=0)
max_retries: int = Field(3, ge=0)
# 默认模型参数
temperature: float = Field(0.7, ge=0.0, le=2.0)
top_p: float = Field(0.5, ge=0.0, le=1.0)
presence_penalty: float = Field(1.2, ge=-2.0, le=2.0)
frequency_penalty: float = Field(0.0, ge=-2.0, le=2.0)
max_tokens: Optional[int] = Field(None, gt=0)
# 流式输出配置
enable_stream: bool = True
stream_chunk_size: int = 1024
# 任务级别的模型参数配置
task_configs: Dict[str, TaskModelConfig] = Field(default_factory=dict)
def get_task_config(self, task_name: str) -> Dict[str, Any]:
"""
获取特定任务的模型配置
Args:
task_name: 任务名称 (topic_generation, content_generation, content_judging, poster_generation)
Returns:
该任务的模型参数字典
"""
if task_name in self.task_configs:
return self.task_configs[task_name].merge_with_default(self)
else:
# 返回默认配置
return {
"temperature": self.temperature,
"top_p": self.top_p,
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
"max_tokens": self.max_tokens,
"enable_stream": self.enable_stream,
}
def set_task_config(self, task_name: str, **params) -> None:
"""
设置特定任务的模型参数
Args:
task_name: 任务名称
**params: 模型参数
"""
self.task_configs[task_name] = TaskModelConfig(**params)
class PromptConfig(BaseConfig):
"""提示词配置 - 解决硬编码的提示词路径问题"""
# 内容生成提示词
topic_generation: Dict[str, str] = Field(default_factory=dict)
content_generation: Dict[str, str] = Field(default_factory=dict)
content_judging: Dict[str, str] = Field(default_factory=dict)
poster_generation: Dict[str, str] = Field(default_factory=dict)
# 提示词模板路径 (支持自定义)
template_directory: Optional[str] = None
@validator("*", pre=True)
def validate_prompts(cls, v):
"""验证提示词配置"""
if isinstance(v, dict):
# 确保包含system和user两个键
if "system" not in v or "user" not in v:
return {
"system": v.get("system", ""),
"user": v.get("user", "")
}
return v
class ContentGenerationConfig(BaseConfig):
"""内容生成配置"""
# 主题生成配置
topic_count: int = Field(5, ge=1, le=20)
topic_date_range: Optional[str] = None
topic_style: str = "default"
# 内容生成配置
content_length: str = Field("medium", regex="^(short|medium|long)$")
content_style: str = "default"
enable_auto_judge: bool = True
# 审核配置
judge_enabled: bool = True
judge_threshold: float = Field(0.7, ge=0.0, le=1.0)
# 参考资料配置
refer_sampling_rate: float = Field(0.5, ge=0.0, le=1.0)
enable_refer_content: bool = True
class PosterGenerationConfig(BaseConfig):
"""海报生成配置 - 解决硬编码的模板和资源问题"""
# 尺寸配置
default_size: List[int] = Field(default_factory=lambda: [900, 1200])
supported_sizes: List[List[int]] = Field(
default_factory=lambda: [[900, 1200], [1080, 1080], [750, 1334]]
)
# 模板配置
default_template: str = "vibrant"
available_templates: List[str] = Field(
default_factory=lambda: ["vibrant", "business", "collage"]
)
template_selection_mode: str = Field("manual", regex="^(manual|random|auto)$")
# 字体配置 - 解决硬编码字体路径
font_configs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
default_font_family: str = "default"
# 颜色主题配置 - 从硬编码中提取
color_themes: Dict[str, Dict[str, List[List[int]]]] = Field(default_factory=dict)
default_color_theme: str = "ocean_deep"
# 效果配置
glass_effect: Dict[str, Any] = Field(default_factory=dict)
text_effects: Dict[str, Any] = Field(default_factory=dict)
# 输出配置
output_format: str = Field("PNG", regex="^(PNG|JPEG|WEBP)$")
output_quality: int = Field(95, ge=1, le=100)
generate_thumbnail: bool = True
thumbnail_size: List[int] = Field(default_factory=lambda: [300, 400])
class DocumentProcessingConfig(BaseConfig):
"""文档处理配置"""
# 支持的文件类型
supported_extensions: List[str] = Field(
default_factory=lambda: [".txt", ".md", ".pdf", ".docx", ".doc", ".xlsx", ".xls"]
)
# 提取配置
max_file_size: int = Field(50 * 1024 * 1024) # 50MB
encoding: str = "utf-8"
extract_images: bool = True
extract_tables: bool = True
# 处理配置
enable_content_cleaning: bool = True
remove_extra_whitespace: bool = True
normalize_text: bool = True
# 爬虫相关配置
enable_web_scraping: bool = True
scraping_config: Dict[str, Any] = Field(default_factory=dict)
class OutputConfig(BaseConfig):
"""输出配置 - 解决硬编码的文件名和路径问题"""
# 基础路径配置
base_output_directory: str = "result"
enable_timestamped_folders: bool = True
folder_name_pattern: str = "run_{timestamp}"
# 文件命名配置
file_naming: Dict[str, str] = Field(default_factory=lambda: {
"topics": "topics.json",
"content": "article.json",
"poster": "poster.png",
"raw_response": "{stage}_raw_response.txt",
"system_prompt": "{stage}_system_prompt.txt",
"user_prompt": "{stage}_user_prompt.txt"
})
# 保存格式配置
save_raw_responses: bool = True
save_prompts: bool = True
save_metadata: bool = True
# 压缩和备份
enable_compression: bool = False
backup_count: int = 5
class ResourceConfig(BaseConfig):
"""资源配置 - 管理外部资源路径包括static文件"""
# 资源基础路径
resource_base_directory: Optional[str] = None
# 字体资源
font_directory: Optional[str] = None
font_files: Dict[str, str] = Field(default_factory=dict)
# 提示词模板目录
prompt_template_directory: Optional[str] = None
# 图片资源
image_assets_directory: Optional[str] = None
default_background_images: List[str] = Field(default_factory=list)
# 静态文件配置 - 解决爬虫模块依赖
static_files_directory: Optional[str] = None
required_js_files: Dict[str, str] = Field(default_factory=lambda: {
"xhs_xs_xsc_56": "xhs_xs_xsc_56.js",
"xhs_xray": "xhs_xray.js",
"xhs_creator_xs": "xhs_creator_xs.js",
"xhs_xray_pack1": "xhs_xray_pack1.js",
"xhs_xray_pack2": "xhs_xray_pack2.js"
})
def get_js_file_path(self, js_name: str) -> Optional[str]:
"""
解析JavaScript文件路径
Args:
js_name: JS文件标识符
Returns:
完整的文件路径如果找不到返回None
"""
if js_name not in self.required_js_files:
return None
filename = self.required_js_files[js_name]
if self.static_files_directory:
js_path = Path(self.resolve_path(self.static_files_directory)) / filename
if js_path.exists():
return str(js_path)
# 搜索可能的路径
search_paths = [
Path("static") / filename,
Path("../static") / filename,
Path("../../static") / filename,
Path(f"../{self.static_files_directory or 'static'}") / filename if self.static_files_directory else None
]
for path in search_paths:
if path and path.exists():
return str(path.resolve())
return None
def resolve_path(self, relative_path: str) -> str:
"""
解析相对路径为绝对路径
Args:
relative_path: 相对路径
Returns:
解析后的路径
"""
if not relative_path:
return ""
path = Path(relative_path)
# 如果已经是绝对路径,直接返回
if path.is_absolute():
return str(path)
# 如果有resource_base_directory基于它解析
if self.resource_base_directory:
base_path = Path(self.resource_base_directory)
resolved = base_path / path
return str(resolved)
# 否则基于当前目录解析
return str(path.resolve())
def find_resource_directory(self, resource_type: str) -> Optional[Path]:
"""
查找资源目录
Args:
resource_type: 资源类型 (fonts, prompts, images, static)
Returns:
找到的目录路径如果找不到返回None
"""
# 根据资源类型确定目录名
dir_mapping = {
"fonts": self.font_directory,
"prompts": self.prompt_template_directory,
"images": self.image_assets_directory,
"static": self.static_files_directory
}
relative_dir = dir_mapping.get(resource_type)
if not relative_dir:
return None
# 构建搜索路径列表
search_paths = []
# 如果有基础目录,优先搜索
if self.resource_base_directory:
base_path = Path(self.resource_base_directory)
search_paths.append(base_path / relative_dir)
# 添加相对路径搜索
search_paths.extend([
Path(relative_dir),
Path("..") / relative_dir,
Path("../..") / relative_dir,
Path("../../..") / relative_dir
])
# 特殊搜索针对外部项目的resource目录
if resource_type == "prompts":
search_paths.extend([
Path("resource/prompt"),
Path("../resource/prompt"),
Path("../../resource/prompt"),
Path("../../../resource/prompt")
])
elif resource_type == "static":
search_paths.extend([
Path("static"),
Path("../static"),
Path("../../static"),
Path("../../../static")
])
# 搜索存在的目录
for path in search_paths:
if path.exists() and path.is_dir():
return path.resolve()
return None
def validate_js_files(self) -> Dict[str, bool]:
"""
验证所有必需的JS文件是否存在
Returns:
文件名->是否存在的字典
"""
result = {}
for js_name in self.required_js_files:
result[js_name] = self.get_js_file_path(js_name) is not None
return result
class WebCrawlingConfig(BaseModel):
"""网页爬虫配置"""
# 请求配置
request_interval: float = Field(default=1.0, ge=0.1, description="请求间隔(秒)")
max_requests: int = Field(default=100, ge=0, description="最大请求数0表示无限制")
request_timeout: float = Field(default=30.0, ge=1.0, description="请求超时时间")
# 浏览器配置
user_agent: str = Field(
default="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
description="用户代理字符串"
)
# 代理配置
enable_proxy: bool = Field(default=False, description="是否启用代理")
proxy_config: Optional[Dict[str, str]] = Field(default=None, description="代理配置")
# 错误处理
max_retries: int = Field(default=3, ge=0, description="最大重试次数")
retry_delay: float = Field(default=2.0, ge=0.1, description="重试延迟")
# 内容过滤
min_content_length: int = Field(default=50, ge=0, description="最小内容长度")
max_content_length: int = Field(default=10000, ge=100, description="最大内容长度")
# 小红书特定配置
xhs_search_config: Dict[str, Any] = Field(
default_factory=lambda: {
"default_sort_type": 0, # 默认排序类型
"default_note_type": 0, # 默认笔记类型
"enable_image_download": True,
"enable_video_download": False,
"search_result_limit": 50
},
description="小红书搜索配置"
)
class KeywordAnalysisConfig(BaseModel):
"""关键词分析配置"""
# 基础配置
max_keywords: int = Field(default=20, ge=1, le=100, description="最大关键词数量")
min_keyword_length: int = Field(default=2, ge=1, description="最小关键词长度")
max_keyword_length: int = Field(default=10, ge=2, description="最大关键词长度")
# AI分析配置
enable_ai_analysis: bool = Field(default=True, description="是否启用AI分析")
ai_analysis_threshold: int = Field(default=100, ge=0, description="启用AI分析的最小内容长度")
# 评分配置
frequency_weight: float = Field(default=0.4, ge=0.0, le=1.0, description="频率权重")
position_weight: float = Field(default=0.3, ge=0.0, le=1.0, description="位置权重")
type_weight: float = Field(default=0.3, ge=0.0, le=1.0, description="类型权重")
# 分类配置
enable_categorization: bool = Field(default=True, description="是否启用关键词分类")
category_threshold: int = Field(default=2, ge=1, description="分类阈值")
# 搜索建议配置
max_suggestions: int = Field(default=20, ge=1, description="最大搜索建议数量")
enable_combination_suggestions: bool = Field(default=True, description="是否启用组合建议")
class ContentAnalysisConfig(BaseModel):
"""内容分析配置"""
# 质量评分配置
content_length_weight: float = Field(default=0.25, ge=0.0, le=1.0, description="内容长度权重")
title_quality_weight: float = Field(default=0.20, ge=0.0, le=1.0, description="标题质量权重")
media_richness_weight: float = Field(default=0.20, ge=0.0, le=1.0, description="媒体丰富度权重")
structure_weight: float = Field(default=0.20, ge=0.0, le=1.0, description="结构权重")
tag_completeness_weight: float = Field(default=0.15, ge=0.0, le=1.0, description="标签完整性权重")
# 互动评分配置
likes_weight: float = Field(default=0.40, ge=0.0, le=1.0, description="点赞权重")
comments_weight: float = Field(default=0.35, ge=0.0, le=1.0, description="评论权重")
shares_weight: float = Field(default=0.25, ge=0.0, le=1.0, description="分享权重")
# 分析配置
enable_sentiment_analysis: bool = Field(default=True, description="是否启用情感分析")
enable_theme_extraction: bool = Field(default=True, description="是否启用主题提取")
enable_readability_analysis: bool = Field(default=True, description="是否启用可读性分析")
# 主题提取配置
max_themes: int = Field(default=10, ge=1, description="最大主题数量")
theme_frequency_threshold: int = Field(default=2, ge=1, description="主题频率阈值")