"""测试Batch客户端模块 测试BatchClient类的所有功能 """ import pytest from unittest.mock import MagicMock, patch from src.core.batch_client import BatchClient from src.models import BatchTaskInfo class TestBatchClient: """测试BatchClient类""" @pytest.fixture def mock_openai_client(self): """创建模拟的OpenAI客户端""" with patch("src.core.batch_client.OpenAI") as mock_openai: mock_client = MagicMock() mock_openai.return_value = mock_client yield mock_client @pytest.fixture def batch_client(self, mock_openai_client): """创建BatchClient实例""" return BatchClient(api_key="test_api_key") def test_init_batch_client(self, mock_openai_client): """测试初始化Batch客户端""" # When: 创建BatchClient BatchClient(api_key="test_key", base_url="https://test.com") # Then: OpenAI客户端应该被正确初始化 # 注意:由于mock,实际的client.client是mock对象 def test_upload_file_success(self, batch_client, mock_openai_client): """测试成功上传文件(私有方法)""" # Given: Mock返回file_id mock_file_object = MagicMock() mock_file_object.id = "file-abc123" mock_openai_client.files.create.return_value = mock_file_object # When: 上传文件(调用私有方法) file_id = batch_client._upload_file("test.jsonl") # Then: 应该返回file_id assert file_id == "file-abc123" mock_openai_client.files.create.assert_called_once() def test_upload_file_failure(self, batch_client, mock_openai_client): """测试上传文件失败(私有方法)""" # Given: Mock抛出异常 mock_openai_client.files.create.side_effect = Exception("Upload failed") # When & Then: 应该抛出异常 with pytest.raises(Exception) as exc_info: batch_client._upload_file("test.jsonl") assert "上传文件失败" in str(exc_info.value) def test_create_batch_success(self, batch_client, mock_openai_client): """测试成功创建Batch任务(私有方法)""" # Given: Mock返回batch_id mock_batch = MagicMock() mock_batch.id = "batch-xyz789" mock_openai_client.batches.create.return_value = mock_batch # When: 创建Batch任务 batch_id = batch_client._create_batch( input_file_id="file-abc123", endpoint="/v1/chat/completions", completion_window="24h", ) # Then: 应该返回batch_id assert batch_id == "batch-xyz789" mock_openai_client.batches.create.assert_called_once_with( input_file_id="file-abc123", endpoint="/v1/chat/completions", completion_window="24h", ) def test_create_batch_failure(self, batch_client, mock_openai_client): """测试创建Batch任务失败(私有方法)""" # Given: Mock抛出异常 mock_openai_client.batches.create.side_effect = Exception("Create failed") # When & Then: 应该抛出异常 with pytest.raises(Exception) as exc_info: batch_client._create_batch(input_file_id="file-abc123") assert "创建Batch任务失败" in str(exc_info.value) def test_get_batch_status_success(self, batch_client, mock_openai_client): """测试成功查询任务状态(私有方法)""" # Given: Mock返回batch对象 mock_batch = MagicMock() mock_batch.id = "batch-xyz789" mock_batch.status = "in_progress" mock_batch.input_file_id = "file-abc123" mock_batch.output_file_id = None mock_batch.error_file_id = None mock_batch.request_counts = None mock_openai_client.batches.retrieve.return_value = mock_batch # When: 查询任务状态 task_info = batch_client._get_batch_status("batch-xyz789") # Then: 应该返回BatchTaskInfo对象 assert isinstance(task_info, BatchTaskInfo) assert task_info.task_id == "batch-xyz789" assert task_info.status == "in_progress" assert task_info.input_file_id == "file-abc123" def test_get_batch_status_completed(self, batch_client, mock_openai_client): """测试查询已完成的任务状态(私有方法)""" # Given: Mock返回completed状态的batch,包含request_counts mock_batch = MagicMock() mock_batch.id = "batch-xyz789" mock_batch.status = "completed" mock_batch.input_file_id = "file-abc123" mock_batch.output_file_id = "file-output-123" mock_batch.error_file_id = None # Mock request_counts mock_request_counts = MagicMock() mock_request_counts.total = 100 mock_request_counts.completed = 100 mock_request_counts.failed = 0 mock_batch.request_counts = mock_request_counts mock_openai_client.batches.retrieve.return_value = mock_batch # When: 查询任务状态 task_info = batch_client._get_batch_status("batch-xyz789") # Then: 状态应该是completed,且包含request_counts信息 assert task_info.status == "completed" assert task_info.output_file_id == "file-output-123" assert task_info.request_counts is not None assert task_info.request_counts["total"] == 100 assert task_info.request_counts["completed"] == 100 def test_get_batch_status_failure(self, batch_client, mock_openai_client): """测试查询任务状态失败(私有方法)""" # Given: Mock抛出异常 mock_openai_client.batches.retrieve.side_effect = Exception("Retrieve failed") # When & Then: 应该抛出异常 with pytest.raises(Exception) as exc_info: batch_client._get_batch_status("batch-xyz789") assert "查询任务状态失败" in str(exc_info.value) def test_download_result_success(self, batch_client, mock_openai_client): """测试成功下载结果(私有方法)""" # Given: Mock返回文件内容 mock_content = MagicMock() mock_content.text = "result content" mock_openai_client.files.content.return_value = mock_content # When: 下载结果 content = batch_client._download_result("file-output-123") # Then: 应该返回文件内容 assert content == "result content" mock_openai_client.files.content.assert_called_once_with("file-output-123") def test_download_result_failure(self, batch_client, mock_openai_client): """测试下载结果失败(私有方法)""" # Given: Mock抛出异常 mock_openai_client.files.content.side_effect = Exception("Download failed") # When & Then: 应该抛出异常 with pytest.raises(Exception) as exc_info: batch_client._download_result("file-output-123") assert "下载结果文件失败" in str(exc_info.value) def test_wait_for_completion_success(self, batch_client, mock_openai_client): """测试等待任务完成(成功,私有方法)""" # Given: Mock返回状态变化:in_progress -> completed mock_batch_in_progress = MagicMock() mock_batch_in_progress.id = "batch-xyz789" mock_batch_in_progress.status = "in_progress" mock_batch_in_progress.input_file_id = "file-abc123" mock_batch_in_progress.output_file_id = None mock_batch_in_progress.error_file_id = None mock_batch_in_progress.request_counts = None mock_batch_completed = MagicMock() mock_batch_completed.id = "batch-xyz789" mock_batch_completed.status = "completed" mock_batch_completed.input_file_id = "file-abc123" mock_batch_completed.output_file_id = "file-output-123" mock_batch_completed.error_file_id = None mock_batch_completed.request_counts = None # 第一次返回in_progress,第二次返回completed mock_openai_client.batches.retrieve.side_effect = [ mock_batch_in_progress, mock_batch_completed, ] # When: 等待任务完成 task_info = batch_client._wait_for_completion( batch_id="batch-xyz789", poll_interval=0.1, max_wait_time=10 ) # Then: 应该返回completed状态的任务信息 assert task_info.status == "completed" assert task_info.output_file_id == "file-output-123" def test_wait_for_completion_timeout(self, batch_client, mock_openai_client): """测试等待任务完成超时(私有方法)""" # Given: Mock一直返回in_progress状态 mock_batch = MagicMock() mock_batch.id = "batch-xyz789" mock_batch.status = "in_progress" mock_batch.input_file_id = "file-abc123" mock_batch.output_file_id = None mock_batch.error_file_id = None mock_batch.request_counts = None mock_openai_client.batches.retrieve.return_value = mock_batch # When & Then: 应该抛出TimeoutError with pytest.raises(TimeoutError) as exc_info: batch_client._wait_for_completion( batch_id="batch-xyz789", poll_interval=0.1, max_wait_time=0.5 ) assert "超时" in str(exc_info.value) def test_wait_for_completion_failed(self, batch_client, mock_openai_client): """测试任务失败(私有方法)""" # Given: Mock返回failed状态 mock_batch = MagicMock() mock_batch.id = "batch-xyz789" mock_batch.status = "failed" mock_batch.input_file_id = "file-abc123" mock_batch.output_file_id = None mock_batch.error_file_id = "file-error-123" mock_batch.request_counts = None mock_openai_client.batches.retrieve.return_value = mock_batch # When & Then: 应该抛出Exception with pytest.raises(Exception) as exc_info: batch_client._wait_for_completion( batch_id="batch-xyz789", poll_interval=0.1, max_wait_time=10 ) assert "执行失败" in str(exc_info.value) def test_wait_for_completion_expired(self, batch_client, mock_openai_client): """测试任务过期(私有方法)""" # Given: Mock返回expired状态 mock_batch = MagicMock() mock_batch.id = "batch-xyz789" mock_batch.status = "expired" mock_batch.input_file_id = "file-abc123" mock_batch.output_file_id = None mock_batch.error_file_id = None mock_batch.request_counts = None mock_openai_client.batches.retrieve.return_value = mock_batch # When & Then: 应该抛出Exception with pytest.raises(Exception) as exc_info: batch_client._wait_for_completion( batch_id="batch-xyz789", poll_interval=0.1, max_wait_time=10 ) assert "超时过期" in str(exc_info.value) def test_wait_for_completion_cancelled(self, batch_client, mock_openai_client): """测试任务取消(私有方法)""" # Given: Mock返回cancelled状态 mock_batch = MagicMock() mock_batch.id = "batch-xyz789" mock_batch.status = "cancelled" mock_batch.input_file_id = "file-abc123" mock_batch.output_file_id = None mock_batch.error_file_id = None mock_batch.request_counts = None mock_openai_client.batches.retrieve.return_value = mock_batch # When & Then: 应该抛出Exception with pytest.raises(Exception) as exc_info: batch_client._wait_for_completion( batch_id="batch-xyz789", poll_interval=0.1, max_wait_time=10 ) assert "已取消" in str(exc_info.value) class TestBatchClientILLMClient: """测试BatchClient实现的ILLMClient接口""" @pytest.fixture def mock_file_handler(self): """创建mock file handler""" return MagicMock() @pytest.fixture def mock_result_parser(self): """创建mock result parser""" return MagicMock() @pytest.fixture def mock_logger(self): """创建mock logger""" return MagicMock() @pytest.fixture def sample_products(self): """创建测试产品数据""" from src.models import ProductInput return [ ProductInput( product_id="P001", product_name="黄山风景区门票", scenic_spot="黄山" ) ] @pytest.fixture def sample_requests(self): """创建测试请求数据""" return [ { "custom_id": "P001", "method": "POST", "url": "/v1/chat/completions", "body": { "model": "qwen-flash", "messages": [ {"role": "system", "content": "你是助手"}, {"role": "user", "content": "分类产品"}, ], }, } ] @patch("src.core.batch_client.OpenAI") def test_classify_products_success( self, mock_openai, sample_products, sample_requests, mock_file_handler, mock_result_parser, mock_logger, ): """测试成功使用Batch模式分类产品""" from src.models import ClassificationResult # 配置mock mock_client_instance = MagicMock() mock_openai.return_value = mock_client_instance # Mock文件上传 mock_file_object = MagicMock() mock_file_object.id = "file-123" mock_client_instance.files.create.return_value = mock_file_object # Mock创建batch mock_batch_create = MagicMock() mock_batch_create.id = "batch-456" mock_client_instance.batches.create.return_value = mock_batch_create # Mock任务状态查询(直接返回completed) mock_batch_status = MagicMock() mock_batch_status.id = "batch-456" mock_batch_status.status = "completed" mock_batch_status.input_file_id = "file-123" mock_batch_status.output_file_id = "file-output-789" mock_batch_status.error_file_id = None mock_batch_status.request_counts = None mock_client_instance.batches.retrieve.return_value = mock_batch_status # Mock下载结果 mock_content = MagicMock() mock_content.text = '{"id":"1","custom_id":"P001","response":{"body":{"choices":[{"message":{"content":"{\\"category\\":\\"门票\\",\\"type\\":\\"自然类\\",\\"sub_type\\":\\"自然风光\\"}"}}]}},"error":null}' mock_client_instance.files.content.return_value = mock_content # Mock result_parser mock_result = ClassificationResult( product_id="P001", category="门票", type="自然类", sub_type="自然风光" ) mock_result_parser.parse.return_value = [mock_result] # 创建BatchClient client = BatchClient( api_key="test-key", base_url="https://test.com", file_handler=mock_file_handler, result_parser=mock_result_parser, logger=mock_logger, poll_interval=0.1, max_wait_time=10, ) # When: 调用classify_products results = client.classify_products( sample_products, "system prompt", sample_requests ) # Then: 验证结果 assert len(results) == 1 assert results[0].product_id == "P001" assert results[0].category == "门票" # 验证方法被调用 mock_file_handler.write_batch_requests.assert_called_once() mock_client_instance.files.create.assert_called_once() mock_client_instance.batches.create.assert_called_once() mock_result_parser.parse.assert_called_once() @patch("src.core.batch_client.OpenAI") @patch("builtins.open", create=True) def test_classify_products_without_file_handler( self, mock_open, mock_openai, sample_products, sample_requests, mock_logger ): """测试没有file_handler时使用简单写入方式""" import json # 配置mock mock_client_instance = MagicMock() mock_openai.return_value = mock_client_instance # Mock文件上传 mock_file_object = MagicMock() mock_file_object.id = "file-123" mock_client_instance.files.create.return_value = mock_file_object # Mock创建batch mock_batch_create = MagicMock() mock_batch_create.id = "batch-456" mock_client_instance.batches.create.return_value = mock_batch_create # Mock任务状态查询 mock_batch_status = MagicMock() mock_batch_status.id = "batch-456" mock_batch_status.status = "completed" mock_batch_status.input_file_id = "file-123" mock_batch_status.output_file_id = "file-output-789" mock_batch_status.error_file_id = None mock_batch_status.request_counts = None mock_client_instance.batches.retrieve.return_value = mock_batch_status # Mock下载结果(返回JSON格式的字符串) result_json = json.dumps( { "id": "1", "custom_id": "P001", "response": { "body": { "choices": [ { "message": { "content": json.dumps( { "category": "门票", "type": "自然类", "sub_type": "自然风光", } ) } } ] } }, "error": None, } ) mock_content = MagicMock() mock_content.text = result_json mock_client_instance.files.content.return_value = mock_content # 创建BatchClient(不提供file_handler和result_parser) client = BatchClient( api_key="test-key", base_url="https://test.com", logger=mock_logger, poll_interval=0.1, max_wait_time=10, ) # When: 调用classify_products results = client.classify_products( sample_products, "system prompt", sample_requests ) # Then: 验证结果 assert len(results) == 1 assert results[0].product_id == "P001" assert results[0].category == "门票"