"""DirectClient的单元测试""" import pytest import json import logging from unittest.mock import MagicMock, patch, Mock from openai.types.chat import ChatCompletion, ChatCompletionMessage from openai.types.chat.chat_completion import Choice from src.core.direct_client import DirectClient from src.models import ProductInput @pytest.fixture def mock_logger(): """创建mock logger""" return MagicMock(spec=logging.Logger) @pytest.fixture def sample_products(): """创建测试用的产品数据""" return [ ProductInput( product_id="P001", product_name="黄山风景区门票", scenic_spot="黄山" ), ProductInput( product_id="P002", product_name="豪华酒店标准间", scenic_spot="杭州" ), ] @pytest.fixture def sample_requests(): """创建测试用的请求数据""" return [ { "custom_id": "P001", "method": "POST", "url": "/v1/chat/completions", "body": { "model": "qwen-flash", "messages": [ {"role": "system", "content": "你是一个产品分类助手"}, {"role": "user", "content": "产品:黄山风景区门票"}, ], }, }, { "custom_id": "P002", "method": "POST", "url": "/v1/chat/completions", "body": { "model": "qwen-flash", "messages": [ {"role": "system", "content": "你是一个产品分类助手"}, {"role": "user", "content": "产品:豪华酒店标准间"}, ], }, }, ] class TestDirectClient: """测试DirectClient类""" def test_init(self, mock_logger): """测试DirectClient初始化""" client = DirectClient( api_key="test-key", base_url="https://test.com", model="qwen-flash", logger=mock_logger, max_retries=5, retry_delay=2, ) assert client._model == "qwen-flash" assert client._max_retries == 5 assert client._retry_delay == 2 assert client._logger == mock_logger @patch("src.core.direct_client.OpenAI") def test_classify_products_success( self, mock_openai, sample_products, sample_requests, mock_logger ): """测试成功分类产品""" # 创建mock的API响应 mock_response = Mock(spec=ChatCompletion) mock_message = Mock(spec=ChatCompletionMessage) mock_message.content = json.dumps( {"category": "门票", "type": "自然类", "sub_type": "自然风光"} ) mock_choice = Mock(spec=Choice) mock_choice.message = mock_message mock_response.choices = [mock_choice] # 配置mock客户端 mock_client_instance = MagicMock() mock_client_instance.chat.completions.create.return_value = mock_response mock_openai.return_value = mock_client_instance # 创建DirectClient并调用 client = DirectClient( api_key="test-key", base_url="https://test.com", model="qwen-flash", logger=mock_logger, ) results = client.classify_products( sample_products[:1], "test system prompt", sample_requests[:1] ) # 验证结果 assert len(results) == 1 assert results[0].product_id == "P001" assert results[0].category == "门票" assert results[0].type == "自然类" assert results[0].sub_type == "自然风光" # 验证API被调用 mock_client_instance.chat.completions.create.assert_called_once() @patch("src.core.direct_client.OpenAI") def test_classify_products_with_retry( self, mock_openai, sample_products, sample_requests, mock_logger ): """测试带重试的产品分类""" # 第一次调用失败,第二次成功 mock_response = Mock(spec=ChatCompletion) mock_message = Mock(spec=ChatCompletionMessage) mock_message.content = json.dumps( {"category": "门票", "type": "自然类", "sub_type": "自然风光"} ) mock_choice = Mock(spec=Choice) mock_choice.message = mock_message mock_response.choices = [mock_choice] mock_client_instance = MagicMock() mock_client_instance.chat.completions.create.side_effect = [ Exception("API Error"), # 第一次失败 mock_response, # 第二次成功 ] mock_openai.return_value = mock_client_instance client = DirectClient( api_key="test-key", base_url="https://test.com", model="qwen-flash", logger=mock_logger, max_retries=3, retry_delay=0, # 测试时不延迟 ) results = client.classify_products( sample_products[:1], "test system prompt", sample_requests[:1] ) # 验证结果 assert len(results) == 1 assert results[0].product_id == "P001" # 验证API被调用了2次(1次失败+1次成功) assert mock_client_instance.chat.completions.create.call_count == 2 @patch("src.core.direct_client.OpenAI") def test_classify_products_all_retries_failed( self, mock_openai, sample_products, sample_requests, mock_logger ): """测试所有重试都失败的情况""" mock_client_instance = MagicMock() mock_client_instance.chat.completions.create.side_effect = Exception( "API Error" ) mock_openai.return_value = mock_client_instance client = DirectClient( api_key="test-key", base_url="https://test.com", model="qwen-flash", logger=mock_logger, max_retries=2, retry_delay=0, ) results = client.classify_products( sample_products[:1], "test system prompt", sample_requests[:1] ) # 所有重试都失败,返回空列表 assert len(results) == 0 # 验证API被调用了max_retries次 assert mock_client_instance.chat.completions.create.call_count == 2 @patch("src.core.direct_client.OpenAI") def test_parse_response_invalid_json( self, mock_openai, sample_products, sample_requests, mock_logger ): """测试解析无效JSON响应""" # 创建返回无效JSON的mock响应 mock_response = Mock(spec=ChatCompletion) mock_message = Mock(spec=ChatCompletionMessage) mock_message.content = "这不是一个JSON格式的响应" mock_choice = Mock(spec=Choice) mock_choice.message = mock_message mock_response.choices = [mock_choice] mock_client_instance = MagicMock() mock_client_instance.chat.completions.create.return_value = mock_response mock_openai.return_value = mock_client_instance client = DirectClient( api_key="test-key", base_url="https://test.com", model="qwen-flash", logger=mock_logger, ) results = client.classify_products( sample_products[:1], "test system prompt", sample_requests[:1] ) # 无法解析,返回空列表 assert len(results) == 0 @patch("src.core.direct_client.OpenAI") def test_classify_multiple_products( self, mock_openai, sample_products, sample_requests, mock_logger ): """测试批量分类多个产品""" # 为两个产品创建不同的响应 def create_response(category, type_val, sub_type): mock_response = Mock(spec=ChatCompletion) mock_message = Mock(spec=ChatCompletionMessage) mock_message.content = json.dumps( {"category": category, "type": type_val, "sub_type": sub_type} ) mock_choice = Mock(spec=Choice) mock_choice.message = mock_message mock_response.choices = [mock_choice] return mock_response mock_client_instance = MagicMock() mock_client_instance.chat.completions.create.side_effect = [ create_response("门票", "自然类", "自然风光"), create_response("住宿", "商务酒店", ""), ] mock_openai.return_value = mock_client_instance client = DirectClient( api_key="test-key", base_url="https://test.com", model="qwen-flash", logger=mock_logger, ) results = client.classify_products( sample_products, "test system prompt", sample_requests ) # 验证结果 assert len(results) == 2 assert results[0].product_id == "P001" assert results[0].category == "门票" assert results[1].product_id == "P002" assert results[1].category == "住宿" def test_classify_products_missing_request(self, sample_products, mock_logger): """测试产品没有对应请求的情况""" client = DirectClient( api_key="test-key", base_url="https://test.com", model="qwen-flash", logger=mock_logger, ) # 提供空的请求列表 results = client.classify_products(sample_products, "test system prompt", []) # 没有请求,无法处理任何产品 assert len(results) == 0 @patch("src.core.direct_client.OpenAI") def test_classify_products_empty_choices( self, mock_openai, sample_products, sample_requests, mock_logger ): """测试API响应没有choices的情况""" mock_response = Mock(spec=ChatCompletion) mock_response.choices = [] # 空的choices列表 mock_client_instance = MagicMock() mock_client_instance.chat.completions.create.return_value = mock_response mock_openai.return_value = mock_client_instance client = DirectClient( api_key="test-key", base_url="https://test.com", model="qwen-flash", logger=mock_logger, ) results = client.classify_products( sample_products[:1], "test system prompt", sample_requests[:1] ) # 无法解析,返回空列表 assert len(results) == 0