CategorizeLabel/tests/test_batch_client.py

489 lines
18 KiB
Python
Raw Normal View History

2025-10-15 17:19:26 +08:00
"""测试Batch客户端模块
测试BatchClient类的所有功能
"""
import pytest
from unittest.mock import MagicMock, patch
from src.core.batch_client import BatchClient
from src.models import BatchTaskInfo
class TestBatchClient:
"""测试BatchClient类"""
@pytest.fixture
def mock_openai_client(self):
"""创建模拟的OpenAI客户端"""
with patch("src.core.batch_client.OpenAI") as mock_openai:
mock_client = MagicMock()
mock_openai.return_value = mock_client
yield mock_client
@pytest.fixture
def batch_client(self, mock_openai_client):
"""创建BatchClient实例"""
return BatchClient(api_key="test_api_key")
def test_init_batch_client(self, mock_openai_client):
"""测试初始化Batch客户端"""
# When: 创建BatchClient
BatchClient(api_key="test_key", base_url="https://test.com")
# Then: OpenAI客户端应该被正确初始化
# 注意由于mock实际的client.client是mock对象
def test_upload_file_success(self, batch_client, mock_openai_client):
"""测试成功上传文件(私有方法)"""
# Given: Mock返回file_id
mock_file_object = MagicMock()
mock_file_object.id = "file-abc123"
mock_openai_client.files.create.return_value = mock_file_object
# When: 上传文件(调用私有方法)
file_id = batch_client._upload_file("test.jsonl")
# Then: 应该返回file_id
assert file_id == "file-abc123"
mock_openai_client.files.create.assert_called_once()
def test_upload_file_failure(self, batch_client, mock_openai_client):
"""测试上传文件失败(私有方法)"""
# Given: Mock抛出异常
mock_openai_client.files.create.side_effect = Exception("Upload failed")
# When & Then: 应该抛出异常
with pytest.raises(Exception) as exc_info:
batch_client._upload_file("test.jsonl")
assert "上传文件失败" in str(exc_info.value)
def test_create_batch_success(self, batch_client, mock_openai_client):
"""测试成功创建Batch任务私有方法"""
# Given: Mock返回batch_id
mock_batch = MagicMock()
mock_batch.id = "batch-xyz789"
mock_openai_client.batches.create.return_value = mock_batch
# When: 创建Batch任务
batch_id = batch_client._create_batch(
input_file_id="file-abc123",
endpoint="/v1/chat/completions",
completion_window="24h",
)
# Then: 应该返回batch_id
assert batch_id == "batch-xyz789"
mock_openai_client.batches.create.assert_called_once_with(
input_file_id="file-abc123",
endpoint="/v1/chat/completions",
completion_window="24h",
)
def test_create_batch_failure(self, batch_client, mock_openai_client):
"""测试创建Batch任务失败私有方法"""
# Given: Mock抛出异常
mock_openai_client.batches.create.side_effect = Exception("Create failed")
# When & Then: 应该抛出异常
with pytest.raises(Exception) as exc_info:
batch_client._create_batch(input_file_id="file-abc123")
assert "创建Batch任务失败" in str(exc_info.value)
def test_get_batch_status_success(self, batch_client, mock_openai_client):
"""测试成功查询任务状态(私有方法)"""
# Given: Mock返回batch对象
mock_batch = MagicMock()
mock_batch.id = "batch-xyz789"
mock_batch.status = "in_progress"
mock_batch.input_file_id = "file-abc123"
mock_batch.output_file_id = None
mock_batch.error_file_id = None
mock_batch.request_counts = None
mock_openai_client.batches.retrieve.return_value = mock_batch
# When: 查询任务状态
task_info = batch_client._get_batch_status("batch-xyz789")
# Then: 应该返回BatchTaskInfo对象
assert isinstance(task_info, BatchTaskInfo)
assert task_info.task_id == "batch-xyz789"
assert task_info.status == "in_progress"
assert task_info.input_file_id == "file-abc123"
def test_get_batch_status_completed(self, batch_client, mock_openai_client):
"""测试查询已完成的任务状态(私有方法)"""
# Given: Mock返回completed状态的batch包含request_counts
mock_batch = MagicMock()
mock_batch.id = "batch-xyz789"
mock_batch.status = "completed"
mock_batch.input_file_id = "file-abc123"
mock_batch.output_file_id = "file-output-123"
mock_batch.error_file_id = None
# Mock request_counts
mock_request_counts = MagicMock()
mock_request_counts.total = 100
mock_request_counts.completed = 100
mock_request_counts.failed = 0
mock_batch.request_counts = mock_request_counts
mock_openai_client.batches.retrieve.return_value = mock_batch
# When: 查询任务状态
task_info = batch_client._get_batch_status("batch-xyz789")
# Then: 状态应该是completed且包含request_counts信息
assert task_info.status == "completed"
assert task_info.output_file_id == "file-output-123"
assert task_info.request_counts is not None
assert task_info.request_counts["total"] == 100
assert task_info.request_counts["completed"] == 100
def test_get_batch_status_failure(self, batch_client, mock_openai_client):
"""测试查询任务状态失败(私有方法)"""
# Given: Mock抛出异常
mock_openai_client.batches.retrieve.side_effect = Exception("Retrieve failed")
# When & Then: 应该抛出异常
with pytest.raises(Exception) as exc_info:
batch_client._get_batch_status("batch-xyz789")
assert "查询任务状态失败" in str(exc_info.value)
def test_download_result_success(self, batch_client, mock_openai_client):
"""测试成功下载结果(私有方法)"""
# Given: Mock返回文件内容
mock_content = MagicMock()
mock_content.text = "result content"
mock_openai_client.files.content.return_value = mock_content
# When: 下载结果
content = batch_client._download_result("file-output-123")
# Then: 应该返回文件内容
assert content == "result content"
mock_openai_client.files.content.assert_called_once_with("file-output-123")
def test_download_result_failure(self, batch_client, mock_openai_client):
"""测试下载结果失败(私有方法)"""
# Given: Mock抛出异常
mock_openai_client.files.content.side_effect = Exception("Download failed")
# When & Then: 应该抛出异常
with pytest.raises(Exception) as exc_info:
batch_client._download_result("file-output-123")
assert "下载结果文件失败" in str(exc_info.value)
def test_wait_for_completion_success(self, batch_client, mock_openai_client):
"""测试等待任务完成(成功,私有方法)"""
# Given: Mock返回状态变化in_progress -> completed
mock_batch_in_progress = MagicMock()
mock_batch_in_progress.id = "batch-xyz789"
mock_batch_in_progress.status = "in_progress"
mock_batch_in_progress.input_file_id = "file-abc123"
mock_batch_in_progress.output_file_id = None
mock_batch_in_progress.error_file_id = None
mock_batch_in_progress.request_counts = None
mock_batch_completed = MagicMock()
mock_batch_completed.id = "batch-xyz789"
mock_batch_completed.status = "completed"
mock_batch_completed.input_file_id = "file-abc123"
mock_batch_completed.output_file_id = "file-output-123"
mock_batch_completed.error_file_id = None
mock_batch_completed.request_counts = None
# 第一次返回in_progress第二次返回completed
mock_openai_client.batches.retrieve.side_effect = [
mock_batch_in_progress,
mock_batch_completed,
]
# When: 等待任务完成
task_info = batch_client._wait_for_completion(
batch_id="batch-xyz789", poll_interval=0.1, max_wait_time=10
)
# Then: 应该返回completed状态的任务信息
assert task_info.status == "completed"
assert task_info.output_file_id == "file-output-123"
def test_wait_for_completion_timeout(self, batch_client, mock_openai_client):
"""测试等待任务完成超时(私有方法)"""
# Given: Mock一直返回in_progress状态
mock_batch = MagicMock()
mock_batch.id = "batch-xyz789"
mock_batch.status = "in_progress"
mock_batch.input_file_id = "file-abc123"
mock_batch.output_file_id = None
mock_batch.error_file_id = None
mock_batch.request_counts = None
mock_openai_client.batches.retrieve.return_value = mock_batch
# When & Then: 应该抛出TimeoutError
with pytest.raises(TimeoutError) as exc_info:
batch_client._wait_for_completion(
batch_id="batch-xyz789", poll_interval=0.1, max_wait_time=0.5
)
assert "超时" in str(exc_info.value)
def test_wait_for_completion_failed(self, batch_client, mock_openai_client):
"""测试任务失败(私有方法)"""
# Given: Mock返回failed状态
mock_batch = MagicMock()
mock_batch.id = "batch-xyz789"
mock_batch.status = "failed"
mock_batch.input_file_id = "file-abc123"
mock_batch.output_file_id = None
mock_batch.error_file_id = "file-error-123"
mock_batch.request_counts = None
mock_openai_client.batches.retrieve.return_value = mock_batch
# When & Then: 应该抛出Exception
with pytest.raises(Exception) as exc_info:
batch_client._wait_for_completion(
batch_id="batch-xyz789", poll_interval=0.1, max_wait_time=10
)
assert "执行失败" in str(exc_info.value)
def test_wait_for_completion_expired(self, batch_client, mock_openai_client):
"""测试任务过期(私有方法)"""
# Given: Mock返回expired状态
mock_batch = MagicMock()
mock_batch.id = "batch-xyz789"
mock_batch.status = "expired"
mock_batch.input_file_id = "file-abc123"
mock_batch.output_file_id = None
mock_batch.error_file_id = None
mock_batch.request_counts = None
mock_openai_client.batches.retrieve.return_value = mock_batch
# When & Then: 应该抛出Exception
with pytest.raises(Exception) as exc_info:
batch_client._wait_for_completion(
batch_id="batch-xyz789", poll_interval=0.1, max_wait_time=10
)
assert "超时过期" in str(exc_info.value)
def test_wait_for_completion_cancelled(self, batch_client, mock_openai_client):
"""测试任务取消(私有方法)"""
# Given: Mock返回cancelled状态
mock_batch = MagicMock()
mock_batch.id = "batch-xyz789"
mock_batch.status = "cancelled"
mock_batch.input_file_id = "file-abc123"
mock_batch.output_file_id = None
mock_batch.error_file_id = None
mock_batch.request_counts = None
mock_openai_client.batches.retrieve.return_value = mock_batch
# When & Then: 应该抛出Exception
with pytest.raises(Exception) as exc_info:
batch_client._wait_for_completion(
batch_id="batch-xyz789", poll_interval=0.1, max_wait_time=10
)
assert "已取消" in str(exc_info.value)
class TestBatchClientILLMClient:
"""测试BatchClient实现的ILLMClient接口"""
@pytest.fixture
def mock_file_handler(self):
"""创建mock file handler"""
return MagicMock()
@pytest.fixture
def mock_result_parser(self):
"""创建mock result parser"""
return MagicMock()
@pytest.fixture
def mock_logger(self):
"""创建mock logger"""
return MagicMock()
@pytest.fixture
def sample_products(self):
"""创建测试产品数据"""
from src.models import ProductInput
return [
ProductInput(
product_id="P001", product_name="黄山风景区门票", scenic_spot="黄山"
)
]
@pytest.fixture
def sample_requests(self):
"""创建测试请求数据"""
return [
{
"custom_id": "P001",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": "qwen-flash",
"messages": [
{"role": "system", "content": "你是助手"},
{"role": "user", "content": "分类产品"},
],
},
}
]
@patch("src.core.batch_client.OpenAI")
def test_classify_products_success(
self,
mock_openai,
sample_products,
sample_requests,
mock_file_handler,
mock_result_parser,
mock_logger,
):
"""测试成功使用Batch模式分类产品"""
from src.models import ClassificationResult
# 配置mock
mock_client_instance = MagicMock()
mock_openai.return_value = mock_client_instance
# Mock文件上传
mock_file_object = MagicMock()
mock_file_object.id = "file-123"
mock_client_instance.files.create.return_value = mock_file_object
# Mock创建batch
mock_batch_create = MagicMock()
mock_batch_create.id = "batch-456"
mock_client_instance.batches.create.return_value = mock_batch_create
# Mock任务状态查询直接返回completed
mock_batch_status = MagicMock()
mock_batch_status.id = "batch-456"
mock_batch_status.status = "completed"
mock_batch_status.input_file_id = "file-123"
mock_batch_status.output_file_id = "file-output-789"
mock_batch_status.error_file_id = None
mock_batch_status.request_counts = None
mock_client_instance.batches.retrieve.return_value = mock_batch_status
# Mock下载结果
mock_content = MagicMock()
mock_content.text = '{"id":"1","custom_id":"P001","response":{"body":{"choices":[{"message":{"content":"{\\"category\\":\\"门票\\",\\"type\\":\\"自然类\\",\\"sub_type\\":\\"自然风光\\"}"}}]}},"error":null}'
mock_client_instance.files.content.return_value = mock_content
# Mock result_parser
mock_result = ClassificationResult(
product_id="P001", category="门票", type="自然类", sub_type="自然风光"
)
mock_result_parser.parse.return_value = [mock_result]
# 创建BatchClient
client = BatchClient(
api_key="test-key",
base_url="https://test.com",
file_handler=mock_file_handler,
result_parser=mock_result_parser,
logger=mock_logger,
poll_interval=0.1,
max_wait_time=10,
)
# When: 调用classify_products
results = client.classify_products(
sample_products, "system prompt", sample_requests
)
# Then: 验证结果
assert len(results) == 1
assert results[0].product_id == "P001"
assert results[0].category == "门票"
# 验证方法被调用
mock_file_handler.write_batch_requests.assert_called_once()
mock_client_instance.files.create.assert_called_once()
mock_client_instance.batches.create.assert_called_once()
mock_result_parser.parse.assert_called_once()
@patch("src.core.batch_client.OpenAI")
@patch("builtins.open", create=True)
def test_classify_products_without_file_handler(
self, mock_open, mock_openai, sample_products, sample_requests, mock_logger
):
"""测试没有file_handler时使用简单写入方式"""
import json
# 配置mock
mock_client_instance = MagicMock()
mock_openai.return_value = mock_client_instance
# Mock文件上传
mock_file_object = MagicMock()
mock_file_object.id = "file-123"
mock_client_instance.files.create.return_value = mock_file_object
# Mock创建batch
mock_batch_create = MagicMock()
mock_batch_create.id = "batch-456"
mock_client_instance.batches.create.return_value = mock_batch_create
# Mock任务状态查询
mock_batch_status = MagicMock()
mock_batch_status.id = "batch-456"
mock_batch_status.status = "completed"
mock_batch_status.input_file_id = "file-123"
mock_batch_status.output_file_id = "file-output-789"
mock_batch_status.error_file_id = None
mock_batch_status.request_counts = None
mock_client_instance.batches.retrieve.return_value = mock_batch_status
# Mock下载结果返回JSON格式的字符串
result_json = json.dumps(
{
"id": "1",
"custom_id": "P001",
"response": {
"body": {
"choices": [
{
"message": {
"content": json.dumps(
{
"category": "门票",
"type": "自然类",
"sub_type": "自然风光",
}
)
}
}
]
}
},
"error": None,
}
)
mock_content = MagicMock()
mock_content.text = result_json
mock_client_instance.files.content.return_value = mock_content
# 创建BatchClient不提供file_handler和result_parser
client = BatchClient(
api_key="test-key",
base_url="https://test.com",
logger=mock_logger,
poll_interval=0.1,
max_wait_time=10,
)
# When: 调用classify_products
results = client.classify_products(
sample_products, "system prompt", sample_requests
)
# Then: 验证结果
assert len(results) == 1
assert results[0].product_id == "P001"
assert results[0].category == "门票"