191 lines
6.1 KiB
Python
191 lines
6.1 KiB
Python
"""ClientFactory的单元测试"""
|
||
|
||
import pytest
|
||
import logging
|
||
from unittest.mock import MagicMock, patch
|
||
|
||
from src.core.client_factory import LLMClientFactory
|
||
from src.models import LLMClientMode
|
||
from src.core.batch_client import BatchClient
|
||
from src.core.direct_client import DirectClient
|
||
|
||
|
||
@pytest.fixture
|
||
def mock_logger():
|
||
"""创建mock logger"""
|
||
return MagicMock(spec=logging.Logger)
|
||
|
||
|
||
@pytest.fixture
|
||
def mock_file_handler():
|
||
"""创建mock file handler"""
|
||
return MagicMock()
|
||
|
||
|
||
@pytest.fixture
|
||
def mock_result_parser():
|
||
"""创建mock result parser"""
|
||
return MagicMock()
|
||
|
||
|
||
class TestLLMClientFactory:
|
||
"""测试LLMClientFactory类"""
|
||
|
||
@patch("src.core.batch_client.OpenAI")
|
||
def test_create_batch_client(
|
||
self, mock_openai, mock_logger, mock_file_handler, mock_result_parser
|
||
):
|
||
"""测试创建Batch模式客户端"""
|
||
# 创建客户端
|
||
client = LLMClientFactory.create_client(
|
||
mode=LLMClientMode.BATCH,
|
||
api_key="test-key",
|
||
base_url="https://test.com",
|
||
model="qwen-flash",
|
||
logger=mock_logger,
|
||
file_handler=mock_file_handler,
|
||
result_parser=mock_result_parser,
|
||
completion_window="24h",
|
||
poll_interval=60,
|
||
max_wait_time=3600,
|
||
)
|
||
|
||
# 验证返回的是BatchClient实例
|
||
assert isinstance(client, BatchClient)
|
||
|
||
# 验证日志被记录
|
||
assert mock_logger.info.call_count >= 2
|
||
|
||
@patch("src.core.direct_client.OpenAI")
|
||
def test_create_direct_client(self, mock_openai, mock_logger):
|
||
"""测试创建Direct模式客户端"""
|
||
# 创建客户端
|
||
client = LLMClientFactory.create_client(
|
||
mode=LLMClientMode.DIRECT,
|
||
api_key="test-key",
|
||
base_url="https://test.com",
|
||
model="qwen-flash",
|
||
logger=mock_logger,
|
||
max_retries=3,
|
||
retry_delay=1,
|
||
)
|
||
|
||
# 验证返回的是DirectClient实例
|
||
assert isinstance(client, DirectClient)
|
||
|
||
# 验证日志被记录
|
||
assert mock_logger.info.call_count >= 2
|
||
|
||
def test_create_client_invalid_mode(self, mock_logger):
|
||
"""测试创建客户端时传入无效模式"""
|
||
|
||
# 尝试创建一个无效模式的客户端
|
||
# 创建一个假的枚举值对象(有value属性)
|
||
class FakeMode:
|
||
value = "invalid"
|
||
|
||
with pytest.raises(ValueError) as exc_info:
|
||
LLMClientFactory.create_client(
|
||
mode=FakeMode(), # type: ignore
|
||
api_key="test-key",
|
||
base_url="https://test.com",
|
||
model="qwen-flash",
|
||
logger=mock_logger,
|
||
)
|
||
|
||
assert "不支持的LLM客户端模式" in str(exc_info.value)
|
||
|
||
@patch("src.core.batch_client.OpenAI")
|
||
def test_create_batch_client_with_defaults(self, mock_openai, mock_logger):
|
||
"""测试使用默认参数创建Batch客户端"""
|
||
client = LLMClientFactory.create_client(
|
||
mode=LLMClientMode.BATCH,
|
||
api_key="test-key",
|
||
base_url="https://test.com",
|
||
model="qwen-flash",
|
||
logger=mock_logger,
|
||
)
|
||
|
||
# 验证返回的是BatchClient实例
|
||
assert isinstance(client, BatchClient)
|
||
|
||
@patch("src.core.direct_client.OpenAI")
|
||
def test_create_direct_client_with_defaults(self, mock_openai, mock_logger):
|
||
"""测试使用默认参数创建Direct客户端"""
|
||
client = LLMClientFactory.create_client(
|
||
mode=LLMClientMode.DIRECT,
|
||
api_key="test-key",
|
||
base_url="https://test.com",
|
||
model="qwen-flash",
|
||
logger=mock_logger,
|
||
)
|
||
|
||
# 验证返回的是DirectClient实例
|
||
assert isinstance(client, DirectClient)
|
||
|
||
@patch("src.core.direct_client.OpenAI")
|
||
@patch("src.core.batch_client.OpenAI")
|
||
def test_factory_returns_different_instances(
|
||
self, mock_batch_openai, mock_direct_openai, mock_logger
|
||
):
|
||
"""测试工厂为不同模式返回不同实例"""
|
||
# 创建Batch客户端
|
||
batch_client = LLMClientFactory.create_client(
|
||
mode=LLMClientMode.BATCH,
|
||
api_key="test-key",
|
||
base_url="https://test.com",
|
||
model="qwen-flash",
|
||
logger=mock_logger,
|
||
)
|
||
|
||
# 创建Direct客户端
|
||
direct_client = LLMClientFactory.create_client(
|
||
mode=LLMClientMode.DIRECT,
|
||
api_key="test-key",
|
||
base_url="https://test.com",
|
||
model="qwen-flash",
|
||
logger=mock_logger,
|
||
)
|
||
|
||
# 验证返回的是不同类型的实例
|
||
assert isinstance(batch_client, BatchClient)
|
||
assert isinstance(direct_client, DirectClient)
|
||
assert not isinstance(batch_client, type(direct_client))
|
||
|
||
@patch("src.core.batch_client.OpenAI")
|
||
def test_create_batch_client_custom_params(
|
||
self, mock_openai, mock_logger, mock_file_handler, mock_result_parser
|
||
):
|
||
"""测试使用自定义参数创建Batch客户端"""
|
||
client = LLMClientFactory.create_client(
|
||
mode=LLMClientMode.BATCH,
|
||
api_key="custom-key",
|
||
base_url="https://custom.com",
|
||
model="qwen-max",
|
||
logger=mock_logger,
|
||
file_handler=mock_file_handler,
|
||
result_parser=mock_result_parser,
|
||
completion_window="48h",
|
||
poll_interval=120,
|
||
max_wait_time=7200,
|
||
)
|
||
|
||
# 验证返回的是BatchClient实例
|
||
assert isinstance(client, BatchClient)
|
||
|
||
@patch("src.core.direct_client.OpenAI")
|
||
def test_create_direct_client_custom_params(self, mock_openai, mock_logger):
|
||
"""测试使用自定义参数创建Direct客户端"""
|
||
client = LLMClientFactory.create_client(
|
||
mode=LLMClientMode.DIRECT,
|
||
api_key="custom-key",
|
||
base_url="https://custom.com",
|
||
model="qwen-max",
|
||
logger=mock_logger,
|
||
max_retries=5,
|
||
retry_delay=2,
|
||
)
|
||
|
||
# 验证返回的是DirectClient实例
|
||
assert isinstance(client, DirectClient)
|