CategorizeLabel/tests/test_config.py

287 lines
8.8 KiB
Python
Raw Permalink Normal View History

2025-10-15 17:19:26 +08:00
"""测试配置管理模块
测试配置文件加载和参数合并功能
"""
import pytest
import tempfile
import os
from src.config import (
load_config,
merge_config_with_args,
FilesConfig,
APIConfig,
LoggingConfig,
BatchConfig,
AppConfig,
)
class TestConfigModels:
"""测试配置模型"""
def test_files_config_creation(self):
"""测试文件配置创建"""
# Given: 有效的文件配置
config = FilesConfig(
input_file="test.xlsx",
category_file="test.jsonl",
output_file="output.jsonl",
temp_request_file="temp.jsonl",
)
# Then: 配置应该正确创建
assert config.input_file == "test.xlsx"
assert config.category_file == "test.jsonl"
assert config.output_file == "output.jsonl"
assert config.temp_request_file == "temp.jsonl"
def test_api_config_with_defaults(self):
"""测试API配置使用默认值"""
# Given: 不提供可选字段
config = APIConfig()
# Then: 应该使用默认值
assert config.api_key is None
assert config.base_url == "https://dashscope.aliyuncs.com/compatible-mode/v1"
assert config.model == "qwen-flash"
assert config.completion_window == "24h"
def test_logging_config_defaults(self):
"""测试日志配置默认值"""
# Given: 创建日志配置
config = LoggingConfig()
# Then: 应该有默认值
assert config.log_dir == "logs"
assert config.log_level == "INFO"
def test_batch_config_defaults(self):
"""测试Batch配置默认值"""
# Given: 创建Batch配置
config = BatchConfig()
# Then: 应该有默认值
assert config.poll_interval == 60
assert config.max_wait_time == 86400
class TestLoadConfig:
"""测试配置文件加载"""
@pytest.fixture
def temp_config_file(self):
"""创建临时配置文件"""
config_content = """
files:
input_file: "input/test.xlsx"
category_file: "config/test.jsonl"
output_file: "output/test.jsonl"
temp_request_file: "temp_test.jsonl"
api:
api_key: null
base_url: "https://test.com"
model: "qwen-plus"
completion_window: "12h"
logging:
log_dir: "test_logs"
log_level: "DEBUG"
batch:
poll_interval: 30
max_wait_time: 3600
"""
with tempfile.NamedTemporaryFile(
mode="w", suffix=".yaml", delete=False, encoding="utf-8"
) as f:
f.write(config_content)
temp_path = f.name
yield temp_path
# 清理
if os.path.exists(temp_path):
os.remove(temp_path)
def test_load_config_success(self, temp_config_file):
"""测试成功加载配置文件"""
# When: 加载配置文件
config = load_config(temp_config_file)
# Then: 配置应该正确加载
assert config.files.input_file == "input/test.xlsx"
assert config.files.category_file == "config/test.jsonl"
assert config.api.base_url == "https://test.com"
assert config.api.model == "qwen-plus"
assert config.logging.log_dir == "test_logs"
assert config.logging.log_level == "DEBUG"
assert config.batch.poll_interval == 30
def test_load_config_file_not_found(self):
"""测试配置文件不存在"""
# When & Then: 应该抛出FileNotFoundError
with pytest.raises(FileNotFoundError) as exc_info:
load_config("nonexistent.yaml")
assert "不存在" in str(exc_info.value)
def test_load_config_invalid_yaml(self):
"""测试无效的YAML格式"""
# Given: 无效的YAML文件
with tempfile.NamedTemporaryFile(
mode="w", suffix=".yaml", delete=False, encoding="utf-8"
) as f:
f.write("invalid: yaml: content: [")
temp_path = f.name
try:
# When & Then: 应该抛出ValueError
with pytest.raises(ValueError) as exc_info:
load_config(temp_path)
assert "格式错误" in str(exc_info.value)
finally:
if os.path.exists(temp_path):
os.remove(temp_path)
class TestMergeConfig:
"""测试配置合并"""
@pytest.fixture
def base_config(self):
"""创建基础配置"""
return AppConfig(
files=FilesConfig(
input_file="base_input.xlsx",
category_file="base_category.jsonl",
output_file="base_output.jsonl",
),
api=APIConfig(
api_key=None, base_url="https://base.com", model="base-model"
),
logging=LoggingConfig(log_dir="base_logs"),
batch=BatchConfig(),
)
def test_merge_with_no_overrides(self, base_config):
"""测试不覆盖任何参数"""
# When: 合并空参数
merged = merge_config_with_args(base_config)
# Then: 配置应该保持不变
assert merged.files.input_file == "base_input.xlsx"
assert merged.api.base_url == "https://base.com"
def test_merge_with_input_file_override(self, base_config):
"""测试覆盖输入文件"""
# When: 覆盖input_file
merged = merge_config_with_args(base_config, input_file="override_input.xlsx")
# Then: input_file应该被覆盖其他保持不变
assert merged.files.input_file == "override_input.xlsx"
assert merged.files.category_file == "base_category.jsonl"
assert merged.api.base_url == "https://base.com"
def test_merge_with_api_overrides(self, base_config):
"""测试覆盖API配置"""
# When: 覆盖多个API参数
merged = merge_config_with_args(
base_config,
api_key="test_key",
base_url="https://override.com",
model="override-model",
)
# Then: API配置应该被覆盖
assert merged.api.api_key == "test_key"
assert merged.api.base_url == "https://override.com"
assert merged.api.model == "override-model"
def test_merge_with_all_overrides(self, base_config):
"""测试覆盖所有参数"""
# When: 覆盖所有参数
merged = merge_config_with_args(
base_config,
input_file="new_input.xlsx",
category_file="new_category.jsonl",
output_file="new_output.jsonl",
temp_request_file="new_temp.jsonl",
api_key="new_key",
base_url="https://new.com",
model="new-model",
completion_window="6h",
)
# Then: 所有配置都应该被覆盖
assert merged.files.input_file == "new_input.xlsx"
assert merged.files.category_file == "new_category.jsonl"
assert merged.files.output_file == "new_output.jsonl"
assert merged.files.temp_request_file == "new_temp.jsonl"
assert merged.api.api_key == "new_key"
assert merged.api.base_url == "https://new.com"
assert merged.api.model == "new-model"
assert merged.api.completion_window == "6h"
def test_merge_with_none_values(self, base_config):
"""测试None值不覆盖配置"""
# When: 传入None值
merged = merge_config_with_args(base_config, input_file=None, api_key=None)
# Then: None值不应该覆盖原配置
assert merged.files.input_file == "base_input.xlsx"
assert merged.api.api_key is None # 原本就是None
class TestConfigIntegration:
"""测试配置集成"""
def test_full_config_workflow(self):
"""测试完整的配置工作流"""
# Given: 创建临时配置文件
config_content = """
files:
input_file: "input/products.xlsx"
category_file: "config/categories.jsonl"
output_file: "output/results.jsonl"
api:
model: "qwen-flash"
completion_window: "24h"
logging:
log_level: "INFO"
batch:
poll_interval: 60
"""
with tempfile.NamedTemporaryFile(
mode="w", suffix=".yaml", delete=False, encoding="utf-8"
) as f:
f.write(config_content)
temp_path = f.name
try:
# When: 加载配置
config = load_config(temp_path)
# Then: 配置应该正确加载
assert config.files.input_file == "input/products.xlsx"
assert config.api.model == "qwen-flash"
# When: 合并命令行参数
merged = merge_config_with_args(
config, model="qwen-plus", api_key="override_key"
)
# Then: 参数应该被覆盖
assert merged.api.model == "qwen-plus"
assert merged.api.api_key == "override_key"
# 其他配置保持不变
assert merged.files.input_file == "input/products.xlsx"
finally:
if os.path.exists(temp_path):
os.remove(temp_path)