CategorizeLabel/tests/test_prompt_builder.py
2025-10-15 17:19:26 +08:00

227 lines
8.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""测试Prompt构建模块
测试PromptBuilder类的所有功能
"""
import pytest
import json
from src.core.prompt_builder import PromptBuilder
from src.models import ProductInput, ProductCategory
class TestPromptBuilder:
"""测试PromptBuilder类"""
@pytest.fixture
def sample_categories(self):
"""返回示例类别列表"""
return [
ProductCategory(category="门票", type="自然类", sub_type="自然风光"),
ProductCategory(category="门票", type="文化类", sub_type="历史遗迹"),
ProductCategory(category="住宿", type="商务酒店", sub_type=""),
ProductCategory(category="住宿", type="度假酒店", sub_type=""),
ProductCategory(category="餐饮", type="地方特色", sub_type="徽菜"),
]
@pytest.fixture
def sample_product(self):
"""返回示例产品"""
return ProductInput(
product_id="P001", product_name="黄山风景区门票", scenic_spot="黄山"
)
@pytest.fixture
def prompt_builder(self):
"""返回PromptBuilder实例"""
return PromptBuilder()
def test_build_system_prompt(self, prompt_builder, sample_categories):
"""测试构建系统提示词"""
# When: 构建系统提示词
system_prompt = prompt_builder.build_system_prompt(sample_categories)
# Then: 系统提示词应该包含所有类别信息
assert "门票" in system_prompt
assert "住宿" in system_prompt
assert "餐饮" in system_prompt
assert "自然类" in system_prompt
assert "商务酒店" in system_prompt
assert "JSON" in system_prompt
def test_build_system_prompt_with_empty_categories(self, prompt_builder):
"""测试使用空类别列表构建系统提示词"""
# When: 使用空类别列表
system_prompt = prompt_builder.build_system_prompt([])
# Then: 系统提示词应该包含基本说明
assert "分类助手" in system_prompt or "助手" in system_prompt
assert "JSON" in system_prompt
def test_build_system_prompt_includes_sub_types(
self, prompt_builder, sample_categories
):
"""测试系统提示词包含子类型信息"""
# When: 构建系统提示词
system_prompt = prompt_builder.build_system_prompt(sample_categories)
# Then: 应该包含子类型信息
assert "自然风光" in system_prompt
assert "历史遗迹" in system_prompt
assert "徽菜" in system_prompt
def test_build_user_prompt(self, prompt_builder, sample_product):
"""测试构建用户提示词"""
# When: 构建用户提示词
user_prompt = prompt_builder.build_user_prompt(sample_product)
# Then: 用户提示词应该包含产品信息
assert "P001" in user_prompt
assert "黄山风景区门票" in user_prompt
assert "黄山" in user_prompt
def test_build_user_prompt_with_different_product(self, prompt_builder):
"""测试使用不同产品构建用户提示词"""
# Given: 另一个产品
product = ProductInput(
product_id="P002", product_name="温泉酒店套房", scenic_spot="华清池"
)
# When: 构建用户提示词
user_prompt = prompt_builder.build_user_prompt(product)
# Then: 应该包含新产品的信息
assert "P002" in user_prompt
assert "温泉酒店套房" in user_prompt
assert "华清池" in user_prompt
def test_build_classification_request(
self, prompt_builder, sample_product, sample_categories
):
"""测试构建完整的分类请求"""
# Given: 系统提示词
system_prompt = prompt_builder.build_system_prompt(sample_categories)
# When: 构建分类请求
request = prompt_builder.build_classification_request(
product=sample_product, system_prompt=system_prompt, model="qwen-flash"
)
# Then: 请求应该符合Batch API格式
assert "custom_id" in request
assert request["custom_id"] == "P001"
assert request["method"] == "POST"
assert request["url"] == "/v1/chat/completions"
assert "body" in request
# 验证body结构
body = request["body"]
assert body["model"] == "qwen-flash"
assert "messages" in body
assert len(body["messages"]) == 2
# 验证messages
assert body["messages"][0]["role"] == "system"
assert body["messages"][1]["role"] == "user"
assert "黄山风景区门票" in body["messages"][1]["content"]
def test_build_classification_request_with_different_model(
self, prompt_builder, sample_product
):
"""测试使用不同模型构建请求"""
# Given: 使用不同的模型
system_prompt = "测试系统提示词"
# When: 构建分类请求
request = prompt_builder.build_classification_request(
product=sample_product, system_prompt=system_prompt, model="qwen-plus"
)
# Then: 模型名称应该正确设置
assert request["body"]["model"] == "qwen-plus"
def test_build_classification_request_is_valid_json(
self, prompt_builder, sample_product, sample_categories
):
"""测试构建的请求可以被序列化为JSON"""
# Given: 系统提示词
system_prompt = prompt_builder.build_system_prompt(sample_categories)
# When: 构建分类请求并序列化
request = prompt_builder.build_classification_request(
product=sample_product, system_prompt=system_prompt
)
# Then: 应该可以成功序列化为JSON
json_str = json.dumps(request, ensure_ascii=False)
assert json_str is not None
# 应该可以反序列化
parsed = json.loads(json_str)
assert parsed["custom_id"] == "P001"
def test_build_system_prompt_categories_grouped_by_category(self, prompt_builder):
"""测试系统提示词中类别按照category分组"""
# Given: 多个类别
categories = [
ProductCategory(category="门票", type="类型A", sub_type=""),
ProductCategory(category="门票", type="类型B", sub_type=""),
ProductCategory(category="住宿", type="类型C", sub_type=""),
]
# When: 构建系统提示词
system_prompt = prompt_builder.build_system_prompt(categories)
# Then: 应该包含所有类别和类型
assert "门票" in system_prompt
assert "住宿" in system_prompt
assert "类型A" in system_prompt
assert "类型B" in system_prompt
assert "类型C" in system_prompt
# 验证类型A和类型B在类型C之前因为门票在住宿之前
type_a_pos = system_prompt.find("类型A")
type_b_pos = system_prompt.find("类型B")
type_c_pos = system_prompt.find("类型C")
# 类型A和类型B都应该在类型C之前
assert type_a_pos < type_c_pos
assert type_b_pos < type_c_pos
def test_build_user_prompt_format(self, prompt_builder, sample_product):
"""测试用户提示词格式"""
# When: 构建用户提示词
user_prompt = prompt_builder.build_user_prompt(sample_product)
# Then: 应该包含所有必要字段的标签
assert "产品编号" in user_prompt or "编号" in user_prompt
assert "产品名称" in user_prompt or "名称" in user_prompt
assert "景区" in user_prompt
def test_multiple_products_same_system_prompt(
self, prompt_builder, sample_categories
):
"""测试多个产品可以使用同一个系统提示词"""
# Given: 一个系统提示词
system_prompt = prompt_builder.build_system_prompt(sample_categories)
# Given: 多个产品
products = [
ProductInput(product_id="P001", product_name="产品1", scenic_spot="景区1"),
ProductInput(product_id="P002", product_name="产品2", scenic_spot="景区2"),
ProductInput(product_id="P003", product_name="产品3", scenic_spot="景区3"),
]
# When: 为每个产品构建请求
requests = [
prompt_builder.build_classification_request(product, system_prompt)
for product in products
]
# Then: 所有请求应该使用相同的系统提示词
system_prompts = [req["body"]["messages"][0]["content"] for req in requests]
assert len(set(system_prompts)) == 1 # 所有系统提示词应该相同
# Then: 每个请求的custom_id应该不同
custom_ids = [req["custom_id"] for req in requests]
assert len(set(custom_ids)) == 3 # 所有custom_id应该不同