CategorizeLabel/tests/test_client_factory.py

191 lines
6.1 KiB
Python
Raw Normal View History

2025-10-15 17:19:26 +08:00
"""ClientFactory的单元测试"""
import pytest
import logging
from unittest.mock import MagicMock, patch
from src.core.client_factory import LLMClientFactory
from src.models import LLMClientMode
from src.core.batch_client import BatchClient
from src.core.direct_client import DirectClient
@pytest.fixture
def mock_logger():
"""创建mock logger"""
return MagicMock(spec=logging.Logger)
@pytest.fixture
def mock_file_handler():
"""创建mock file handler"""
return MagicMock()
@pytest.fixture
def mock_result_parser():
"""创建mock result parser"""
return MagicMock()
class TestLLMClientFactory:
"""测试LLMClientFactory类"""
@patch("src.core.batch_client.OpenAI")
def test_create_batch_client(
self, mock_openai, mock_logger, mock_file_handler, mock_result_parser
):
"""测试创建Batch模式客户端"""
# 创建客户端
client = LLMClientFactory.create_client(
mode=LLMClientMode.BATCH,
api_key="test-key",
base_url="https://test.com",
model="qwen-flash",
logger=mock_logger,
file_handler=mock_file_handler,
result_parser=mock_result_parser,
completion_window="24h",
poll_interval=60,
max_wait_time=3600,
)
# 验证返回的是BatchClient实例
assert isinstance(client, BatchClient)
# 验证日志被记录
assert mock_logger.info.call_count >= 2
@patch("src.core.direct_client.OpenAI")
def test_create_direct_client(self, mock_openai, mock_logger):
"""测试创建Direct模式客户端"""
# 创建客户端
client = LLMClientFactory.create_client(
mode=LLMClientMode.DIRECT,
api_key="test-key",
base_url="https://test.com",
model="qwen-flash",
logger=mock_logger,
max_retries=3,
retry_delay=1,
)
# 验证返回的是DirectClient实例
assert isinstance(client, DirectClient)
# 验证日志被记录
assert mock_logger.info.call_count >= 2
def test_create_client_invalid_mode(self, mock_logger):
"""测试创建客户端时传入无效模式"""
# 尝试创建一个无效模式的客户端
# 创建一个假的枚举值对象有value属性
class FakeMode:
value = "invalid"
with pytest.raises(ValueError) as exc_info:
LLMClientFactory.create_client(
mode=FakeMode(), # type: ignore
api_key="test-key",
base_url="https://test.com",
model="qwen-flash",
logger=mock_logger,
)
assert "不支持的LLM客户端模式" in str(exc_info.value)
@patch("src.core.batch_client.OpenAI")
def test_create_batch_client_with_defaults(self, mock_openai, mock_logger):
"""测试使用默认参数创建Batch客户端"""
client = LLMClientFactory.create_client(
mode=LLMClientMode.BATCH,
api_key="test-key",
base_url="https://test.com",
model="qwen-flash",
logger=mock_logger,
)
# 验证返回的是BatchClient实例
assert isinstance(client, BatchClient)
@patch("src.core.direct_client.OpenAI")
def test_create_direct_client_with_defaults(self, mock_openai, mock_logger):
"""测试使用默认参数创建Direct客户端"""
client = LLMClientFactory.create_client(
mode=LLMClientMode.DIRECT,
api_key="test-key",
base_url="https://test.com",
model="qwen-flash",
logger=mock_logger,
)
# 验证返回的是DirectClient实例
assert isinstance(client, DirectClient)
@patch("src.core.direct_client.OpenAI")
@patch("src.core.batch_client.OpenAI")
def test_factory_returns_different_instances(
self, mock_batch_openai, mock_direct_openai, mock_logger
):
"""测试工厂为不同模式返回不同实例"""
# 创建Batch客户端
batch_client = LLMClientFactory.create_client(
mode=LLMClientMode.BATCH,
api_key="test-key",
base_url="https://test.com",
model="qwen-flash",
logger=mock_logger,
)
# 创建Direct客户端
direct_client = LLMClientFactory.create_client(
mode=LLMClientMode.DIRECT,
api_key="test-key",
base_url="https://test.com",
model="qwen-flash",
logger=mock_logger,
)
# 验证返回的是不同类型的实例
assert isinstance(batch_client, BatchClient)
assert isinstance(direct_client, DirectClient)
assert not isinstance(batch_client, type(direct_client))
@patch("src.core.batch_client.OpenAI")
def test_create_batch_client_custom_params(
self, mock_openai, mock_logger, mock_file_handler, mock_result_parser
):
"""测试使用自定义参数创建Batch客户端"""
client = LLMClientFactory.create_client(
mode=LLMClientMode.BATCH,
api_key="custom-key",
base_url="https://custom.com",
model="qwen-max",
logger=mock_logger,
file_handler=mock_file_handler,
result_parser=mock_result_parser,
completion_window="48h",
poll_interval=120,
max_wait_time=7200,
)
# 验证返回的是BatchClient实例
assert isinstance(client, BatchClient)
@patch("src.core.direct_client.OpenAI")
def test_create_direct_client_custom_params(self, mock_openai, mock_logger):
"""测试使用自定义参数创建Direct客户端"""
client = LLMClientFactory.create_client(
mode=LLMClientMode.DIRECT,
api_key="custom-key",
base_url="https://custom.com",
model="qwen-max",
logger=mock_logger,
max_retries=5,
retry_delay=2,
)
# 验证返回的是DirectClient实例
assert isinstance(client, DirectClient)