405 lines
14 KiB
Python
405 lines
14 KiB
Python
|
|
"""测试主程序
|
|||
|
|
|
|||
|
|
测试main.py的主流程和集成功能
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import pytest
|
|||
|
|
import os
|
|||
|
|
import tempfile
|
|||
|
|
from unittest.mock import MagicMock, patch
|
|||
|
|
|
|||
|
|
from src.main import main, setup_logger, get_env_api_key, cleanup_temp_files
|
|||
|
|
from src.config import AppConfig, FilesConfig, APIConfig
|
|||
|
|
from src.models import (
|
|||
|
|
ProductInput,
|
|||
|
|
ProductCategory,
|
|||
|
|
ClassificationResult,
|
|||
|
|
BatchTaskInfo,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestHelperFunctions:
|
|||
|
|
"""测试辅助函数"""
|
|||
|
|
|
|||
|
|
def test_setup_logger(self):
|
|||
|
|
"""测试设置日志"""
|
|||
|
|
# Given & When: 设置日志
|
|||
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|||
|
|
logger = setup_logger(log_dir=temp_dir)
|
|||
|
|
|
|||
|
|
# Then: 日志记录器应该被正确配置
|
|||
|
|
assert logger.name == "product_classifier"
|
|||
|
|
assert logger.level == 20 # INFO level
|
|||
|
|
assert len(logger.handlers) >= 2 # 文件和控制台handler
|
|||
|
|
|
|||
|
|
# 清理handlers以释放文件
|
|||
|
|
for handler in logger.handlers[:]:
|
|||
|
|
handler.close()
|
|||
|
|
logger.removeHandler(handler)
|
|||
|
|
|
|||
|
|
def test_get_env_api_key_success(self):
|
|||
|
|
"""测试成功获取环境变量API Key"""
|
|||
|
|
# Given: 环境变量已设置
|
|||
|
|
with patch.dict(os.environ, {"DASHSCOPE_API_KEY": "test_key_123"}):
|
|||
|
|
# When: 获取API Key
|
|||
|
|
api_key = get_env_api_key()
|
|||
|
|
|
|||
|
|
# Then: 应该返回正确的API Key
|
|||
|
|
assert api_key == "test_key_123"
|
|||
|
|
|
|||
|
|
def test_get_env_api_key_not_set(self):
|
|||
|
|
"""测试环境变量未设置"""
|
|||
|
|
# Given: 环境变量未设置
|
|||
|
|
with patch.dict(os.environ, {}, clear=True):
|
|||
|
|
# When & Then: 应该抛出ValueError
|
|||
|
|
with pytest.raises(ValueError) as exc_info:
|
|||
|
|
get_env_api_key()
|
|||
|
|
assert "未设置" in str(exc_info.value)
|
|||
|
|
|
|||
|
|
def test_cleanup_temp_files(self):
|
|||
|
|
"""测试清理临时文件"""
|
|||
|
|
# Given: 创建临时文件
|
|||
|
|
with tempfile.NamedTemporaryFile(delete=False) as f:
|
|||
|
|
temp_file = f.name
|
|||
|
|
f.write(b"test content")
|
|||
|
|
|
|||
|
|
logger = MagicMock()
|
|||
|
|
|
|||
|
|
# When: 清理文件
|
|||
|
|
cleanup_temp_files([temp_file], logger)
|
|||
|
|
|
|||
|
|
# Then: 文件应该被删除
|
|||
|
|
assert not os.path.exists(temp_file)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestMainFunction:
|
|||
|
|
"""测试主函数"""
|
|||
|
|
|
|||
|
|
@pytest.fixture
|
|||
|
|
def mock_dependencies(self):
|
|||
|
|
"""创建所有依赖的mock对象"""
|
|||
|
|
with (
|
|||
|
|
patch("src.main.ProductReader"),
|
|||
|
|
patch("src.main.CategoryReader"),
|
|||
|
|
patch("src.main.BatchFileWriter"),
|
|||
|
|
patch("src.main.ResultParser"),
|
|||
|
|
patch("src.main.ResultWriter"),
|
|||
|
|
patch("src.main.FileHandler") as MockFileHandler,
|
|||
|
|
patch("src.main.BatchClient") as MockBatchClient,
|
|||
|
|
patch("src.main.PromptBuilder") as MockPromptBuilder,
|
|||
|
|
):
|
|||
|
|
|
|||
|
|
# 创建mock实例
|
|||
|
|
mock_file_handler = MagicMock()
|
|||
|
|
MockFileHandler.return_value = mock_file_handler
|
|||
|
|
|
|||
|
|
mock_batch_client = MagicMock()
|
|||
|
|
MockBatchClient.return_value = mock_batch_client
|
|||
|
|
|
|||
|
|
mock_prompt_builder = MagicMock()
|
|||
|
|
MockPromptBuilder.return_value = mock_prompt_builder
|
|||
|
|
|
|||
|
|
# 配置mock返回值
|
|||
|
|
mock_file_handler.read_products.return_value = [
|
|||
|
|
ProductInput(
|
|||
|
|
product_id="P001", product_name="产品1", scenic_spot="景区1"
|
|||
|
|
),
|
|||
|
|
ProductInput(
|
|||
|
|
product_id="P002", product_name="产品2", scenic_spot="景区2"
|
|||
|
|
),
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
mock_file_handler.read_categories.return_value = [
|
|||
|
|
ProductCategory(category="门票", type="自然类", sub_type="自然风光")
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
mock_prompt_builder.build_system_prompt.return_value = "系统提示词"
|
|||
|
|
mock_prompt_builder.build_classification_request.return_value = {
|
|||
|
|
"custom_id": "P001",
|
|||
|
|
"method": "POST",
|
|||
|
|
"url": "/v1/chat/completions",
|
|||
|
|
"body": {},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
mock_file_handler.write_batch_requests.return_value = "temp.jsonl"
|
|||
|
|
|
|||
|
|
mock_batch_client.upload_file.return_value = "file-123"
|
|||
|
|
mock_batch_client.create_batch.return_value = "batch-456"
|
|||
|
|
mock_batch_client.wait_for_completion.return_value = BatchTaskInfo(
|
|||
|
|
task_id="batch-456",
|
|||
|
|
status="completed",
|
|||
|
|
input_file_id="file-123",
|
|||
|
|
output_file_id="file-output-789",
|
|||
|
|
error_file_id=None,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
mock_batch_client.download_result.return_value = "result content"
|
|||
|
|
|
|||
|
|
mock_file_handler.parse_batch_responses.return_value = [
|
|||
|
|
ClassificationResult(
|
|||
|
|
product_id="P001",
|
|||
|
|
category="门票",
|
|||
|
|
type="自然类",
|
|||
|
|
sub_type="自然风光",
|
|||
|
|
)
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
yield {
|
|||
|
|
"file_handler": mock_file_handler,
|
|||
|
|
"batch_client": mock_batch_client,
|
|||
|
|
"prompt_builder": mock_prompt_builder,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def test_main_success_flow(self, mock_dependencies):
|
|||
|
|
"""测试成功的完整流程"""
|
|||
|
|
# Given: 所有mock都配置为成功
|
|||
|
|
with (
|
|||
|
|
tempfile.NamedTemporaryFile(
|
|||
|
|
mode="w", suffix=".xlsx", delete=False
|
|||
|
|
) as input_f,
|
|||
|
|
tempfile.NamedTemporaryFile(
|
|||
|
|
mode="w", suffix=".jsonl", delete=False
|
|||
|
|
) as category_f,
|
|||
|
|
tempfile.NamedTemporaryFile(
|
|||
|
|
mode="w", suffix=".jsonl", delete=False
|
|||
|
|
) as output_f,
|
|||
|
|
):
|
|||
|
|
|
|||
|
|
input_file = input_f.name
|
|||
|
|
category_file = category_f.name
|
|||
|
|
output_file = output_f.name
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 创建配置对象
|
|||
|
|
config = AppConfig(
|
|||
|
|
files=FilesConfig(
|
|||
|
|
input_file=input_file,
|
|||
|
|
category_file=category_file,
|
|||
|
|
output_file=output_file,
|
|||
|
|
),
|
|||
|
|
api=APIConfig(api_key="test_api_key"),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# When: 执行main函数
|
|||
|
|
main(config)
|
|||
|
|
|
|||
|
|
# Then: 所有步骤应该被调用
|
|||
|
|
file_handler = mock_dependencies["file_handler"]
|
|||
|
|
batch_client = mock_dependencies["batch_client"]
|
|||
|
|
prompt_builder = mock_dependencies["prompt_builder"]
|
|||
|
|
|
|||
|
|
file_handler.read_products.assert_called_once()
|
|||
|
|
file_handler.read_categories.assert_called_once()
|
|||
|
|
prompt_builder.build_system_prompt.assert_called_once()
|
|||
|
|
batch_client.upload_file.assert_called_once()
|
|||
|
|
batch_client.create_batch.assert_called_once()
|
|||
|
|
batch_client.wait_for_completion.assert_called_once()
|
|||
|
|
batch_client.download_result.assert_called_once()
|
|||
|
|
file_handler.parse_batch_responses.assert_called_once()
|
|||
|
|
file_handler.write_results.assert_called_once()
|
|||
|
|
|
|||
|
|
finally:
|
|||
|
|
# 清理临时文件
|
|||
|
|
for f in [input_file, category_file, output_file]:
|
|||
|
|
if os.path.exists(f):
|
|||
|
|
os.remove(f)
|
|||
|
|
|
|||
|
|
def test_main_with_env_api_key(self, mock_dependencies):
|
|||
|
|
"""测试使用环境变量的API Key"""
|
|||
|
|
# Given: 环境变量设置了API Key
|
|||
|
|
with patch.dict(os.environ, {"DASHSCOPE_API_KEY": "env_api_key"}):
|
|||
|
|
with (
|
|||
|
|
tempfile.NamedTemporaryFile(
|
|||
|
|
mode="w", suffix=".xlsx", delete=False
|
|||
|
|
) as input_f,
|
|||
|
|
tempfile.NamedTemporaryFile(
|
|||
|
|
mode="w", suffix=".jsonl", delete=False
|
|||
|
|
) as category_f,
|
|||
|
|
tempfile.NamedTemporaryFile(
|
|||
|
|
mode="w", suffix=".jsonl", delete=False
|
|||
|
|
) as output_f,
|
|||
|
|
):
|
|||
|
|
|
|||
|
|
input_file = input_f.name
|
|||
|
|
category_file = category_f.name
|
|||
|
|
output_file = output_f.name
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 创建配置对象,api_key为None表示从环境变量读取
|
|||
|
|
config = AppConfig(
|
|||
|
|
files=FilesConfig(
|
|||
|
|
input_file=input_file,
|
|||
|
|
category_file=category_file,
|
|||
|
|
output_file=output_file,
|
|||
|
|
),
|
|||
|
|
api=APIConfig(api_key=None),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# When: 执行main函数
|
|||
|
|
main(config)
|
|||
|
|
|
|||
|
|
# Then: 应该成功执行
|
|||
|
|
# (如果失败会抛出异常)
|
|||
|
|
|
|||
|
|
finally:
|
|||
|
|
# 清理临时文件
|
|||
|
|
for f in [input_file, category_file, output_file]:
|
|||
|
|
if os.path.exists(f):
|
|||
|
|
os.remove(f)
|
|||
|
|
|
|||
|
|
def test_main_batch_task_failed(self, mock_dependencies):
|
|||
|
|
"""测试Batch任务失败"""
|
|||
|
|
# Given: Batch任务返回failed状态
|
|||
|
|
mock_dependencies["batch_client"].wait_for_completion.return_value = (
|
|||
|
|
BatchTaskInfo(
|
|||
|
|
task_id="batch-456",
|
|||
|
|
status="failed",
|
|||
|
|
input_file_id="file-123",
|
|||
|
|
output_file_id=None,
|
|||
|
|
error_file_id="file-error-999",
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
with (
|
|||
|
|
tempfile.NamedTemporaryFile(
|
|||
|
|
mode="w", suffix=".xlsx", delete=False
|
|||
|
|
) as input_f,
|
|||
|
|
tempfile.NamedTemporaryFile(
|
|||
|
|
mode="w", suffix=".jsonl", delete=False
|
|||
|
|
) as category_f,
|
|||
|
|
tempfile.NamedTemporaryFile(
|
|||
|
|
mode="w", suffix=".jsonl", delete=False
|
|||
|
|
) as output_f,
|
|||
|
|
):
|
|||
|
|
|
|||
|
|
input_file = input_f.name
|
|||
|
|
category_file = category_f.name
|
|||
|
|
output_file = output_f.name
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 创建配置对象
|
|||
|
|
config = AppConfig(
|
|||
|
|
files=FilesConfig(
|
|||
|
|
input_file=input_file,
|
|||
|
|
category_file=category_file,
|
|||
|
|
output_file=output_file,
|
|||
|
|
),
|
|||
|
|
api=APIConfig(api_key="test_api_key"),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# When & Then: 应该抛出RuntimeError
|
|||
|
|
with pytest.raises(RuntimeError) as exc_info:
|
|||
|
|
main(config)
|
|||
|
|
assert "未成功完成" in str(exc_info.value)
|
|||
|
|
|
|||
|
|
finally:
|
|||
|
|
# 清理临时文件
|
|||
|
|
for f in [input_file, category_file, output_file]:
|
|||
|
|
if os.path.exists(f):
|
|||
|
|
os.remove(f)
|
|||
|
|
|
|||
|
|
def test_main_file_not_found(self, mock_dependencies):
|
|||
|
|
"""测试输入文件不存在"""
|
|||
|
|
# Given: 文件不存在
|
|||
|
|
mock_dependencies["file_handler"].read_products.side_effect = FileNotFoundError(
|
|||
|
|
"文件不存在"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
with (
|
|||
|
|
tempfile.NamedTemporaryFile(
|
|||
|
|
mode="w", suffix=".jsonl", delete=False
|
|||
|
|
) as category_f,
|
|||
|
|
tempfile.NamedTemporaryFile(
|
|||
|
|
mode="w", suffix=".jsonl", delete=False
|
|||
|
|
) as output_f,
|
|||
|
|
):
|
|||
|
|
|
|||
|
|
category_file = category_f.name
|
|||
|
|
output_file = output_f.name
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 创建配置对象
|
|||
|
|
config = AppConfig(
|
|||
|
|
files=FilesConfig(
|
|||
|
|
input_file="nonexistent.xlsx",
|
|||
|
|
category_file=category_file,
|
|||
|
|
output_file=output_file,
|
|||
|
|
),
|
|||
|
|
api=APIConfig(api_key="test_api_key"),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# When & Then: 应该抛出FileNotFoundError
|
|||
|
|
with pytest.raises(FileNotFoundError):
|
|||
|
|
main(config)
|
|||
|
|
|
|||
|
|
finally:
|
|||
|
|
# 清理临时文件
|
|||
|
|
for f in [category_file, output_file]:
|
|||
|
|
if os.path.exists(f):
|
|||
|
|
os.remove(f)
|
|||
|
|
|
|||
|
|
def test_main_with_error_file(self, mock_dependencies):
|
|||
|
|
"""测试存在错误文件的情况"""
|
|||
|
|
# Given: 任务完成但有错误文件
|
|||
|
|
mock_dependencies["batch_client"].wait_for_completion.return_value = (
|
|||
|
|
BatchTaskInfo(
|
|||
|
|
task_id="batch-456",
|
|||
|
|
status="completed",
|
|||
|
|
input_file_id="file-123",
|
|||
|
|
output_file_id="file-output-789",
|
|||
|
|
error_file_id="file-error-999",
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
mock_dependencies["batch_client"].download_result.side_effect = [
|
|||
|
|
"result content", # 第一次调用返回结果内容
|
|||
|
|
"error content", # 第二次调用返回错误内容
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
with (
|
|||
|
|
tempfile.NamedTemporaryFile(
|
|||
|
|
mode="w", suffix=".xlsx", delete=False
|
|||
|
|
) as input_f,
|
|||
|
|
tempfile.NamedTemporaryFile(
|
|||
|
|
mode="w", suffix=".jsonl", delete=False
|
|||
|
|
) as category_f,
|
|||
|
|
tempfile.NamedTemporaryFile(
|
|||
|
|
mode="w", suffix=".jsonl", delete=False
|
|||
|
|
) as output_f,
|
|||
|
|
):
|
|||
|
|
|
|||
|
|
input_file = input_f.name
|
|||
|
|
category_file = category_f.name
|
|||
|
|
output_file = output_f.name
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 创建配置对象
|
|||
|
|
config = AppConfig(
|
|||
|
|
files=FilesConfig(
|
|||
|
|
input_file=input_file,
|
|||
|
|
category_file=category_file,
|
|||
|
|
output_file=output_file,
|
|||
|
|
),
|
|||
|
|
api=APIConfig(api_key="test_api_key"),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# When: 执行main函数
|
|||
|
|
main(config)
|
|||
|
|
|
|||
|
|
# Then: download_result应该被调用两次
|
|||
|
|
assert mock_dependencies["batch_client"].download_result.call_count == 2
|
|||
|
|
|
|||
|
|
# 错误文件应该被创建
|
|||
|
|
error_file = output_file.replace(".jsonl", "_errors.jsonl")
|
|||
|
|
assert os.path.exists(error_file)
|
|||
|
|
|
|||
|
|
# 清理错误文件
|
|||
|
|
if os.path.exists(error_file):
|
|||
|
|
os.remove(error_file)
|
|||
|
|
|
|||
|
|
finally:
|
|||
|
|
# 清理临时文件
|
|||
|
|
for f in [input_file, category_file, output_file]:
|
|||
|
|
if os.path.exists(f):
|
|||
|
|
os.remove(f)
|