CategorizeLabel/tests/test_models.py

389 lines
12 KiB
Python
Raw Permalink Normal View History

2025-10-15 17:19:26 +08:00
"""测试数据模型
测试所有Pydantic模型的验证和行为
"""
import pytest
from pydantic import ValidationError
from src.models import (
ProductInput,
ProductCategory,
ClassificationResult,
BatchTaskInfo,
BatchRequest,
BatchResponse,
)
class TestProductInput:
"""测试ProductInput模型"""
def test_create_valid_product_input(self):
"""测试创建有效的产品输入"""
# Given: 有效的产品数据
data = {
"product_id": "P001",
"product_name": "黄山风景区门票",
"scenic_spot": "黄山",
}
# When: 创建ProductInput对象
product = ProductInput(**data)
# Then: 所有字段应该正确设置
assert product.product_id == "P001"
assert product.product_name == "黄山风景区门票"
assert product.scenic_spot == "黄山"
def test_product_input_with_empty_strings(self):
"""测试产品输入包含空字符串"""
# Given: 包含空字符串的数据
data = {"product_id": "", "product_name": "", "scenic_spot": ""}
# When: 创建ProductInput对象
product = ProductInput(**data)
# Then: 应该成功创建(空字符串是有效的)
assert product.product_id == ""
assert product.product_name == ""
assert product.scenic_spot == ""
def test_product_input_missing_required_field(self):
"""测试缺少必需字段"""
# Given: 缺少product_name字段
data = {"product_id": "P001", "scenic_spot": "黄山"}
# When & Then: 应该抛出ValidationError
with pytest.raises(ValidationError) as exc_info:
ProductInput(**data)
# 验证错误消息包含缺失字段信息
assert "product_name" in str(exc_info.value)
def test_product_input_to_dict(self):
"""测试转换为字典"""
# Given: 有效的ProductInput对象
product = ProductInput(
product_id="P001", product_name="黄山风景区门票", scenic_spot="黄山"
)
# When: 转换为字典
data = product.model_dump()
# Then: 字典应该包含所有字段
assert data == {
"product_id": "P001",
"product_name": "黄山风景区门票",
"scenic_spot": "黄山",
}
def test_product_input_from_json(self):
"""测试从JSON字符串创建"""
# Given: JSON字符串
json_str = (
'{"product_id":"P001","product_name":"黄山风景区门票","scenic_spot":"黄山"}'
)
# When: 从JSON创建对象
product = ProductInput.model_validate_json(json_str)
# Then: 对象应该正确创建
assert product.product_id == "P001"
assert product.product_name == "黄山风景区门票"
class TestProductCategory:
"""测试ProductCategory模型"""
def test_create_valid_product_category(self):
"""测试创建有效的产品类别"""
# Given: 有效的类别数据
data = {"category": "门票", "type": "自然类", "sub_type": "自然风光"}
# When: 创建ProductCategory对象
category = ProductCategory(**data)
# Then: 所有字段应该正确设置
assert category.category == "门票"
assert category.type == "自然类"
assert category.sub_type == "自然风光"
def test_product_category_with_default_sub_type(self):
"""测试sub_type使用默认值"""
# Given: 不提供sub_type字段
data = {"category": "住宿", "type": "商务酒店"}
# When: 创建ProductCategory对象
category = ProductCategory(**data)
# Then: sub_type应该是默认空字符串
assert category.category == "住宿"
assert category.type == "商务酒店"
assert category.sub_type == ""
def test_product_category_with_empty_sub_type(self):
"""测试显式设置空sub_type"""
# Given: 显式提供空sub_type
data = {"category": "餐饮", "type": "快餐", "sub_type": ""}
# When: 创建ProductCategory对象
category = ProductCategory(**data)
# Then: sub_type应该是空字符串
assert category.sub_type == ""
def test_product_category_missing_required_field(self):
"""测试缺少必需字段"""
# Given: 缺少category字段
data = {"type": "自然类", "sub_type": "自然风光"}
# When & Then: 应该抛出ValidationError
with pytest.raises(ValidationError) as exc_info:
ProductCategory(**data)
assert "category" in str(exc_info.value)
class TestClassificationResult:
"""测试ClassificationResult模型"""
def test_create_valid_classification_result(self):
"""测试创建有效的分类结果"""
# Given: 有效的分类结果数据
data = {
"product_id": "P001",
"category": "门票",
"type": "自然类",
"sub_type": "自然风光",
}
# When: 创建ClassificationResult对象
result = ClassificationResult(**data)
# Then: 所有字段应该正确设置
assert result.product_id == "P001"
assert result.category == "门票"
assert result.type == "自然类"
assert result.sub_type == "自然风光"
def test_classification_result_with_default_sub_type(self):
"""测试分类结果使用默认sub_type"""
# Given: 不提供sub_type字段
data = {"product_id": "P002", "category": "住宿", "type": "商务酒店"}
# When: 创建ClassificationResult对象
result = ClassificationResult(**data)
# Then: sub_type应该是默认空字符串
assert result.sub_type == ""
def test_classification_result_to_json(self):
"""测试转换为JSON"""
# Given: 有效的ClassificationResult对象
result = ClassificationResult(
product_id="P001", category="门票", type="自然类", sub_type="自然风光"
)
# When: 转换为JSON
json_str = result.model_dump_json()
# Then: JSON字符串应该包含所有字段
assert "P001" in json_str
assert "门票" in json_str
assert "自然类" in json_str
class TestBatchTaskInfo:
"""测试BatchTaskInfo模型"""
def test_create_valid_batch_task_info(self):
"""测试创建有效的Batch任务信息"""
# Given: 有效的任务信息
data = {
"task_id": "batch_abc123",
"status": "completed",
"input_file_id": "file-batch-xyz",
"output_file_id": "file-batch_output-xyz",
"error_file_id": None,
}
# When: 创建BatchTaskInfo对象
task = BatchTaskInfo(**data)
# Then: 所有字段应该正确设置
assert task.task_id == "batch_abc123"
assert task.status == "completed"
assert task.input_file_id == "file-batch-xyz"
assert task.output_file_id == "file-batch_output-xyz"
assert task.error_file_id is None
def test_batch_task_info_with_default_optional_fields(self):
"""测试可选字段使用默认值"""
# Given: 只提供必需字段
data = {
"task_id": "batch_def456",
"status": "in_progress",
"input_file_id": "file-batch-abc",
}
# When: 创建BatchTaskInfo对象
task = BatchTaskInfo(**data)
# Then: 可选字段应该是None
assert task.output_file_id is None
assert task.error_file_id is None
def test_batch_task_info_with_error_file(self):
"""测试包含错误文件的任务信息"""
# Given: 任务有错误文件
data = {
"task_id": "batch_error123",
"status": "completed",
"input_file_id": "file-batch-input",
"output_file_id": "file-batch_output",
"error_file_id": "file-batch_error",
}
# When: 创建BatchTaskInfo对象
task = BatchTaskInfo(**data)
# Then: error_file_id应该正确设置
assert task.error_file_id == "file-batch_error"
def test_batch_task_info_different_statuses(self):
"""测试不同的任务状态"""
# Given: 不同状态的任务
statuses = [
"validating",
"in_progress",
"completed",
"failed",
"expired",
"cancelled",
]
for status in statuses:
# When: 创建不同状态的任务
task = BatchTaskInfo(
task_id=f"batch_{status}", status=status, input_file_id="file-input"
)
# Then: 状态应该正确设置
assert task.status == status
class TestBatchRequest:
"""测试BatchRequest模型"""
def test_create_valid_batch_request(self):
"""测试创建有效的Batch请求"""
# Given: 有效的请求数据
data = {
"custom_id": "P001",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": "qwen-flash",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "分类产品"},
],
},
}
# When: 创建BatchRequest对象
request = BatchRequest(**data)
# Then: 所有字段应该正确设置
assert request.custom_id == "P001"
assert request.method == "POST"
assert request.url == "/v1/chat/completions"
assert request.body["model"] == "qwen-flash"
def test_batch_request_with_default_fields(self):
"""测试使用默认字段值"""
# Given: 只提供必需字段
data = {"custom_id": "P002", "body": {"model": "qwen-flash", "messages": []}}
# When: 创建BatchRequest对象
request = BatchRequest(**data)
# Then: 默认字段应该正确设置
assert request.method == "POST"
assert request.url == "/v1/chat/completions"
def test_batch_request_to_dict(self):
"""测试转换为字典"""
# Given: 有效的BatchRequest对象
request = BatchRequest(
custom_id="P003",
method="POST",
url="/v1/chat/completions",
body={"model": "qwen-flash", "messages": []},
)
# When: 转换为字典
data = request.model_dump()
# Then: 字典应该包含所有字段
assert data["custom_id"] == "P003"
assert data["method"] == "POST"
assert data["url"] == "/v1/chat/completions"
assert "body" in data
class TestBatchResponse:
"""测试BatchResponse模型"""
def test_create_valid_batch_response(self):
"""测试创建有效的Batch响应"""
# Given: 有效的响应数据
data = {
"id": "req_123",
"custom_id": "P001",
"response": {
"status_code": 200,
"body": {"choices": [{"message": {"content": "分类结果"}}]},
},
"error": None,
}
# When: 创建BatchResponse对象
response = BatchResponse(**data)
# Then: 所有字段应该正确设置
assert response.id == "req_123"
assert response.custom_id == "P001"
assert response.response["status_code"] == 200
assert response.error is None
def test_batch_response_with_error(self):
"""测试包含错误的响应"""
# Given: 包含错误的响应数据
data = {
"id": "req_456",
"custom_id": "P002",
"response": {},
"error": {"code": "InvalidRequest", "message": "请求参数错误"},
}
# When: 创建BatchResponse对象
response = BatchResponse(**data)
# Then: 错误信息应该正确设置
assert response.error is not None
assert response.error["code"] == "InvalidRequest"
assert response.error["message"] == "请求参数错误"
def test_batch_response_with_default_error(self):
"""测试默认error为None"""
# Given: 不提供error字段
data = {"id": "req_789", "custom_id": "P003", "response": {"status_code": 200}}
# When: 创建BatchResponse对象
response = BatchResponse(**data)
# Then: error应该是None
assert response.error is None