CategorizeLabel/tests/test_prompt_builder.py

227 lines
8.7 KiB
Python
Raw Normal View History

2025-10-15 17:19:26 +08:00
"""测试Prompt构建模块
测试PromptBuilder类的所有功能
"""
import pytest
import json
from src.core.prompt_builder import PromptBuilder
from src.models import ProductInput, ProductCategory
class TestPromptBuilder:
"""测试PromptBuilder类"""
@pytest.fixture
def sample_categories(self):
"""返回示例类别列表"""
return [
ProductCategory(category="门票", type="自然类", sub_type="自然风光"),
ProductCategory(category="门票", type="文化类", sub_type="历史遗迹"),
ProductCategory(category="住宿", type="商务酒店", sub_type=""),
ProductCategory(category="住宿", type="度假酒店", sub_type=""),
ProductCategory(category="餐饮", type="地方特色", sub_type="徽菜"),
]
@pytest.fixture
def sample_product(self):
"""返回示例产品"""
return ProductInput(
product_id="P001", product_name="黄山风景区门票", scenic_spot="黄山"
)
@pytest.fixture
def prompt_builder(self):
"""返回PromptBuilder实例"""
return PromptBuilder()
def test_build_system_prompt(self, prompt_builder, sample_categories):
"""测试构建系统提示词"""
# When: 构建系统提示词
system_prompt = prompt_builder.build_system_prompt(sample_categories)
# Then: 系统提示词应该包含所有类别信息
assert "门票" in system_prompt
assert "住宿" in system_prompt
assert "餐饮" in system_prompt
assert "自然类" in system_prompt
assert "商务酒店" in system_prompt
assert "JSON" in system_prompt
def test_build_system_prompt_with_empty_categories(self, prompt_builder):
"""测试使用空类别列表构建系统提示词"""
# When: 使用空类别列表
system_prompt = prompt_builder.build_system_prompt([])
# Then: 系统提示词应该包含基本说明
assert "分类助手" in system_prompt or "助手" in system_prompt
assert "JSON" in system_prompt
def test_build_system_prompt_includes_sub_types(
self, prompt_builder, sample_categories
):
"""测试系统提示词包含子类型信息"""
# When: 构建系统提示词
system_prompt = prompt_builder.build_system_prompt(sample_categories)
# Then: 应该包含子类型信息
assert "自然风光" in system_prompt
assert "历史遗迹" in system_prompt
assert "徽菜" in system_prompt
def test_build_user_prompt(self, prompt_builder, sample_product):
"""测试构建用户提示词"""
# When: 构建用户提示词
user_prompt = prompt_builder.build_user_prompt(sample_product)
# Then: 用户提示词应该包含产品信息
assert "P001" in user_prompt
assert "黄山风景区门票" in user_prompt
assert "黄山" in user_prompt
def test_build_user_prompt_with_different_product(self, prompt_builder):
"""测试使用不同产品构建用户提示词"""
# Given: 另一个产品
product = ProductInput(
product_id="P002", product_name="温泉酒店套房", scenic_spot="华清池"
)
# When: 构建用户提示词
user_prompt = prompt_builder.build_user_prompt(product)
# Then: 应该包含新产品的信息
assert "P002" in user_prompt
assert "温泉酒店套房" in user_prompt
assert "华清池" in user_prompt
def test_build_classification_request(
self, prompt_builder, sample_product, sample_categories
):
"""测试构建完整的分类请求"""
# Given: 系统提示词
system_prompt = prompt_builder.build_system_prompt(sample_categories)
# When: 构建分类请求
request = prompt_builder.build_classification_request(
product=sample_product, system_prompt=system_prompt, model="qwen-flash"
)
# Then: 请求应该符合Batch API格式
assert "custom_id" in request
assert request["custom_id"] == "P001"
assert request["method"] == "POST"
assert request["url"] == "/v1/chat/completions"
assert "body" in request
# 验证body结构
body = request["body"]
assert body["model"] == "qwen-flash"
assert "messages" in body
assert len(body["messages"]) == 2
# 验证messages
assert body["messages"][0]["role"] == "system"
assert body["messages"][1]["role"] == "user"
assert "黄山风景区门票" in body["messages"][1]["content"]
def test_build_classification_request_with_different_model(
self, prompt_builder, sample_product
):
"""测试使用不同模型构建请求"""
# Given: 使用不同的模型
system_prompt = "测试系统提示词"
# When: 构建分类请求
request = prompt_builder.build_classification_request(
product=sample_product, system_prompt=system_prompt, model="qwen-plus"
)
# Then: 模型名称应该正确设置
assert request["body"]["model"] == "qwen-plus"
def test_build_classification_request_is_valid_json(
self, prompt_builder, sample_product, sample_categories
):
"""测试构建的请求可以被序列化为JSON"""
# Given: 系统提示词
system_prompt = prompt_builder.build_system_prompt(sample_categories)
# When: 构建分类请求并序列化
request = prompt_builder.build_classification_request(
product=sample_product, system_prompt=system_prompt
)
# Then: 应该可以成功序列化为JSON
json_str = json.dumps(request, ensure_ascii=False)
assert json_str is not None
# 应该可以反序列化
parsed = json.loads(json_str)
assert parsed["custom_id"] == "P001"
def test_build_system_prompt_categories_grouped_by_category(self, prompt_builder):
"""测试系统提示词中类别按照category分组"""
# Given: 多个类别
categories = [
ProductCategory(category="门票", type="类型A", sub_type=""),
ProductCategory(category="门票", type="类型B", sub_type=""),
ProductCategory(category="住宿", type="类型C", sub_type=""),
]
# When: 构建系统提示词
system_prompt = prompt_builder.build_system_prompt(categories)
# Then: 应该包含所有类别和类型
assert "门票" in system_prompt
assert "住宿" in system_prompt
assert "类型A" in system_prompt
assert "类型B" in system_prompt
assert "类型C" in system_prompt
# 验证类型A和类型B在类型C之前因为门票在住宿之前
type_a_pos = system_prompt.find("类型A")
type_b_pos = system_prompt.find("类型B")
type_c_pos = system_prompt.find("类型C")
# 类型A和类型B都应该在类型C之前
assert type_a_pos < type_c_pos
assert type_b_pos < type_c_pos
def test_build_user_prompt_format(self, prompt_builder, sample_product):
"""测试用户提示词格式"""
# When: 构建用户提示词
user_prompt = prompt_builder.build_user_prompt(sample_product)
# Then: 应该包含所有必要字段的标签
assert "产品编号" in user_prompt or "编号" in user_prompt
assert "产品名称" in user_prompt or "名称" in user_prompt
assert "景区" in user_prompt
def test_multiple_products_same_system_prompt(
self, prompt_builder, sample_categories
):
"""测试多个产品可以使用同一个系统提示词"""
# Given: 一个系统提示词
system_prompt = prompt_builder.build_system_prompt(sample_categories)
# Given: 多个产品
products = [
ProductInput(product_id="P001", product_name="产品1", scenic_spot="景区1"),
ProductInput(product_id="P002", product_name="产品2", scenic_spot="景区2"),
ProductInput(product_id="P003", product_name="产品3", scenic_spot="景区3"),
]
# When: 为每个产品构建请求
requests = [
prompt_builder.build_classification_request(product, system_prompt)
for product in products
]
# Then: 所有请求应该使用相同的系统提示词
system_prompts = [req["body"]["messages"][0]["content"] for req in requests]
assert len(set(system_prompts)) == 1 # 所有系统提示词应该相同
# Then: 每个请求的custom_id应该不同
custom_ids = [req["custom_id"] for req in requests]
assert len(set(custom_ids)) == 3 # 所有custom_id应该不同