CategorizeLabel/tests/test_direct_client.py
2025-10-15 17:19:26 +08:00

311 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""DirectClient的单元测试"""
import pytest
import json
import logging
from unittest.mock import MagicMock, patch, Mock
from openai.types.chat import ChatCompletion, ChatCompletionMessage
from openai.types.chat.chat_completion import Choice
from src.core.direct_client import DirectClient
from src.models import ProductInput
@pytest.fixture
def mock_logger():
"""创建mock logger"""
return MagicMock(spec=logging.Logger)
@pytest.fixture
def sample_products():
"""创建测试用的产品数据"""
return [
ProductInput(
product_id="P001", product_name="黄山风景区门票", scenic_spot="黄山"
),
ProductInput(
product_id="P002", product_name="豪华酒店标准间", scenic_spot="杭州"
),
]
@pytest.fixture
def sample_requests():
"""创建测试用的请求数据"""
return [
{
"custom_id": "P001",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": "qwen-flash",
"messages": [
{"role": "system", "content": "你是一个产品分类助手"},
{"role": "user", "content": "产品:黄山风景区门票"},
],
},
},
{
"custom_id": "P002",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": "qwen-flash",
"messages": [
{"role": "system", "content": "你是一个产品分类助手"},
{"role": "user", "content": "产品:豪华酒店标准间"},
],
},
},
]
class TestDirectClient:
"""测试DirectClient类"""
def test_init(self, mock_logger):
"""测试DirectClient初始化"""
client = DirectClient(
api_key="test-key",
base_url="https://test.com",
model="qwen-flash",
logger=mock_logger,
max_retries=5,
retry_delay=2,
)
assert client._model == "qwen-flash"
assert client._max_retries == 5
assert client._retry_delay == 2
assert client._logger == mock_logger
@patch("src.core.direct_client.OpenAI")
def test_classify_products_success(
self, mock_openai, sample_products, sample_requests, mock_logger
):
"""测试成功分类产品"""
# 创建mock的API响应
mock_response = Mock(spec=ChatCompletion)
mock_message = Mock(spec=ChatCompletionMessage)
mock_message.content = json.dumps(
{"category": "门票", "type": "自然类", "sub_type": "自然风光"}
)
mock_choice = Mock(spec=Choice)
mock_choice.message = mock_message
mock_response.choices = [mock_choice]
# 配置mock客户端
mock_client_instance = MagicMock()
mock_client_instance.chat.completions.create.return_value = mock_response
mock_openai.return_value = mock_client_instance
# 创建DirectClient并调用
client = DirectClient(
api_key="test-key",
base_url="https://test.com",
model="qwen-flash",
logger=mock_logger,
)
results = client.classify_products(
sample_products[:1], "test system prompt", sample_requests[:1]
)
# 验证结果
assert len(results) == 1
assert results[0].product_id == "P001"
assert results[0].category == "门票"
assert results[0].type == "自然类"
assert results[0].sub_type == "自然风光"
# 验证API被调用
mock_client_instance.chat.completions.create.assert_called_once()
@patch("src.core.direct_client.OpenAI")
def test_classify_products_with_retry(
self, mock_openai, sample_products, sample_requests, mock_logger
):
"""测试带重试的产品分类"""
# 第一次调用失败,第二次成功
mock_response = Mock(spec=ChatCompletion)
mock_message = Mock(spec=ChatCompletionMessage)
mock_message.content = json.dumps(
{"category": "门票", "type": "自然类", "sub_type": "自然风光"}
)
mock_choice = Mock(spec=Choice)
mock_choice.message = mock_message
mock_response.choices = [mock_choice]
mock_client_instance = MagicMock()
mock_client_instance.chat.completions.create.side_effect = [
Exception("API Error"), # 第一次失败
mock_response, # 第二次成功
]
mock_openai.return_value = mock_client_instance
client = DirectClient(
api_key="test-key",
base_url="https://test.com",
model="qwen-flash",
logger=mock_logger,
max_retries=3,
retry_delay=0, # 测试时不延迟
)
results = client.classify_products(
sample_products[:1], "test system prompt", sample_requests[:1]
)
# 验证结果
assert len(results) == 1
assert results[0].product_id == "P001"
# 验证API被调用了2次1次失败+1次成功
assert mock_client_instance.chat.completions.create.call_count == 2
@patch("src.core.direct_client.OpenAI")
def test_classify_products_all_retries_failed(
self, mock_openai, sample_products, sample_requests, mock_logger
):
"""测试所有重试都失败的情况"""
mock_client_instance = MagicMock()
mock_client_instance.chat.completions.create.side_effect = Exception(
"API Error"
)
mock_openai.return_value = mock_client_instance
client = DirectClient(
api_key="test-key",
base_url="https://test.com",
model="qwen-flash",
logger=mock_logger,
max_retries=2,
retry_delay=0,
)
results = client.classify_products(
sample_products[:1], "test system prompt", sample_requests[:1]
)
# 所有重试都失败,返回空列表
assert len(results) == 0
# 验证API被调用了max_retries次
assert mock_client_instance.chat.completions.create.call_count == 2
@patch("src.core.direct_client.OpenAI")
def test_parse_response_invalid_json(
self, mock_openai, sample_products, sample_requests, mock_logger
):
"""测试解析无效JSON响应"""
# 创建返回无效JSON的mock响应
mock_response = Mock(spec=ChatCompletion)
mock_message = Mock(spec=ChatCompletionMessage)
mock_message.content = "这不是一个JSON格式的响应"
mock_choice = Mock(spec=Choice)
mock_choice.message = mock_message
mock_response.choices = [mock_choice]
mock_client_instance = MagicMock()
mock_client_instance.chat.completions.create.return_value = mock_response
mock_openai.return_value = mock_client_instance
client = DirectClient(
api_key="test-key",
base_url="https://test.com",
model="qwen-flash",
logger=mock_logger,
)
results = client.classify_products(
sample_products[:1], "test system prompt", sample_requests[:1]
)
# 无法解析,返回空列表
assert len(results) == 0
@patch("src.core.direct_client.OpenAI")
def test_classify_multiple_products(
self, mock_openai, sample_products, sample_requests, mock_logger
):
"""测试批量分类多个产品"""
# 为两个产品创建不同的响应
def create_response(category, type_val, sub_type):
mock_response = Mock(spec=ChatCompletion)
mock_message = Mock(spec=ChatCompletionMessage)
mock_message.content = json.dumps(
{"category": category, "type": type_val, "sub_type": sub_type}
)
mock_choice = Mock(spec=Choice)
mock_choice.message = mock_message
mock_response.choices = [mock_choice]
return mock_response
mock_client_instance = MagicMock()
mock_client_instance.chat.completions.create.side_effect = [
create_response("门票", "自然类", "自然风光"),
create_response("住宿", "商务酒店", ""),
]
mock_openai.return_value = mock_client_instance
client = DirectClient(
api_key="test-key",
base_url="https://test.com",
model="qwen-flash",
logger=mock_logger,
)
results = client.classify_products(
sample_products, "test system prompt", sample_requests
)
# 验证结果
assert len(results) == 2
assert results[0].product_id == "P001"
assert results[0].category == "门票"
assert results[1].product_id == "P002"
assert results[1].category == "住宿"
def test_classify_products_missing_request(self, sample_products, mock_logger):
"""测试产品没有对应请求的情况"""
client = DirectClient(
api_key="test-key",
base_url="https://test.com",
model="qwen-flash",
logger=mock_logger,
)
# 提供空的请求列表
results = client.classify_products(sample_products, "test system prompt", [])
# 没有请求,无法处理任何产品
assert len(results) == 0
@patch("src.core.direct_client.OpenAI")
def test_classify_products_empty_choices(
self, mock_openai, sample_products, sample_requests, mock_logger
):
"""测试API响应没有choices的情况"""
mock_response = Mock(spec=ChatCompletion)
mock_response.choices = [] # 空的choices列表
mock_client_instance = MagicMock()
mock_client_instance.chat.completions.create.return_value = mock_response
mock_openai.return_value = mock_client_instance
client = DirectClient(
api_key="test-key",
base_url="https://test.com",
model="qwen-flash",
logger=mock_logger,
)
results = client.classify_products(
sample_products[:1], "test system prompt", sample_requests[:1]
)
# 无法解析,返回空列表
assert len(results) == 0