#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ AIGC 引擎测试脚本 """ import asyncio import sys from pathlib import Path # 添加项目根目录到路径 project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) def test_imports(): """测试模块导入""" print("=" * 50) print("测试模块导入") print("=" * 50) try: from domain.aigc import EngineRegistry, EngineExecutor, BaseAIGCEngine, EngineResult print("✅ domain.aigc 导入成功") from domain.aigc.shared import ( ComponentFactory, LLMClient, PromptBuilder, ImageProcessor, DatabaseAccessor, FileStorage ) print("✅ domain.aigc.shared 导入成功") from domain.aigc.engines.poster_generate import PosterGenerateEngine from domain.aigc.engines.content_generate import ContentGenerateEngine from domain.aigc.engines.topic_generate import TopicGenerateEngine print("✅ 所有引擎导入成功") return True except Exception as e: print(f"❌ 导入失败: {e}") return False def test_engine_registry(): """测试引擎注册表""" print("\n" + "=" * 50) print("测试引擎注册表") print("=" * 50) try: from domain.aigc import EngineRegistry registry = EngineRegistry() registry.auto_discover() engines = registry.list_engines() print(f"✅ 发现 {len(engines)} 个引擎:") for engine in engines: print(f" - {engine['id']}: {engine['name']} (v{engine['version']})") # 测试获取引擎 for engine_id in ['poster_generate', 'content_generate', 'topic_generate']: if registry.has(engine_id): print(f"✅ 引擎 {engine_id} 已注册") else: print(f"❌ 引擎 {engine_id} 未注册") return True except Exception as e: print(f"❌ 测试失败: {e}") import traceback traceback.print_exc() return False def test_component_factory(): """测试组件工厂""" print("\n" + "=" * 50) print("测试组件工厂") print("=" * 50) try: from domain.aigc.shared import ComponentFactory factory = ComponentFactory() print(f"✅ 项目根目录: {factory.project_root}") print(f"✅ 路径配置: {list(factory.paths.keys())}") components = factory.create_components() print(f"✅ 创建组件: {list(components.keys())}") # 测试各组件 assert components['llm'] is not None, "LLM 客户端为空" assert components['prompt'] is not None, "提示词构建器为空" assert components['image'] is not None, "图片处理器为空" assert components['db'] is not None, "数据库访问器为空" assert components['storage'] is not None, "文件存储为空" print("✅ 所有组件创建成功") return True except Exception as e: print(f"❌ 测试失败: {e}") import traceback traceback.print_exc() return False def test_engine_param_schema(): """测试引擎参数 Schema""" print("\n" + "=" * 50) print("测试引擎参数 Schema") print("=" * 50) try: from domain.aigc import EngineRegistry registry = EngineRegistry() registry.auto_discover() for engine_id in ['poster_generate', 'content_generate', 'topic_generate']: info = registry.get_engine_info(engine_id) if info: print(f"\n📋 {engine_id} 参数:") schema = info.get('param_schema', {}) for param_name, param_def in schema.items(): required = "必填" if param_def.get('required') else "可选" param_type = param_def.get('type', 'any') desc = param_def.get('desc', '') print(f" - {param_name} ({param_type}, {required}): {desc}") return True except Exception as e: print(f"❌ 测试失败: {e}") import traceback traceback.print_exc() return False def test_param_validation(): """测试参数验证""" print("\n" + "=" * 50) print("测试参数验证") print("=" * 50) try: from domain.aigc import EngineRegistry from domain.aigc.shared import ComponentFactory registry = EngineRegistry() factory = ComponentFactory() components = factory.create_components() registry.set_shared_components(components) registry.auto_discover() # 获取海报引擎 engine = registry.get('poster_generate') # 测试有效参数 valid_params = { 'template_id': 'vibrant', 'images': ['base64_image_data'] } is_valid, error = engine.validate_params(valid_params) print(f"✅ 有效参数验证: {is_valid}") # 测试缺少必填参数 invalid_params = { 'template_id': 'vibrant' # 缺少 images } is_valid, error = engine.validate_params(invalid_params) print(f"✅ 无效参数验证: {not is_valid}, 错误: {error}") return True except Exception as e: print(f"❌ 测试失败: {e}") import traceback traceback.print_exc() return False def test_api_router(): """测试 API 路由""" print("\n" + "=" * 50) print("测试 API 路由") print("=" * 50) try: from api.routers import aigc print(f"✅ 路由前缀: {aigc.router.prefix}") print(f"✅ 路由数量: {len(aigc.router.routes)}") print("✅ 路由列表:") for route in aigc.router.routes: methods = getattr(route, 'methods', set()) path = getattr(route, 'path', '') print(f" - {methods} {path}") return True except Exception as e: print(f"❌ 测试失败: {e}") import traceback traceback.print_exc() return False def main(): """运行所有测试""" print("\n" + "=" * 60) print(" AIGC 引擎系统测试") print("=" * 60) results = [] results.append(("模块导入", test_imports())) results.append(("引擎注册表", test_engine_registry())) results.append(("组件工厂", test_component_factory())) results.append(("参数 Schema", test_engine_param_schema())) results.append(("参数验证", test_param_validation())) results.append(("API 路由", test_api_router())) print("\n" + "=" * 60) print(" 测试结果汇总") print("=" * 60) passed = 0 failed = 0 for name, result in results: status = "✅ 通过" if result else "❌ 失败" print(f" {name}: {status}") if result: passed += 1 else: failed += 1 print(f"\n 总计: {passed} 通过, {failed} 失败") print("=" * 60) return failed == 0 if __name__ == "__main__": success = main() sys.exit(0 if success else 1)