191 lines
6.1 KiB
Python
191 lines
6.1 KiB
Python
|
|
"""ClientFactory的单元测试"""
|
|||
|
|
|
|||
|
|
import pytest
|
|||
|
|
import logging
|
|||
|
|
from unittest.mock import MagicMock, patch
|
|||
|
|
|
|||
|
|
from src.core.client_factory import LLMClientFactory
|
|||
|
|
from src.models import LLMClientMode
|
|||
|
|
from src.core.batch_client import BatchClient
|
|||
|
|
from src.core.direct_client import DirectClient
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.fixture
|
|||
|
|
def mock_logger():
|
|||
|
|
"""创建mock logger"""
|
|||
|
|
return MagicMock(spec=logging.Logger)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.fixture
|
|||
|
|
def mock_file_handler():
|
|||
|
|
"""创建mock file handler"""
|
|||
|
|
return MagicMock()
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.fixture
|
|||
|
|
def mock_result_parser():
|
|||
|
|
"""创建mock result parser"""
|
|||
|
|
return MagicMock()
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestLLMClientFactory:
|
|||
|
|
"""测试LLMClientFactory类"""
|
|||
|
|
|
|||
|
|
@patch("src.core.batch_client.OpenAI")
|
|||
|
|
def test_create_batch_client(
|
|||
|
|
self, mock_openai, mock_logger, mock_file_handler, mock_result_parser
|
|||
|
|
):
|
|||
|
|
"""测试创建Batch模式客户端"""
|
|||
|
|
# 创建客户端
|
|||
|
|
client = LLMClientFactory.create_client(
|
|||
|
|
mode=LLMClientMode.BATCH,
|
|||
|
|
api_key="test-key",
|
|||
|
|
base_url="https://test.com",
|
|||
|
|
model="qwen-flash",
|
|||
|
|
logger=mock_logger,
|
|||
|
|
file_handler=mock_file_handler,
|
|||
|
|
result_parser=mock_result_parser,
|
|||
|
|
completion_window="24h",
|
|||
|
|
poll_interval=60,
|
|||
|
|
max_wait_time=3600,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 验证返回的是BatchClient实例
|
|||
|
|
assert isinstance(client, BatchClient)
|
|||
|
|
|
|||
|
|
# 验证日志被记录
|
|||
|
|
assert mock_logger.info.call_count >= 2
|
|||
|
|
|
|||
|
|
@patch("src.core.direct_client.OpenAI")
|
|||
|
|
def test_create_direct_client(self, mock_openai, mock_logger):
|
|||
|
|
"""测试创建Direct模式客户端"""
|
|||
|
|
# 创建客户端
|
|||
|
|
client = LLMClientFactory.create_client(
|
|||
|
|
mode=LLMClientMode.DIRECT,
|
|||
|
|
api_key="test-key",
|
|||
|
|
base_url="https://test.com",
|
|||
|
|
model="qwen-flash",
|
|||
|
|
logger=mock_logger,
|
|||
|
|
max_retries=3,
|
|||
|
|
retry_delay=1,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 验证返回的是DirectClient实例
|
|||
|
|
assert isinstance(client, DirectClient)
|
|||
|
|
|
|||
|
|
# 验证日志被记录
|
|||
|
|
assert mock_logger.info.call_count >= 2
|
|||
|
|
|
|||
|
|
def test_create_client_invalid_mode(self, mock_logger):
|
|||
|
|
"""测试创建客户端时传入无效模式"""
|
|||
|
|
|
|||
|
|
# 尝试创建一个无效模式的客户端
|
|||
|
|
# 创建一个假的枚举值对象(有value属性)
|
|||
|
|
class FakeMode:
|
|||
|
|
value = "invalid"
|
|||
|
|
|
|||
|
|
with pytest.raises(ValueError) as exc_info:
|
|||
|
|
LLMClientFactory.create_client(
|
|||
|
|
mode=FakeMode(), # type: ignore
|
|||
|
|
api_key="test-key",
|
|||
|
|
base_url="https://test.com",
|
|||
|
|
model="qwen-flash",
|
|||
|
|
logger=mock_logger,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assert "不支持的LLM客户端模式" in str(exc_info.value)
|
|||
|
|
|
|||
|
|
@patch("src.core.batch_client.OpenAI")
|
|||
|
|
def test_create_batch_client_with_defaults(self, mock_openai, mock_logger):
|
|||
|
|
"""测试使用默认参数创建Batch客户端"""
|
|||
|
|
client = LLMClientFactory.create_client(
|
|||
|
|
mode=LLMClientMode.BATCH,
|
|||
|
|
api_key="test-key",
|
|||
|
|
base_url="https://test.com",
|
|||
|
|
model="qwen-flash",
|
|||
|
|
logger=mock_logger,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 验证返回的是BatchClient实例
|
|||
|
|
assert isinstance(client, BatchClient)
|
|||
|
|
|
|||
|
|
@patch("src.core.direct_client.OpenAI")
|
|||
|
|
def test_create_direct_client_with_defaults(self, mock_openai, mock_logger):
|
|||
|
|
"""测试使用默认参数创建Direct客户端"""
|
|||
|
|
client = LLMClientFactory.create_client(
|
|||
|
|
mode=LLMClientMode.DIRECT,
|
|||
|
|
api_key="test-key",
|
|||
|
|
base_url="https://test.com",
|
|||
|
|
model="qwen-flash",
|
|||
|
|
logger=mock_logger,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 验证返回的是DirectClient实例
|
|||
|
|
assert isinstance(client, DirectClient)
|
|||
|
|
|
|||
|
|
@patch("src.core.direct_client.OpenAI")
|
|||
|
|
@patch("src.core.batch_client.OpenAI")
|
|||
|
|
def test_factory_returns_different_instances(
|
|||
|
|
self, mock_batch_openai, mock_direct_openai, mock_logger
|
|||
|
|
):
|
|||
|
|
"""测试工厂为不同模式返回不同实例"""
|
|||
|
|
# 创建Batch客户端
|
|||
|
|
batch_client = LLMClientFactory.create_client(
|
|||
|
|
mode=LLMClientMode.BATCH,
|
|||
|
|
api_key="test-key",
|
|||
|
|
base_url="https://test.com",
|
|||
|
|
model="qwen-flash",
|
|||
|
|
logger=mock_logger,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 创建Direct客户端
|
|||
|
|
direct_client = LLMClientFactory.create_client(
|
|||
|
|
mode=LLMClientMode.DIRECT,
|
|||
|
|
api_key="test-key",
|
|||
|
|
base_url="https://test.com",
|
|||
|
|
model="qwen-flash",
|
|||
|
|
logger=mock_logger,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 验证返回的是不同类型的实例
|
|||
|
|
assert isinstance(batch_client, BatchClient)
|
|||
|
|
assert isinstance(direct_client, DirectClient)
|
|||
|
|
assert not isinstance(batch_client, type(direct_client))
|
|||
|
|
|
|||
|
|
@patch("src.core.batch_client.OpenAI")
|
|||
|
|
def test_create_batch_client_custom_params(
|
|||
|
|
self, mock_openai, mock_logger, mock_file_handler, mock_result_parser
|
|||
|
|
):
|
|||
|
|
"""测试使用自定义参数创建Batch客户端"""
|
|||
|
|
client = LLMClientFactory.create_client(
|
|||
|
|
mode=LLMClientMode.BATCH,
|
|||
|
|
api_key="custom-key",
|
|||
|
|
base_url="https://custom.com",
|
|||
|
|
model="qwen-max",
|
|||
|
|
logger=mock_logger,
|
|||
|
|
file_handler=mock_file_handler,
|
|||
|
|
result_parser=mock_result_parser,
|
|||
|
|
completion_window="48h",
|
|||
|
|
poll_interval=120,
|
|||
|
|
max_wait_time=7200,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 验证返回的是BatchClient实例
|
|||
|
|
assert isinstance(client, BatchClient)
|
|||
|
|
|
|||
|
|
@patch("src.core.direct_client.OpenAI")
|
|||
|
|
def test_create_direct_client_custom_params(self, mock_openai, mock_logger):
|
|||
|
|
"""测试使用自定义参数创建Direct客户端"""
|
|||
|
|
client = LLMClientFactory.create_client(
|
|||
|
|
mode=LLMClientMode.DIRECT,
|
|||
|
|
api_key="custom-key",
|
|||
|
|
base_url="https://custom.com",
|
|||
|
|
model="qwen-max",
|
|||
|
|
logger=mock_logger,
|
|||
|
|
max_retries=5,
|
|||
|
|
retry_delay=2,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 验证返回的是DirectClient实例
|
|||
|
|
assert isinstance(client, DirectClient)
|