287 lines
8.8 KiB
Python
287 lines
8.8 KiB
Python
"""测试配置管理模块
|
||
|
||
测试配置文件加载和参数合并功能
|
||
"""
|
||
|
||
import pytest
|
||
import tempfile
|
||
import os
|
||
|
||
from src.config import (
|
||
load_config,
|
||
merge_config_with_args,
|
||
FilesConfig,
|
||
APIConfig,
|
||
LoggingConfig,
|
||
BatchConfig,
|
||
AppConfig,
|
||
)
|
||
|
||
|
||
class TestConfigModels:
|
||
"""测试配置模型"""
|
||
|
||
def test_files_config_creation(self):
|
||
"""测试文件配置创建"""
|
||
# Given: 有效的文件配置
|
||
config = FilesConfig(
|
||
input_file="test.xlsx",
|
||
category_file="test.jsonl",
|
||
output_file="output.jsonl",
|
||
temp_request_file="temp.jsonl",
|
||
)
|
||
|
||
# Then: 配置应该正确创建
|
||
assert config.input_file == "test.xlsx"
|
||
assert config.category_file == "test.jsonl"
|
||
assert config.output_file == "output.jsonl"
|
||
assert config.temp_request_file == "temp.jsonl"
|
||
|
||
def test_api_config_with_defaults(self):
|
||
"""测试API配置使用默认值"""
|
||
# Given: 不提供可选字段
|
||
config = APIConfig()
|
||
|
||
# Then: 应该使用默认值
|
||
assert config.api_key is None
|
||
assert config.base_url == "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||
assert config.model == "qwen-flash"
|
||
assert config.completion_window == "24h"
|
||
|
||
def test_logging_config_defaults(self):
|
||
"""测试日志配置默认值"""
|
||
# Given: 创建日志配置
|
||
config = LoggingConfig()
|
||
|
||
# Then: 应该有默认值
|
||
assert config.log_dir == "logs"
|
||
assert config.log_level == "INFO"
|
||
|
||
def test_batch_config_defaults(self):
|
||
"""测试Batch配置默认值"""
|
||
# Given: 创建Batch配置
|
||
config = BatchConfig()
|
||
|
||
# Then: 应该有默认值
|
||
assert config.poll_interval == 60
|
||
assert config.max_wait_time == 86400
|
||
|
||
|
||
class TestLoadConfig:
|
||
"""测试配置文件加载"""
|
||
|
||
@pytest.fixture
|
||
def temp_config_file(self):
|
||
"""创建临时配置文件"""
|
||
config_content = """
|
||
files:
|
||
input_file: "input/test.xlsx"
|
||
category_file: "config/test.jsonl"
|
||
output_file: "output/test.jsonl"
|
||
temp_request_file: "temp_test.jsonl"
|
||
|
||
api:
|
||
api_key: null
|
||
base_url: "https://test.com"
|
||
model: "qwen-plus"
|
||
completion_window: "12h"
|
||
|
||
logging:
|
||
log_dir: "test_logs"
|
||
log_level: "DEBUG"
|
||
|
||
batch:
|
||
poll_interval: 30
|
||
max_wait_time: 3600
|
||
"""
|
||
with tempfile.NamedTemporaryFile(
|
||
mode="w", suffix=".yaml", delete=False, encoding="utf-8"
|
||
) as f:
|
||
f.write(config_content)
|
||
temp_path = f.name
|
||
|
||
yield temp_path
|
||
|
||
# 清理
|
||
if os.path.exists(temp_path):
|
||
os.remove(temp_path)
|
||
|
||
def test_load_config_success(self, temp_config_file):
|
||
"""测试成功加载配置文件"""
|
||
# When: 加载配置文件
|
||
config = load_config(temp_config_file)
|
||
|
||
# Then: 配置应该正确加载
|
||
assert config.files.input_file == "input/test.xlsx"
|
||
assert config.files.category_file == "config/test.jsonl"
|
||
assert config.api.base_url == "https://test.com"
|
||
assert config.api.model == "qwen-plus"
|
||
assert config.logging.log_dir == "test_logs"
|
||
assert config.logging.log_level == "DEBUG"
|
||
assert config.batch.poll_interval == 30
|
||
|
||
def test_load_config_file_not_found(self):
|
||
"""测试配置文件不存在"""
|
||
# When & Then: 应该抛出FileNotFoundError
|
||
with pytest.raises(FileNotFoundError) as exc_info:
|
||
load_config("nonexistent.yaml")
|
||
assert "不存在" in str(exc_info.value)
|
||
|
||
def test_load_config_invalid_yaml(self):
|
||
"""测试无效的YAML格式"""
|
||
# Given: 无效的YAML文件
|
||
with tempfile.NamedTemporaryFile(
|
||
mode="w", suffix=".yaml", delete=False, encoding="utf-8"
|
||
) as f:
|
||
f.write("invalid: yaml: content: [")
|
||
temp_path = f.name
|
||
|
||
try:
|
||
# When & Then: 应该抛出ValueError
|
||
with pytest.raises(ValueError) as exc_info:
|
||
load_config(temp_path)
|
||
assert "格式错误" in str(exc_info.value)
|
||
finally:
|
||
if os.path.exists(temp_path):
|
||
os.remove(temp_path)
|
||
|
||
|
||
class TestMergeConfig:
|
||
"""测试配置合并"""
|
||
|
||
@pytest.fixture
|
||
def base_config(self):
|
||
"""创建基础配置"""
|
||
return AppConfig(
|
||
files=FilesConfig(
|
||
input_file="base_input.xlsx",
|
||
category_file="base_category.jsonl",
|
||
output_file="base_output.jsonl",
|
||
),
|
||
api=APIConfig(
|
||
api_key=None, base_url="https://base.com", model="base-model"
|
||
),
|
||
logging=LoggingConfig(log_dir="base_logs"),
|
||
batch=BatchConfig(),
|
||
)
|
||
|
||
def test_merge_with_no_overrides(self, base_config):
|
||
"""测试不覆盖任何参数"""
|
||
# When: 合并空参数
|
||
merged = merge_config_with_args(base_config)
|
||
|
||
# Then: 配置应该保持不变
|
||
assert merged.files.input_file == "base_input.xlsx"
|
||
assert merged.api.base_url == "https://base.com"
|
||
|
||
def test_merge_with_input_file_override(self, base_config):
|
||
"""测试覆盖输入文件"""
|
||
# When: 覆盖input_file
|
||
merged = merge_config_with_args(base_config, input_file="override_input.xlsx")
|
||
|
||
# Then: input_file应该被覆盖,其他保持不变
|
||
assert merged.files.input_file == "override_input.xlsx"
|
||
assert merged.files.category_file == "base_category.jsonl"
|
||
assert merged.api.base_url == "https://base.com"
|
||
|
||
def test_merge_with_api_overrides(self, base_config):
|
||
"""测试覆盖API配置"""
|
||
# When: 覆盖多个API参数
|
||
merged = merge_config_with_args(
|
||
base_config,
|
||
api_key="test_key",
|
||
base_url="https://override.com",
|
||
model="override-model",
|
||
)
|
||
|
||
# Then: API配置应该被覆盖
|
||
assert merged.api.api_key == "test_key"
|
||
assert merged.api.base_url == "https://override.com"
|
||
assert merged.api.model == "override-model"
|
||
|
||
def test_merge_with_all_overrides(self, base_config):
|
||
"""测试覆盖所有参数"""
|
||
# When: 覆盖所有参数
|
||
merged = merge_config_with_args(
|
||
base_config,
|
||
input_file="new_input.xlsx",
|
||
category_file="new_category.jsonl",
|
||
output_file="new_output.jsonl",
|
||
temp_request_file="new_temp.jsonl",
|
||
api_key="new_key",
|
||
base_url="https://new.com",
|
||
model="new-model",
|
||
completion_window="6h",
|
||
)
|
||
|
||
# Then: 所有配置都应该被覆盖
|
||
assert merged.files.input_file == "new_input.xlsx"
|
||
assert merged.files.category_file == "new_category.jsonl"
|
||
assert merged.files.output_file == "new_output.jsonl"
|
||
assert merged.files.temp_request_file == "new_temp.jsonl"
|
||
assert merged.api.api_key == "new_key"
|
||
assert merged.api.base_url == "https://new.com"
|
||
assert merged.api.model == "new-model"
|
||
assert merged.api.completion_window == "6h"
|
||
|
||
def test_merge_with_none_values(self, base_config):
|
||
"""测试None值不覆盖配置"""
|
||
# When: 传入None值
|
||
merged = merge_config_with_args(base_config, input_file=None, api_key=None)
|
||
|
||
# Then: None值不应该覆盖原配置
|
||
assert merged.files.input_file == "base_input.xlsx"
|
||
assert merged.api.api_key is None # 原本就是None
|
||
|
||
|
||
class TestConfigIntegration:
|
||
"""测试配置集成"""
|
||
|
||
def test_full_config_workflow(self):
|
||
"""测试完整的配置工作流"""
|
||
# Given: 创建临时配置文件
|
||
config_content = """
|
||
files:
|
||
input_file: "input/products.xlsx"
|
||
category_file: "config/categories.jsonl"
|
||
output_file: "output/results.jsonl"
|
||
|
||
api:
|
||
model: "qwen-flash"
|
||
completion_window: "24h"
|
||
|
||
logging:
|
||
log_level: "INFO"
|
||
|
||
batch:
|
||
poll_interval: 60
|
||
"""
|
||
with tempfile.NamedTemporaryFile(
|
||
mode="w", suffix=".yaml", delete=False, encoding="utf-8"
|
||
) as f:
|
||
f.write(config_content)
|
||
temp_path = f.name
|
||
|
||
try:
|
||
# When: 加载配置
|
||
config = load_config(temp_path)
|
||
|
||
# Then: 配置应该正确加载
|
||
assert config.files.input_file == "input/products.xlsx"
|
||
assert config.api.model == "qwen-flash"
|
||
|
||
# When: 合并命令行参数
|
||
merged = merge_config_with_args(
|
||
config, model="qwen-plus", api_key="override_key"
|
||
)
|
||
|
||
# Then: 参数应该被覆盖
|
||
assert merged.api.model == "qwen-plus"
|
||
assert merged.api.api_key == "override_key"
|
||
# 其他配置保持不变
|
||
assert merged.files.input_file == "input/products.xlsx"
|
||
|
||
finally:
|
||
if os.path.exists(temp_path):
|
||
os.remove(temp_path)
|