457 lines
16 KiB
Python
457 lines
16 KiB
Python
"""测试文件处理模块
|
||
|
||
测试所有文件处理相关的类和函数
|
||
"""
|
||
|
||
import pytest
|
||
import json
|
||
import os
|
||
import tempfile
|
||
from unittest.mock import MagicMock
|
||
|
||
from src.core.file_handler import (
|
||
ProductReader,
|
||
CategoryReader,
|
||
BatchFileWriter,
|
||
ResultParser,
|
||
ResultWriter,
|
||
FileHandler,
|
||
)
|
||
from src.models import ProductInput, ProductCategory, ClassificationResult
|
||
|
||
|
||
class TestProductReader:
|
||
"""测试ProductReader类"""
|
||
|
||
@pytest.fixture
|
||
def valid_excel_path(self):
|
||
"""返回有效的Excel文件路径"""
|
||
return "tests/fixtures/sample_products.xlsx"
|
||
|
||
def test_read_valid_excel(self, valid_excel_path):
|
||
"""测试读取有效的Excel文件"""
|
||
# Given: 有效的Excel文件
|
||
reader = ProductReader()
|
||
|
||
# When: 读取文件
|
||
products = reader.read(valid_excel_path)
|
||
|
||
# Then: 应该返回产品列表
|
||
assert len(products) == 3
|
||
assert products[0].product_id == "P001"
|
||
assert products[0].product_name == "黄山风景区门票"
|
||
assert products[0].scenic_spot == "黄山"
|
||
|
||
def test_read_nonexistent_file(self):
|
||
"""测试读取不存在的文件"""
|
||
# Given: 不存在的文件路径
|
||
reader = ProductReader()
|
||
|
||
# When & Then: 应该抛出FileNotFoundError
|
||
with pytest.raises(FileNotFoundError) as exc_info:
|
||
reader.read("nonexistent.xlsx")
|
||
assert "不存在" in str(exc_info.value)
|
||
|
||
def test_read_excel_missing_columns(self):
|
||
"""测试读取缺少必要列的Excel文件"""
|
||
# Given: 缺少必要列的Excel文件
|
||
import pandas as pd
|
||
|
||
reader = ProductReader()
|
||
|
||
# 创建临时Excel文件
|
||
with tempfile.NamedTemporaryFile(mode="w", suffix=".xlsx", delete=False) as f:
|
||
temp_path = f.name
|
||
|
||
try:
|
||
# 创建缺少"产品名称"列的DataFrame
|
||
df = pd.DataFrame({"产品编号": ["P001"], "景区名称": ["黄山"]})
|
||
df.to_excel(temp_path, index=False)
|
||
|
||
# When & Then: 应该抛出ValueError
|
||
with pytest.raises(ValueError) as exc_info:
|
||
reader.read(temp_path)
|
||
assert "缺少必要列" in str(exc_info.value)
|
||
|
||
finally:
|
||
# 清理临时文件
|
||
if os.path.exists(temp_path):
|
||
os.remove(temp_path)
|
||
|
||
|
||
class TestCategoryReader:
|
||
"""测试CategoryReader类"""
|
||
|
||
@pytest.fixture
|
||
def valid_jsonl_path(self):
|
||
"""返回有效的JSONL文件路径"""
|
||
return "tests/fixtures/sample_categories.jsonl"
|
||
|
||
def test_read_valid_jsonl(self, valid_jsonl_path):
|
||
"""测试读取有效的JSONL文件"""
|
||
# Given: 有效的JSONL文件
|
||
reader = CategoryReader()
|
||
|
||
# When: 读取文件
|
||
categories = reader.read(valid_jsonl_path)
|
||
|
||
# Then: 应该返回类别列表
|
||
assert len(categories) == 3
|
||
assert categories[0].category == "门票"
|
||
assert categories[0].type == "自然类"
|
||
assert categories[0].sub_type == "自然风光"
|
||
|
||
def test_read_nonexistent_file(self):
|
||
"""测试读取不存在的文件"""
|
||
# Given: 不存在的文件路径
|
||
reader = CategoryReader()
|
||
|
||
# When & Then: 应该抛出FileNotFoundError
|
||
with pytest.raises(FileNotFoundError) as exc_info:
|
||
reader.read("nonexistent.jsonl")
|
||
assert "不存在" in str(exc_info.value)
|
||
|
||
def test_read_invalid_json(self):
|
||
"""测试读取无效JSON格式的文件"""
|
||
# Given: 包含无效JSON的文件
|
||
reader = CategoryReader()
|
||
|
||
# 创建临时文件
|
||
with tempfile.NamedTemporaryFile(
|
||
mode="w", suffix=".jsonl", delete=False, encoding="utf-8"
|
||
) as f:
|
||
f.write('{"category":"门票","type":"自然类"\n') # 缺少闭合括号
|
||
temp_path = f.name
|
||
|
||
try:
|
||
# When & Then: 应该抛出ValueError
|
||
with pytest.raises(ValueError) as exc_info:
|
||
reader.read(temp_path)
|
||
assert "JSON格式错误" in str(exc_info.value)
|
||
|
||
finally:
|
||
# 清理临时文件
|
||
if os.path.exists(temp_path):
|
||
os.remove(temp_path)
|
||
|
||
def test_read_with_empty_lines(self):
|
||
"""测试读取包含空行的JSONL文件"""
|
||
# Given: 包含空行的JSONL文件
|
||
reader = CategoryReader()
|
||
|
||
# 创建临时文件
|
||
with tempfile.NamedTemporaryFile(
|
||
mode="w", suffix=".jsonl", delete=False, encoding="utf-8"
|
||
) as f:
|
||
f.write('{"category":"门票","type":"自然类","sub_type":"自然风光"}\n')
|
||
f.write("\n") # 空行
|
||
f.write('{"category":"住宿","type":"商务酒店","sub_type":""}\n')
|
||
temp_path = f.name
|
||
|
||
try:
|
||
# When: 读取文件
|
||
categories = reader.read(temp_path)
|
||
|
||
# Then: 应该跳过空行
|
||
assert len(categories) == 2
|
||
|
||
finally:
|
||
# 清理临时文件
|
||
if os.path.exists(temp_path):
|
||
os.remove(temp_path)
|
||
|
||
|
||
class TestBatchFileWriter:
|
||
"""测试BatchFileWriter类"""
|
||
|
||
def test_write_batch_requests(self):
|
||
"""测试写入Batch请求"""
|
||
# Given: Batch请求列表
|
||
writer = BatchFileWriter()
|
||
requests = [
|
||
{
|
||
"custom_id": "P001",
|
||
"method": "POST",
|
||
"url": "/v1/chat/completions",
|
||
"body": {"model": "qwen-flash", "messages": []},
|
||
},
|
||
{
|
||
"custom_id": "P002",
|
||
"method": "POST",
|
||
"url": "/v1/chat/completions",
|
||
"body": {"model": "qwen-flash", "messages": []},
|
||
},
|
||
]
|
||
|
||
# 创建临时文件
|
||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||
temp_path = f.name
|
||
|
||
try:
|
||
# When: 写入文件
|
||
result_path = writer.write(requests, temp_path)
|
||
|
||
# Then: 文件应该存在且格式正确
|
||
assert os.path.exists(result_path)
|
||
with open(result_path, "r", encoding="utf-8") as f:
|
||
lines = f.readlines()
|
||
assert len(lines) == 2
|
||
# 验证每行都是有效的JSON
|
||
for line in lines:
|
||
data = json.loads(line)
|
||
assert "custom_id" in data
|
||
assert "method" in data
|
||
|
||
finally:
|
||
# 清理临时文件
|
||
if os.path.exists(temp_path):
|
||
os.remove(temp_path)
|
||
|
||
def test_write_to_nested_directory(self):
|
||
"""测试写入到嵌套目录"""
|
||
# Given: 嵌套目录路径
|
||
writer = BatchFileWriter()
|
||
requests = [
|
||
{
|
||
"custom_id": "P001",
|
||
"method": "POST",
|
||
"url": "/v1/chat/completions",
|
||
"body": {},
|
||
}
|
||
]
|
||
|
||
# 创建临时目录
|
||
with tempfile.TemporaryDirectory() as temp_dir:
|
||
nested_path = os.path.join(temp_dir, "sub1", "sub2", "test.jsonl")
|
||
|
||
# When: 写入文件
|
||
result_path = writer.write(requests, nested_path)
|
||
|
||
# Then: 目录应该被创建且文件存在
|
||
assert os.path.exists(result_path)
|
||
|
||
|
||
class TestResultParser:
|
||
"""测试ResultParser类"""
|
||
|
||
@pytest.fixture
|
||
def valid_batch_response(self):
|
||
"""返回有效的Batch响应内容"""
|
||
return """{"id":"req_001","custom_id":"P001","response":{"status_code":200,"body":{"choices":[{"message":{"content":"{\\"category\\":\\"门票\\",\\"type\\":\\"自然类\\",\\"sub_type\\":\\"自然风光\\"}"}}]}},"error":null}
|
||
{"id":"req_002","custom_id":"P002","response":{"status_code":200,"body":{"choices":[{"message":{"content":"{\\"category\\":\\"住宿\\",\\"type\\":\\"商务酒店\\",\\"sub_type\\":\\"\\"}"}}]}},"error":null}"""
|
||
|
||
def test_parse_valid_response(self, valid_batch_response):
|
||
"""测试解析有效的Batch响应"""
|
||
# Given: 有效的响应内容
|
||
parser = ResultParser()
|
||
|
||
# When: 解析响应
|
||
results = parser.parse(valid_batch_response)
|
||
|
||
# Then: 应该返回分类结果列表
|
||
assert len(results) == 2
|
||
assert results[0].product_id == "P001"
|
||
assert results[0].category == "门票"
|
||
assert results[0].type == "自然类"
|
||
assert results[0].sub_type == "自然风光"
|
||
|
||
def test_parse_response_with_errors(self):
|
||
"""测试解析包含错误的响应"""
|
||
# Given: 包含错误的响应
|
||
parser = ResultParser()
|
||
response_with_error = """{"id":"req_001","custom_id":"P001","response":{},"error":{"code":"InvalidRequest","message":"错误"}}
|
||
{"id":"req_002","custom_id":"P002","response":{"status_code":200,"body":{"choices":[{"message":{"content":"{\\"category\\":\\"住宿\\",\\"type\\":\\"商务酒店\\",\\"sub_type\\":\\"\\"}"}}]}},"error":null}"""
|
||
|
||
# When: 解析响应
|
||
results = parser.parse(response_with_error)
|
||
|
||
# Then: 应该跳过错误的响应,只返回成功的
|
||
assert len(results) == 1
|
||
assert results[0].product_id == "P002"
|
||
|
||
def test_parse_empty_content(self):
|
||
"""测试解析空内容"""
|
||
# Given: 空内容
|
||
parser = ResultParser()
|
||
|
||
# When: 解析空内容
|
||
results = parser.parse("")
|
||
|
||
# Then: 应该返回空列表
|
||
assert len(results) == 0
|
||
|
||
def test_parse_response_with_non_json_content(self):
|
||
"""测试解析包含非JSON内容的响应"""
|
||
# Given: content不是JSON的响应
|
||
parser = ResultParser()
|
||
response = """{"id":"req_001","custom_id":"P001","response":{"status_code":200,"body":{"choices":[{"message":{"content":"这是普通文本,不是JSON"}}]}},"error":null}"""
|
||
|
||
# When: 解析响应
|
||
results = parser.parse(response)
|
||
|
||
# Then: 应该跳过非JSON内容
|
||
assert len(results) == 0
|
||
|
||
|
||
class TestResultWriter:
|
||
"""测试ResultWriter类"""
|
||
|
||
def test_write_results(self):
|
||
"""测试写入分类结果"""
|
||
# Given: 分类结果列表
|
||
writer = ResultWriter()
|
||
results = [
|
||
ClassificationResult(
|
||
product_id="P001", category="门票", type="自然类", sub_type="自然风光"
|
||
),
|
||
ClassificationResult(
|
||
product_id="P002", category="住宿", type="商务酒店", sub_type=""
|
||
),
|
||
]
|
||
|
||
# 创建临时文件
|
||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||
temp_path = f.name
|
||
|
||
try:
|
||
# When: 写入结果
|
||
writer.write(results, temp_path)
|
||
|
||
# Then: 文件应该存在且格式正确
|
||
assert os.path.exists(temp_path)
|
||
with open(temp_path, "r", encoding="utf-8") as f:
|
||
lines = f.readlines()
|
||
assert len(lines) == 2
|
||
# 验证每行都是有效的JSON
|
||
for line in lines:
|
||
data = json.loads(line)
|
||
assert "product_id" in data
|
||
assert "category" in data
|
||
|
||
finally:
|
||
# 清理临时文件
|
||
if os.path.exists(temp_path):
|
||
os.remove(temp_path)
|
||
|
||
def test_write_empty_results(self):
|
||
"""测试写入空结果列表"""
|
||
# Given: 空结果列表
|
||
writer = ResultWriter()
|
||
results = []
|
||
|
||
# 创建临时文件
|
||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||
temp_path = f.name
|
||
|
||
try:
|
||
# When: 写入空结果
|
||
writer.write(results, temp_path)
|
||
|
||
# Then: 文件应该存在但为空
|
||
assert os.path.exists(temp_path)
|
||
with open(temp_path, "r", encoding="utf-8") as f:
|
||
content = f.read()
|
||
assert content == ""
|
||
|
||
finally:
|
||
# 清理临时文件
|
||
if os.path.exists(temp_path):
|
||
os.remove(temp_path)
|
||
|
||
|
||
class TestFileHandler:
|
||
"""测试FileHandler组合类"""
|
||
|
||
@pytest.fixture
|
||
def mock_handlers(self):
|
||
"""创建模拟的处理器"""
|
||
return {
|
||
"product_reader": MagicMock(spec=ProductReader),
|
||
"category_reader": MagicMock(spec=CategoryReader),
|
||
"batch_file_writer": MagicMock(spec=BatchFileWriter),
|
||
"result_parser": MagicMock(spec=ResultParser),
|
||
"result_writer": MagicMock(spec=ResultWriter),
|
||
}
|
||
|
||
def test_read_products_delegates_correctly(self, mock_handlers):
|
||
"""测试read_products正确委托给product_reader"""
|
||
# Given: FileHandler with mock product_reader
|
||
handler = FileHandler(**mock_handlers)
|
||
mock_products = [
|
||
ProductInput(product_id="P001", product_name="测试", scenic_spot="测试景区")
|
||
]
|
||
mock_handlers["product_reader"].read.return_value = mock_products
|
||
|
||
# When: 调用read_products
|
||
result = handler.read_products("test.xlsx")
|
||
|
||
# Then: 应该调用product_reader.read并返回结果
|
||
mock_handlers["product_reader"].read.assert_called_once_with("test.xlsx")
|
||
assert result == mock_products
|
||
|
||
def test_read_categories_delegates_correctly(self, mock_handlers):
|
||
"""测试read_categories正确委托给category_reader"""
|
||
# Given: FileHandler with mock category_reader
|
||
handler = FileHandler(**mock_handlers)
|
||
mock_categories = [ProductCategory(category="门票", type="自然类", sub_type="")]
|
||
mock_handlers["category_reader"].read.return_value = mock_categories
|
||
|
||
# When: 调用read_categories
|
||
result = handler.read_categories("test.jsonl")
|
||
|
||
# Then: 应该调用category_reader.read并返回结果
|
||
mock_handlers["category_reader"].read.assert_called_once_with("test.jsonl")
|
||
assert result == mock_categories
|
||
|
||
def test_write_batch_requests_delegates_correctly(self, mock_handlers):
|
||
"""测试write_batch_requests正确委托给batch_file_writer"""
|
||
# Given: FileHandler with mock batch_file_writer
|
||
handler = FileHandler(**mock_handlers)
|
||
mock_requests = [{"custom_id": "P001"}]
|
||
mock_handlers["batch_file_writer"].write.return_value = "output.jsonl"
|
||
|
||
# When: 调用write_batch_requests
|
||
result = handler.write_batch_requests(mock_requests, "output.jsonl")
|
||
|
||
# Then: 应该调用batch_file_writer.write并返回结果
|
||
mock_handlers["batch_file_writer"].write.assert_called_once_with(
|
||
mock_requests, "output.jsonl"
|
||
)
|
||
assert result == "output.jsonl"
|
||
|
||
def test_parse_batch_responses_delegates_correctly(self, mock_handlers):
|
||
"""测试parse_batch_responses正确委托给result_parser"""
|
||
# Given: FileHandler with mock result_parser
|
||
handler = FileHandler(**mock_handlers)
|
||
mock_results = [
|
||
ClassificationResult(
|
||
product_id="P001", category="门票", type="自然类", sub_type=""
|
||
)
|
||
]
|
||
mock_handlers["result_parser"].parse.return_value = mock_results
|
||
|
||
# When: 调用parse_batch_responses
|
||
result = handler.parse_batch_responses("response content")
|
||
|
||
# Then: 应该调用result_parser.parse并返回结果
|
||
mock_handlers["result_parser"].parse.assert_called_once_with("response content")
|
||
assert result == mock_results
|
||
|
||
def test_write_results_delegates_correctly(self, mock_handlers):
|
||
"""测试write_results正确委托给result_writer"""
|
||
# Given: FileHandler with mock result_writer
|
||
handler = FileHandler(**mock_handlers)
|
||
mock_results = [
|
||
ClassificationResult(
|
||
product_id="P001", category="门票", type="自然类", sub_type=""
|
||
)
|
||
]
|
||
|
||
# When: 调用write_results
|
||
handler.write_results(mock_results, "output.jsonl")
|
||
|
||
# Then: 应该调用result_writer.write
|
||
mock_handlers["result_writer"].write.assert_called_once_with(
|
||
mock_results, "output.jsonl"
|
||
)
|