246 lines
7.2 KiB
Python
246 lines
7.2 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
AIGC 引擎测试脚本
|
|
"""
|
|
|
|
import asyncio
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
# 添加项目根目录到路径
|
|
project_root = Path(__file__).parent.parent
|
|
sys.path.insert(0, str(project_root))
|
|
|
|
|
|
def test_imports():
|
|
"""测试模块导入"""
|
|
print("=" * 50)
|
|
print("测试模块导入")
|
|
print("=" * 50)
|
|
|
|
try:
|
|
from domain.aigc import EngineRegistry, EngineExecutor, BaseAIGCEngine, EngineResult
|
|
print("✅ domain.aigc 导入成功")
|
|
|
|
from domain.aigc.shared import (
|
|
ComponentFactory, LLMClient, PromptBuilder,
|
|
ImageProcessor, DatabaseAccessor, FileStorage
|
|
)
|
|
print("✅ domain.aigc.shared 导入成功")
|
|
|
|
from domain.aigc.engines.poster_generate import PosterGenerateEngine
|
|
from domain.aigc.engines.content_generate import ContentGenerateEngine
|
|
from domain.aigc.engines.topic_generate import TopicGenerateEngine
|
|
print("✅ 所有引擎导入成功")
|
|
|
|
return True
|
|
except Exception as e:
|
|
print(f"❌ 导入失败: {e}")
|
|
return False
|
|
|
|
|
|
def test_engine_registry():
|
|
"""测试引擎注册表"""
|
|
print("\n" + "=" * 50)
|
|
print("测试引擎注册表")
|
|
print("=" * 50)
|
|
|
|
try:
|
|
from domain.aigc import EngineRegistry
|
|
|
|
registry = EngineRegistry()
|
|
registry.auto_discover()
|
|
|
|
engines = registry.list_engines()
|
|
print(f"✅ 发现 {len(engines)} 个引擎:")
|
|
for engine in engines:
|
|
print(f" - {engine['id']}: {engine['name']} (v{engine['version']})")
|
|
|
|
# 测试获取引擎
|
|
for engine_id in ['poster_generate', 'content_generate', 'topic_generate']:
|
|
if registry.has(engine_id):
|
|
print(f"✅ 引擎 {engine_id} 已注册")
|
|
else:
|
|
print(f"❌ 引擎 {engine_id} 未注册")
|
|
|
|
return True
|
|
except Exception as e:
|
|
print(f"❌ 测试失败: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
|
|
def test_component_factory():
|
|
"""测试组件工厂"""
|
|
print("\n" + "=" * 50)
|
|
print("测试组件工厂")
|
|
print("=" * 50)
|
|
|
|
try:
|
|
from domain.aigc.shared import ComponentFactory
|
|
|
|
factory = ComponentFactory()
|
|
print(f"✅ 项目根目录: {factory.project_root}")
|
|
print(f"✅ 路径配置: {list(factory.paths.keys())}")
|
|
|
|
components = factory.create_components()
|
|
print(f"✅ 创建组件: {list(components.keys())}")
|
|
|
|
# 测试各组件
|
|
assert components['llm'] is not None, "LLM 客户端为空"
|
|
assert components['prompt'] is not None, "提示词构建器为空"
|
|
assert components['image'] is not None, "图片处理器为空"
|
|
assert components['db'] is not None, "数据库访问器为空"
|
|
assert components['storage'] is not None, "文件存储为空"
|
|
|
|
print("✅ 所有组件创建成功")
|
|
return True
|
|
except Exception as e:
|
|
print(f"❌ 测试失败: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
|
|
def test_engine_param_schema():
|
|
"""测试引擎参数 Schema"""
|
|
print("\n" + "=" * 50)
|
|
print("测试引擎参数 Schema")
|
|
print("=" * 50)
|
|
|
|
try:
|
|
from domain.aigc import EngineRegistry
|
|
|
|
registry = EngineRegistry()
|
|
registry.auto_discover()
|
|
|
|
for engine_id in ['poster_generate', 'content_generate', 'topic_generate']:
|
|
info = registry.get_engine_info(engine_id)
|
|
if info:
|
|
print(f"\n📋 {engine_id} 参数:")
|
|
schema = info.get('param_schema', {})
|
|
for param_name, param_def in schema.items():
|
|
required = "必填" if param_def.get('required') else "可选"
|
|
param_type = param_def.get('type', 'any')
|
|
desc = param_def.get('desc', '')
|
|
print(f" - {param_name} ({param_type}, {required}): {desc}")
|
|
|
|
return True
|
|
except Exception as e:
|
|
print(f"❌ 测试失败: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
|
|
def test_param_validation():
|
|
"""测试参数验证"""
|
|
print("\n" + "=" * 50)
|
|
print("测试参数验证")
|
|
print("=" * 50)
|
|
|
|
try:
|
|
from domain.aigc import EngineRegistry
|
|
from domain.aigc.shared import ComponentFactory
|
|
|
|
registry = EngineRegistry()
|
|
factory = ComponentFactory()
|
|
components = factory.create_components()
|
|
registry.set_shared_components(components)
|
|
registry.auto_discover()
|
|
|
|
# 获取海报引擎
|
|
engine = registry.get('poster_generate')
|
|
|
|
# 测试有效参数
|
|
valid_params = {
|
|
'template_id': 'vibrant',
|
|
'images': ['base64_image_data']
|
|
}
|
|
is_valid, error = engine.validate_params(valid_params)
|
|
print(f"✅ 有效参数验证: {is_valid}")
|
|
|
|
# 测试缺少必填参数
|
|
invalid_params = {
|
|
'template_id': 'vibrant'
|
|
# 缺少 images
|
|
}
|
|
is_valid, error = engine.validate_params(invalid_params)
|
|
print(f"✅ 无效参数验证: {not is_valid}, 错误: {error}")
|
|
|
|
return True
|
|
except Exception as e:
|
|
print(f"❌ 测试失败: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
|
|
def test_api_router():
|
|
"""测试 API 路由"""
|
|
print("\n" + "=" * 50)
|
|
print("测试 API 路由")
|
|
print("=" * 50)
|
|
|
|
try:
|
|
from api.routers import aigc
|
|
|
|
print(f"✅ 路由前缀: {aigc.router.prefix}")
|
|
print(f"✅ 路由数量: {len(aigc.router.routes)}")
|
|
|
|
print("✅ 路由列表:")
|
|
for route in aigc.router.routes:
|
|
methods = getattr(route, 'methods', set())
|
|
path = getattr(route, 'path', '')
|
|
print(f" - {methods} {path}")
|
|
|
|
return True
|
|
except Exception as e:
|
|
print(f"❌ 测试失败: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
|
|
def main():
|
|
"""运行所有测试"""
|
|
print("\n" + "=" * 60)
|
|
print(" AIGC 引擎系统测试")
|
|
print("=" * 60)
|
|
|
|
results = []
|
|
|
|
results.append(("模块导入", test_imports()))
|
|
results.append(("引擎注册表", test_engine_registry()))
|
|
results.append(("组件工厂", test_component_factory()))
|
|
results.append(("参数 Schema", test_engine_param_schema()))
|
|
results.append(("参数验证", test_param_validation()))
|
|
results.append(("API 路由", test_api_router()))
|
|
|
|
print("\n" + "=" * 60)
|
|
print(" 测试结果汇总")
|
|
print("=" * 60)
|
|
|
|
passed = 0
|
|
failed = 0
|
|
for name, result in results:
|
|
status = "✅ 通过" if result else "❌ 失败"
|
|
print(f" {name}: {status}")
|
|
if result:
|
|
passed += 1
|
|
else:
|
|
failed += 1
|
|
|
|
print(f"\n 总计: {passed} 通过, {failed} 失败")
|
|
print("=" * 60)
|
|
|
|
return failed == 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
success = main()
|
|
sys.exit(0 if success else 1)
|