311 lines
10 KiB
Python
311 lines
10 KiB
Python
|
|
"""DirectClient的单元测试"""
|
|||
|
|
|
|||
|
|
import pytest
|
|||
|
|
import json
|
|||
|
|
import logging
|
|||
|
|
from unittest.mock import MagicMock, patch, Mock
|
|||
|
|
from openai.types.chat import ChatCompletion, ChatCompletionMessage
|
|||
|
|
from openai.types.chat.chat_completion import Choice
|
|||
|
|
|
|||
|
|
from src.core.direct_client import DirectClient
|
|||
|
|
from src.models import ProductInput
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.fixture
|
|||
|
|
def mock_logger():
|
|||
|
|
"""创建mock logger"""
|
|||
|
|
return MagicMock(spec=logging.Logger)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.fixture
|
|||
|
|
def sample_products():
|
|||
|
|
"""创建测试用的产品数据"""
|
|||
|
|
return [
|
|||
|
|
ProductInput(
|
|||
|
|
product_id="P001", product_name="黄山风景区门票", scenic_spot="黄山"
|
|||
|
|
),
|
|||
|
|
ProductInput(
|
|||
|
|
product_id="P002", product_name="豪华酒店标准间", scenic_spot="杭州"
|
|||
|
|
),
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.fixture
|
|||
|
|
def sample_requests():
|
|||
|
|
"""创建测试用的请求数据"""
|
|||
|
|
return [
|
|||
|
|
{
|
|||
|
|
"custom_id": "P001",
|
|||
|
|
"method": "POST",
|
|||
|
|
"url": "/v1/chat/completions",
|
|||
|
|
"body": {
|
|||
|
|
"model": "qwen-flash",
|
|||
|
|
"messages": [
|
|||
|
|
{"role": "system", "content": "你是一个产品分类助手"},
|
|||
|
|
{"role": "user", "content": "产品:黄山风景区门票"},
|
|||
|
|
],
|
|||
|
|
},
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"custom_id": "P002",
|
|||
|
|
"method": "POST",
|
|||
|
|
"url": "/v1/chat/completions",
|
|||
|
|
"body": {
|
|||
|
|
"model": "qwen-flash",
|
|||
|
|
"messages": [
|
|||
|
|
{"role": "system", "content": "你是一个产品分类助手"},
|
|||
|
|
{"role": "user", "content": "产品:豪华酒店标准间"},
|
|||
|
|
],
|
|||
|
|
},
|
|||
|
|
},
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestDirectClient:
|
|||
|
|
"""测试DirectClient类"""
|
|||
|
|
|
|||
|
|
def test_init(self, mock_logger):
|
|||
|
|
"""测试DirectClient初始化"""
|
|||
|
|
client = DirectClient(
|
|||
|
|
api_key="test-key",
|
|||
|
|
base_url="https://test.com",
|
|||
|
|
model="qwen-flash",
|
|||
|
|
logger=mock_logger,
|
|||
|
|
max_retries=5,
|
|||
|
|
retry_delay=2,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assert client._model == "qwen-flash"
|
|||
|
|
assert client._max_retries == 5
|
|||
|
|
assert client._retry_delay == 2
|
|||
|
|
assert client._logger == mock_logger
|
|||
|
|
|
|||
|
|
@patch("src.core.direct_client.OpenAI")
|
|||
|
|
def test_classify_products_success(
|
|||
|
|
self, mock_openai, sample_products, sample_requests, mock_logger
|
|||
|
|
):
|
|||
|
|
"""测试成功分类产品"""
|
|||
|
|
# 创建mock的API响应
|
|||
|
|
mock_response = Mock(spec=ChatCompletion)
|
|||
|
|
mock_message = Mock(spec=ChatCompletionMessage)
|
|||
|
|
mock_message.content = json.dumps(
|
|||
|
|
{"category": "门票", "type": "自然类", "sub_type": "自然风光"}
|
|||
|
|
)
|
|||
|
|
mock_choice = Mock(spec=Choice)
|
|||
|
|
mock_choice.message = mock_message
|
|||
|
|
mock_response.choices = [mock_choice]
|
|||
|
|
|
|||
|
|
# 配置mock客户端
|
|||
|
|
mock_client_instance = MagicMock()
|
|||
|
|
mock_client_instance.chat.completions.create.return_value = mock_response
|
|||
|
|
mock_openai.return_value = mock_client_instance
|
|||
|
|
|
|||
|
|
# 创建DirectClient并调用
|
|||
|
|
client = DirectClient(
|
|||
|
|
api_key="test-key",
|
|||
|
|
base_url="https://test.com",
|
|||
|
|
model="qwen-flash",
|
|||
|
|
logger=mock_logger,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
results = client.classify_products(
|
|||
|
|
sample_products[:1], "test system prompt", sample_requests[:1]
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 验证结果
|
|||
|
|
assert len(results) == 1
|
|||
|
|
assert results[0].product_id == "P001"
|
|||
|
|
assert results[0].category == "门票"
|
|||
|
|
assert results[0].type == "自然类"
|
|||
|
|
assert results[0].sub_type == "自然风光"
|
|||
|
|
|
|||
|
|
# 验证API被调用
|
|||
|
|
mock_client_instance.chat.completions.create.assert_called_once()
|
|||
|
|
|
|||
|
|
@patch("src.core.direct_client.OpenAI")
|
|||
|
|
def test_classify_products_with_retry(
|
|||
|
|
self, mock_openai, sample_products, sample_requests, mock_logger
|
|||
|
|
):
|
|||
|
|
"""测试带重试的产品分类"""
|
|||
|
|
# 第一次调用失败,第二次成功
|
|||
|
|
mock_response = Mock(spec=ChatCompletion)
|
|||
|
|
mock_message = Mock(spec=ChatCompletionMessage)
|
|||
|
|
mock_message.content = json.dumps(
|
|||
|
|
{"category": "门票", "type": "自然类", "sub_type": "自然风光"}
|
|||
|
|
)
|
|||
|
|
mock_choice = Mock(spec=Choice)
|
|||
|
|
mock_choice.message = mock_message
|
|||
|
|
mock_response.choices = [mock_choice]
|
|||
|
|
|
|||
|
|
mock_client_instance = MagicMock()
|
|||
|
|
mock_client_instance.chat.completions.create.side_effect = [
|
|||
|
|
Exception("API Error"), # 第一次失败
|
|||
|
|
mock_response, # 第二次成功
|
|||
|
|
]
|
|||
|
|
mock_openai.return_value = mock_client_instance
|
|||
|
|
|
|||
|
|
client = DirectClient(
|
|||
|
|
api_key="test-key",
|
|||
|
|
base_url="https://test.com",
|
|||
|
|
model="qwen-flash",
|
|||
|
|
logger=mock_logger,
|
|||
|
|
max_retries=3,
|
|||
|
|
retry_delay=0, # 测试时不延迟
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
results = client.classify_products(
|
|||
|
|
sample_products[:1], "test system prompt", sample_requests[:1]
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 验证结果
|
|||
|
|
assert len(results) == 1
|
|||
|
|
assert results[0].product_id == "P001"
|
|||
|
|
|
|||
|
|
# 验证API被调用了2次(1次失败+1次成功)
|
|||
|
|
assert mock_client_instance.chat.completions.create.call_count == 2
|
|||
|
|
|
|||
|
|
@patch("src.core.direct_client.OpenAI")
|
|||
|
|
def test_classify_products_all_retries_failed(
|
|||
|
|
self, mock_openai, sample_products, sample_requests, mock_logger
|
|||
|
|
):
|
|||
|
|
"""测试所有重试都失败的情况"""
|
|||
|
|
mock_client_instance = MagicMock()
|
|||
|
|
mock_client_instance.chat.completions.create.side_effect = Exception(
|
|||
|
|
"API Error"
|
|||
|
|
)
|
|||
|
|
mock_openai.return_value = mock_client_instance
|
|||
|
|
|
|||
|
|
client = DirectClient(
|
|||
|
|
api_key="test-key",
|
|||
|
|
base_url="https://test.com",
|
|||
|
|
model="qwen-flash",
|
|||
|
|
logger=mock_logger,
|
|||
|
|
max_retries=2,
|
|||
|
|
retry_delay=0,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
results = client.classify_products(
|
|||
|
|
sample_products[:1], "test system prompt", sample_requests[:1]
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 所有重试都失败,返回空列表
|
|||
|
|
assert len(results) == 0
|
|||
|
|
|
|||
|
|
# 验证API被调用了max_retries次
|
|||
|
|
assert mock_client_instance.chat.completions.create.call_count == 2
|
|||
|
|
|
|||
|
|
@patch("src.core.direct_client.OpenAI")
|
|||
|
|
def test_parse_response_invalid_json(
|
|||
|
|
self, mock_openai, sample_products, sample_requests, mock_logger
|
|||
|
|
):
|
|||
|
|
"""测试解析无效JSON响应"""
|
|||
|
|
# 创建返回无效JSON的mock响应
|
|||
|
|
mock_response = Mock(spec=ChatCompletion)
|
|||
|
|
mock_message = Mock(spec=ChatCompletionMessage)
|
|||
|
|
mock_message.content = "这不是一个JSON格式的响应"
|
|||
|
|
mock_choice = Mock(spec=Choice)
|
|||
|
|
mock_choice.message = mock_message
|
|||
|
|
mock_response.choices = [mock_choice]
|
|||
|
|
|
|||
|
|
mock_client_instance = MagicMock()
|
|||
|
|
mock_client_instance.chat.completions.create.return_value = mock_response
|
|||
|
|
mock_openai.return_value = mock_client_instance
|
|||
|
|
|
|||
|
|
client = DirectClient(
|
|||
|
|
api_key="test-key",
|
|||
|
|
base_url="https://test.com",
|
|||
|
|
model="qwen-flash",
|
|||
|
|
logger=mock_logger,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
results = client.classify_products(
|
|||
|
|
sample_products[:1], "test system prompt", sample_requests[:1]
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 无法解析,返回空列表
|
|||
|
|
assert len(results) == 0
|
|||
|
|
|
|||
|
|
@patch("src.core.direct_client.OpenAI")
|
|||
|
|
def test_classify_multiple_products(
|
|||
|
|
self, mock_openai, sample_products, sample_requests, mock_logger
|
|||
|
|
):
|
|||
|
|
"""测试批量分类多个产品"""
|
|||
|
|
|
|||
|
|
# 为两个产品创建不同的响应
|
|||
|
|
def create_response(category, type_val, sub_type):
|
|||
|
|
mock_response = Mock(spec=ChatCompletion)
|
|||
|
|
mock_message = Mock(spec=ChatCompletionMessage)
|
|||
|
|
mock_message.content = json.dumps(
|
|||
|
|
{"category": category, "type": type_val, "sub_type": sub_type}
|
|||
|
|
)
|
|||
|
|
mock_choice = Mock(spec=Choice)
|
|||
|
|
mock_choice.message = mock_message
|
|||
|
|
mock_response.choices = [mock_choice]
|
|||
|
|
return mock_response
|
|||
|
|
|
|||
|
|
mock_client_instance = MagicMock()
|
|||
|
|
mock_client_instance.chat.completions.create.side_effect = [
|
|||
|
|
create_response("门票", "自然类", "自然风光"),
|
|||
|
|
create_response("住宿", "商务酒店", ""),
|
|||
|
|
]
|
|||
|
|
mock_openai.return_value = mock_client_instance
|
|||
|
|
|
|||
|
|
client = DirectClient(
|
|||
|
|
api_key="test-key",
|
|||
|
|
base_url="https://test.com",
|
|||
|
|
model="qwen-flash",
|
|||
|
|
logger=mock_logger,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
results = client.classify_products(
|
|||
|
|
sample_products, "test system prompt", sample_requests
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 验证结果
|
|||
|
|
assert len(results) == 2
|
|||
|
|
assert results[0].product_id == "P001"
|
|||
|
|
assert results[0].category == "门票"
|
|||
|
|
assert results[1].product_id == "P002"
|
|||
|
|
assert results[1].category == "住宿"
|
|||
|
|
|
|||
|
|
def test_classify_products_missing_request(self, sample_products, mock_logger):
|
|||
|
|
"""测试产品没有对应请求的情况"""
|
|||
|
|
client = DirectClient(
|
|||
|
|
api_key="test-key",
|
|||
|
|
base_url="https://test.com",
|
|||
|
|
model="qwen-flash",
|
|||
|
|
logger=mock_logger,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 提供空的请求列表
|
|||
|
|
results = client.classify_products(sample_products, "test system prompt", [])
|
|||
|
|
|
|||
|
|
# 没有请求,无法处理任何产品
|
|||
|
|
assert len(results) == 0
|
|||
|
|
|
|||
|
|
@patch("src.core.direct_client.OpenAI")
|
|||
|
|
def test_classify_products_empty_choices(
|
|||
|
|
self, mock_openai, sample_products, sample_requests, mock_logger
|
|||
|
|
):
|
|||
|
|
"""测试API响应没有choices的情况"""
|
|||
|
|
mock_response = Mock(spec=ChatCompletion)
|
|||
|
|
mock_response.choices = [] # 空的choices列表
|
|||
|
|
|
|||
|
|
mock_client_instance = MagicMock()
|
|||
|
|
mock_client_instance.chat.completions.create.return_value = mock_response
|
|||
|
|
mock_openai.return_value = mock_client_instance
|
|||
|
|
|
|||
|
|
client = DirectClient(
|
|||
|
|
api_key="test-key",
|
|||
|
|
base_url="https://test.com",
|
|||
|
|
model="qwen-flash",
|
|||
|
|
logger=mock_logger,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
results = client.classify_products(
|
|||
|
|
sample_products[:1], "test system prompt", sample_requests[:1]
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 无法解析,返回空列表
|
|||
|
|
assert len(results) == 0
|