CategorizeLabel/tests/test_direct_client.py

311 lines
10 KiB
Python
Raw Permalink Normal View History

2025-10-15 17:19:26 +08:00
"""DirectClient的单元测试"""
import pytest
import json
import logging
from unittest.mock import MagicMock, patch, Mock
from openai.types.chat import ChatCompletion, ChatCompletionMessage
from openai.types.chat.chat_completion import Choice
from src.core.direct_client import DirectClient
from src.models import ProductInput
@pytest.fixture
def mock_logger():
"""创建mock logger"""
return MagicMock(spec=logging.Logger)
@pytest.fixture
def sample_products():
"""创建测试用的产品数据"""
return [
ProductInput(
product_id="P001", product_name="黄山风景区门票", scenic_spot="黄山"
),
ProductInput(
product_id="P002", product_name="豪华酒店标准间", scenic_spot="杭州"
),
]
@pytest.fixture
def sample_requests():
"""创建测试用的请求数据"""
return [
{
"custom_id": "P001",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": "qwen-flash",
"messages": [
{"role": "system", "content": "你是一个产品分类助手"},
{"role": "user", "content": "产品:黄山风景区门票"},
],
},
},
{
"custom_id": "P002",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": "qwen-flash",
"messages": [
{"role": "system", "content": "你是一个产品分类助手"},
{"role": "user", "content": "产品:豪华酒店标准间"},
],
},
},
]
class TestDirectClient:
"""测试DirectClient类"""
def test_init(self, mock_logger):
"""测试DirectClient初始化"""
client = DirectClient(
api_key="test-key",
base_url="https://test.com",
model="qwen-flash",
logger=mock_logger,
max_retries=5,
retry_delay=2,
)
assert client._model == "qwen-flash"
assert client._max_retries == 5
assert client._retry_delay == 2
assert client._logger == mock_logger
@patch("src.core.direct_client.OpenAI")
def test_classify_products_success(
self, mock_openai, sample_products, sample_requests, mock_logger
):
"""测试成功分类产品"""
# 创建mock的API响应
mock_response = Mock(spec=ChatCompletion)
mock_message = Mock(spec=ChatCompletionMessage)
mock_message.content = json.dumps(
{"category": "门票", "type": "自然类", "sub_type": "自然风光"}
)
mock_choice = Mock(spec=Choice)
mock_choice.message = mock_message
mock_response.choices = [mock_choice]
# 配置mock客户端
mock_client_instance = MagicMock()
mock_client_instance.chat.completions.create.return_value = mock_response
mock_openai.return_value = mock_client_instance
# 创建DirectClient并调用
client = DirectClient(
api_key="test-key",
base_url="https://test.com",
model="qwen-flash",
logger=mock_logger,
)
results = client.classify_products(
sample_products[:1], "test system prompt", sample_requests[:1]
)
# 验证结果
assert len(results) == 1
assert results[0].product_id == "P001"
assert results[0].category == "门票"
assert results[0].type == "自然类"
assert results[0].sub_type == "自然风光"
# 验证API被调用
mock_client_instance.chat.completions.create.assert_called_once()
@patch("src.core.direct_client.OpenAI")
def test_classify_products_with_retry(
self, mock_openai, sample_products, sample_requests, mock_logger
):
"""测试带重试的产品分类"""
# 第一次调用失败,第二次成功
mock_response = Mock(spec=ChatCompletion)
mock_message = Mock(spec=ChatCompletionMessage)
mock_message.content = json.dumps(
{"category": "门票", "type": "自然类", "sub_type": "自然风光"}
)
mock_choice = Mock(spec=Choice)
mock_choice.message = mock_message
mock_response.choices = [mock_choice]
mock_client_instance = MagicMock()
mock_client_instance.chat.completions.create.side_effect = [
Exception("API Error"), # 第一次失败
mock_response, # 第二次成功
]
mock_openai.return_value = mock_client_instance
client = DirectClient(
api_key="test-key",
base_url="https://test.com",
model="qwen-flash",
logger=mock_logger,
max_retries=3,
retry_delay=0, # 测试时不延迟
)
results = client.classify_products(
sample_products[:1], "test system prompt", sample_requests[:1]
)
# 验证结果
assert len(results) == 1
assert results[0].product_id == "P001"
# 验证API被调用了2次1次失败+1次成功
assert mock_client_instance.chat.completions.create.call_count == 2
@patch("src.core.direct_client.OpenAI")
def test_classify_products_all_retries_failed(
self, mock_openai, sample_products, sample_requests, mock_logger
):
"""测试所有重试都失败的情况"""
mock_client_instance = MagicMock()
mock_client_instance.chat.completions.create.side_effect = Exception(
"API Error"
)
mock_openai.return_value = mock_client_instance
client = DirectClient(
api_key="test-key",
base_url="https://test.com",
model="qwen-flash",
logger=mock_logger,
max_retries=2,
retry_delay=0,
)
results = client.classify_products(
sample_products[:1], "test system prompt", sample_requests[:1]
)
# 所有重试都失败,返回空列表
assert len(results) == 0
# 验证API被调用了max_retries次
assert mock_client_instance.chat.completions.create.call_count == 2
@patch("src.core.direct_client.OpenAI")
def test_parse_response_invalid_json(
self, mock_openai, sample_products, sample_requests, mock_logger
):
"""测试解析无效JSON响应"""
# 创建返回无效JSON的mock响应
mock_response = Mock(spec=ChatCompletion)
mock_message = Mock(spec=ChatCompletionMessage)
mock_message.content = "这不是一个JSON格式的响应"
mock_choice = Mock(spec=Choice)
mock_choice.message = mock_message
mock_response.choices = [mock_choice]
mock_client_instance = MagicMock()
mock_client_instance.chat.completions.create.return_value = mock_response
mock_openai.return_value = mock_client_instance
client = DirectClient(
api_key="test-key",
base_url="https://test.com",
model="qwen-flash",
logger=mock_logger,
)
results = client.classify_products(
sample_products[:1], "test system prompt", sample_requests[:1]
)
# 无法解析,返回空列表
assert len(results) == 0
@patch("src.core.direct_client.OpenAI")
def test_classify_multiple_products(
self, mock_openai, sample_products, sample_requests, mock_logger
):
"""测试批量分类多个产品"""
# 为两个产品创建不同的响应
def create_response(category, type_val, sub_type):
mock_response = Mock(spec=ChatCompletion)
mock_message = Mock(spec=ChatCompletionMessage)
mock_message.content = json.dumps(
{"category": category, "type": type_val, "sub_type": sub_type}
)
mock_choice = Mock(spec=Choice)
mock_choice.message = mock_message
mock_response.choices = [mock_choice]
return mock_response
mock_client_instance = MagicMock()
mock_client_instance.chat.completions.create.side_effect = [
create_response("门票", "自然类", "自然风光"),
create_response("住宿", "商务酒店", ""),
]
mock_openai.return_value = mock_client_instance
client = DirectClient(
api_key="test-key",
base_url="https://test.com",
model="qwen-flash",
logger=mock_logger,
)
results = client.classify_products(
sample_products, "test system prompt", sample_requests
)
# 验证结果
assert len(results) == 2
assert results[0].product_id == "P001"
assert results[0].category == "门票"
assert results[1].product_id == "P002"
assert results[1].category == "住宿"
def test_classify_products_missing_request(self, sample_products, mock_logger):
"""测试产品没有对应请求的情况"""
client = DirectClient(
api_key="test-key",
base_url="https://test.com",
model="qwen-flash",
logger=mock_logger,
)
# 提供空的请求列表
results = client.classify_products(sample_products, "test system prompt", [])
# 没有请求,无法处理任何产品
assert len(results) == 0
@patch("src.core.direct_client.OpenAI")
def test_classify_products_empty_choices(
self, mock_openai, sample_products, sample_requests, mock_logger
):
"""测试API响应没有choices的情况"""
mock_response = Mock(spec=ChatCompletion)
mock_response.choices = [] # 空的choices列表
mock_client_instance = MagicMock()
mock_client_instance.chat.completions.create.return_value = mock_response
mock_openai.return_value = mock_client_instance
client = DirectClient(
api_key="test-key",
base_url="https://test.com",
model="qwen-flash",
logger=mock_logger,
)
results = client.classify_products(
sample_products[:1], "test system prompt", sample_requests[:1]
)
# 无法解析,返回空列表
assert len(results) == 0