CategorizeLabel/tests/test_config.py
2025-10-15 17:19:26 +08:00

287 lines
8.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""测试配置管理模块
测试配置文件加载和参数合并功能
"""
import pytest
import tempfile
import os
from src.config import (
load_config,
merge_config_with_args,
FilesConfig,
APIConfig,
LoggingConfig,
BatchConfig,
AppConfig,
)
class TestConfigModels:
"""测试配置模型"""
def test_files_config_creation(self):
"""测试文件配置创建"""
# Given: 有效的文件配置
config = FilesConfig(
input_file="test.xlsx",
category_file="test.jsonl",
output_file="output.jsonl",
temp_request_file="temp.jsonl",
)
# Then: 配置应该正确创建
assert config.input_file == "test.xlsx"
assert config.category_file == "test.jsonl"
assert config.output_file == "output.jsonl"
assert config.temp_request_file == "temp.jsonl"
def test_api_config_with_defaults(self):
"""测试API配置使用默认值"""
# Given: 不提供可选字段
config = APIConfig()
# Then: 应该使用默认值
assert config.api_key is None
assert config.base_url == "https://dashscope.aliyuncs.com/compatible-mode/v1"
assert config.model == "qwen-flash"
assert config.completion_window == "24h"
def test_logging_config_defaults(self):
"""测试日志配置默认值"""
# Given: 创建日志配置
config = LoggingConfig()
# Then: 应该有默认值
assert config.log_dir == "logs"
assert config.log_level == "INFO"
def test_batch_config_defaults(self):
"""测试Batch配置默认值"""
# Given: 创建Batch配置
config = BatchConfig()
# Then: 应该有默认值
assert config.poll_interval == 60
assert config.max_wait_time == 86400
class TestLoadConfig:
"""测试配置文件加载"""
@pytest.fixture
def temp_config_file(self):
"""创建临时配置文件"""
config_content = """
files:
input_file: "input/test.xlsx"
category_file: "config/test.jsonl"
output_file: "output/test.jsonl"
temp_request_file: "temp_test.jsonl"
api:
api_key: null
base_url: "https://test.com"
model: "qwen-plus"
completion_window: "12h"
logging:
log_dir: "test_logs"
log_level: "DEBUG"
batch:
poll_interval: 30
max_wait_time: 3600
"""
with tempfile.NamedTemporaryFile(
mode="w", suffix=".yaml", delete=False, encoding="utf-8"
) as f:
f.write(config_content)
temp_path = f.name
yield temp_path
# 清理
if os.path.exists(temp_path):
os.remove(temp_path)
def test_load_config_success(self, temp_config_file):
"""测试成功加载配置文件"""
# When: 加载配置文件
config = load_config(temp_config_file)
# Then: 配置应该正确加载
assert config.files.input_file == "input/test.xlsx"
assert config.files.category_file == "config/test.jsonl"
assert config.api.base_url == "https://test.com"
assert config.api.model == "qwen-plus"
assert config.logging.log_dir == "test_logs"
assert config.logging.log_level == "DEBUG"
assert config.batch.poll_interval == 30
def test_load_config_file_not_found(self):
"""测试配置文件不存在"""
# When & Then: 应该抛出FileNotFoundError
with pytest.raises(FileNotFoundError) as exc_info:
load_config("nonexistent.yaml")
assert "不存在" in str(exc_info.value)
def test_load_config_invalid_yaml(self):
"""测试无效的YAML格式"""
# Given: 无效的YAML文件
with tempfile.NamedTemporaryFile(
mode="w", suffix=".yaml", delete=False, encoding="utf-8"
) as f:
f.write("invalid: yaml: content: [")
temp_path = f.name
try:
# When & Then: 应该抛出ValueError
with pytest.raises(ValueError) as exc_info:
load_config(temp_path)
assert "格式错误" in str(exc_info.value)
finally:
if os.path.exists(temp_path):
os.remove(temp_path)
class TestMergeConfig:
"""测试配置合并"""
@pytest.fixture
def base_config(self):
"""创建基础配置"""
return AppConfig(
files=FilesConfig(
input_file="base_input.xlsx",
category_file="base_category.jsonl",
output_file="base_output.jsonl",
),
api=APIConfig(
api_key=None, base_url="https://base.com", model="base-model"
),
logging=LoggingConfig(log_dir="base_logs"),
batch=BatchConfig(),
)
def test_merge_with_no_overrides(self, base_config):
"""测试不覆盖任何参数"""
# When: 合并空参数
merged = merge_config_with_args(base_config)
# Then: 配置应该保持不变
assert merged.files.input_file == "base_input.xlsx"
assert merged.api.base_url == "https://base.com"
def test_merge_with_input_file_override(self, base_config):
"""测试覆盖输入文件"""
# When: 覆盖input_file
merged = merge_config_with_args(base_config, input_file="override_input.xlsx")
# Then: input_file应该被覆盖其他保持不变
assert merged.files.input_file == "override_input.xlsx"
assert merged.files.category_file == "base_category.jsonl"
assert merged.api.base_url == "https://base.com"
def test_merge_with_api_overrides(self, base_config):
"""测试覆盖API配置"""
# When: 覆盖多个API参数
merged = merge_config_with_args(
base_config,
api_key="test_key",
base_url="https://override.com",
model="override-model",
)
# Then: API配置应该被覆盖
assert merged.api.api_key == "test_key"
assert merged.api.base_url == "https://override.com"
assert merged.api.model == "override-model"
def test_merge_with_all_overrides(self, base_config):
"""测试覆盖所有参数"""
# When: 覆盖所有参数
merged = merge_config_with_args(
base_config,
input_file="new_input.xlsx",
category_file="new_category.jsonl",
output_file="new_output.jsonl",
temp_request_file="new_temp.jsonl",
api_key="new_key",
base_url="https://new.com",
model="new-model",
completion_window="6h",
)
# Then: 所有配置都应该被覆盖
assert merged.files.input_file == "new_input.xlsx"
assert merged.files.category_file == "new_category.jsonl"
assert merged.files.output_file == "new_output.jsonl"
assert merged.files.temp_request_file == "new_temp.jsonl"
assert merged.api.api_key == "new_key"
assert merged.api.base_url == "https://new.com"
assert merged.api.model == "new-model"
assert merged.api.completion_window == "6h"
def test_merge_with_none_values(self, base_config):
"""测试None值不覆盖配置"""
# When: 传入None值
merged = merge_config_with_args(base_config, input_file=None, api_key=None)
# Then: None值不应该覆盖原配置
assert merged.files.input_file == "base_input.xlsx"
assert merged.api.api_key is None # 原本就是None
class TestConfigIntegration:
"""测试配置集成"""
def test_full_config_workflow(self):
"""测试完整的配置工作流"""
# Given: 创建临时配置文件
config_content = """
files:
input_file: "input/products.xlsx"
category_file: "config/categories.jsonl"
output_file: "output/results.jsonl"
api:
model: "qwen-flash"
completion_window: "24h"
logging:
log_level: "INFO"
batch:
poll_interval: 60
"""
with tempfile.NamedTemporaryFile(
mode="w", suffix=".yaml", delete=False, encoding="utf-8"
) as f:
f.write(config_content)
temp_path = f.name
try:
# When: 加载配置
config = load_config(temp_path)
# Then: 配置应该正确加载
assert config.files.input_file == "input/products.xlsx"
assert config.api.model == "qwen-flash"
# When: 合并命令行参数
merged = merge_config_with_args(
config, model="qwen-plus", api_key="override_key"
)
# Then: 参数应该被覆盖
assert merged.api.model == "qwen-plus"
assert merged.api.api_key == "override_key"
# 其他配置保持不变
assert merged.files.input_file == "input/products.xlsx"
finally:
if os.path.exists(temp_path):
os.remove(temp_path)