317 lines
9.1 KiB
Python
317 lines
9.1 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
新引擎集成测试脚本
|
||
测试 V2 引擎(现在是默认引擎)
|
||
"""
|
||
|
||
import asyncio
|
||
import sys
|
||
import json
|
||
from pathlib import Path
|
||
|
||
# 添加项目根目录到路径
|
||
project_root = Path(__file__).parent.parent
|
||
sys.path.insert(0, str(project_root))
|
||
|
||
|
||
# 测试数据
|
||
SAMPLE_SCENIC_SPOT = {
|
||
"id": 1,
|
||
"name": "天津冒险湾",
|
||
"description": "天津最大的水上乐园,拥有多种刺激水上项目和亲子设施",
|
||
"address": "天津市滨海新区海滨大道",
|
||
"traffic_info": "地铁9号线直达,自驾可走津滨高速",
|
||
"highlights": ["水上过山车", "儿童戏水区", "漂流河"],
|
||
"opening_hours": "09:00-18:00"
|
||
}
|
||
|
||
SAMPLE_PRODUCT = {
|
||
"id": 10,
|
||
"name": "家庭套票",
|
||
"price": 299,
|
||
"original_price": 399,
|
||
"package_info": "含2大1小门票+午餐套餐",
|
||
"usage_rules": "需提前1天预约,有效期30天"
|
||
}
|
||
|
||
SAMPLE_STYLE = {
|
||
"id": 1,
|
||
"name": "攻略风",
|
||
"prompt_key": "gonglue"
|
||
}
|
||
|
||
SAMPLE_AUDIENCE = {
|
||
"id": 1,
|
||
"name": "亲子向",
|
||
"prompt_key": "qinzi"
|
||
}
|
||
|
||
|
||
def test_engine_registry():
|
||
"""测试引擎注册"""
|
||
print("=" * 60)
|
||
print("测试引擎注册")
|
||
print("=" * 60)
|
||
|
||
from domain.aigc import EngineRegistry
|
||
|
||
registry = EngineRegistry()
|
||
registry.auto_discover()
|
||
|
||
engines = registry.list_engines()
|
||
print(f"\n已注册 {len(engines)} 个引擎:")
|
||
for e in engines:
|
||
print(f" - {e['id']}: {e['name']} v{e['version']}")
|
||
|
||
# 验证默认引擎名称
|
||
assert registry.has("topic_generate"), "topic_generate 未注册"
|
||
assert registry.has("content_generate"), "content_generate 未注册"
|
||
assert registry.has("poster_generate"), "poster_generate 未注册"
|
||
|
||
# 验证 V2 后缀已移除
|
||
assert not registry.has("topic_generate_v2"), "topic_generate_v2 不应存在"
|
||
|
||
print("\n✅ 引擎注册测试通过")
|
||
return True
|
||
|
||
|
||
def test_prompt_registry():
|
||
"""测试 Prompt 注册"""
|
||
print("\n" + "=" * 60)
|
||
print("测试 Prompt 注册")
|
||
print("=" * 60)
|
||
|
||
from domain.prompt import PromptRegistry
|
||
|
||
registry = PromptRegistry('prompts')
|
||
prompts = registry.list_prompts()
|
||
|
||
print(f"\n已注册 {len(prompts)} 个 Prompt:")
|
||
for p in prompts:
|
||
versions = registry.list_versions(p)
|
||
print(f" - {p}: {versions}")
|
||
|
||
# 验证核心 Prompt
|
||
assert "topic_generate" in prompts, "topic_generate prompt 未找到"
|
||
assert "content_generate" in prompts, "content_generate prompt 未找到"
|
||
|
||
# 风格和人群使用嵌套路径
|
||
style_prompts = [p for p in prompts if p.startswith('style/')]
|
||
audience_prompts = [p for p in prompts if p.startswith('audience/')]
|
||
assert len(style_prompts) >= 2, "style prompt 未找到"
|
||
assert len(audience_prompts) >= 3, "audience prompt 未找到"
|
||
|
||
print("\n✅ Prompt 注册测试通过")
|
||
return True
|
||
|
||
|
||
def test_prompt_render():
|
||
"""测试 Prompt 渲染"""
|
||
print("\n" + "=" * 60)
|
||
print("测试 Prompt 渲染")
|
||
print("=" * 60)
|
||
|
||
from domain.prompt import PromptRegistry
|
||
|
||
registry = PromptRegistry('prompts')
|
||
|
||
# 测试 topic_generate - 使用完整对象
|
||
context = {
|
||
"scenic_spot": SAMPLE_SCENIC_SPOT,
|
||
"product": SAMPLE_PRODUCT,
|
||
"style": SAMPLE_STYLE,
|
||
"audience": SAMPLE_AUDIENCE,
|
||
"month": "2025-01",
|
||
"num_topics": 5
|
||
}
|
||
|
||
system, user = registry.render("topic_generate", context)
|
||
|
||
print(f"\n渲染 topic_generate:")
|
||
print(f" System prompt 长度: {len(system)} 字符")
|
||
print(f" User prompt 长度: {len(user)} 字符")
|
||
print(f" User prompt 预览: {user[:200]}...")
|
||
|
||
assert len(system) > 100, "System prompt 太短"
|
||
assert len(user) > 50, "User prompt 太短"
|
||
assert SAMPLE_SCENIC_SPOT["name"] in user, "景区名称未出现在 user prompt"
|
||
|
||
print("\n✅ Prompt 渲染测试通过")
|
||
return True
|
||
|
||
|
||
def test_style_audience_prompts():
|
||
"""测试风格和人群 Prompt"""
|
||
print("\n" + "=" * 60)
|
||
print("测试风格和人群 Prompt")
|
||
print("=" * 60)
|
||
|
||
from domain.prompt import PromptRegistry
|
||
|
||
registry = PromptRegistry('prompts')
|
||
|
||
# 获取所有 prompt,筛选风格和人群
|
||
all_prompts = registry.list_prompts()
|
||
|
||
# 测试风格
|
||
style_prompts = [p for p in all_prompts if p.startswith('style/')]
|
||
print(f"\n风格 Prompt ({len(style_prompts)} 个):")
|
||
|
||
for prompt_name in style_prompts:
|
||
config = registry.get(prompt_name)
|
||
style_name = config.meta.get('style_name', prompt_name)
|
||
print(f" - {prompt_name}: {style_name} ({len(config.content)} 字符)")
|
||
assert len(config.content) > 0, f"{prompt_name} 内容为空"
|
||
|
||
# 测试人群
|
||
audience_prompts = [p for p in all_prompts if p.startswith('audience/')]
|
||
print(f"\n人群 Prompt ({len(audience_prompts)} 个):")
|
||
|
||
for prompt_name in audience_prompts:
|
||
config = registry.get(prompt_name)
|
||
audience_name = config.meta.get('audience_name', prompt_name)
|
||
print(f" - {prompt_name}: {audience_name} ({len(config.content)} 字符)")
|
||
assert len(config.content) > 0, f"{prompt_name} 内容为空"
|
||
|
||
assert len(style_prompts) >= 2, "风格 Prompt 数量不足"
|
||
assert len(audience_prompts) >= 3, "人群 Prompt 数量不足"
|
||
|
||
print("\n✅ 风格和人群 Prompt 测试通过")
|
||
return True
|
||
|
||
|
||
def test_engine_param_schema():
|
||
"""测试引擎参数 Schema"""
|
||
print("\n" + "=" * 60)
|
||
print("测试引擎参数 Schema")
|
||
print("=" * 60)
|
||
|
||
from domain.aigc import EngineRegistry
|
||
|
||
registry = EngineRegistry()
|
||
registry.auto_discover()
|
||
|
||
for engine_id in ["topic_generate", "content_generate", "poster_generate"]:
|
||
engine = registry.get(engine_id)
|
||
schema = engine.get_param_schema()
|
||
|
||
print(f"\n{engine_id} 参数:")
|
||
for key, info in schema.items():
|
||
required = "必填" if info.get("required") else "可选"
|
||
print(f" - {key}: {info.get('type')} ({required})")
|
||
|
||
print("\n✅ 引擎参数 Schema 测试通过")
|
||
return True
|
||
|
||
|
||
def test_api_request_format():
|
||
"""测试 API 请求格式"""
|
||
print("\n" + "=" * 60)
|
||
print("测试 API 请求格式")
|
||
print("=" * 60)
|
||
|
||
# 选题生成请求
|
||
topic_request = {
|
||
"engine": "topic_generate",
|
||
"params": {
|
||
"num_topics": 5,
|
||
"month": "2025-01",
|
||
"scenic_spot": SAMPLE_SCENIC_SPOT,
|
||
"product": SAMPLE_PRODUCT,
|
||
"style": SAMPLE_STYLE,
|
||
"audience": SAMPLE_AUDIENCE
|
||
}
|
||
}
|
||
|
||
print("\n选题生成请求示例:")
|
||
print(json.dumps(topic_request, indent=2, ensure_ascii=False)[:500] + "...")
|
||
|
||
# 内容生成请求
|
||
content_request = {
|
||
"engine": "content_generate",
|
||
"params": {
|
||
"topic": {
|
||
"index": 1,
|
||
"date": "2025-01-15",
|
||
"title": "寒假遛娃好去处"
|
||
},
|
||
"scenic_spot": SAMPLE_SCENIC_SPOT,
|
||
"product": SAMPLE_PRODUCT,
|
||
"style": SAMPLE_STYLE,
|
||
"audience": SAMPLE_AUDIENCE,
|
||
"need_judge": True
|
||
}
|
||
}
|
||
|
||
print("\n内容生成请求示例:")
|
||
print(json.dumps(content_request, indent=2, ensure_ascii=False)[:500] + "...")
|
||
|
||
# 海报生成请求
|
||
poster_request = {
|
||
"engine": "poster_generate",
|
||
"params": {
|
||
"template_id": "poster-template-1",
|
||
"poster_content": {
|
||
"title": "寒假特惠",
|
||
"subtitle": "天津冒险湾家庭套票"
|
||
},
|
||
"image_urls": [
|
||
"https://example.com/image1.jpg"
|
||
]
|
||
}
|
||
}
|
||
|
||
print("\n海报生成请求示例:")
|
||
print(json.dumps(poster_request, indent=2, ensure_ascii=False))
|
||
|
||
print("\n✅ API 请求格式测试通过")
|
||
return True
|
||
|
||
|
||
def main():
|
||
"""运行所有测试"""
|
||
print("\n" + "=" * 60)
|
||
print("AIGC 新引擎集成测试")
|
||
print("=" * 60)
|
||
|
||
tests = [
|
||
("引擎注册", test_engine_registry),
|
||
("Prompt 注册", test_prompt_registry),
|
||
("Prompt 渲染", test_prompt_render),
|
||
("风格人群 Prompt", test_style_audience_prompts),
|
||
("引擎参数 Schema", test_engine_param_schema),
|
||
("API 请求格式", test_api_request_format),
|
||
]
|
||
|
||
results = []
|
||
for name, test_func in tests:
|
||
try:
|
||
result = test_func()
|
||
results.append((name, result))
|
||
except Exception as e:
|
||
print(f"\n❌ {name} 测试失败: {e}")
|
||
results.append((name, False))
|
||
|
||
# 汇总
|
||
print("\n" + "=" * 60)
|
||
print("测试结果汇总")
|
||
print("=" * 60)
|
||
|
||
passed = sum(1 for _, r in results if r)
|
||
total = len(results)
|
||
|
||
for name, result in results:
|
||
status = "✅" if result else "❌"
|
||
print(f" {status} {name}")
|
||
|
||
print(f"\n总计: {passed}/{total} 通过")
|
||
|
||
return passed == total
|
||
|
||
|
||
if __name__ == "__main__":
|
||
success = main()
|
||
sys.exit(0 if success else 1)
|