CategorizeLabel/tests/test_main.py

405 lines
14 KiB
Python
Raw Normal View History

2025-10-15 17:19:26 +08:00
"""测试主程序
测试main.py的主流程和集成功能
"""
import pytest
import os
import tempfile
from unittest.mock import MagicMock, patch
from src.main import main, setup_logger, get_env_api_key, cleanup_temp_files
from src.config import AppConfig, FilesConfig, APIConfig
from src.models import (
ProductInput,
ProductCategory,
ClassificationResult,
BatchTaskInfo,
)
class TestHelperFunctions:
"""测试辅助函数"""
def test_setup_logger(self):
"""测试设置日志"""
# Given & When: 设置日志
with tempfile.TemporaryDirectory() as temp_dir:
logger = setup_logger(log_dir=temp_dir)
# Then: 日志记录器应该被正确配置
assert logger.name == "product_classifier"
assert logger.level == 20 # INFO level
assert len(logger.handlers) >= 2 # 文件和控制台handler
# 清理handlers以释放文件
for handler in logger.handlers[:]:
handler.close()
logger.removeHandler(handler)
def test_get_env_api_key_success(self):
"""测试成功获取环境变量API Key"""
# Given: 环境变量已设置
with patch.dict(os.environ, {"DASHSCOPE_API_KEY": "test_key_123"}):
# When: 获取API Key
api_key = get_env_api_key()
# Then: 应该返回正确的API Key
assert api_key == "test_key_123"
def test_get_env_api_key_not_set(self):
"""测试环境变量未设置"""
# Given: 环境变量未设置
with patch.dict(os.environ, {}, clear=True):
# When & Then: 应该抛出ValueError
with pytest.raises(ValueError) as exc_info:
get_env_api_key()
assert "未设置" in str(exc_info.value)
def test_cleanup_temp_files(self):
"""测试清理临时文件"""
# Given: 创建临时文件
with tempfile.NamedTemporaryFile(delete=False) as f:
temp_file = f.name
f.write(b"test content")
logger = MagicMock()
# When: 清理文件
cleanup_temp_files([temp_file], logger)
# Then: 文件应该被删除
assert not os.path.exists(temp_file)
class TestMainFunction:
"""测试主函数"""
@pytest.fixture
def mock_dependencies(self):
"""创建所有依赖的mock对象"""
with (
patch("src.main.ProductReader"),
patch("src.main.CategoryReader"),
patch("src.main.BatchFileWriter"),
patch("src.main.ResultParser"),
patch("src.main.ResultWriter"),
patch("src.main.FileHandler") as MockFileHandler,
patch("src.main.BatchClient") as MockBatchClient,
patch("src.main.PromptBuilder") as MockPromptBuilder,
):
# 创建mock实例
mock_file_handler = MagicMock()
MockFileHandler.return_value = mock_file_handler
mock_batch_client = MagicMock()
MockBatchClient.return_value = mock_batch_client
mock_prompt_builder = MagicMock()
MockPromptBuilder.return_value = mock_prompt_builder
# 配置mock返回值
mock_file_handler.read_products.return_value = [
ProductInput(
product_id="P001", product_name="产品1", scenic_spot="景区1"
),
ProductInput(
product_id="P002", product_name="产品2", scenic_spot="景区2"
),
]
mock_file_handler.read_categories.return_value = [
ProductCategory(category="门票", type="自然类", sub_type="自然风光")
]
mock_prompt_builder.build_system_prompt.return_value = "系统提示词"
mock_prompt_builder.build_classification_request.return_value = {
"custom_id": "P001",
"method": "POST",
"url": "/v1/chat/completions",
"body": {},
}
mock_file_handler.write_batch_requests.return_value = "temp.jsonl"
mock_batch_client.upload_file.return_value = "file-123"
mock_batch_client.create_batch.return_value = "batch-456"
mock_batch_client.wait_for_completion.return_value = BatchTaskInfo(
task_id="batch-456",
status="completed",
input_file_id="file-123",
output_file_id="file-output-789",
error_file_id=None,
)
mock_batch_client.download_result.return_value = "result content"
mock_file_handler.parse_batch_responses.return_value = [
ClassificationResult(
product_id="P001",
category="门票",
type="自然类",
sub_type="自然风光",
)
]
yield {
"file_handler": mock_file_handler,
"batch_client": mock_batch_client,
"prompt_builder": mock_prompt_builder,
}
def test_main_success_flow(self, mock_dependencies):
"""测试成功的完整流程"""
# Given: 所有mock都配置为成功
with (
tempfile.NamedTemporaryFile(
mode="w", suffix=".xlsx", delete=False
) as input_f,
tempfile.NamedTemporaryFile(
mode="w", suffix=".jsonl", delete=False
) as category_f,
tempfile.NamedTemporaryFile(
mode="w", suffix=".jsonl", delete=False
) as output_f,
):
input_file = input_f.name
category_file = category_f.name
output_file = output_f.name
try:
# 创建配置对象
config = AppConfig(
files=FilesConfig(
input_file=input_file,
category_file=category_file,
output_file=output_file,
),
api=APIConfig(api_key="test_api_key"),
)
# When: 执行main函数
main(config)
# Then: 所有步骤应该被调用
file_handler = mock_dependencies["file_handler"]
batch_client = mock_dependencies["batch_client"]
prompt_builder = mock_dependencies["prompt_builder"]
file_handler.read_products.assert_called_once()
file_handler.read_categories.assert_called_once()
prompt_builder.build_system_prompt.assert_called_once()
batch_client.upload_file.assert_called_once()
batch_client.create_batch.assert_called_once()
batch_client.wait_for_completion.assert_called_once()
batch_client.download_result.assert_called_once()
file_handler.parse_batch_responses.assert_called_once()
file_handler.write_results.assert_called_once()
finally:
# 清理临时文件
for f in [input_file, category_file, output_file]:
if os.path.exists(f):
os.remove(f)
def test_main_with_env_api_key(self, mock_dependencies):
"""测试使用环境变量的API Key"""
# Given: 环境变量设置了API Key
with patch.dict(os.environ, {"DASHSCOPE_API_KEY": "env_api_key"}):
with (
tempfile.NamedTemporaryFile(
mode="w", suffix=".xlsx", delete=False
) as input_f,
tempfile.NamedTemporaryFile(
mode="w", suffix=".jsonl", delete=False
) as category_f,
tempfile.NamedTemporaryFile(
mode="w", suffix=".jsonl", delete=False
) as output_f,
):
input_file = input_f.name
category_file = category_f.name
output_file = output_f.name
try:
# 创建配置对象api_key为None表示从环境变量读取
config = AppConfig(
files=FilesConfig(
input_file=input_file,
category_file=category_file,
output_file=output_file,
),
api=APIConfig(api_key=None),
)
# When: 执行main函数
main(config)
# Then: 应该成功执行
# (如果失败会抛出异常)
finally:
# 清理临时文件
for f in [input_file, category_file, output_file]:
if os.path.exists(f):
os.remove(f)
def test_main_batch_task_failed(self, mock_dependencies):
"""测试Batch任务失败"""
# Given: Batch任务返回failed状态
mock_dependencies["batch_client"].wait_for_completion.return_value = (
BatchTaskInfo(
task_id="batch-456",
status="failed",
input_file_id="file-123",
output_file_id=None,
error_file_id="file-error-999",
)
)
with (
tempfile.NamedTemporaryFile(
mode="w", suffix=".xlsx", delete=False
) as input_f,
tempfile.NamedTemporaryFile(
mode="w", suffix=".jsonl", delete=False
) as category_f,
tempfile.NamedTemporaryFile(
mode="w", suffix=".jsonl", delete=False
) as output_f,
):
input_file = input_f.name
category_file = category_f.name
output_file = output_f.name
try:
# 创建配置对象
config = AppConfig(
files=FilesConfig(
input_file=input_file,
category_file=category_file,
output_file=output_file,
),
api=APIConfig(api_key="test_api_key"),
)
# When & Then: 应该抛出RuntimeError
with pytest.raises(RuntimeError) as exc_info:
main(config)
assert "未成功完成" in str(exc_info.value)
finally:
# 清理临时文件
for f in [input_file, category_file, output_file]:
if os.path.exists(f):
os.remove(f)
def test_main_file_not_found(self, mock_dependencies):
"""测试输入文件不存在"""
# Given: 文件不存在
mock_dependencies["file_handler"].read_products.side_effect = FileNotFoundError(
"文件不存在"
)
with (
tempfile.NamedTemporaryFile(
mode="w", suffix=".jsonl", delete=False
) as category_f,
tempfile.NamedTemporaryFile(
mode="w", suffix=".jsonl", delete=False
) as output_f,
):
category_file = category_f.name
output_file = output_f.name
try:
# 创建配置对象
config = AppConfig(
files=FilesConfig(
input_file="nonexistent.xlsx",
category_file=category_file,
output_file=output_file,
),
api=APIConfig(api_key="test_api_key"),
)
# When & Then: 应该抛出FileNotFoundError
with pytest.raises(FileNotFoundError):
main(config)
finally:
# 清理临时文件
for f in [category_file, output_file]:
if os.path.exists(f):
os.remove(f)
def test_main_with_error_file(self, mock_dependencies):
"""测试存在错误文件的情况"""
# Given: 任务完成但有错误文件
mock_dependencies["batch_client"].wait_for_completion.return_value = (
BatchTaskInfo(
task_id="batch-456",
status="completed",
input_file_id="file-123",
output_file_id="file-output-789",
error_file_id="file-error-999",
)
)
mock_dependencies["batch_client"].download_result.side_effect = [
"result content", # 第一次调用返回结果内容
"error content", # 第二次调用返回错误内容
]
with (
tempfile.NamedTemporaryFile(
mode="w", suffix=".xlsx", delete=False
) as input_f,
tempfile.NamedTemporaryFile(
mode="w", suffix=".jsonl", delete=False
) as category_f,
tempfile.NamedTemporaryFile(
mode="w", suffix=".jsonl", delete=False
) as output_f,
):
input_file = input_f.name
category_file = category_f.name
output_file = output_f.name
try:
# 创建配置对象
config = AppConfig(
files=FilesConfig(
input_file=input_file,
category_file=category_file,
output_file=output_file,
),
api=APIConfig(api_key="test_api_key"),
)
# When: 执行main函数
main(config)
# Then: download_result应该被调用两次
assert mock_dependencies["batch_client"].download_result.call_count == 2
# 错误文件应该被创建
error_file = output_file.replace(".jsonl", "_errors.jsonl")
assert os.path.exists(error_file)
# 清理错误文件
if os.path.exists(error_file):
os.remove(error_file)
finally:
# 清理临时文件
for f in [input_file, category_file, output_file]:
if os.path.exists(f):
os.remove(f)