CategorizeLabel/tests/test_file_handler.py
2025-10-15 17:19:26 +08:00

457 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""测试文件处理模块
测试所有文件处理相关的类和函数
"""
import pytest
import json
import os
import tempfile
from unittest.mock import MagicMock
from src.core.file_handler import (
ProductReader,
CategoryReader,
BatchFileWriter,
ResultParser,
ResultWriter,
FileHandler,
)
from src.models import ProductInput, ProductCategory, ClassificationResult
class TestProductReader:
"""测试ProductReader类"""
@pytest.fixture
def valid_excel_path(self):
"""返回有效的Excel文件路径"""
return "tests/fixtures/sample_products.xlsx"
def test_read_valid_excel(self, valid_excel_path):
"""测试读取有效的Excel文件"""
# Given: 有效的Excel文件
reader = ProductReader()
# When: 读取文件
products = reader.read(valid_excel_path)
# Then: 应该返回产品列表
assert len(products) == 3
assert products[0].product_id == "P001"
assert products[0].product_name == "黄山风景区门票"
assert products[0].scenic_spot == "黄山"
def test_read_nonexistent_file(self):
"""测试读取不存在的文件"""
# Given: 不存在的文件路径
reader = ProductReader()
# When & Then: 应该抛出FileNotFoundError
with pytest.raises(FileNotFoundError) as exc_info:
reader.read("nonexistent.xlsx")
assert "不存在" in str(exc_info.value)
def test_read_excel_missing_columns(self):
"""测试读取缺少必要列的Excel文件"""
# Given: 缺少必要列的Excel文件
import pandas as pd
reader = ProductReader()
# 创建临时Excel文件
with tempfile.NamedTemporaryFile(mode="w", suffix=".xlsx", delete=False) as f:
temp_path = f.name
try:
# 创建缺少"产品名称"列的DataFrame
df = pd.DataFrame({"产品编号": ["P001"], "景区名称": ["黄山"]})
df.to_excel(temp_path, index=False)
# When & Then: 应该抛出ValueError
with pytest.raises(ValueError) as exc_info:
reader.read(temp_path)
assert "缺少必要列" in str(exc_info.value)
finally:
# 清理临时文件
if os.path.exists(temp_path):
os.remove(temp_path)
class TestCategoryReader:
"""测试CategoryReader类"""
@pytest.fixture
def valid_jsonl_path(self):
"""返回有效的JSONL文件路径"""
return "tests/fixtures/sample_categories.jsonl"
def test_read_valid_jsonl(self, valid_jsonl_path):
"""测试读取有效的JSONL文件"""
# Given: 有效的JSONL文件
reader = CategoryReader()
# When: 读取文件
categories = reader.read(valid_jsonl_path)
# Then: 应该返回类别列表
assert len(categories) == 3
assert categories[0].category == "门票"
assert categories[0].type == "自然类"
assert categories[0].sub_type == "自然风光"
def test_read_nonexistent_file(self):
"""测试读取不存在的文件"""
# Given: 不存在的文件路径
reader = CategoryReader()
# When & Then: 应该抛出FileNotFoundError
with pytest.raises(FileNotFoundError) as exc_info:
reader.read("nonexistent.jsonl")
assert "不存在" in str(exc_info.value)
def test_read_invalid_json(self):
"""测试读取无效JSON格式的文件"""
# Given: 包含无效JSON的文件
reader = CategoryReader()
# 创建临时文件
with tempfile.NamedTemporaryFile(
mode="w", suffix=".jsonl", delete=False, encoding="utf-8"
) as f:
f.write('{"category":"门票","type":"自然类"\n') # 缺少闭合括号
temp_path = f.name
try:
# When & Then: 应该抛出ValueError
with pytest.raises(ValueError) as exc_info:
reader.read(temp_path)
assert "JSON格式错误" in str(exc_info.value)
finally:
# 清理临时文件
if os.path.exists(temp_path):
os.remove(temp_path)
def test_read_with_empty_lines(self):
"""测试读取包含空行的JSONL文件"""
# Given: 包含空行的JSONL文件
reader = CategoryReader()
# 创建临时文件
with tempfile.NamedTemporaryFile(
mode="w", suffix=".jsonl", delete=False, encoding="utf-8"
) as f:
f.write('{"category":"门票","type":"自然类","sub_type":"自然风光"}\n')
f.write("\n") # 空行
f.write('{"category":"住宿","type":"商务酒店","sub_type":""}\n')
temp_path = f.name
try:
# When: 读取文件
categories = reader.read(temp_path)
# Then: 应该跳过空行
assert len(categories) == 2
finally:
# 清理临时文件
if os.path.exists(temp_path):
os.remove(temp_path)
class TestBatchFileWriter:
"""测试BatchFileWriter类"""
def test_write_batch_requests(self):
"""测试写入Batch请求"""
# Given: Batch请求列表
writer = BatchFileWriter()
requests = [
{
"custom_id": "P001",
"method": "POST",
"url": "/v1/chat/completions",
"body": {"model": "qwen-flash", "messages": []},
},
{
"custom_id": "P002",
"method": "POST",
"url": "/v1/chat/completions",
"body": {"model": "qwen-flash", "messages": []},
},
]
# 创建临时文件
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
temp_path = f.name
try:
# When: 写入文件
result_path = writer.write(requests, temp_path)
# Then: 文件应该存在且格式正确
assert os.path.exists(result_path)
with open(result_path, "r", encoding="utf-8") as f:
lines = f.readlines()
assert len(lines) == 2
# 验证每行都是有效的JSON
for line in lines:
data = json.loads(line)
assert "custom_id" in data
assert "method" in data
finally:
# 清理临时文件
if os.path.exists(temp_path):
os.remove(temp_path)
def test_write_to_nested_directory(self):
"""测试写入到嵌套目录"""
# Given: 嵌套目录路径
writer = BatchFileWriter()
requests = [
{
"custom_id": "P001",
"method": "POST",
"url": "/v1/chat/completions",
"body": {},
}
]
# 创建临时目录
with tempfile.TemporaryDirectory() as temp_dir:
nested_path = os.path.join(temp_dir, "sub1", "sub2", "test.jsonl")
# When: 写入文件
result_path = writer.write(requests, nested_path)
# Then: 目录应该被创建且文件存在
assert os.path.exists(result_path)
class TestResultParser:
"""测试ResultParser类"""
@pytest.fixture
def valid_batch_response(self):
"""返回有效的Batch响应内容"""
return """{"id":"req_001","custom_id":"P001","response":{"status_code":200,"body":{"choices":[{"message":{"content":"{\\"category\\":\\"门票\\",\\"type\\":\\"自然类\\",\\"sub_type\\":\\"自然风光\\"}"}}]}},"error":null}
{"id":"req_002","custom_id":"P002","response":{"status_code":200,"body":{"choices":[{"message":{"content":"{\\"category\\":\\"住宿\\",\\"type\\":\\"商务酒店\\",\\"sub_type\\":\\"\\"}"}}]}},"error":null}"""
def test_parse_valid_response(self, valid_batch_response):
"""测试解析有效的Batch响应"""
# Given: 有效的响应内容
parser = ResultParser()
# When: 解析响应
results = parser.parse(valid_batch_response)
# Then: 应该返回分类结果列表
assert len(results) == 2
assert results[0].product_id == "P001"
assert results[0].category == "门票"
assert results[0].type == "自然类"
assert results[0].sub_type == "自然风光"
def test_parse_response_with_errors(self):
"""测试解析包含错误的响应"""
# Given: 包含错误的响应
parser = ResultParser()
response_with_error = """{"id":"req_001","custom_id":"P001","response":{},"error":{"code":"InvalidRequest","message":"错误"}}
{"id":"req_002","custom_id":"P002","response":{"status_code":200,"body":{"choices":[{"message":{"content":"{\\"category\\":\\"住宿\\",\\"type\\":\\"商务酒店\\",\\"sub_type\\":\\"\\"}"}}]}},"error":null}"""
# When: 解析响应
results = parser.parse(response_with_error)
# Then: 应该跳过错误的响应,只返回成功的
assert len(results) == 1
assert results[0].product_id == "P002"
def test_parse_empty_content(self):
"""测试解析空内容"""
# Given: 空内容
parser = ResultParser()
# When: 解析空内容
results = parser.parse("")
# Then: 应该返回空列表
assert len(results) == 0
def test_parse_response_with_non_json_content(self):
"""测试解析包含非JSON内容的响应"""
# Given: content不是JSON的响应
parser = ResultParser()
response = """{"id":"req_001","custom_id":"P001","response":{"status_code":200,"body":{"choices":[{"message":{"content":"这是普通文本不是JSON"}}]}},"error":null}"""
# When: 解析响应
results = parser.parse(response)
# Then: 应该跳过非JSON内容
assert len(results) == 0
class TestResultWriter:
"""测试ResultWriter类"""
def test_write_results(self):
"""测试写入分类结果"""
# Given: 分类结果列表
writer = ResultWriter()
results = [
ClassificationResult(
product_id="P001", category="门票", type="自然类", sub_type="自然风光"
),
ClassificationResult(
product_id="P002", category="住宿", type="商务酒店", sub_type=""
),
]
# 创建临时文件
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
temp_path = f.name
try:
# When: 写入结果
writer.write(results, temp_path)
# Then: 文件应该存在且格式正确
assert os.path.exists(temp_path)
with open(temp_path, "r", encoding="utf-8") as f:
lines = f.readlines()
assert len(lines) == 2
# 验证每行都是有效的JSON
for line in lines:
data = json.loads(line)
assert "product_id" in data
assert "category" in data
finally:
# 清理临时文件
if os.path.exists(temp_path):
os.remove(temp_path)
def test_write_empty_results(self):
"""测试写入空结果列表"""
# Given: 空结果列表
writer = ResultWriter()
results = []
# 创建临时文件
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
temp_path = f.name
try:
# When: 写入空结果
writer.write(results, temp_path)
# Then: 文件应该存在但为空
assert os.path.exists(temp_path)
with open(temp_path, "r", encoding="utf-8") as f:
content = f.read()
assert content == ""
finally:
# 清理临时文件
if os.path.exists(temp_path):
os.remove(temp_path)
class TestFileHandler:
"""测试FileHandler组合类"""
@pytest.fixture
def mock_handlers(self):
"""创建模拟的处理器"""
return {
"product_reader": MagicMock(spec=ProductReader),
"category_reader": MagicMock(spec=CategoryReader),
"batch_file_writer": MagicMock(spec=BatchFileWriter),
"result_parser": MagicMock(spec=ResultParser),
"result_writer": MagicMock(spec=ResultWriter),
}
def test_read_products_delegates_correctly(self, mock_handlers):
"""测试read_products正确委托给product_reader"""
# Given: FileHandler with mock product_reader
handler = FileHandler(**mock_handlers)
mock_products = [
ProductInput(product_id="P001", product_name="测试", scenic_spot="测试景区")
]
mock_handlers["product_reader"].read.return_value = mock_products
# When: 调用read_products
result = handler.read_products("test.xlsx")
# Then: 应该调用product_reader.read并返回结果
mock_handlers["product_reader"].read.assert_called_once_with("test.xlsx")
assert result == mock_products
def test_read_categories_delegates_correctly(self, mock_handlers):
"""测试read_categories正确委托给category_reader"""
# Given: FileHandler with mock category_reader
handler = FileHandler(**mock_handlers)
mock_categories = [ProductCategory(category="门票", type="自然类", sub_type="")]
mock_handlers["category_reader"].read.return_value = mock_categories
# When: 调用read_categories
result = handler.read_categories("test.jsonl")
# Then: 应该调用category_reader.read并返回结果
mock_handlers["category_reader"].read.assert_called_once_with("test.jsonl")
assert result == mock_categories
def test_write_batch_requests_delegates_correctly(self, mock_handlers):
"""测试write_batch_requests正确委托给batch_file_writer"""
# Given: FileHandler with mock batch_file_writer
handler = FileHandler(**mock_handlers)
mock_requests = [{"custom_id": "P001"}]
mock_handlers["batch_file_writer"].write.return_value = "output.jsonl"
# When: 调用write_batch_requests
result = handler.write_batch_requests(mock_requests, "output.jsonl")
# Then: 应该调用batch_file_writer.write并返回结果
mock_handlers["batch_file_writer"].write.assert_called_once_with(
mock_requests, "output.jsonl"
)
assert result == "output.jsonl"
def test_parse_batch_responses_delegates_correctly(self, mock_handlers):
"""测试parse_batch_responses正确委托给result_parser"""
# Given: FileHandler with mock result_parser
handler = FileHandler(**mock_handlers)
mock_results = [
ClassificationResult(
product_id="P001", category="门票", type="自然类", sub_type=""
)
]
mock_handlers["result_parser"].parse.return_value = mock_results
# When: 调用parse_batch_responses
result = handler.parse_batch_responses("response content")
# Then: 应该调用result_parser.parse并返回结果
mock_handlers["result_parser"].parse.assert_called_once_with("response content")
assert result == mock_results
def test_write_results_delegates_correctly(self, mock_handlers):
"""测试write_results正确委托给result_writer"""
# Given: FileHandler with mock result_writer
handler = FileHandler(**mock_handlers)
mock_results = [
ClassificationResult(
product_id="P001", category="门票", type="自然类", sub_type=""
)
]
# When: 调用write_results
handler.write_results(mock_results, "output.jsonl")
# Then: 应该调用result_writer.write
mock_handlers["result_writer"].write.assert_called_once_with(
mock_results, "output.jsonl"
)