"""测试主程序 测试main.py的主流程和集成功能 """ import pytest import os import tempfile from unittest.mock import MagicMock, patch from src.main import main, setup_logger, get_env_api_key, cleanup_temp_files from src.config import AppConfig, FilesConfig, APIConfig from src.models import ( ProductInput, ProductCategory, ClassificationResult, BatchTaskInfo, ) class TestHelperFunctions: """测试辅助函数""" def test_setup_logger(self): """测试设置日志""" # Given & When: 设置日志 with tempfile.TemporaryDirectory() as temp_dir: logger = setup_logger(log_dir=temp_dir) # Then: 日志记录器应该被正确配置 assert logger.name == "product_classifier" assert logger.level == 20 # INFO level assert len(logger.handlers) >= 2 # 文件和控制台handler # 清理handlers以释放文件 for handler in logger.handlers[:]: handler.close() logger.removeHandler(handler) def test_get_env_api_key_success(self): """测试成功获取环境变量API Key""" # Given: 环境变量已设置 with patch.dict(os.environ, {"DASHSCOPE_API_KEY": "test_key_123"}): # When: 获取API Key api_key = get_env_api_key() # Then: 应该返回正确的API Key assert api_key == "test_key_123" def test_get_env_api_key_not_set(self): """测试环境变量未设置""" # Given: 环境变量未设置 with patch.dict(os.environ, {}, clear=True): # When & Then: 应该抛出ValueError with pytest.raises(ValueError) as exc_info: get_env_api_key() assert "未设置" in str(exc_info.value) def test_cleanup_temp_files(self): """测试清理临时文件""" # Given: 创建临时文件 with tempfile.NamedTemporaryFile(delete=False) as f: temp_file = f.name f.write(b"test content") logger = MagicMock() # When: 清理文件 cleanup_temp_files([temp_file], logger) # Then: 文件应该被删除 assert not os.path.exists(temp_file) class TestMainFunction: """测试主函数""" @pytest.fixture def mock_dependencies(self): """创建所有依赖的mock对象""" with ( patch("src.main.ProductReader"), patch("src.main.CategoryReader"), patch("src.main.BatchFileWriter"), patch("src.main.ResultParser"), patch("src.main.ResultWriter"), patch("src.main.FileHandler") as MockFileHandler, patch("src.main.BatchClient") as MockBatchClient, patch("src.main.PromptBuilder") as MockPromptBuilder, ): # 创建mock实例 mock_file_handler = MagicMock() MockFileHandler.return_value = mock_file_handler mock_batch_client = MagicMock() MockBatchClient.return_value = mock_batch_client mock_prompt_builder = MagicMock() MockPromptBuilder.return_value = mock_prompt_builder # 配置mock返回值 mock_file_handler.read_products.return_value = [ ProductInput( product_id="P001", product_name="产品1", scenic_spot="景区1" ), ProductInput( product_id="P002", product_name="产品2", scenic_spot="景区2" ), ] mock_file_handler.read_categories.return_value = [ ProductCategory(category="门票", type="自然类", sub_type="自然风光") ] mock_prompt_builder.build_system_prompt.return_value = "系统提示词" mock_prompt_builder.build_classification_request.return_value = { "custom_id": "P001", "method": "POST", "url": "/v1/chat/completions", "body": {}, } mock_file_handler.write_batch_requests.return_value = "temp.jsonl" mock_batch_client.upload_file.return_value = "file-123" mock_batch_client.create_batch.return_value = "batch-456" mock_batch_client.wait_for_completion.return_value = BatchTaskInfo( task_id="batch-456", status="completed", input_file_id="file-123", output_file_id="file-output-789", error_file_id=None, ) mock_batch_client.download_result.return_value = "result content" mock_file_handler.parse_batch_responses.return_value = [ ClassificationResult( product_id="P001", category="门票", type="自然类", sub_type="自然风光", ) ] yield { "file_handler": mock_file_handler, "batch_client": mock_batch_client, "prompt_builder": mock_prompt_builder, } def test_main_success_flow(self, mock_dependencies): """测试成功的完整流程""" # Given: 所有mock都配置为成功 with ( tempfile.NamedTemporaryFile( mode="w", suffix=".xlsx", delete=False ) as input_f, tempfile.NamedTemporaryFile( mode="w", suffix=".jsonl", delete=False ) as category_f, tempfile.NamedTemporaryFile( mode="w", suffix=".jsonl", delete=False ) as output_f, ): input_file = input_f.name category_file = category_f.name output_file = output_f.name try: # 创建配置对象 config = AppConfig( files=FilesConfig( input_file=input_file, category_file=category_file, output_file=output_file, ), api=APIConfig(api_key="test_api_key"), ) # When: 执行main函数 main(config) # Then: 所有步骤应该被调用 file_handler = mock_dependencies["file_handler"] batch_client = mock_dependencies["batch_client"] prompt_builder = mock_dependencies["prompt_builder"] file_handler.read_products.assert_called_once() file_handler.read_categories.assert_called_once() prompt_builder.build_system_prompt.assert_called_once() batch_client.upload_file.assert_called_once() batch_client.create_batch.assert_called_once() batch_client.wait_for_completion.assert_called_once() batch_client.download_result.assert_called_once() file_handler.parse_batch_responses.assert_called_once() file_handler.write_results.assert_called_once() finally: # 清理临时文件 for f in [input_file, category_file, output_file]: if os.path.exists(f): os.remove(f) def test_main_with_env_api_key(self, mock_dependencies): """测试使用环境变量的API Key""" # Given: 环境变量设置了API Key with patch.dict(os.environ, {"DASHSCOPE_API_KEY": "env_api_key"}): with ( tempfile.NamedTemporaryFile( mode="w", suffix=".xlsx", delete=False ) as input_f, tempfile.NamedTemporaryFile( mode="w", suffix=".jsonl", delete=False ) as category_f, tempfile.NamedTemporaryFile( mode="w", suffix=".jsonl", delete=False ) as output_f, ): input_file = input_f.name category_file = category_f.name output_file = output_f.name try: # 创建配置对象,api_key为None表示从环境变量读取 config = AppConfig( files=FilesConfig( input_file=input_file, category_file=category_file, output_file=output_file, ), api=APIConfig(api_key=None), ) # When: 执行main函数 main(config) # Then: 应该成功执行 # (如果失败会抛出异常) finally: # 清理临时文件 for f in [input_file, category_file, output_file]: if os.path.exists(f): os.remove(f) def test_main_batch_task_failed(self, mock_dependencies): """测试Batch任务失败""" # Given: Batch任务返回failed状态 mock_dependencies["batch_client"].wait_for_completion.return_value = ( BatchTaskInfo( task_id="batch-456", status="failed", input_file_id="file-123", output_file_id=None, error_file_id="file-error-999", ) ) with ( tempfile.NamedTemporaryFile( mode="w", suffix=".xlsx", delete=False ) as input_f, tempfile.NamedTemporaryFile( mode="w", suffix=".jsonl", delete=False ) as category_f, tempfile.NamedTemporaryFile( mode="w", suffix=".jsonl", delete=False ) as output_f, ): input_file = input_f.name category_file = category_f.name output_file = output_f.name try: # 创建配置对象 config = AppConfig( files=FilesConfig( input_file=input_file, category_file=category_file, output_file=output_file, ), api=APIConfig(api_key="test_api_key"), ) # When & Then: 应该抛出RuntimeError with pytest.raises(RuntimeError) as exc_info: main(config) assert "未成功完成" in str(exc_info.value) finally: # 清理临时文件 for f in [input_file, category_file, output_file]: if os.path.exists(f): os.remove(f) def test_main_file_not_found(self, mock_dependencies): """测试输入文件不存在""" # Given: 文件不存在 mock_dependencies["file_handler"].read_products.side_effect = FileNotFoundError( "文件不存在" ) with ( tempfile.NamedTemporaryFile( mode="w", suffix=".jsonl", delete=False ) as category_f, tempfile.NamedTemporaryFile( mode="w", suffix=".jsonl", delete=False ) as output_f, ): category_file = category_f.name output_file = output_f.name try: # 创建配置对象 config = AppConfig( files=FilesConfig( input_file="nonexistent.xlsx", category_file=category_file, output_file=output_file, ), api=APIConfig(api_key="test_api_key"), ) # When & Then: 应该抛出FileNotFoundError with pytest.raises(FileNotFoundError): main(config) finally: # 清理临时文件 for f in [category_file, output_file]: if os.path.exists(f): os.remove(f) def test_main_with_error_file(self, mock_dependencies): """测试存在错误文件的情况""" # Given: 任务完成但有错误文件 mock_dependencies["batch_client"].wait_for_completion.return_value = ( BatchTaskInfo( task_id="batch-456", status="completed", input_file_id="file-123", output_file_id="file-output-789", error_file_id="file-error-999", ) ) mock_dependencies["batch_client"].download_result.side_effect = [ "result content", # 第一次调用返回结果内容 "error content", # 第二次调用返回错误内容 ] with ( tempfile.NamedTemporaryFile( mode="w", suffix=".xlsx", delete=False ) as input_f, tempfile.NamedTemporaryFile( mode="w", suffix=".jsonl", delete=False ) as category_f, tempfile.NamedTemporaryFile( mode="w", suffix=".jsonl", delete=False ) as output_f, ): input_file = input_f.name category_file = category_f.name output_file = output_f.name try: # 创建配置对象 config = AppConfig( files=FilesConfig( input_file=input_file, category_file=category_file, output_file=output_file, ), api=APIConfig(api_key="test_api_key"), ) # When: 执行main函数 main(config) # Then: download_result应该被调用两次 assert mock_dependencies["batch_client"].download_result.call_count == 2 # 错误文件应该被创建 error_file = output_file.replace(".jsonl", "_errors.jsonl") assert os.path.exists(error_file) # 清理错误文件 if os.path.exists(error_file): os.remove(error_file) finally: # 清理临时文件 for f in [input_file, category_file, output_file]: if os.path.exists(f): os.remove(f)