227 lines
8.7 KiB
Python
227 lines
8.7 KiB
Python
"""测试Prompt构建模块
|
||
|
||
测试PromptBuilder类的所有功能
|
||
"""
|
||
|
||
import pytest
|
||
import json
|
||
|
||
from src.core.prompt_builder import PromptBuilder
|
||
from src.models import ProductInput, ProductCategory
|
||
|
||
|
||
class TestPromptBuilder:
|
||
"""测试PromptBuilder类"""
|
||
|
||
@pytest.fixture
|
||
def sample_categories(self):
|
||
"""返回示例类别列表"""
|
||
return [
|
||
ProductCategory(category="门票", type="自然类", sub_type="自然风光"),
|
||
ProductCategory(category="门票", type="文化类", sub_type="历史遗迹"),
|
||
ProductCategory(category="住宿", type="商务酒店", sub_type=""),
|
||
ProductCategory(category="住宿", type="度假酒店", sub_type=""),
|
||
ProductCategory(category="餐饮", type="地方特色", sub_type="徽菜"),
|
||
]
|
||
|
||
@pytest.fixture
|
||
def sample_product(self):
|
||
"""返回示例产品"""
|
||
return ProductInput(
|
||
product_id="P001", product_name="黄山风景区门票", scenic_spot="黄山"
|
||
)
|
||
|
||
@pytest.fixture
|
||
def prompt_builder(self):
|
||
"""返回PromptBuilder实例"""
|
||
return PromptBuilder()
|
||
|
||
def test_build_system_prompt(self, prompt_builder, sample_categories):
|
||
"""测试构建系统提示词"""
|
||
# When: 构建系统提示词
|
||
system_prompt = prompt_builder.build_system_prompt(sample_categories)
|
||
|
||
# Then: 系统提示词应该包含所有类别信息
|
||
assert "门票" in system_prompt
|
||
assert "住宿" in system_prompt
|
||
assert "餐饮" in system_prompt
|
||
assert "自然类" in system_prompt
|
||
assert "商务酒店" in system_prompt
|
||
assert "JSON" in system_prompt
|
||
|
||
def test_build_system_prompt_with_empty_categories(self, prompt_builder):
|
||
"""测试使用空类别列表构建系统提示词"""
|
||
# When: 使用空类别列表
|
||
system_prompt = prompt_builder.build_system_prompt([])
|
||
|
||
# Then: 系统提示词应该包含基本说明
|
||
assert "分类助手" in system_prompt or "助手" in system_prompt
|
||
assert "JSON" in system_prompt
|
||
|
||
def test_build_system_prompt_includes_sub_types(
|
||
self, prompt_builder, sample_categories
|
||
):
|
||
"""测试系统提示词包含子类型信息"""
|
||
# When: 构建系统提示词
|
||
system_prompt = prompt_builder.build_system_prompt(sample_categories)
|
||
|
||
# Then: 应该包含子类型信息
|
||
assert "自然风光" in system_prompt
|
||
assert "历史遗迹" in system_prompt
|
||
assert "徽菜" in system_prompt
|
||
|
||
def test_build_user_prompt(self, prompt_builder, sample_product):
|
||
"""测试构建用户提示词"""
|
||
# When: 构建用户提示词
|
||
user_prompt = prompt_builder.build_user_prompt(sample_product)
|
||
|
||
# Then: 用户提示词应该包含产品信息
|
||
assert "P001" in user_prompt
|
||
assert "黄山风景区门票" in user_prompt
|
||
assert "黄山" in user_prompt
|
||
|
||
def test_build_user_prompt_with_different_product(self, prompt_builder):
|
||
"""测试使用不同产品构建用户提示词"""
|
||
# Given: 另一个产品
|
||
product = ProductInput(
|
||
product_id="P002", product_name="温泉酒店套房", scenic_spot="华清池"
|
||
)
|
||
|
||
# When: 构建用户提示词
|
||
user_prompt = prompt_builder.build_user_prompt(product)
|
||
|
||
# Then: 应该包含新产品的信息
|
||
assert "P002" in user_prompt
|
||
assert "温泉酒店套房" in user_prompt
|
||
assert "华清池" in user_prompt
|
||
|
||
def test_build_classification_request(
|
||
self, prompt_builder, sample_product, sample_categories
|
||
):
|
||
"""测试构建完整的分类请求"""
|
||
# Given: 系统提示词
|
||
system_prompt = prompt_builder.build_system_prompt(sample_categories)
|
||
|
||
# When: 构建分类请求
|
||
request = prompt_builder.build_classification_request(
|
||
product=sample_product, system_prompt=system_prompt, model="qwen-flash"
|
||
)
|
||
|
||
# Then: 请求应该符合Batch API格式
|
||
assert "custom_id" in request
|
||
assert request["custom_id"] == "P001"
|
||
assert request["method"] == "POST"
|
||
assert request["url"] == "/v1/chat/completions"
|
||
assert "body" in request
|
||
|
||
# 验证body结构
|
||
body = request["body"]
|
||
assert body["model"] == "qwen-flash"
|
||
assert "messages" in body
|
||
assert len(body["messages"]) == 2
|
||
|
||
# 验证messages
|
||
assert body["messages"][0]["role"] == "system"
|
||
assert body["messages"][1]["role"] == "user"
|
||
assert "黄山风景区门票" in body["messages"][1]["content"]
|
||
|
||
def test_build_classification_request_with_different_model(
|
||
self, prompt_builder, sample_product
|
||
):
|
||
"""测试使用不同模型构建请求"""
|
||
# Given: 使用不同的模型
|
||
system_prompt = "测试系统提示词"
|
||
|
||
# When: 构建分类请求
|
||
request = prompt_builder.build_classification_request(
|
||
product=sample_product, system_prompt=system_prompt, model="qwen-plus"
|
||
)
|
||
|
||
# Then: 模型名称应该正确设置
|
||
assert request["body"]["model"] == "qwen-plus"
|
||
|
||
def test_build_classification_request_is_valid_json(
|
||
self, prompt_builder, sample_product, sample_categories
|
||
):
|
||
"""测试构建的请求可以被序列化为JSON"""
|
||
# Given: 系统提示词
|
||
system_prompt = prompt_builder.build_system_prompt(sample_categories)
|
||
|
||
# When: 构建分类请求并序列化
|
||
request = prompt_builder.build_classification_request(
|
||
product=sample_product, system_prompt=system_prompt
|
||
)
|
||
|
||
# Then: 应该可以成功序列化为JSON
|
||
json_str = json.dumps(request, ensure_ascii=False)
|
||
assert json_str is not None
|
||
# 应该可以反序列化
|
||
parsed = json.loads(json_str)
|
||
assert parsed["custom_id"] == "P001"
|
||
|
||
def test_build_system_prompt_categories_grouped_by_category(self, prompt_builder):
|
||
"""测试系统提示词中类别按照category分组"""
|
||
# Given: 多个类别
|
||
categories = [
|
||
ProductCategory(category="门票", type="类型A", sub_type=""),
|
||
ProductCategory(category="门票", type="类型B", sub_type=""),
|
||
ProductCategory(category="住宿", type="类型C", sub_type=""),
|
||
]
|
||
|
||
# When: 构建系统提示词
|
||
system_prompt = prompt_builder.build_system_prompt(categories)
|
||
|
||
# Then: 应该包含所有类别和类型
|
||
assert "门票" in system_prompt
|
||
assert "住宿" in system_prompt
|
||
assert "类型A" in system_prompt
|
||
assert "类型B" in system_prompt
|
||
assert "类型C" in system_prompt
|
||
|
||
# 验证类型A和类型B在类型C之前(因为门票在住宿之前)
|
||
type_a_pos = system_prompt.find("类型A")
|
||
type_b_pos = system_prompt.find("类型B")
|
||
type_c_pos = system_prompt.find("类型C")
|
||
|
||
# 类型A和类型B都应该在类型C之前
|
||
assert type_a_pos < type_c_pos
|
||
assert type_b_pos < type_c_pos
|
||
|
||
def test_build_user_prompt_format(self, prompt_builder, sample_product):
|
||
"""测试用户提示词格式"""
|
||
# When: 构建用户提示词
|
||
user_prompt = prompt_builder.build_user_prompt(sample_product)
|
||
|
||
# Then: 应该包含所有必要字段的标签
|
||
assert "产品编号" in user_prompt or "编号" in user_prompt
|
||
assert "产品名称" in user_prompt or "名称" in user_prompt
|
||
assert "景区" in user_prompt
|
||
|
||
def test_multiple_products_same_system_prompt(
|
||
self, prompt_builder, sample_categories
|
||
):
|
||
"""测试多个产品可以使用同一个系统提示词"""
|
||
# Given: 一个系统提示词
|
||
system_prompt = prompt_builder.build_system_prompt(sample_categories)
|
||
|
||
# Given: 多个产品
|
||
products = [
|
||
ProductInput(product_id="P001", product_name="产品1", scenic_spot="景区1"),
|
||
ProductInput(product_id="P002", product_name="产品2", scenic_spot="景区2"),
|
||
ProductInput(product_id="P003", product_name="产品3", scenic_spot="景区3"),
|
||
]
|
||
|
||
# When: 为每个产品构建请求
|
||
requests = [
|
||
prompt_builder.build_classification_request(product, system_prompt)
|
||
for product in products
|
||
]
|
||
|
||
# Then: 所有请求应该使用相同的系统提示词
|
||
system_prompts = [req["body"]["messages"][0]["content"] for req in requests]
|
||
assert len(set(system_prompts)) == 1 # 所有系统提示词应该相同
|
||
|
||
# Then: 每个请求的custom_id应该不同
|
||
custom_ids = [req["custom_id"] for req in requests]
|
||
assert len(set(custom_ids)) == 3 # 所有custom_id应该不同
|