CategorizeLabel/tests/test_main.py
2025-10-15 17:19:26 +08:00

405 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""测试主程序
测试main.py的主流程和集成功能
"""
import pytest
import os
import tempfile
from unittest.mock import MagicMock, patch
from src.main import main, setup_logger, get_env_api_key, cleanup_temp_files
from src.config import AppConfig, FilesConfig, APIConfig
from src.models import (
ProductInput,
ProductCategory,
ClassificationResult,
BatchTaskInfo,
)
class TestHelperFunctions:
"""测试辅助函数"""
def test_setup_logger(self):
"""测试设置日志"""
# Given & When: 设置日志
with tempfile.TemporaryDirectory() as temp_dir:
logger = setup_logger(log_dir=temp_dir)
# Then: 日志记录器应该被正确配置
assert logger.name == "product_classifier"
assert logger.level == 20 # INFO level
assert len(logger.handlers) >= 2 # 文件和控制台handler
# 清理handlers以释放文件
for handler in logger.handlers[:]:
handler.close()
logger.removeHandler(handler)
def test_get_env_api_key_success(self):
"""测试成功获取环境变量API Key"""
# Given: 环境变量已设置
with patch.dict(os.environ, {"DASHSCOPE_API_KEY": "test_key_123"}):
# When: 获取API Key
api_key = get_env_api_key()
# Then: 应该返回正确的API Key
assert api_key == "test_key_123"
def test_get_env_api_key_not_set(self):
"""测试环境变量未设置"""
# Given: 环境变量未设置
with patch.dict(os.environ, {}, clear=True):
# When & Then: 应该抛出ValueError
with pytest.raises(ValueError) as exc_info:
get_env_api_key()
assert "未设置" in str(exc_info.value)
def test_cleanup_temp_files(self):
"""测试清理临时文件"""
# Given: 创建临时文件
with tempfile.NamedTemporaryFile(delete=False) as f:
temp_file = f.name
f.write(b"test content")
logger = MagicMock()
# When: 清理文件
cleanup_temp_files([temp_file], logger)
# Then: 文件应该被删除
assert not os.path.exists(temp_file)
class TestMainFunction:
"""测试主函数"""
@pytest.fixture
def mock_dependencies(self):
"""创建所有依赖的mock对象"""
with (
patch("src.main.ProductReader"),
patch("src.main.CategoryReader"),
patch("src.main.BatchFileWriter"),
patch("src.main.ResultParser"),
patch("src.main.ResultWriter"),
patch("src.main.FileHandler") as MockFileHandler,
patch("src.main.BatchClient") as MockBatchClient,
patch("src.main.PromptBuilder") as MockPromptBuilder,
):
# 创建mock实例
mock_file_handler = MagicMock()
MockFileHandler.return_value = mock_file_handler
mock_batch_client = MagicMock()
MockBatchClient.return_value = mock_batch_client
mock_prompt_builder = MagicMock()
MockPromptBuilder.return_value = mock_prompt_builder
# 配置mock返回值
mock_file_handler.read_products.return_value = [
ProductInput(
product_id="P001", product_name="产品1", scenic_spot="景区1"
),
ProductInput(
product_id="P002", product_name="产品2", scenic_spot="景区2"
),
]
mock_file_handler.read_categories.return_value = [
ProductCategory(category="门票", type="自然类", sub_type="自然风光")
]
mock_prompt_builder.build_system_prompt.return_value = "系统提示词"
mock_prompt_builder.build_classification_request.return_value = {
"custom_id": "P001",
"method": "POST",
"url": "/v1/chat/completions",
"body": {},
}
mock_file_handler.write_batch_requests.return_value = "temp.jsonl"
mock_batch_client.upload_file.return_value = "file-123"
mock_batch_client.create_batch.return_value = "batch-456"
mock_batch_client.wait_for_completion.return_value = BatchTaskInfo(
task_id="batch-456",
status="completed",
input_file_id="file-123",
output_file_id="file-output-789",
error_file_id=None,
)
mock_batch_client.download_result.return_value = "result content"
mock_file_handler.parse_batch_responses.return_value = [
ClassificationResult(
product_id="P001",
category="门票",
type="自然类",
sub_type="自然风光",
)
]
yield {
"file_handler": mock_file_handler,
"batch_client": mock_batch_client,
"prompt_builder": mock_prompt_builder,
}
def test_main_success_flow(self, mock_dependencies):
"""测试成功的完整流程"""
# Given: 所有mock都配置为成功
with (
tempfile.NamedTemporaryFile(
mode="w", suffix=".xlsx", delete=False
) as input_f,
tempfile.NamedTemporaryFile(
mode="w", suffix=".jsonl", delete=False
) as category_f,
tempfile.NamedTemporaryFile(
mode="w", suffix=".jsonl", delete=False
) as output_f,
):
input_file = input_f.name
category_file = category_f.name
output_file = output_f.name
try:
# 创建配置对象
config = AppConfig(
files=FilesConfig(
input_file=input_file,
category_file=category_file,
output_file=output_file,
),
api=APIConfig(api_key="test_api_key"),
)
# When: 执行main函数
main(config)
# Then: 所有步骤应该被调用
file_handler = mock_dependencies["file_handler"]
batch_client = mock_dependencies["batch_client"]
prompt_builder = mock_dependencies["prompt_builder"]
file_handler.read_products.assert_called_once()
file_handler.read_categories.assert_called_once()
prompt_builder.build_system_prompt.assert_called_once()
batch_client.upload_file.assert_called_once()
batch_client.create_batch.assert_called_once()
batch_client.wait_for_completion.assert_called_once()
batch_client.download_result.assert_called_once()
file_handler.parse_batch_responses.assert_called_once()
file_handler.write_results.assert_called_once()
finally:
# 清理临时文件
for f in [input_file, category_file, output_file]:
if os.path.exists(f):
os.remove(f)
def test_main_with_env_api_key(self, mock_dependencies):
"""测试使用环境变量的API Key"""
# Given: 环境变量设置了API Key
with patch.dict(os.environ, {"DASHSCOPE_API_KEY": "env_api_key"}):
with (
tempfile.NamedTemporaryFile(
mode="w", suffix=".xlsx", delete=False
) as input_f,
tempfile.NamedTemporaryFile(
mode="w", suffix=".jsonl", delete=False
) as category_f,
tempfile.NamedTemporaryFile(
mode="w", suffix=".jsonl", delete=False
) as output_f,
):
input_file = input_f.name
category_file = category_f.name
output_file = output_f.name
try:
# 创建配置对象api_key为None表示从环境变量读取
config = AppConfig(
files=FilesConfig(
input_file=input_file,
category_file=category_file,
output_file=output_file,
),
api=APIConfig(api_key=None),
)
# When: 执行main函数
main(config)
# Then: 应该成功执行
# (如果失败会抛出异常)
finally:
# 清理临时文件
for f in [input_file, category_file, output_file]:
if os.path.exists(f):
os.remove(f)
def test_main_batch_task_failed(self, mock_dependencies):
"""测试Batch任务失败"""
# Given: Batch任务返回failed状态
mock_dependencies["batch_client"].wait_for_completion.return_value = (
BatchTaskInfo(
task_id="batch-456",
status="failed",
input_file_id="file-123",
output_file_id=None,
error_file_id="file-error-999",
)
)
with (
tempfile.NamedTemporaryFile(
mode="w", suffix=".xlsx", delete=False
) as input_f,
tempfile.NamedTemporaryFile(
mode="w", suffix=".jsonl", delete=False
) as category_f,
tempfile.NamedTemporaryFile(
mode="w", suffix=".jsonl", delete=False
) as output_f,
):
input_file = input_f.name
category_file = category_f.name
output_file = output_f.name
try:
# 创建配置对象
config = AppConfig(
files=FilesConfig(
input_file=input_file,
category_file=category_file,
output_file=output_file,
),
api=APIConfig(api_key="test_api_key"),
)
# When & Then: 应该抛出RuntimeError
with pytest.raises(RuntimeError) as exc_info:
main(config)
assert "未成功完成" in str(exc_info.value)
finally:
# 清理临时文件
for f in [input_file, category_file, output_file]:
if os.path.exists(f):
os.remove(f)
def test_main_file_not_found(self, mock_dependencies):
"""测试输入文件不存在"""
# Given: 文件不存在
mock_dependencies["file_handler"].read_products.side_effect = FileNotFoundError(
"文件不存在"
)
with (
tempfile.NamedTemporaryFile(
mode="w", suffix=".jsonl", delete=False
) as category_f,
tempfile.NamedTemporaryFile(
mode="w", suffix=".jsonl", delete=False
) as output_f,
):
category_file = category_f.name
output_file = output_f.name
try:
# 创建配置对象
config = AppConfig(
files=FilesConfig(
input_file="nonexistent.xlsx",
category_file=category_file,
output_file=output_file,
),
api=APIConfig(api_key="test_api_key"),
)
# When & Then: 应该抛出FileNotFoundError
with pytest.raises(FileNotFoundError):
main(config)
finally:
# 清理临时文件
for f in [category_file, output_file]:
if os.path.exists(f):
os.remove(f)
def test_main_with_error_file(self, mock_dependencies):
"""测试存在错误文件的情况"""
# Given: 任务完成但有错误文件
mock_dependencies["batch_client"].wait_for_completion.return_value = (
BatchTaskInfo(
task_id="batch-456",
status="completed",
input_file_id="file-123",
output_file_id="file-output-789",
error_file_id="file-error-999",
)
)
mock_dependencies["batch_client"].download_result.side_effect = [
"result content", # 第一次调用返回结果内容
"error content", # 第二次调用返回错误内容
]
with (
tempfile.NamedTemporaryFile(
mode="w", suffix=".xlsx", delete=False
) as input_f,
tempfile.NamedTemporaryFile(
mode="w", suffix=".jsonl", delete=False
) as category_f,
tempfile.NamedTemporaryFile(
mode="w", suffix=".jsonl", delete=False
) as output_f,
):
input_file = input_f.name
category_file = category_f.name
output_file = output_f.name
try:
# 创建配置对象
config = AppConfig(
files=FilesConfig(
input_file=input_file,
category_file=category_file,
output_file=output_file,
),
api=APIConfig(api_key="test_api_key"),
)
# When: 执行main函数
main(config)
# Then: download_result应该被调用两次
assert mock_dependencies["batch_client"].download_result.call_count == 2
# 错误文件应该被创建
error_file = output_file.replace(".jsonl", "_errors.jsonl")
assert os.path.exists(error_file)
# 清理错误文件
if os.path.exists(error_file):
os.remove(error_file)
finally:
# 清理临时文件
for f in [input_file, category_file, output_file]:
if os.path.exists(f):
os.remove(f)