CategorizeLabel/tests/test_client_factory.py
2025-10-15 17:19:26 +08:00

191 lines
6.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""ClientFactory的单元测试"""
import pytest
import logging
from unittest.mock import MagicMock, patch
from src.core.client_factory import LLMClientFactory
from src.models import LLMClientMode
from src.core.batch_client import BatchClient
from src.core.direct_client import DirectClient
@pytest.fixture
def mock_logger():
"""创建mock logger"""
return MagicMock(spec=logging.Logger)
@pytest.fixture
def mock_file_handler():
"""创建mock file handler"""
return MagicMock()
@pytest.fixture
def mock_result_parser():
"""创建mock result parser"""
return MagicMock()
class TestLLMClientFactory:
"""测试LLMClientFactory类"""
@patch("src.core.batch_client.OpenAI")
def test_create_batch_client(
self, mock_openai, mock_logger, mock_file_handler, mock_result_parser
):
"""测试创建Batch模式客户端"""
# 创建客户端
client = LLMClientFactory.create_client(
mode=LLMClientMode.BATCH,
api_key="test-key",
base_url="https://test.com",
model="qwen-flash",
logger=mock_logger,
file_handler=mock_file_handler,
result_parser=mock_result_parser,
completion_window="24h",
poll_interval=60,
max_wait_time=3600,
)
# 验证返回的是BatchClient实例
assert isinstance(client, BatchClient)
# 验证日志被记录
assert mock_logger.info.call_count >= 2
@patch("src.core.direct_client.OpenAI")
def test_create_direct_client(self, mock_openai, mock_logger):
"""测试创建Direct模式客户端"""
# 创建客户端
client = LLMClientFactory.create_client(
mode=LLMClientMode.DIRECT,
api_key="test-key",
base_url="https://test.com",
model="qwen-flash",
logger=mock_logger,
max_retries=3,
retry_delay=1,
)
# 验证返回的是DirectClient实例
assert isinstance(client, DirectClient)
# 验证日志被记录
assert mock_logger.info.call_count >= 2
def test_create_client_invalid_mode(self, mock_logger):
"""测试创建客户端时传入无效模式"""
# 尝试创建一个无效模式的客户端
# 创建一个假的枚举值对象有value属性
class FakeMode:
value = "invalid"
with pytest.raises(ValueError) as exc_info:
LLMClientFactory.create_client(
mode=FakeMode(), # type: ignore
api_key="test-key",
base_url="https://test.com",
model="qwen-flash",
logger=mock_logger,
)
assert "不支持的LLM客户端模式" in str(exc_info.value)
@patch("src.core.batch_client.OpenAI")
def test_create_batch_client_with_defaults(self, mock_openai, mock_logger):
"""测试使用默认参数创建Batch客户端"""
client = LLMClientFactory.create_client(
mode=LLMClientMode.BATCH,
api_key="test-key",
base_url="https://test.com",
model="qwen-flash",
logger=mock_logger,
)
# 验证返回的是BatchClient实例
assert isinstance(client, BatchClient)
@patch("src.core.direct_client.OpenAI")
def test_create_direct_client_with_defaults(self, mock_openai, mock_logger):
"""测试使用默认参数创建Direct客户端"""
client = LLMClientFactory.create_client(
mode=LLMClientMode.DIRECT,
api_key="test-key",
base_url="https://test.com",
model="qwen-flash",
logger=mock_logger,
)
# 验证返回的是DirectClient实例
assert isinstance(client, DirectClient)
@patch("src.core.direct_client.OpenAI")
@patch("src.core.batch_client.OpenAI")
def test_factory_returns_different_instances(
self, mock_batch_openai, mock_direct_openai, mock_logger
):
"""测试工厂为不同模式返回不同实例"""
# 创建Batch客户端
batch_client = LLMClientFactory.create_client(
mode=LLMClientMode.BATCH,
api_key="test-key",
base_url="https://test.com",
model="qwen-flash",
logger=mock_logger,
)
# 创建Direct客户端
direct_client = LLMClientFactory.create_client(
mode=LLMClientMode.DIRECT,
api_key="test-key",
base_url="https://test.com",
model="qwen-flash",
logger=mock_logger,
)
# 验证返回的是不同类型的实例
assert isinstance(batch_client, BatchClient)
assert isinstance(direct_client, DirectClient)
assert not isinstance(batch_client, type(direct_client))
@patch("src.core.batch_client.OpenAI")
def test_create_batch_client_custom_params(
self, mock_openai, mock_logger, mock_file_handler, mock_result_parser
):
"""测试使用自定义参数创建Batch客户端"""
client = LLMClientFactory.create_client(
mode=LLMClientMode.BATCH,
api_key="custom-key",
base_url="https://custom.com",
model="qwen-max",
logger=mock_logger,
file_handler=mock_file_handler,
result_parser=mock_result_parser,
completion_window="48h",
poll_interval=120,
max_wait_time=7200,
)
# 验证返回的是BatchClient实例
assert isinstance(client, BatchClient)
@patch("src.core.direct_client.OpenAI")
def test_create_direct_client_custom_params(self, mock_openai, mock_logger):
"""测试使用自定义参数创建Direct客户端"""
client = LLMClientFactory.create_client(
mode=LLMClientMode.DIRECT,
api_key="custom-key",
base_url="https://custom.com",
model="qwen-max",
logger=mock_logger,
max_retries=5,
retry_delay=2,
)
# 验证返回的是DirectClient实例
assert isinstance(client, DirectClient)