TravelContentCreator/tests/test_aigc_engines.py

246 lines
7.2 KiB
Python
Raw Normal View History

2025-12-08 14:58:35 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
AIGC 引擎测试脚本
"""
import asyncio
import sys
from pathlib import Path
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
def test_imports():
"""测试模块导入"""
print("=" * 50)
print("测试模块导入")
print("=" * 50)
try:
from domain.aigc import EngineRegistry, EngineExecutor, BaseAIGCEngine, EngineResult
print("✅ domain.aigc 导入成功")
from domain.aigc.shared import (
ComponentFactory, LLMClient, PromptBuilder,
ImageProcessor, DatabaseAccessor, FileStorage
)
print("✅ domain.aigc.shared 导入成功")
from domain.aigc.engines.poster_generate import PosterGenerateEngine
from domain.aigc.engines.content_generate import ContentGenerateEngine
from domain.aigc.engines.topic_generate import TopicGenerateEngine
print("✅ 所有引擎导入成功")
return True
except Exception as e:
print(f"❌ 导入失败: {e}")
return False
def test_engine_registry():
"""测试引擎注册表"""
print("\n" + "=" * 50)
print("测试引擎注册表")
print("=" * 50)
try:
from domain.aigc import EngineRegistry
registry = EngineRegistry()
registry.auto_discover()
engines = registry.list_engines()
print(f"✅ 发现 {len(engines)} 个引擎:")
for engine in engines:
print(f" - {engine['id']}: {engine['name']} (v{engine['version']})")
# 测试获取引擎
for engine_id in ['poster_generate', 'content_generate', 'topic_generate']:
if registry.has(engine_id):
print(f"✅ 引擎 {engine_id} 已注册")
else:
print(f"❌ 引擎 {engine_id} 未注册")
return True
except Exception as e:
print(f"❌ 测试失败: {e}")
import traceback
traceback.print_exc()
return False
def test_component_factory():
"""测试组件工厂"""
print("\n" + "=" * 50)
print("测试组件工厂")
print("=" * 50)
try:
from domain.aigc.shared import ComponentFactory
factory = ComponentFactory()
print(f"✅ 项目根目录: {factory.project_root}")
print(f"✅ 路径配置: {list(factory.paths.keys())}")
components = factory.create_components()
print(f"✅ 创建组件: {list(components.keys())}")
# 测试各组件
assert components['llm'] is not None, "LLM 客户端为空"
assert components['prompt'] is not None, "提示词构建器为空"
assert components['image'] is not None, "图片处理器为空"
assert components['db'] is not None, "数据库访问器为空"
assert components['storage'] is not None, "文件存储为空"
print("✅ 所有组件创建成功")
return True
except Exception as e:
print(f"❌ 测试失败: {e}")
import traceback
traceback.print_exc()
return False
def test_engine_param_schema():
"""测试引擎参数 Schema"""
print("\n" + "=" * 50)
print("测试引擎参数 Schema")
print("=" * 50)
try:
from domain.aigc import EngineRegistry
registry = EngineRegistry()
registry.auto_discover()
for engine_id in ['poster_generate', 'content_generate', 'topic_generate']:
info = registry.get_engine_info(engine_id)
if info:
print(f"\n📋 {engine_id} 参数:")
schema = info.get('param_schema', {})
for param_name, param_def in schema.items():
required = "必填" if param_def.get('required') else "可选"
param_type = param_def.get('type', 'any')
desc = param_def.get('desc', '')
print(f" - {param_name} ({param_type}, {required}): {desc}")
return True
except Exception as e:
print(f"❌ 测试失败: {e}")
import traceback
traceback.print_exc()
return False
def test_param_validation():
"""测试参数验证"""
print("\n" + "=" * 50)
print("测试参数验证")
print("=" * 50)
try:
from domain.aigc import EngineRegistry
from domain.aigc.shared import ComponentFactory
registry = EngineRegistry()
factory = ComponentFactory()
components = factory.create_components()
registry.set_shared_components(components)
registry.auto_discover()
# 获取海报引擎
engine = registry.get('poster_generate')
# 测试有效参数
valid_params = {
'template_id': 'vibrant',
'images': ['base64_image_data']
}
is_valid, error = engine.validate_params(valid_params)
print(f"✅ 有效参数验证: {is_valid}")
# 测试缺少必填参数
invalid_params = {
'template_id': 'vibrant'
# 缺少 images
}
is_valid, error = engine.validate_params(invalid_params)
print(f"✅ 无效参数验证: {not is_valid}, 错误: {error}")
return True
except Exception as e:
print(f"❌ 测试失败: {e}")
import traceback
traceback.print_exc()
return False
def test_api_router():
"""测试 API 路由"""
print("\n" + "=" * 50)
print("测试 API 路由")
print("=" * 50)
try:
from api.routers import aigc
print(f"✅ 路由前缀: {aigc.router.prefix}")
print(f"✅ 路由数量: {len(aigc.router.routes)}")
print("✅ 路由列表:")
for route in aigc.router.routes:
methods = getattr(route, 'methods', set())
path = getattr(route, 'path', '')
print(f" - {methods} {path}")
return True
except Exception as e:
print(f"❌ 测试失败: {e}")
import traceback
traceback.print_exc()
return False
def main():
"""运行所有测试"""
print("\n" + "=" * 60)
print(" AIGC 引擎系统测试")
print("=" * 60)
results = []
results.append(("模块导入", test_imports()))
results.append(("引擎注册表", test_engine_registry()))
results.append(("组件工厂", test_component_factory()))
results.append(("参数 Schema", test_engine_param_schema()))
results.append(("参数验证", test_param_validation()))
results.append(("API 路由", test_api_router()))
print("\n" + "=" * 60)
print(" 测试结果汇总")
print("=" * 60)
passed = 0
failed = 0
for name, result in results:
status = "✅ 通过" if result else "❌ 失败"
print(f" {name}: {status}")
if result:
passed += 1
else:
failed += 1
print(f"\n 总计: {passed} 通过, {failed} 失败")
print("=" * 60)
return failed == 0
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)