TravelContentCreator/tests/test_new_engines.py

317 lines
9.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
新引擎集成测试脚本
测试 V2 引擎(现在是默认引擎)
"""
import asyncio
import sys
import json
from pathlib import Path
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
# 测试数据
SAMPLE_SCENIC_SPOT = {
"id": 1,
"name": "天津冒险湾",
"description": "天津最大的水上乐园,拥有多种刺激水上项目和亲子设施",
"address": "天津市滨海新区海滨大道",
"traffic_info": "地铁9号线直达自驾可走津滨高速",
"highlights": ["水上过山车", "儿童戏水区", "漂流河"],
"opening_hours": "09:00-18:00"
}
SAMPLE_PRODUCT = {
"id": 10,
"name": "家庭套票",
"price": 299,
"original_price": 399,
"package_info": "含2大1小门票+午餐套餐",
"usage_rules": "需提前1天预约有效期30天"
}
SAMPLE_STYLE = {
"id": 1,
"name": "攻略风",
"prompt_key": "gonglue"
}
SAMPLE_AUDIENCE = {
"id": 1,
"name": "亲子向",
"prompt_key": "qinzi"
}
def test_engine_registry():
"""测试引擎注册"""
print("=" * 60)
print("测试引擎注册")
print("=" * 60)
from domain.aigc import EngineRegistry
registry = EngineRegistry()
registry.auto_discover()
engines = registry.list_engines()
print(f"\n已注册 {len(engines)} 个引擎:")
for e in engines:
print(f" - {e['id']}: {e['name']} v{e['version']}")
# 验证默认引擎名称
assert registry.has("topic_generate"), "topic_generate 未注册"
assert registry.has("content_generate"), "content_generate 未注册"
assert registry.has("poster_generate"), "poster_generate 未注册"
# 验证 V2 后缀已移除
assert not registry.has("topic_generate_v2"), "topic_generate_v2 不应存在"
print("\n✅ 引擎注册测试通过")
return True
def test_prompt_registry():
"""测试 Prompt 注册"""
print("\n" + "=" * 60)
print("测试 Prompt 注册")
print("=" * 60)
from domain.prompt import PromptRegistry
registry = PromptRegistry('prompts')
prompts = registry.list_prompts()
print(f"\n已注册 {len(prompts)} 个 Prompt:")
for p in prompts:
versions = registry.list_versions(p)
print(f" - {p}: {versions}")
# 验证核心 Prompt
assert "topic_generate" in prompts, "topic_generate prompt 未找到"
assert "content_generate" in prompts, "content_generate prompt 未找到"
# 风格和人群使用嵌套路径
style_prompts = [p for p in prompts if p.startswith('style/')]
audience_prompts = [p for p in prompts if p.startswith('audience/')]
assert len(style_prompts) >= 2, "style prompt 未找到"
assert len(audience_prompts) >= 3, "audience prompt 未找到"
print("\n✅ Prompt 注册测试通过")
return True
def test_prompt_render():
"""测试 Prompt 渲染"""
print("\n" + "=" * 60)
print("测试 Prompt 渲染")
print("=" * 60)
from domain.prompt import PromptRegistry
registry = PromptRegistry('prompts')
# 测试 topic_generate - 使用完整对象
context = {
"scenic_spot": SAMPLE_SCENIC_SPOT,
"product": SAMPLE_PRODUCT,
"style": SAMPLE_STYLE,
"audience": SAMPLE_AUDIENCE,
"month": "2025-01",
"num_topics": 5
}
system, user = registry.render("topic_generate", context)
print(f"\n渲染 topic_generate:")
print(f" System prompt 长度: {len(system)} 字符")
print(f" User prompt 长度: {len(user)} 字符")
print(f" User prompt 预览: {user[:200]}...")
assert len(system) > 100, "System prompt 太短"
assert len(user) > 50, "User prompt 太短"
assert SAMPLE_SCENIC_SPOT["name"] in user, "景区名称未出现在 user prompt"
print("\n✅ Prompt 渲染测试通过")
return True
def test_style_audience_prompts():
"""测试风格和人群 Prompt"""
print("\n" + "=" * 60)
print("测试风格和人群 Prompt")
print("=" * 60)
from domain.prompt import PromptRegistry
registry = PromptRegistry('prompts')
# 获取所有 prompt筛选风格和人群
all_prompts = registry.list_prompts()
# 测试风格
style_prompts = [p for p in all_prompts if p.startswith('style/')]
print(f"\n风格 Prompt ({len(style_prompts)} 个):")
for prompt_name in style_prompts:
config = registry.get(prompt_name)
style_name = config.meta.get('style_name', prompt_name)
print(f" - {prompt_name}: {style_name} ({len(config.content)} 字符)")
assert len(config.content) > 0, f"{prompt_name} 内容为空"
# 测试人群
audience_prompts = [p for p in all_prompts if p.startswith('audience/')]
print(f"\n人群 Prompt ({len(audience_prompts)} 个):")
for prompt_name in audience_prompts:
config = registry.get(prompt_name)
audience_name = config.meta.get('audience_name', prompt_name)
print(f" - {prompt_name}: {audience_name} ({len(config.content)} 字符)")
assert len(config.content) > 0, f"{prompt_name} 内容为空"
assert len(style_prompts) >= 2, "风格 Prompt 数量不足"
assert len(audience_prompts) >= 3, "人群 Prompt 数量不足"
print("\n✅ 风格和人群 Prompt 测试通过")
return True
def test_engine_param_schema():
"""测试引擎参数 Schema"""
print("\n" + "=" * 60)
print("测试引擎参数 Schema")
print("=" * 60)
from domain.aigc import EngineRegistry
registry = EngineRegistry()
registry.auto_discover()
for engine_id in ["topic_generate", "content_generate", "poster_generate"]:
engine = registry.get(engine_id)
schema = engine.get_param_schema()
print(f"\n{engine_id} 参数:")
for key, info in schema.items():
required = "必填" if info.get("required") else "可选"
print(f" - {key}: {info.get('type')} ({required})")
print("\n✅ 引擎参数 Schema 测试通过")
return True
def test_api_request_format():
"""测试 API 请求格式"""
print("\n" + "=" * 60)
print("测试 API 请求格式")
print("=" * 60)
# 选题生成请求
topic_request = {
"engine": "topic_generate",
"params": {
"num_topics": 5,
"month": "2025-01",
"scenic_spot": SAMPLE_SCENIC_SPOT,
"product": SAMPLE_PRODUCT,
"style": SAMPLE_STYLE,
"audience": SAMPLE_AUDIENCE
}
}
print("\n选题生成请求示例:")
print(json.dumps(topic_request, indent=2, ensure_ascii=False)[:500] + "...")
# 内容生成请求
content_request = {
"engine": "content_generate",
"params": {
"topic": {
"index": 1,
"date": "2025-01-15",
"title": "寒假遛娃好去处"
},
"scenic_spot": SAMPLE_SCENIC_SPOT,
"product": SAMPLE_PRODUCT,
"style": SAMPLE_STYLE,
"audience": SAMPLE_AUDIENCE,
"need_judge": True
}
}
print("\n内容生成请求示例:")
print(json.dumps(content_request, indent=2, ensure_ascii=False)[:500] + "...")
# 海报生成请求
poster_request = {
"engine": "poster_generate",
"params": {
"template_id": "poster-template-1",
"poster_content": {
"title": "寒假特惠",
"subtitle": "天津冒险湾家庭套票"
},
"image_urls": [
"https://example.com/image1.jpg"
]
}
}
print("\n海报生成请求示例:")
print(json.dumps(poster_request, indent=2, ensure_ascii=False))
print("\n✅ API 请求格式测试通过")
return True
def main():
"""运行所有测试"""
print("\n" + "=" * 60)
print("AIGC 新引擎集成测试")
print("=" * 60)
tests = [
("引擎注册", test_engine_registry),
("Prompt 注册", test_prompt_registry),
("Prompt 渲染", test_prompt_render),
("风格人群 Prompt", test_style_audience_prompts),
("引擎参数 Schema", test_engine_param_schema),
("API 请求格式", test_api_request_format),
]
results = []
for name, test_func in tests:
try:
result = test_func()
results.append((name, result))
except Exception as e:
print(f"\n{name} 测试失败: {e}")
results.append((name, False))
# 汇总
print("\n" + "=" * 60)
print("测试结果汇总")
print("=" * 60)
passed = sum(1 for _, r in results if r)
total = len(results)
for name, result in results:
status = "" if result else ""
print(f" {status} {name}")
print(f"\n总计: {passed}/{total} 通过")
return passed == total
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)