#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 新引擎集成测试脚本 测试 V2 引擎(现在是默认引擎) """ import asyncio import sys import json from pathlib import Path # 添加项目根目录到路径 project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) # 测试数据 SAMPLE_SCENIC_SPOT = { "id": 1, "name": "天津冒险湾", "description": "天津最大的水上乐园,拥有多种刺激水上项目和亲子设施", "address": "天津市滨海新区海滨大道", "traffic_info": "地铁9号线直达,自驾可走津滨高速", "highlights": ["水上过山车", "儿童戏水区", "漂流河"], "opening_hours": "09:00-18:00" } SAMPLE_PRODUCT = { "id": 10, "name": "家庭套票", "price": 299, "original_price": 399, "package_info": "含2大1小门票+午餐套餐", "usage_rules": "需提前1天预约,有效期30天" } SAMPLE_STYLE = { "id": 1, "name": "攻略风", "prompt_key": "gonglue" } SAMPLE_AUDIENCE = { "id": 1, "name": "亲子向", "prompt_key": "qinzi" } def test_engine_registry(): """测试引擎注册""" print("=" * 60) print("测试引擎注册") print("=" * 60) from domain.aigc import EngineRegistry registry = EngineRegistry() registry.auto_discover() engines = registry.list_engines() print(f"\n已注册 {len(engines)} 个引擎:") for e in engines: print(f" - {e['id']}: {e['name']} v{e['version']}") # 验证默认引擎名称 assert registry.has("topic_generate"), "topic_generate 未注册" assert registry.has("content_generate"), "content_generate 未注册" assert registry.has("poster_generate"), "poster_generate 未注册" # 验证 V2 后缀已移除 assert not registry.has("topic_generate_v2"), "topic_generate_v2 不应存在" print("\n✅ 引擎注册测试通过") return True def test_prompt_registry(): """测试 Prompt 注册""" print("\n" + "=" * 60) print("测试 Prompt 注册") print("=" * 60) from domain.prompt import PromptRegistry registry = PromptRegistry('prompts') prompts = registry.list_prompts() print(f"\n已注册 {len(prompts)} 个 Prompt:") for p in prompts: versions = registry.list_versions(p) print(f" - {p}: {versions}") # 验证核心 Prompt assert "topic_generate" in prompts, "topic_generate prompt 未找到" assert "content_generate" in prompts, "content_generate prompt 未找到" # 风格和人群使用嵌套路径 style_prompts = [p for p in prompts if p.startswith('style/')] audience_prompts = [p for p in prompts if p.startswith('audience/')] assert len(style_prompts) >= 2, "style prompt 未找到" assert len(audience_prompts) >= 3, "audience prompt 未找到" print("\n✅ Prompt 注册测试通过") return True def test_prompt_render(): """测试 Prompt 渲染""" print("\n" + "=" * 60) print("测试 Prompt 渲染") print("=" * 60) from domain.prompt import PromptRegistry registry = PromptRegistry('prompts') # 测试 topic_generate - 使用完整对象 context = { "scenic_spot": SAMPLE_SCENIC_SPOT, "product": SAMPLE_PRODUCT, "style": SAMPLE_STYLE, "audience": SAMPLE_AUDIENCE, "month": "2025-01", "num_topics": 5 } system, user = registry.render("topic_generate", context) print(f"\n渲染 topic_generate:") print(f" System prompt 长度: {len(system)} 字符") print(f" User prompt 长度: {len(user)} 字符") print(f" User prompt 预览: {user[:200]}...") assert len(system) > 100, "System prompt 太短" assert len(user) > 50, "User prompt 太短" assert SAMPLE_SCENIC_SPOT["name"] in user, "景区名称未出现在 user prompt" print("\n✅ Prompt 渲染测试通过") return True def test_style_audience_prompts(): """测试风格和人群 Prompt""" print("\n" + "=" * 60) print("测试风格和人群 Prompt") print("=" * 60) from domain.prompt import PromptRegistry registry = PromptRegistry('prompts') # 获取所有 prompt,筛选风格和人群 all_prompts = registry.list_prompts() # 测试风格 style_prompts = [p for p in all_prompts if p.startswith('style/')] print(f"\n风格 Prompt ({len(style_prompts)} 个):") for prompt_name in style_prompts: config = registry.get(prompt_name) style_name = config.meta.get('style_name', prompt_name) print(f" - {prompt_name}: {style_name} ({len(config.content)} 字符)") assert len(config.content) > 0, f"{prompt_name} 内容为空" # 测试人群 audience_prompts = [p for p in all_prompts if p.startswith('audience/')] print(f"\n人群 Prompt ({len(audience_prompts)} 个):") for prompt_name in audience_prompts: config = registry.get(prompt_name) audience_name = config.meta.get('audience_name', prompt_name) print(f" - {prompt_name}: {audience_name} ({len(config.content)} 字符)") assert len(config.content) > 0, f"{prompt_name} 内容为空" assert len(style_prompts) >= 2, "风格 Prompt 数量不足" assert len(audience_prompts) >= 3, "人群 Prompt 数量不足" print("\n✅ 风格和人群 Prompt 测试通过") return True def test_engine_param_schema(): """测试引擎参数 Schema""" print("\n" + "=" * 60) print("测试引擎参数 Schema") print("=" * 60) from domain.aigc import EngineRegistry registry = EngineRegistry() registry.auto_discover() for engine_id in ["topic_generate", "content_generate", "poster_generate"]: engine = registry.get(engine_id) schema = engine.get_param_schema() print(f"\n{engine_id} 参数:") for key, info in schema.items(): required = "必填" if info.get("required") else "可选" print(f" - {key}: {info.get('type')} ({required})") print("\n✅ 引擎参数 Schema 测试通过") return True def test_api_request_format(): """测试 API 请求格式""" print("\n" + "=" * 60) print("测试 API 请求格式") print("=" * 60) # 选题生成请求 topic_request = { "engine": "topic_generate", "params": { "num_topics": 5, "month": "2025-01", "scenic_spot": SAMPLE_SCENIC_SPOT, "product": SAMPLE_PRODUCT, "style": SAMPLE_STYLE, "audience": SAMPLE_AUDIENCE } } print("\n选题生成请求示例:") print(json.dumps(topic_request, indent=2, ensure_ascii=False)[:500] + "...") # 内容生成请求 content_request = { "engine": "content_generate", "params": { "topic": { "index": 1, "date": "2025-01-15", "title": "寒假遛娃好去处" }, "scenic_spot": SAMPLE_SCENIC_SPOT, "product": SAMPLE_PRODUCT, "style": SAMPLE_STYLE, "audience": SAMPLE_AUDIENCE, "need_judge": True } } print("\n内容生成请求示例:") print(json.dumps(content_request, indent=2, ensure_ascii=False)[:500] + "...") # 海报生成请求 poster_request = { "engine": "poster_generate", "params": { "template_id": "poster-template-1", "poster_content": { "title": "寒假特惠", "subtitle": "天津冒险湾家庭套票" }, "image_urls": [ "https://example.com/image1.jpg" ] } } print("\n海报生成请求示例:") print(json.dumps(poster_request, indent=2, ensure_ascii=False)) print("\n✅ API 请求格式测试通过") return True def main(): """运行所有测试""" print("\n" + "=" * 60) print("AIGC 新引擎集成测试") print("=" * 60) tests = [ ("引擎注册", test_engine_registry), ("Prompt 注册", test_prompt_registry), ("Prompt 渲染", test_prompt_render), ("风格人群 Prompt", test_style_audience_prompts), ("引擎参数 Schema", test_engine_param_schema), ("API 请求格式", test_api_request_format), ] results = [] for name, test_func in tests: try: result = test_func() results.append((name, result)) except Exception as e: print(f"\n❌ {name} 测试失败: {e}") results.append((name, False)) # 汇总 print("\n" + "=" * 60) print("测试结果汇总") print("=" * 60) passed = sum(1 for _, r in results if r) total = len(results) for name, result in results: status = "✅" if result else "❌" print(f" {status} {name}") print(f"\n总计: {passed}/{total} 通过") return passed == total if __name__ == "__main__": success = main() sys.exit(0 if success else 1)