TravelContentCreator/tests/test_new_engines.py

317 lines
9.1 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
新引擎集成测试脚本
测试 V2 引擎现在是默认引擎
"""
import asyncio
import sys
import json
from pathlib import Path
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
# 测试数据
SAMPLE_SCENIC_SPOT = {
"id": 1,
"name": "天津冒险湾",
"description": "天津最大的水上乐园,拥有多种刺激水上项目和亲子设施",
"address": "天津市滨海新区海滨大道",
"traffic_info": "地铁9号线直达自驾可走津滨高速",
"highlights": ["水上过山车", "儿童戏水区", "漂流河"],
"opening_hours": "09:00-18:00"
}
SAMPLE_PRODUCT = {
"id": 10,
"name": "家庭套票",
"price": 299,
"original_price": 399,
"package_info": "含2大1小门票+午餐套餐",
"usage_rules": "需提前1天预约有效期30天"
}
SAMPLE_STYLE = {
"id": 1,
"name": "攻略风",
"prompt_key": "gonglue"
}
SAMPLE_AUDIENCE = {
"id": 1,
"name": "亲子向",
"prompt_key": "qinzi"
}
def test_engine_registry():
"""测试引擎注册"""
print("=" * 60)
print("测试引擎注册")
print("=" * 60)
from domain.aigc import EngineRegistry
registry = EngineRegistry()
registry.auto_discover()
engines = registry.list_engines()
print(f"\n已注册 {len(engines)} 个引擎:")
for e in engines:
print(f" - {e['id']}: {e['name']} v{e['version']}")
# 验证默认引擎名称
assert registry.has("topic_generate"), "topic_generate 未注册"
assert registry.has("content_generate"), "content_generate 未注册"
assert registry.has("poster_generate"), "poster_generate 未注册"
# 验证 V2 后缀已移除
assert not registry.has("topic_generate_v2"), "topic_generate_v2 不应存在"
print("\n✅ 引擎注册测试通过")
return True
def test_prompt_registry():
"""测试 Prompt 注册"""
print("\n" + "=" * 60)
print("测试 Prompt 注册")
print("=" * 60)
from domain.prompt import PromptRegistry
registry = PromptRegistry('prompts')
prompts = registry.list_prompts()
print(f"\n已注册 {len(prompts)} 个 Prompt:")
for p in prompts:
versions = registry.list_versions(p)
print(f" - {p}: {versions}")
# 验证核心 Prompt
assert "topic_generate" in prompts, "topic_generate prompt 未找到"
assert "content_generate" in prompts, "content_generate prompt 未找到"
# 风格和人群使用嵌套路径
style_prompts = [p for p in prompts if p.startswith('style/')]
audience_prompts = [p for p in prompts if p.startswith('audience/')]
assert len(style_prompts) >= 2, "style prompt 未找到"
assert len(audience_prompts) >= 3, "audience prompt 未找到"
print("\n✅ Prompt 注册测试通过")
return True
def test_prompt_render():
"""测试 Prompt 渲染"""
print("\n" + "=" * 60)
print("测试 Prompt 渲染")
print("=" * 60)
from domain.prompt import PromptRegistry
registry = PromptRegistry('prompts')
# 测试 topic_generate - 使用完整对象
context = {
"scenic_spot": SAMPLE_SCENIC_SPOT,
"product": SAMPLE_PRODUCT,
"style": SAMPLE_STYLE,
"audience": SAMPLE_AUDIENCE,
"month": "2025-01",
"num_topics": 5
}
system, user = registry.render("topic_generate", context)
print(f"\n渲染 topic_generate:")
print(f" System prompt 长度: {len(system)} 字符")
print(f" User prompt 长度: {len(user)} 字符")
print(f" User prompt 预览: {user[:200]}...")
assert len(system) > 100, "System prompt 太短"
assert len(user) > 50, "User prompt 太短"
assert SAMPLE_SCENIC_SPOT["name"] in user, "景区名称未出现在 user prompt"
print("\n✅ Prompt 渲染测试通过")
return True
def test_style_audience_prompts():
"""测试风格和人群 Prompt"""
print("\n" + "=" * 60)
print("测试风格和人群 Prompt")
print("=" * 60)
from domain.prompt import PromptRegistry
registry = PromptRegistry('prompts')
# 获取所有 prompt筛选风格和人群
all_prompts = registry.list_prompts()
# 测试风格
style_prompts = [p for p in all_prompts if p.startswith('style/')]
print(f"\n风格 Prompt ({len(style_prompts)} 个):")
for prompt_name in style_prompts:
config = registry.get(prompt_name)
style_name = config.meta.get('style_name', prompt_name)
print(f" - {prompt_name}: {style_name} ({len(config.content)} 字符)")
assert len(config.content) > 0, f"{prompt_name} 内容为空"
# 测试人群
audience_prompts = [p for p in all_prompts if p.startswith('audience/')]
print(f"\n人群 Prompt ({len(audience_prompts)} 个):")
for prompt_name in audience_prompts:
config = registry.get(prompt_name)
audience_name = config.meta.get('audience_name', prompt_name)
print(f" - {prompt_name}: {audience_name} ({len(config.content)} 字符)")
assert len(config.content) > 0, f"{prompt_name} 内容为空"
assert len(style_prompts) >= 2, "风格 Prompt 数量不足"
assert len(audience_prompts) >= 3, "人群 Prompt 数量不足"
print("\n✅ 风格和人群 Prompt 测试通过")
return True
def test_engine_param_schema():
"""测试引擎参数 Schema"""
print("\n" + "=" * 60)
print("测试引擎参数 Schema")
print("=" * 60)
from domain.aigc import EngineRegistry
registry = EngineRegistry()
registry.auto_discover()
for engine_id in ["topic_generate", "content_generate", "poster_generate"]:
engine = registry.get(engine_id)
schema = engine.get_param_schema()
print(f"\n{engine_id} 参数:")
for key, info in schema.items():
required = "必填" if info.get("required") else "可选"
print(f" - {key}: {info.get('type')} ({required})")
print("\n✅ 引擎参数 Schema 测试通过")
return True
def test_api_request_format():
"""测试 API 请求格式"""
print("\n" + "=" * 60)
print("测试 API 请求格式")
print("=" * 60)
# 选题生成请求
topic_request = {
"engine": "topic_generate",
"params": {
"num_topics": 5,
"month": "2025-01",
"scenic_spot": SAMPLE_SCENIC_SPOT,
"product": SAMPLE_PRODUCT,
"style": SAMPLE_STYLE,
"audience": SAMPLE_AUDIENCE
}
}
print("\n选题生成请求示例:")
print(json.dumps(topic_request, indent=2, ensure_ascii=False)[:500] + "...")
# 内容生成请求
content_request = {
"engine": "content_generate",
"params": {
"topic": {
"index": 1,
"date": "2025-01-15",
"title": "寒假遛娃好去处"
},
"scenic_spot": SAMPLE_SCENIC_SPOT,
"product": SAMPLE_PRODUCT,
"style": SAMPLE_STYLE,
"audience": SAMPLE_AUDIENCE,
"need_judge": True
}
}
print("\n内容生成请求示例:")
print(json.dumps(content_request, indent=2, ensure_ascii=False)[:500] + "...")
# 海报生成请求
poster_request = {
"engine": "poster_generate",
"params": {
"template_id": "poster-template-1",
"poster_content": {
"title": "寒假特惠",
"subtitle": "天津冒险湾家庭套票"
},
"image_urls": [
"https://example.com/image1.jpg"
]
}
}
print("\n海报生成请求示例:")
print(json.dumps(poster_request, indent=2, ensure_ascii=False))
print("\n✅ API 请求格式测试通过")
return True
def main():
"""运行所有测试"""
print("\n" + "=" * 60)
print("AIGC 新引擎集成测试")
print("=" * 60)
tests = [
("引擎注册", test_engine_registry),
("Prompt 注册", test_prompt_registry),
("Prompt 渲染", test_prompt_render),
("风格人群 Prompt", test_style_audience_prompts),
("引擎参数 Schema", test_engine_param_schema),
("API 请求格式", test_api_request_format),
]
results = []
for name, test_func in tests:
try:
result = test_func()
results.append((name, result))
except Exception as e:
print(f"\n{name} 测试失败: {e}")
results.append((name, False))
# 汇总
print("\n" + "=" * 60)
print("测试结果汇总")
print("=" * 60)
passed = sum(1 for _, r in results if r)
total = len(results)
for name, result in results:
status = "" if result else ""
print(f" {status} {name}")
print(f"\n总计: {passed}/{total} 通过")
return passed == total
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)