CategorizeLabel/tests/test_file_handler.py

457 lines
16 KiB
Python
Raw Normal View History

2025-10-15 17:19:26 +08:00
"""测试文件处理模块
测试所有文件处理相关的类和函数
"""
import pytest
import json
import os
import tempfile
from unittest.mock import MagicMock
from src.core.file_handler import (
ProductReader,
CategoryReader,
BatchFileWriter,
ResultParser,
ResultWriter,
FileHandler,
)
from src.models import ProductInput, ProductCategory, ClassificationResult
class TestProductReader:
"""测试ProductReader类"""
@pytest.fixture
def valid_excel_path(self):
"""返回有效的Excel文件路径"""
return "tests/fixtures/sample_products.xlsx"
def test_read_valid_excel(self, valid_excel_path):
"""测试读取有效的Excel文件"""
# Given: 有效的Excel文件
reader = ProductReader()
# When: 读取文件
products = reader.read(valid_excel_path)
# Then: 应该返回产品列表
assert len(products) == 3
assert products[0].product_id == "P001"
assert products[0].product_name == "黄山风景区门票"
assert products[0].scenic_spot == "黄山"
def test_read_nonexistent_file(self):
"""测试读取不存在的文件"""
# Given: 不存在的文件路径
reader = ProductReader()
# When & Then: 应该抛出FileNotFoundError
with pytest.raises(FileNotFoundError) as exc_info:
reader.read("nonexistent.xlsx")
assert "不存在" in str(exc_info.value)
def test_read_excel_missing_columns(self):
"""测试读取缺少必要列的Excel文件"""
# Given: 缺少必要列的Excel文件
import pandas as pd
reader = ProductReader()
# 创建临时Excel文件
with tempfile.NamedTemporaryFile(mode="w", suffix=".xlsx", delete=False) as f:
temp_path = f.name
try:
# 创建缺少"产品名称"列的DataFrame
df = pd.DataFrame({"产品编号": ["P001"], "景区名称": ["黄山"]})
df.to_excel(temp_path, index=False)
# When & Then: 应该抛出ValueError
with pytest.raises(ValueError) as exc_info:
reader.read(temp_path)
assert "缺少必要列" in str(exc_info.value)
finally:
# 清理临时文件
if os.path.exists(temp_path):
os.remove(temp_path)
class TestCategoryReader:
"""测试CategoryReader类"""
@pytest.fixture
def valid_jsonl_path(self):
"""返回有效的JSONL文件路径"""
return "tests/fixtures/sample_categories.jsonl"
def test_read_valid_jsonl(self, valid_jsonl_path):
"""测试读取有效的JSONL文件"""
# Given: 有效的JSONL文件
reader = CategoryReader()
# When: 读取文件
categories = reader.read(valid_jsonl_path)
# Then: 应该返回类别列表
assert len(categories) == 3
assert categories[0].category == "门票"
assert categories[0].type == "自然类"
assert categories[0].sub_type == "自然风光"
def test_read_nonexistent_file(self):
"""测试读取不存在的文件"""
# Given: 不存在的文件路径
reader = CategoryReader()
# When & Then: 应该抛出FileNotFoundError
with pytest.raises(FileNotFoundError) as exc_info:
reader.read("nonexistent.jsonl")
assert "不存在" in str(exc_info.value)
def test_read_invalid_json(self):
"""测试读取无效JSON格式的文件"""
# Given: 包含无效JSON的文件
reader = CategoryReader()
# 创建临时文件
with tempfile.NamedTemporaryFile(
mode="w", suffix=".jsonl", delete=False, encoding="utf-8"
) as f:
f.write('{"category":"门票","type":"自然类"\n') # 缺少闭合括号
temp_path = f.name
try:
# When & Then: 应该抛出ValueError
with pytest.raises(ValueError) as exc_info:
reader.read(temp_path)
assert "JSON格式错误" in str(exc_info.value)
finally:
# 清理临时文件
if os.path.exists(temp_path):
os.remove(temp_path)
def test_read_with_empty_lines(self):
"""测试读取包含空行的JSONL文件"""
# Given: 包含空行的JSONL文件
reader = CategoryReader()
# 创建临时文件
with tempfile.NamedTemporaryFile(
mode="w", suffix=".jsonl", delete=False, encoding="utf-8"
) as f:
f.write('{"category":"门票","type":"自然类","sub_type":"自然风光"}\n')
f.write("\n") # 空行
f.write('{"category":"住宿","type":"商务酒店","sub_type":""}\n')
temp_path = f.name
try:
# When: 读取文件
categories = reader.read(temp_path)
# Then: 应该跳过空行
assert len(categories) == 2
finally:
# 清理临时文件
if os.path.exists(temp_path):
os.remove(temp_path)
class TestBatchFileWriter:
"""测试BatchFileWriter类"""
def test_write_batch_requests(self):
"""测试写入Batch请求"""
# Given: Batch请求列表
writer = BatchFileWriter()
requests = [
{
"custom_id": "P001",
"method": "POST",
"url": "/v1/chat/completions",
"body": {"model": "qwen-flash", "messages": []},
},
{
"custom_id": "P002",
"method": "POST",
"url": "/v1/chat/completions",
"body": {"model": "qwen-flash", "messages": []},
},
]
# 创建临时文件
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
temp_path = f.name
try:
# When: 写入文件
result_path = writer.write(requests, temp_path)
# Then: 文件应该存在且格式正确
assert os.path.exists(result_path)
with open(result_path, "r", encoding="utf-8") as f:
lines = f.readlines()
assert len(lines) == 2
# 验证每行都是有效的JSON
for line in lines:
data = json.loads(line)
assert "custom_id" in data
assert "method" in data
finally:
# 清理临时文件
if os.path.exists(temp_path):
os.remove(temp_path)
def test_write_to_nested_directory(self):
"""测试写入到嵌套目录"""
# Given: 嵌套目录路径
writer = BatchFileWriter()
requests = [
{
"custom_id": "P001",
"method": "POST",
"url": "/v1/chat/completions",
"body": {},
}
]
# 创建临时目录
with tempfile.TemporaryDirectory() as temp_dir:
nested_path = os.path.join(temp_dir, "sub1", "sub2", "test.jsonl")
# When: 写入文件
result_path = writer.write(requests, nested_path)
# Then: 目录应该被创建且文件存在
assert os.path.exists(result_path)
class TestResultParser:
"""测试ResultParser类"""
@pytest.fixture
def valid_batch_response(self):
"""返回有效的Batch响应内容"""
return """{"id":"req_001","custom_id":"P001","response":{"status_code":200,"body":{"choices":[{"message":{"content":"{\\"category\\":\\"门票\\",\\"type\\":\\"自然类\\",\\"sub_type\\":\\"自然风光\\"}"}}]}},"error":null}
{"id":"req_002","custom_id":"P002","response":{"status_code":200,"body":{"choices":[{"message":{"content":"{\\"category\\":\\"住宿\\",\\"type\\":\\"商务酒店\\",\\"sub_type\\":\\"\\"}"}}]}},"error":null}"""
def test_parse_valid_response(self, valid_batch_response):
"""测试解析有效的Batch响应"""
# Given: 有效的响应内容
parser = ResultParser()
# When: 解析响应
results = parser.parse(valid_batch_response)
# Then: 应该返回分类结果列表
assert len(results) == 2
assert results[0].product_id == "P001"
assert results[0].category == "门票"
assert results[0].type == "自然类"
assert results[0].sub_type == "自然风光"
def test_parse_response_with_errors(self):
"""测试解析包含错误的响应"""
# Given: 包含错误的响应
parser = ResultParser()
response_with_error = """{"id":"req_001","custom_id":"P001","response":{},"error":{"code":"InvalidRequest","message":"错误"}}
{"id":"req_002","custom_id":"P002","response":{"status_code":200,"body":{"choices":[{"message":{"content":"{\\"category\\":\\"住宿\\",\\"type\\":\\"商务酒店\\",\\"sub_type\\":\\"\\"}"}}]}},"error":null}"""
# When: 解析响应
results = parser.parse(response_with_error)
# Then: 应该跳过错误的响应,只返回成功的
assert len(results) == 1
assert results[0].product_id == "P002"
def test_parse_empty_content(self):
"""测试解析空内容"""
# Given: 空内容
parser = ResultParser()
# When: 解析空内容
results = parser.parse("")
# Then: 应该返回空列表
assert len(results) == 0
def test_parse_response_with_non_json_content(self):
"""测试解析包含非JSON内容的响应"""
# Given: content不是JSON的响应
parser = ResultParser()
response = """{"id":"req_001","custom_id":"P001","response":{"status_code":200,"body":{"choices":[{"message":{"content":"这是普通文本不是JSON"}}]}},"error":null}"""
# When: 解析响应
results = parser.parse(response)
# Then: 应该跳过非JSON内容
assert len(results) == 0
class TestResultWriter:
"""测试ResultWriter类"""
def test_write_results(self):
"""测试写入分类结果"""
# Given: 分类结果列表
writer = ResultWriter()
results = [
ClassificationResult(
product_id="P001", category="门票", type="自然类", sub_type="自然风光"
),
ClassificationResult(
product_id="P002", category="住宿", type="商务酒店", sub_type=""
),
]
# 创建临时文件
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
temp_path = f.name
try:
# When: 写入结果
writer.write(results, temp_path)
# Then: 文件应该存在且格式正确
assert os.path.exists(temp_path)
with open(temp_path, "r", encoding="utf-8") as f:
lines = f.readlines()
assert len(lines) == 2
# 验证每行都是有效的JSON
for line in lines:
data = json.loads(line)
assert "product_id" in data
assert "category" in data
finally:
# 清理临时文件
if os.path.exists(temp_path):
os.remove(temp_path)
def test_write_empty_results(self):
"""测试写入空结果列表"""
# Given: 空结果列表
writer = ResultWriter()
results = []
# 创建临时文件
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
temp_path = f.name
try:
# When: 写入空结果
writer.write(results, temp_path)
# Then: 文件应该存在但为空
assert os.path.exists(temp_path)
with open(temp_path, "r", encoding="utf-8") as f:
content = f.read()
assert content == ""
finally:
# 清理临时文件
if os.path.exists(temp_path):
os.remove(temp_path)
class TestFileHandler:
"""测试FileHandler组合类"""
@pytest.fixture
def mock_handlers(self):
"""创建模拟的处理器"""
return {
"product_reader": MagicMock(spec=ProductReader),
"category_reader": MagicMock(spec=CategoryReader),
"batch_file_writer": MagicMock(spec=BatchFileWriter),
"result_parser": MagicMock(spec=ResultParser),
"result_writer": MagicMock(spec=ResultWriter),
}
def test_read_products_delegates_correctly(self, mock_handlers):
"""测试read_products正确委托给product_reader"""
# Given: FileHandler with mock product_reader
handler = FileHandler(**mock_handlers)
mock_products = [
ProductInput(product_id="P001", product_name="测试", scenic_spot="测试景区")
]
mock_handlers["product_reader"].read.return_value = mock_products
# When: 调用read_products
result = handler.read_products("test.xlsx")
# Then: 应该调用product_reader.read并返回结果
mock_handlers["product_reader"].read.assert_called_once_with("test.xlsx")
assert result == mock_products
def test_read_categories_delegates_correctly(self, mock_handlers):
"""测试read_categories正确委托给category_reader"""
# Given: FileHandler with mock category_reader
handler = FileHandler(**mock_handlers)
mock_categories = [ProductCategory(category="门票", type="自然类", sub_type="")]
mock_handlers["category_reader"].read.return_value = mock_categories
# When: 调用read_categories
result = handler.read_categories("test.jsonl")
# Then: 应该调用category_reader.read并返回结果
mock_handlers["category_reader"].read.assert_called_once_with("test.jsonl")
assert result == mock_categories
def test_write_batch_requests_delegates_correctly(self, mock_handlers):
"""测试write_batch_requests正确委托给batch_file_writer"""
# Given: FileHandler with mock batch_file_writer
handler = FileHandler(**mock_handlers)
mock_requests = [{"custom_id": "P001"}]
mock_handlers["batch_file_writer"].write.return_value = "output.jsonl"
# When: 调用write_batch_requests
result = handler.write_batch_requests(mock_requests, "output.jsonl")
# Then: 应该调用batch_file_writer.write并返回结果
mock_handlers["batch_file_writer"].write.assert_called_once_with(
mock_requests, "output.jsonl"
)
assert result == "output.jsonl"
def test_parse_batch_responses_delegates_correctly(self, mock_handlers):
"""测试parse_batch_responses正确委托给result_parser"""
# Given: FileHandler with mock result_parser
handler = FileHandler(**mock_handlers)
mock_results = [
ClassificationResult(
product_id="P001", category="门票", type="自然类", sub_type=""
)
]
mock_handlers["result_parser"].parse.return_value = mock_results
# When: 调用parse_batch_responses
result = handler.parse_batch_responses("response content")
# Then: 应该调用result_parser.parse并返回结果
mock_handlers["result_parser"].parse.assert_called_once_with("response content")
assert result == mock_results
def test_write_results_delegates_correctly(self, mock_handlers):
"""测试write_results正确委托给result_writer"""
# Given: FileHandler with mock result_writer
handler = FileHandler(**mock_handlers)
mock_results = [
ClassificationResult(
product_id="P001", category="门票", type="自然类", sub_type=""
)
]
# When: 调用write_results
handler.write_results(mock_results, "output.jsonl")
# Then: 应该调用result_writer.write
mock_handlers["result_writer"].write.assert_called_once_with(
mock_results, "output.jsonl"
)