389 lines
12 KiB
Python
389 lines
12 KiB
Python
"""测试数据模型
|
|
|
|
测试所有Pydantic模型的验证和行为
|
|
"""
|
|
|
|
import pytest
|
|
from pydantic import ValidationError
|
|
|
|
from src.models import (
|
|
ProductInput,
|
|
ProductCategory,
|
|
ClassificationResult,
|
|
BatchTaskInfo,
|
|
BatchRequest,
|
|
BatchResponse,
|
|
)
|
|
|
|
|
|
class TestProductInput:
|
|
"""测试ProductInput模型"""
|
|
|
|
def test_create_valid_product_input(self):
|
|
"""测试创建有效的产品输入"""
|
|
# Given: 有效的产品数据
|
|
data = {
|
|
"product_id": "P001",
|
|
"product_name": "黄山风景区门票",
|
|
"scenic_spot": "黄山",
|
|
}
|
|
|
|
# When: 创建ProductInput对象
|
|
product = ProductInput(**data)
|
|
|
|
# Then: 所有字段应该正确设置
|
|
assert product.product_id == "P001"
|
|
assert product.product_name == "黄山风景区门票"
|
|
assert product.scenic_spot == "黄山"
|
|
|
|
def test_product_input_with_empty_strings(self):
|
|
"""测试产品输入包含空字符串"""
|
|
# Given: 包含空字符串的数据
|
|
data = {"product_id": "", "product_name": "", "scenic_spot": ""}
|
|
|
|
# When: 创建ProductInput对象
|
|
product = ProductInput(**data)
|
|
|
|
# Then: 应该成功创建(空字符串是有效的)
|
|
assert product.product_id == ""
|
|
assert product.product_name == ""
|
|
assert product.scenic_spot == ""
|
|
|
|
def test_product_input_missing_required_field(self):
|
|
"""测试缺少必需字段"""
|
|
# Given: 缺少product_name字段
|
|
data = {"product_id": "P001", "scenic_spot": "黄山"}
|
|
|
|
# When & Then: 应该抛出ValidationError
|
|
with pytest.raises(ValidationError) as exc_info:
|
|
ProductInput(**data)
|
|
|
|
# 验证错误消息包含缺失字段信息
|
|
assert "product_name" in str(exc_info.value)
|
|
|
|
def test_product_input_to_dict(self):
|
|
"""测试转换为字典"""
|
|
# Given: 有效的ProductInput对象
|
|
product = ProductInput(
|
|
product_id="P001", product_name="黄山风景区门票", scenic_spot="黄山"
|
|
)
|
|
|
|
# When: 转换为字典
|
|
data = product.model_dump()
|
|
|
|
# Then: 字典应该包含所有字段
|
|
assert data == {
|
|
"product_id": "P001",
|
|
"product_name": "黄山风景区门票",
|
|
"scenic_spot": "黄山",
|
|
}
|
|
|
|
def test_product_input_from_json(self):
|
|
"""测试从JSON字符串创建"""
|
|
# Given: JSON字符串
|
|
json_str = (
|
|
'{"product_id":"P001","product_name":"黄山风景区门票","scenic_spot":"黄山"}'
|
|
)
|
|
|
|
# When: 从JSON创建对象
|
|
product = ProductInput.model_validate_json(json_str)
|
|
|
|
# Then: 对象应该正确创建
|
|
assert product.product_id == "P001"
|
|
assert product.product_name == "黄山风景区门票"
|
|
|
|
|
|
class TestProductCategory:
|
|
"""测试ProductCategory模型"""
|
|
|
|
def test_create_valid_product_category(self):
|
|
"""测试创建有效的产品类别"""
|
|
# Given: 有效的类别数据
|
|
data = {"category": "门票", "type": "自然类", "sub_type": "自然风光"}
|
|
|
|
# When: 创建ProductCategory对象
|
|
category = ProductCategory(**data)
|
|
|
|
# Then: 所有字段应该正确设置
|
|
assert category.category == "门票"
|
|
assert category.type == "自然类"
|
|
assert category.sub_type == "自然风光"
|
|
|
|
def test_product_category_with_default_sub_type(self):
|
|
"""测试sub_type使用默认值"""
|
|
# Given: 不提供sub_type字段
|
|
data = {"category": "住宿", "type": "商务酒店"}
|
|
|
|
# When: 创建ProductCategory对象
|
|
category = ProductCategory(**data)
|
|
|
|
# Then: sub_type应该是默认空字符串
|
|
assert category.category == "住宿"
|
|
assert category.type == "商务酒店"
|
|
assert category.sub_type == ""
|
|
|
|
def test_product_category_with_empty_sub_type(self):
|
|
"""测试显式设置空sub_type"""
|
|
# Given: 显式提供空sub_type
|
|
data = {"category": "餐饮", "type": "快餐", "sub_type": ""}
|
|
|
|
# When: 创建ProductCategory对象
|
|
category = ProductCategory(**data)
|
|
|
|
# Then: sub_type应该是空字符串
|
|
assert category.sub_type == ""
|
|
|
|
def test_product_category_missing_required_field(self):
|
|
"""测试缺少必需字段"""
|
|
# Given: 缺少category字段
|
|
data = {"type": "自然类", "sub_type": "自然风光"}
|
|
|
|
# When & Then: 应该抛出ValidationError
|
|
with pytest.raises(ValidationError) as exc_info:
|
|
ProductCategory(**data)
|
|
|
|
assert "category" in str(exc_info.value)
|
|
|
|
|
|
class TestClassificationResult:
|
|
"""测试ClassificationResult模型"""
|
|
|
|
def test_create_valid_classification_result(self):
|
|
"""测试创建有效的分类结果"""
|
|
# Given: 有效的分类结果数据
|
|
data = {
|
|
"product_id": "P001",
|
|
"category": "门票",
|
|
"type": "自然类",
|
|
"sub_type": "自然风光",
|
|
}
|
|
|
|
# When: 创建ClassificationResult对象
|
|
result = ClassificationResult(**data)
|
|
|
|
# Then: 所有字段应该正确设置
|
|
assert result.product_id == "P001"
|
|
assert result.category == "门票"
|
|
assert result.type == "自然类"
|
|
assert result.sub_type == "自然风光"
|
|
|
|
def test_classification_result_with_default_sub_type(self):
|
|
"""测试分类结果使用默认sub_type"""
|
|
# Given: 不提供sub_type字段
|
|
data = {"product_id": "P002", "category": "住宿", "type": "商务酒店"}
|
|
|
|
# When: 创建ClassificationResult对象
|
|
result = ClassificationResult(**data)
|
|
|
|
# Then: sub_type应该是默认空字符串
|
|
assert result.sub_type == ""
|
|
|
|
def test_classification_result_to_json(self):
|
|
"""测试转换为JSON"""
|
|
# Given: 有效的ClassificationResult对象
|
|
result = ClassificationResult(
|
|
product_id="P001", category="门票", type="自然类", sub_type="自然风光"
|
|
)
|
|
|
|
# When: 转换为JSON
|
|
json_str = result.model_dump_json()
|
|
|
|
# Then: JSON字符串应该包含所有字段
|
|
assert "P001" in json_str
|
|
assert "门票" in json_str
|
|
assert "自然类" in json_str
|
|
|
|
|
|
class TestBatchTaskInfo:
|
|
"""测试BatchTaskInfo模型"""
|
|
|
|
def test_create_valid_batch_task_info(self):
|
|
"""测试创建有效的Batch任务信息"""
|
|
# Given: 有效的任务信息
|
|
data = {
|
|
"task_id": "batch_abc123",
|
|
"status": "completed",
|
|
"input_file_id": "file-batch-xyz",
|
|
"output_file_id": "file-batch_output-xyz",
|
|
"error_file_id": None,
|
|
}
|
|
|
|
# When: 创建BatchTaskInfo对象
|
|
task = BatchTaskInfo(**data)
|
|
|
|
# Then: 所有字段应该正确设置
|
|
assert task.task_id == "batch_abc123"
|
|
assert task.status == "completed"
|
|
assert task.input_file_id == "file-batch-xyz"
|
|
assert task.output_file_id == "file-batch_output-xyz"
|
|
assert task.error_file_id is None
|
|
|
|
def test_batch_task_info_with_default_optional_fields(self):
|
|
"""测试可选字段使用默认值"""
|
|
# Given: 只提供必需字段
|
|
data = {
|
|
"task_id": "batch_def456",
|
|
"status": "in_progress",
|
|
"input_file_id": "file-batch-abc",
|
|
}
|
|
|
|
# When: 创建BatchTaskInfo对象
|
|
task = BatchTaskInfo(**data)
|
|
|
|
# Then: 可选字段应该是None
|
|
assert task.output_file_id is None
|
|
assert task.error_file_id is None
|
|
|
|
def test_batch_task_info_with_error_file(self):
|
|
"""测试包含错误文件的任务信息"""
|
|
# Given: 任务有错误文件
|
|
data = {
|
|
"task_id": "batch_error123",
|
|
"status": "completed",
|
|
"input_file_id": "file-batch-input",
|
|
"output_file_id": "file-batch_output",
|
|
"error_file_id": "file-batch_error",
|
|
}
|
|
|
|
# When: 创建BatchTaskInfo对象
|
|
task = BatchTaskInfo(**data)
|
|
|
|
# Then: error_file_id应该正确设置
|
|
assert task.error_file_id == "file-batch_error"
|
|
|
|
def test_batch_task_info_different_statuses(self):
|
|
"""测试不同的任务状态"""
|
|
# Given: 不同状态的任务
|
|
statuses = [
|
|
"validating",
|
|
"in_progress",
|
|
"completed",
|
|
"failed",
|
|
"expired",
|
|
"cancelled",
|
|
]
|
|
|
|
for status in statuses:
|
|
# When: 创建不同状态的任务
|
|
task = BatchTaskInfo(
|
|
task_id=f"batch_{status}", status=status, input_file_id="file-input"
|
|
)
|
|
|
|
# Then: 状态应该正确设置
|
|
assert task.status == status
|
|
|
|
|
|
class TestBatchRequest:
|
|
"""测试BatchRequest模型"""
|
|
|
|
def test_create_valid_batch_request(self):
|
|
"""测试创建有效的Batch请求"""
|
|
# Given: 有效的请求数据
|
|
data = {
|
|
"custom_id": "P001",
|
|
"method": "POST",
|
|
"url": "/v1/chat/completions",
|
|
"body": {
|
|
"model": "qwen-flash",
|
|
"messages": [
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
{"role": "user", "content": "分类产品"},
|
|
],
|
|
},
|
|
}
|
|
|
|
# When: 创建BatchRequest对象
|
|
request = BatchRequest(**data)
|
|
|
|
# Then: 所有字段应该正确设置
|
|
assert request.custom_id == "P001"
|
|
assert request.method == "POST"
|
|
assert request.url == "/v1/chat/completions"
|
|
assert request.body["model"] == "qwen-flash"
|
|
|
|
def test_batch_request_with_default_fields(self):
|
|
"""测试使用默认字段值"""
|
|
# Given: 只提供必需字段
|
|
data = {"custom_id": "P002", "body": {"model": "qwen-flash", "messages": []}}
|
|
|
|
# When: 创建BatchRequest对象
|
|
request = BatchRequest(**data)
|
|
|
|
# Then: 默认字段应该正确设置
|
|
assert request.method == "POST"
|
|
assert request.url == "/v1/chat/completions"
|
|
|
|
def test_batch_request_to_dict(self):
|
|
"""测试转换为字典"""
|
|
# Given: 有效的BatchRequest对象
|
|
request = BatchRequest(
|
|
custom_id="P003",
|
|
method="POST",
|
|
url="/v1/chat/completions",
|
|
body={"model": "qwen-flash", "messages": []},
|
|
)
|
|
|
|
# When: 转换为字典
|
|
data = request.model_dump()
|
|
|
|
# Then: 字典应该包含所有字段
|
|
assert data["custom_id"] == "P003"
|
|
assert data["method"] == "POST"
|
|
assert data["url"] == "/v1/chat/completions"
|
|
assert "body" in data
|
|
|
|
|
|
class TestBatchResponse:
|
|
"""测试BatchResponse模型"""
|
|
|
|
def test_create_valid_batch_response(self):
|
|
"""测试创建有效的Batch响应"""
|
|
# Given: 有效的响应数据
|
|
data = {
|
|
"id": "req_123",
|
|
"custom_id": "P001",
|
|
"response": {
|
|
"status_code": 200,
|
|
"body": {"choices": [{"message": {"content": "分类结果"}}]},
|
|
},
|
|
"error": None,
|
|
}
|
|
|
|
# When: 创建BatchResponse对象
|
|
response = BatchResponse(**data)
|
|
|
|
# Then: 所有字段应该正确设置
|
|
assert response.id == "req_123"
|
|
assert response.custom_id == "P001"
|
|
assert response.response["status_code"] == 200
|
|
assert response.error is None
|
|
|
|
def test_batch_response_with_error(self):
|
|
"""测试包含错误的响应"""
|
|
# Given: 包含错误的响应数据
|
|
data = {
|
|
"id": "req_456",
|
|
"custom_id": "P002",
|
|
"response": {},
|
|
"error": {"code": "InvalidRequest", "message": "请求参数错误"},
|
|
}
|
|
|
|
# When: 创建BatchResponse对象
|
|
response = BatchResponse(**data)
|
|
|
|
# Then: 错误信息应该正确设置
|
|
assert response.error is not None
|
|
assert response.error["code"] == "InvalidRequest"
|
|
assert response.error["message"] == "请求参数错误"
|
|
|
|
def test_batch_response_with_default_error(self):
|
|
"""测试默认error为None"""
|
|
# Given: 不提供error字段
|
|
data = {"id": "req_789", "custom_id": "P003", "response": {"status_code": 200}}
|
|
|
|
# When: 创建BatchResponse对象
|
|
response = BatchResponse(**data)
|
|
|
|
# Then: error应该是None
|
|
assert response.error is None
|