"""测试数据模型 测试所有Pydantic模型的验证和行为 """ import pytest from pydantic import ValidationError from src.models import ( ProductInput, ProductCategory, ClassificationResult, BatchTaskInfo, BatchRequest, BatchResponse, ) class TestProductInput: """测试ProductInput模型""" def test_create_valid_product_input(self): """测试创建有效的产品输入""" # Given: 有效的产品数据 data = { "product_id": "P001", "product_name": "黄山风景区门票", "scenic_spot": "黄山", } # When: 创建ProductInput对象 product = ProductInput(**data) # Then: 所有字段应该正确设置 assert product.product_id == "P001" assert product.product_name == "黄山风景区门票" assert product.scenic_spot == "黄山" def test_product_input_with_empty_strings(self): """测试产品输入包含空字符串""" # Given: 包含空字符串的数据 data = {"product_id": "", "product_name": "", "scenic_spot": ""} # When: 创建ProductInput对象 product = ProductInput(**data) # Then: 应该成功创建(空字符串是有效的) assert product.product_id == "" assert product.product_name == "" assert product.scenic_spot == "" def test_product_input_missing_required_field(self): """测试缺少必需字段""" # Given: 缺少product_name字段 data = {"product_id": "P001", "scenic_spot": "黄山"} # When & Then: 应该抛出ValidationError with pytest.raises(ValidationError) as exc_info: ProductInput(**data) # 验证错误消息包含缺失字段信息 assert "product_name" in str(exc_info.value) def test_product_input_to_dict(self): """测试转换为字典""" # Given: 有效的ProductInput对象 product = ProductInput( product_id="P001", product_name="黄山风景区门票", scenic_spot="黄山" ) # When: 转换为字典 data = product.model_dump() # Then: 字典应该包含所有字段 assert data == { "product_id": "P001", "product_name": "黄山风景区门票", "scenic_spot": "黄山", } def test_product_input_from_json(self): """测试从JSON字符串创建""" # Given: JSON字符串 json_str = ( '{"product_id":"P001","product_name":"黄山风景区门票","scenic_spot":"黄山"}' ) # When: 从JSON创建对象 product = ProductInput.model_validate_json(json_str) # Then: 对象应该正确创建 assert product.product_id == "P001" assert product.product_name == "黄山风景区门票" class TestProductCategory: """测试ProductCategory模型""" def test_create_valid_product_category(self): """测试创建有效的产品类别""" # Given: 有效的类别数据 data = {"category": "门票", "type": "自然类", "sub_type": "自然风光"} # When: 创建ProductCategory对象 category = ProductCategory(**data) # Then: 所有字段应该正确设置 assert category.category == "门票" assert category.type == "自然类" assert category.sub_type == "自然风光" def test_product_category_with_default_sub_type(self): """测试sub_type使用默认值""" # Given: 不提供sub_type字段 data = {"category": "住宿", "type": "商务酒店"} # When: 创建ProductCategory对象 category = ProductCategory(**data) # Then: sub_type应该是默认空字符串 assert category.category == "住宿" assert category.type == "商务酒店" assert category.sub_type == "" def test_product_category_with_empty_sub_type(self): """测试显式设置空sub_type""" # Given: 显式提供空sub_type data = {"category": "餐饮", "type": "快餐", "sub_type": ""} # When: 创建ProductCategory对象 category = ProductCategory(**data) # Then: sub_type应该是空字符串 assert category.sub_type == "" def test_product_category_missing_required_field(self): """测试缺少必需字段""" # Given: 缺少category字段 data = {"type": "自然类", "sub_type": "自然风光"} # When & Then: 应该抛出ValidationError with pytest.raises(ValidationError) as exc_info: ProductCategory(**data) assert "category" in str(exc_info.value) class TestClassificationResult: """测试ClassificationResult模型""" def test_create_valid_classification_result(self): """测试创建有效的分类结果""" # Given: 有效的分类结果数据 data = { "product_id": "P001", "category": "门票", "type": "自然类", "sub_type": "自然风光", } # When: 创建ClassificationResult对象 result = ClassificationResult(**data) # Then: 所有字段应该正确设置 assert result.product_id == "P001" assert result.category == "门票" assert result.type == "自然类" assert result.sub_type == "自然风光" def test_classification_result_with_default_sub_type(self): """测试分类结果使用默认sub_type""" # Given: 不提供sub_type字段 data = {"product_id": "P002", "category": "住宿", "type": "商务酒店"} # When: 创建ClassificationResult对象 result = ClassificationResult(**data) # Then: sub_type应该是默认空字符串 assert result.sub_type == "" def test_classification_result_to_json(self): """测试转换为JSON""" # Given: 有效的ClassificationResult对象 result = ClassificationResult( product_id="P001", category="门票", type="自然类", sub_type="自然风光" ) # When: 转换为JSON json_str = result.model_dump_json() # Then: JSON字符串应该包含所有字段 assert "P001" in json_str assert "门票" in json_str assert "自然类" in json_str class TestBatchTaskInfo: """测试BatchTaskInfo模型""" def test_create_valid_batch_task_info(self): """测试创建有效的Batch任务信息""" # Given: 有效的任务信息 data = { "task_id": "batch_abc123", "status": "completed", "input_file_id": "file-batch-xyz", "output_file_id": "file-batch_output-xyz", "error_file_id": None, } # When: 创建BatchTaskInfo对象 task = BatchTaskInfo(**data) # Then: 所有字段应该正确设置 assert task.task_id == "batch_abc123" assert task.status == "completed" assert task.input_file_id == "file-batch-xyz" assert task.output_file_id == "file-batch_output-xyz" assert task.error_file_id is None def test_batch_task_info_with_default_optional_fields(self): """测试可选字段使用默认值""" # Given: 只提供必需字段 data = { "task_id": "batch_def456", "status": "in_progress", "input_file_id": "file-batch-abc", } # When: 创建BatchTaskInfo对象 task = BatchTaskInfo(**data) # Then: 可选字段应该是None assert task.output_file_id is None assert task.error_file_id is None def test_batch_task_info_with_error_file(self): """测试包含错误文件的任务信息""" # Given: 任务有错误文件 data = { "task_id": "batch_error123", "status": "completed", "input_file_id": "file-batch-input", "output_file_id": "file-batch_output", "error_file_id": "file-batch_error", } # When: 创建BatchTaskInfo对象 task = BatchTaskInfo(**data) # Then: error_file_id应该正确设置 assert task.error_file_id == "file-batch_error" def test_batch_task_info_different_statuses(self): """测试不同的任务状态""" # Given: 不同状态的任务 statuses = [ "validating", "in_progress", "completed", "failed", "expired", "cancelled", ] for status in statuses: # When: 创建不同状态的任务 task = BatchTaskInfo( task_id=f"batch_{status}", status=status, input_file_id="file-input" ) # Then: 状态应该正确设置 assert task.status == status class TestBatchRequest: """测试BatchRequest模型""" def test_create_valid_batch_request(self): """测试创建有效的Batch请求""" # Given: 有效的请求数据 data = { "custom_id": "P001", "method": "POST", "url": "/v1/chat/completions", "body": { "model": "qwen-flash", "messages": [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "分类产品"}, ], }, } # When: 创建BatchRequest对象 request = BatchRequest(**data) # Then: 所有字段应该正确设置 assert request.custom_id == "P001" assert request.method == "POST" assert request.url == "/v1/chat/completions" assert request.body["model"] == "qwen-flash" def test_batch_request_with_default_fields(self): """测试使用默认字段值""" # Given: 只提供必需字段 data = {"custom_id": "P002", "body": {"model": "qwen-flash", "messages": []}} # When: 创建BatchRequest对象 request = BatchRequest(**data) # Then: 默认字段应该正确设置 assert request.method == "POST" assert request.url == "/v1/chat/completions" def test_batch_request_to_dict(self): """测试转换为字典""" # Given: 有效的BatchRequest对象 request = BatchRequest( custom_id="P003", method="POST", url="/v1/chat/completions", body={"model": "qwen-flash", "messages": []}, ) # When: 转换为字典 data = request.model_dump() # Then: 字典应该包含所有字段 assert data["custom_id"] == "P003" assert data["method"] == "POST" assert data["url"] == "/v1/chat/completions" assert "body" in data class TestBatchResponse: """测试BatchResponse模型""" def test_create_valid_batch_response(self): """测试创建有效的Batch响应""" # Given: 有效的响应数据 data = { "id": "req_123", "custom_id": "P001", "response": { "status_code": 200, "body": {"choices": [{"message": {"content": "分类结果"}}]}, }, "error": None, } # When: 创建BatchResponse对象 response = BatchResponse(**data) # Then: 所有字段应该正确设置 assert response.id == "req_123" assert response.custom_id == "P001" assert response.response["status_code"] == 200 assert response.error is None def test_batch_response_with_error(self): """测试包含错误的响应""" # Given: 包含错误的响应数据 data = { "id": "req_456", "custom_id": "P002", "response": {}, "error": {"code": "InvalidRequest", "message": "请求参数错误"}, } # When: 创建BatchResponse对象 response = BatchResponse(**data) # Then: 错误信息应该正确设置 assert response.error is not None assert response.error["code"] == "InvalidRequest" assert response.error["message"] == "请求参数错误" def test_batch_response_with_default_error(self): """测试默认error为None""" # Given: 不提供error字段 data = {"id": "req_789", "custom_id": "P003", "response": {"status_code": 200}} # When: 创建BatchResponse对象 response = BatchResponse(**data) # Then: error应该是None assert response.error is None