CategorizeLabel/tests/test_batch_client.py
2025-10-15 17:19:26 +08:00

489 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""测试Batch客户端模块
测试BatchClient类的所有功能
"""
import pytest
from unittest.mock import MagicMock, patch
from src.core.batch_client import BatchClient
from src.models import BatchTaskInfo
class TestBatchClient:
"""测试BatchClient类"""
@pytest.fixture
def mock_openai_client(self):
"""创建模拟的OpenAI客户端"""
with patch("src.core.batch_client.OpenAI") as mock_openai:
mock_client = MagicMock()
mock_openai.return_value = mock_client
yield mock_client
@pytest.fixture
def batch_client(self, mock_openai_client):
"""创建BatchClient实例"""
return BatchClient(api_key="test_api_key")
def test_init_batch_client(self, mock_openai_client):
"""测试初始化Batch客户端"""
# When: 创建BatchClient
BatchClient(api_key="test_key", base_url="https://test.com")
# Then: OpenAI客户端应该被正确初始化
# 注意由于mock实际的client.client是mock对象
def test_upload_file_success(self, batch_client, mock_openai_client):
"""测试成功上传文件(私有方法)"""
# Given: Mock返回file_id
mock_file_object = MagicMock()
mock_file_object.id = "file-abc123"
mock_openai_client.files.create.return_value = mock_file_object
# When: 上传文件(调用私有方法)
file_id = batch_client._upload_file("test.jsonl")
# Then: 应该返回file_id
assert file_id == "file-abc123"
mock_openai_client.files.create.assert_called_once()
def test_upload_file_failure(self, batch_client, mock_openai_client):
"""测试上传文件失败(私有方法)"""
# Given: Mock抛出异常
mock_openai_client.files.create.side_effect = Exception("Upload failed")
# When & Then: 应该抛出异常
with pytest.raises(Exception) as exc_info:
batch_client._upload_file("test.jsonl")
assert "上传文件失败" in str(exc_info.value)
def test_create_batch_success(self, batch_client, mock_openai_client):
"""测试成功创建Batch任务私有方法"""
# Given: Mock返回batch_id
mock_batch = MagicMock()
mock_batch.id = "batch-xyz789"
mock_openai_client.batches.create.return_value = mock_batch
# When: 创建Batch任务
batch_id = batch_client._create_batch(
input_file_id="file-abc123",
endpoint="/v1/chat/completions",
completion_window="24h",
)
# Then: 应该返回batch_id
assert batch_id == "batch-xyz789"
mock_openai_client.batches.create.assert_called_once_with(
input_file_id="file-abc123",
endpoint="/v1/chat/completions",
completion_window="24h",
)
def test_create_batch_failure(self, batch_client, mock_openai_client):
"""测试创建Batch任务失败私有方法"""
# Given: Mock抛出异常
mock_openai_client.batches.create.side_effect = Exception("Create failed")
# When & Then: 应该抛出异常
with pytest.raises(Exception) as exc_info:
batch_client._create_batch(input_file_id="file-abc123")
assert "创建Batch任务失败" in str(exc_info.value)
def test_get_batch_status_success(self, batch_client, mock_openai_client):
"""测试成功查询任务状态(私有方法)"""
# Given: Mock返回batch对象
mock_batch = MagicMock()
mock_batch.id = "batch-xyz789"
mock_batch.status = "in_progress"
mock_batch.input_file_id = "file-abc123"
mock_batch.output_file_id = None
mock_batch.error_file_id = None
mock_batch.request_counts = None
mock_openai_client.batches.retrieve.return_value = mock_batch
# When: 查询任务状态
task_info = batch_client._get_batch_status("batch-xyz789")
# Then: 应该返回BatchTaskInfo对象
assert isinstance(task_info, BatchTaskInfo)
assert task_info.task_id == "batch-xyz789"
assert task_info.status == "in_progress"
assert task_info.input_file_id == "file-abc123"
def test_get_batch_status_completed(self, batch_client, mock_openai_client):
"""测试查询已完成的任务状态(私有方法)"""
# Given: Mock返回completed状态的batch包含request_counts
mock_batch = MagicMock()
mock_batch.id = "batch-xyz789"
mock_batch.status = "completed"
mock_batch.input_file_id = "file-abc123"
mock_batch.output_file_id = "file-output-123"
mock_batch.error_file_id = None
# Mock request_counts
mock_request_counts = MagicMock()
mock_request_counts.total = 100
mock_request_counts.completed = 100
mock_request_counts.failed = 0
mock_batch.request_counts = mock_request_counts
mock_openai_client.batches.retrieve.return_value = mock_batch
# When: 查询任务状态
task_info = batch_client._get_batch_status("batch-xyz789")
# Then: 状态应该是completed且包含request_counts信息
assert task_info.status == "completed"
assert task_info.output_file_id == "file-output-123"
assert task_info.request_counts is not None
assert task_info.request_counts["total"] == 100
assert task_info.request_counts["completed"] == 100
def test_get_batch_status_failure(self, batch_client, mock_openai_client):
"""测试查询任务状态失败(私有方法)"""
# Given: Mock抛出异常
mock_openai_client.batches.retrieve.side_effect = Exception("Retrieve failed")
# When & Then: 应该抛出异常
with pytest.raises(Exception) as exc_info:
batch_client._get_batch_status("batch-xyz789")
assert "查询任务状态失败" in str(exc_info.value)
def test_download_result_success(self, batch_client, mock_openai_client):
"""测试成功下载结果(私有方法)"""
# Given: Mock返回文件内容
mock_content = MagicMock()
mock_content.text = "result content"
mock_openai_client.files.content.return_value = mock_content
# When: 下载结果
content = batch_client._download_result("file-output-123")
# Then: 应该返回文件内容
assert content == "result content"
mock_openai_client.files.content.assert_called_once_with("file-output-123")
def test_download_result_failure(self, batch_client, mock_openai_client):
"""测试下载结果失败(私有方法)"""
# Given: Mock抛出异常
mock_openai_client.files.content.side_effect = Exception("Download failed")
# When & Then: 应该抛出异常
with pytest.raises(Exception) as exc_info:
batch_client._download_result("file-output-123")
assert "下载结果文件失败" in str(exc_info.value)
def test_wait_for_completion_success(self, batch_client, mock_openai_client):
"""测试等待任务完成(成功,私有方法)"""
# Given: Mock返回状态变化in_progress -> completed
mock_batch_in_progress = MagicMock()
mock_batch_in_progress.id = "batch-xyz789"
mock_batch_in_progress.status = "in_progress"
mock_batch_in_progress.input_file_id = "file-abc123"
mock_batch_in_progress.output_file_id = None
mock_batch_in_progress.error_file_id = None
mock_batch_in_progress.request_counts = None
mock_batch_completed = MagicMock()
mock_batch_completed.id = "batch-xyz789"
mock_batch_completed.status = "completed"
mock_batch_completed.input_file_id = "file-abc123"
mock_batch_completed.output_file_id = "file-output-123"
mock_batch_completed.error_file_id = None
mock_batch_completed.request_counts = None
# 第一次返回in_progress第二次返回completed
mock_openai_client.batches.retrieve.side_effect = [
mock_batch_in_progress,
mock_batch_completed,
]
# When: 等待任务完成
task_info = batch_client._wait_for_completion(
batch_id="batch-xyz789", poll_interval=0.1, max_wait_time=10
)
# Then: 应该返回completed状态的任务信息
assert task_info.status == "completed"
assert task_info.output_file_id == "file-output-123"
def test_wait_for_completion_timeout(self, batch_client, mock_openai_client):
"""测试等待任务完成超时(私有方法)"""
# Given: Mock一直返回in_progress状态
mock_batch = MagicMock()
mock_batch.id = "batch-xyz789"
mock_batch.status = "in_progress"
mock_batch.input_file_id = "file-abc123"
mock_batch.output_file_id = None
mock_batch.error_file_id = None
mock_batch.request_counts = None
mock_openai_client.batches.retrieve.return_value = mock_batch
# When & Then: 应该抛出TimeoutError
with pytest.raises(TimeoutError) as exc_info:
batch_client._wait_for_completion(
batch_id="batch-xyz789", poll_interval=0.1, max_wait_time=0.5
)
assert "超时" in str(exc_info.value)
def test_wait_for_completion_failed(self, batch_client, mock_openai_client):
"""测试任务失败(私有方法)"""
# Given: Mock返回failed状态
mock_batch = MagicMock()
mock_batch.id = "batch-xyz789"
mock_batch.status = "failed"
mock_batch.input_file_id = "file-abc123"
mock_batch.output_file_id = None
mock_batch.error_file_id = "file-error-123"
mock_batch.request_counts = None
mock_openai_client.batches.retrieve.return_value = mock_batch
# When & Then: 应该抛出Exception
with pytest.raises(Exception) as exc_info:
batch_client._wait_for_completion(
batch_id="batch-xyz789", poll_interval=0.1, max_wait_time=10
)
assert "执行失败" in str(exc_info.value)
def test_wait_for_completion_expired(self, batch_client, mock_openai_client):
"""测试任务过期(私有方法)"""
# Given: Mock返回expired状态
mock_batch = MagicMock()
mock_batch.id = "batch-xyz789"
mock_batch.status = "expired"
mock_batch.input_file_id = "file-abc123"
mock_batch.output_file_id = None
mock_batch.error_file_id = None
mock_batch.request_counts = None
mock_openai_client.batches.retrieve.return_value = mock_batch
# When & Then: 应该抛出Exception
with pytest.raises(Exception) as exc_info:
batch_client._wait_for_completion(
batch_id="batch-xyz789", poll_interval=0.1, max_wait_time=10
)
assert "超时过期" in str(exc_info.value)
def test_wait_for_completion_cancelled(self, batch_client, mock_openai_client):
"""测试任务取消(私有方法)"""
# Given: Mock返回cancelled状态
mock_batch = MagicMock()
mock_batch.id = "batch-xyz789"
mock_batch.status = "cancelled"
mock_batch.input_file_id = "file-abc123"
mock_batch.output_file_id = None
mock_batch.error_file_id = None
mock_batch.request_counts = None
mock_openai_client.batches.retrieve.return_value = mock_batch
# When & Then: 应该抛出Exception
with pytest.raises(Exception) as exc_info:
batch_client._wait_for_completion(
batch_id="batch-xyz789", poll_interval=0.1, max_wait_time=10
)
assert "已取消" in str(exc_info.value)
class TestBatchClientILLMClient:
"""测试BatchClient实现的ILLMClient接口"""
@pytest.fixture
def mock_file_handler(self):
"""创建mock file handler"""
return MagicMock()
@pytest.fixture
def mock_result_parser(self):
"""创建mock result parser"""
return MagicMock()
@pytest.fixture
def mock_logger(self):
"""创建mock logger"""
return MagicMock()
@pytest.fixture
def sample_products(self):
"""创建测试产品数据"""
from src.models import ProductInput
return [
ProductInput(
product_id="P001", product_name="黄山风景区门票", scenic_spot="黄山"
)
]
@pytest.fixture
def sample_requests(self):
"""创建测试请求数据"""
return [
{
"custom_id": "P001",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": "qwen-flash",
"messages": [
{"role": "system", "content": "你是助手"},
{"role": "user", "content": "分类产品"},
],
},
}
]
@patch("src.core.batch_client.OpenAI")
def test_classify_products_success(
self,
mock_openai,
sample_products,
sample_requests,
mock_file_handler,
mock_result_parser,
mock_logger,
):
"""测试成功使用Batch模式分类产品"""
from src.models import ClassificationResult
# 配置mock
mock_client_instance = MagicMock()
mock_openai.return_value = mock_client_instance
# Mock文件上传
mock_file_object = MagicMock()
mock_file_object.id = "file-123"
mock_client_instance.files.create.return_value = mock_file_object
# Mock创建batch
mock_batch_create = MagicMock()
mock_batch_create.id = "batch-456"
mock_client_instance.batches.create.return_value = mock_batch_create
# Mock任务状态查询直接返回completed
mock_batch_status = MagicMock()
mock_batch_status.id = "batch-456"
mock_batch_status.status = "completed"
mock_batch_status.input_file_id = "file-123"
mock_batch_status.output_file_id = "file-output-789"
mock_batch_status.error_file_id = None
mock_batch_status.request_counts = None
mock_client_instance.batches.retrieve.return_value = mock_batch_status
# Mock下载结果
mock_content = MagicMock()
mock_content.text = '{"id":"1","custom_id":"P001","response":{"body":{"choices":[{"message":{"content":"{\\"category\\":\\"门票\\",\\"type\\":\\"自然类\\",\\"sub_type\\":\\"自然风光\\"}"}}]}},"error":null}'
mock_client_instance.files.content.return_value = mock_content
# Mock result_parser
mock_result = ClassificationResult(
product_id="P001", category="门票", type="自然类", sub_type="自然风光"
)
mock_result_parser.parse.return_value = [mock_result]
# 创建BatchClient
client = BatchClient(
api_key="test-key",
base_url="https://test.com",
file_handler=mock_file_handler,
result_parser=mock_result_parser,
logger=mock_logger,
poll_interval=0.1,
max_wait_time=10,
)
# When: 调用classify_products
results = client.classify_products(
sample_products, "system prompt", sample_requests
)
# Then: 验证结果
assert len(results) == 1
assert results[0].product_id == "P001"
assert results[0].category == "门票"
# 验证方法被调用
mock_file_handler.write_batch_requests.assert_called_once()
mock_client_instance.files.create.assert_called_once()
mock_client_instance.batches.create.assert_called_once()
mock_result_parser.parse.assert_called_once()
@patch("src.core.batch_client.OpenAI")
@patch("builtins.open", create=True)
def test_classify_products_without_file_handler(
self, mock_open, mock_openai, sample_products, sample_requests, mock_logger
):
"""测试没有file_handler时使用简单写入方式"""
import json
# 配置mock
mock_client_instance = MagicMock()
mock_openai.return_value = mock_client_instance
# Mock文件上传
mock_file_object = MagicMock()
mock_file_object.id = "file-123"
mock_client_instance.files.create.return_value = mock_file_object
# Mock创建batch
mock_batch_create = MagicMock()
mock_batch_create.id = "batch-456"
mock_client_instance.batches.create.return_value = mock_batch_create
# Mock任务状态查询
mock_batch_status = MagicMock()
mock_batch_status.id = "batch-456"
mock_batch_status.status = "completed"
mock_batch_status.input_file_id = "file-123"
mock_batch_status.output_file_id = "file-output-789"
mock_batch_status.error_file_id = None
mock_batch_status.request_counts = None
mock_client_instance.batches.retrieve.return_value = mock_batch_status
# Mock下载结果返回JSON格式的字符串
result_json = json.dumps(
{
"id": "1",
"custom_id": "P001",
"response": {
"body": {
"choices": [
{
"message": {
"content": json.dumps(
{
"category": "门票",
"type": "自然类",
"sub_type": "自然风光",
}
)
}
}
]
}
},
"error": None,
}
)
mock_content = MagicMock()
mock_content.text = result_json
mock_client_instance.files.content.return_value = mock_content
# 创建BatchClient不提供file_handler和result_parser
client = BatchClient(
api_key="test-key",
base_url="https://test.com",
logger=mock_logger,
poll_interval=0.1,
max_wait_time=10,
)
# When: 调用classify_products
results = client.classify_products(
sample_products, "system prompt", sample_requests
)
# Then: 验证结果
assert len(results) == 1
assert results[0].product_id == "P001"
assert results[0].category == "门票"