566 lines
18 KiB
Python
566 lines
18 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
V2 API 端到端测试
|
||
|
||
模拟 Java 端调用 Python V2 接口的完整流程
|
||
用于验证 Python 端改造完成后的功能正确性
|
||
|
||
使用方法:
|
||
# 启动服务后测试
|
||
python tests/test_v2_api_e2e.py --server http://localhost:8000
|
||
|
||
# 直接测试 (不启动服务)
|
||
python tests/test_v2_api_e2e.py --direct
|
||
"""
|
||
|
||
import asyncio
|
||
import json
|
||
import argparse
|
||
import sys
|
||
from typing import Dict, Any, Optional
|
||
from datetime import datetime
|
||
|
||
# ========== 测试数据 (模拟 Java 端传入的完整对象) ==========
|
||
|
||
SAMPLE_SCENIC_SPOT = {
|
||
"id": 1,
|
||
"name": "天津冒险湾",
|
||
"description": "天津最大的水上乐园,拥有多种刺激水上项目和亲子设施,包括水上过山车、漂流河、儿童戏水区等",
|
||
"address": "天津市滨海新区海滨大道",
|
||
"location": "天津市",
|
||
"traffic_info": "地铁9号线直达,自驾可走津滨高速",
|
||
"highlights": ["水上过山车", "儿童戏水区", "漂流河", "海浪池"],
|
||
"opening_hours": "09:00-18:00 (夏季延长至21:00)",
|
||
"ticket_info": "成人票 199 元,儿童票 99 元",
|
||
"tips": "建议自带泳衣,园区内也有售卖;建议避开周末高峰"
|
||
}
|
||
|
||
SAMPLE_PRODUCT = {
|
||
"id": 10,
|
||
"name": "家庭套票",
|
||
"price": 299,
|
||
"original_price": 399,
|
||
"description": "含2大1小门票,赠送储物柜一个",
|
||
"includes": ["2张成人票", "1张儿童票", "储物柜1个"],
|
||
"valid_period": "2025-01-01 至 2025-03-31",
|
||
"usage_rules": "需提前1天预约,入园当日有效"
|
||
}
|
||
|
||
SAMPLE_STYLE = {
|
||
"id": "gonglue",
|
||
"name": "攻略风"
|
||
}
|
||
|
||
SAMPLE_AUDIENCE = {
|
||
"id": "qinzi",
|
||
"name": "亲子向"
|
||
}
|
||
|
||
|
||
def print_section(title: str):
|
||
"""打印分隔标题"""
|
||
print("\n" + "=" * 60)
|
||
print(f" {title}")
|
||
print("=" * 60)
|
||
|
||
|
||
def print_json(data: Any, title: str = None):
|
||
"""格式化打印 JSON"""
|
||
if title:
|
||
print(f"\n{title}:")
|
||
print(json.dumps(data, ensure_ascii=False, indent=2))
|
||
|
||
|
||
# ========== 直接测试 (不启动服务) ==========
|
||
|
||
async def test_direct_config_api():
|
||
"""测试配置查询 API (直接调用)"""
|
||
print_section("测试配置查询 API")
|
||
|
||
from api.routers.aigc import get_styles, get_audiences, get_all_config
|
||
|
||
# 测试风格列表
|
||
print("\n1. 获取风格列表...")
|
||
styles = await get_styles()
|
||
print(f" ✅ 获取到 {styles['count']} 个风格:")
|
||
for s in styles['styles']:
|
||
print(f" {s['icon']} {s['id']}: {s['name']}")
|
||
|
||
# 测试人群列表
|
||
print("\n2. 获取人群列表...")
|
||
audiences = await get_audiences()
|
||
print(f" ✅ 获取到 {audiences['count']} 个人群:")
|
||
for a in audiences['audiences']:
|
||
print(f" {a['icon']} {a['id']}: {a['name']}")
|
||
|
||
# 测试全部配置
|
||
print("\n3. 获取全部配置...")
|
||
all_config = await get_all_config()
|
||
print(f" ✅ 风格: {all_config['styles_count']}, 人群: {all_config['audiences_count']}")
|
||
|
||
return True
|
||
|
||
|
||
async def test_direct_topic_generate():
|
||
"""测试选题生成 (直接调用引擎)"""
|
||
print_section("测试选题生成引擎")
|
||
|
||
from domain.aigc import EngineRegistry
|
||
from domain.aigc.shared import ComponentFactory
|
||
|
||
# 初始化
|
||
print("\n1. 初始化引擎...")
|
||
registry = EngineRegistry()
|
||
factory = ComponentFactory()
|
||
|
||
# 创建共享组件 (使用 Mock LLM)
|
||
components = factory.create_components()
|
||
registry.set_shared_components(components)
|
||
registry.auto_discover()
|
||
|
||
# 获取引擎
|
||
engine = registry.get('topic_generate')
|
||
if not engine:
|
||
print(" ❌ 引擎不存在")
|
||
return False
|
||
print(f" ✅ 引擎已加载: {engine.engine_name} v{engine.version}")
|
||
|
||
# 构建请求参数 (模拟 Java 端传入)
|
||
params = {
|
||
"month": "2025-01",
|
||
"num_topics": 3,
|
||
"scenic_spot": SAMPLE_SCENIC_SPOT,
|
||
"product": SAMPLE_PRODUCT,
|
||
"style": SAMPLE_STYLE,
|
||
"audience": SAMPLE_AUDIENCE,
|
||
"prompt_version": "latest"
|
||
}
|
||
|
||
print("\n2. 请求参数:")
|
||
print(f" 月份: {params['month']}")
|
||
print(f" 数量: {params['num_topics']}")
|
||
print(f" 景区: {params['scenic_spot']['name']}")
|
||
print(f" 产品: {params['product']['name']}")
|
||
print(f" 风格: {params['style']['name']}")
|
||
print(f" 人群: {params['audience']['name']}")
|
||
|
||
# 验证参数
|
||
print("\n3. 验证参数...")
|
||
valid, error = engine.validate_params(params)
|
||
if not valid:
|
||
print(f" ❌ 参数验证失败: {error}")
|
||
return False
|
||
print(" ✅ 参数验证通过")
|
||
|
||
# 测试 Prompt 渲染
|
||
print("\n4. 测试 Prompt 渲染...")
|
||
from domain.prompt import PromptRegistry
|
||
prompt_registry = PromptRegistry('prompts')
|
||
|
||
context = {
|
||
'num_topics': params['num_topics'],
|
||
'month': params['month'],
|
||
'scenic_spot': params['scenic_spot'],
|
||
'product': params['product'],
|
||
'style': params['style'],
|
||
'audience': params['audience'],
|
||
'styles_list': '',
|
||
'audiences_list': '',
|
||
}
|
||
|
||
system_prompt, user_prompt = prompt_registry.render('topic_generate', context)
|
||
print(f" System Prompt 长度: {len(system_prompt)} 字符")
|
||
print(f" User Prompt 长度: {len(user_prompt)} 字符")
|
||
print(f" User Prompt 预览:\n {user_prompt[:300]}...")
|
||
|
||
# 验证关键信息是否在 prompt 中
|
||
assert SAMPLE_SCENIC_SPOT['name'] in user_prompt, "景区名称未出现在 prompt 中"
|
||
assert SAMPLE_PRODUCT['name'] in user_prompt, "产品名称未出现在 prompt 中"
|
||
print(" ✅ Prompt 渲染正确,包含所有关键信息")
|
||
|
||
return True
|
||
|
||
|
||
async def test_direct_content_generate():
|
||
"""测试内容生成 (直接调用引擎)"""
|
||
print_section("测试内容生成引擎")
|
||
|
||
from domain.aigc import EngineRegistry
|
||
from domain.aigc.shared import ComponentFactory
|
||
|
||
# 初始化
|
||
print("\n1. 初始化引擎...")
|
||
registry = EngineRegistry()
|
||
factory = ComponentFactory()
|
||
components = factory.create_components()
|
||
registry.set_shared_components(components)
|
||
registry.auto_discover()
|
||
|
||
engine = registry.get('content_generate')
|
||
if not engine:
|
||
print(" ❌ 引擎不存在")
|
||
return False
|
||
print(f" ✅ 引擎已加载: {engine.engine_name} v{engine.version}")
|
||
|
||
# 构建请求参数
|
||
params = {
|
||
"topic": {
|
||
"index": 1,
|
||
"date": "2025-01-15",
|
||
"title": "寒假遛娃好去处!天津冒险湾家庭套票超值体验",
|
||
"object": "天津冒险湾",
|
||
"product": "家庭套票",
|
||
"style": "攻略风",
|
||
"targetAudience": "亲子向",
|
||
"logic": "寒假期间家庭出游需求旺盛,水上乐园是热门选择"
|
||
},
|
||
"scenic_spot": SAMPLE_SCENIC_SPOT,
|
||
"product": SAMPLE_PRODUCT,
|
||
"style": SAMPLE_STYLE,
|
||
"audience": SAMPLE_AUDIENCE,
|
||
"enable_judge": False, # 跳过审核以加快测试
|
||
"prompt_version": "latest"
|
||
}
|
||
|
||
print("\n2. 请求参数:")
|
||
print(f" 选题: {params['topic']['title']}")
|
||
print(f" 景区: {params['scenic_spot']['name']}")
|
||
print(f" 产品: {params['product']['name']}")
|
||
|
||
# 验证参数
|
||
print("\n3. 验证参数...")
|
||
valid, error = engine.validate_params(params)
|
||
if not valid:
|
||
print(f" ❌ 参数验证失败: {error}")
|
||
return False
|
||
print(" ✅ 参数验证通过")
|
||
|
||
# 测试 Prompt 渲染
|
||
print("\n4. 测试 Prompt 渲染...")
|
||
from domain.prompt import PromptRegistry
|
||
prompt_registry = PromptRegistry('prompts')
|
||
|
||
context = {
|
||
'style_content': f"{params['style']['name']}",
|
||
'demand_content': f"{params['audience']['name']}",
|
||
'object_content': f"{params['scenic_spot']['name']}\n{params['scenic_spot']['description']}",
|
||
'product_content': f"{params['product']['name']}\n价格: {params['product']['price']}",
|
||
'refer_content': '',
|
||
}
|
||
|
||
system_prompt, user_prompt = prompt_registry.render('content_generate', context)
|
||
print(f" System Prompt 长度: {len(system_prompt)} 字符")
|
||
print(f" User Prompt 长度: {len(user_prompt)} 字符")
|
||
print(" ✅ Prompt 渲染成功")
|
||
|
||
return True
|
||
|
||
|
||
async def test_direct_execute_api():
|
||
"""测试执行 API (直接调用)"""
|
||
print_section("测试执行 API")
|
||
|
||
from api.routers.aigc import execute_engine, ExecuteRequest
|
||
|
||
# 构建请求
|
||
request = ExecuteRequest(
|
||
engine="topic_generate",
|
||
params={
|
||
"month": "2025-01",
|
||
"num_topics": 2,
|
||
"scenic_spot": SAMPLE_SCENIC_SPOT,
|
||
"product": SAMPLE_PRODUCT,
|
||
"style": SAMPLE_STYLE,
|
||
"audience": SAMPLE_AUDIENCE,
|
||
},
|
||
async_mode=False
|
||
)
|
||
|
||
print("\n请求体:")
|
||
print_json({
|
||
"engine": request.engine,
|
||
"params": request.params,
|
||
"async_mode": request.async_mode
|
||
})
|
||
|
||
print("\n执行中... (需要 LLM 调用,可能需要等待)")
|
||
print("(如果没有配置 LLM,会返回错误,这是预期行为)")
|
||
|
||
try:
|
||
response = await execute_engine(request)
|
||
print("\n响应:")
|
||
print_json({
|
||
"success": response.success,
|
||
"data": response.data,
|
||
"error": response.error
|
||
})
|
||
return response.success
|
||
except Exception as e:
|
||
print(f"\n⚠️ 执行失败 (预期行为,因为没有真实 LLM): {e}")
|
||
return True # 这是预期的,因为没有配置真实 LLM
|
||
|
||
|
||
# ========== HTTP 测试 (需要启动服务) ==========
|
||
|
||
async def test_http_api(base_url: str):
|
||
"""通过 HTTP 测试 API"""
|
||
import aiohttp
|
||
|
||
print_section(f"HTTP API 测试 ({base_url})")
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
# 1. 测试健康检查
|
||
print("\n1. 健康检查...")
|
||
try:
|
||
async with session.get(f"{base_url}/") as resp:
|
||
if resp.status == 200:
|
||
print(" ✅ 服务正常")
|
||
else:
|
||
print(f" ❌ 服务异常: {resp.status}")
|
||
return False
|
||
except Exception as e:
|
||
print(f" ❌ 连接失败: {e}")
|
||
return False
|
||
|
||
# 2. 测试配置 API
|
||
print("\n2. 获取风格配置...")
|
||
async with session.get(f"{base_url}/api/v2/aigc/config/styles") as resp:
|
||
data = await resp.json()
|
||
print(f" ✅ 获取到 {data['count']} 个风格")
|
||
|
||
print("\n3. 获取人群配置...")
|
||
async with session.get(f"{base_url}/api/v2/aigc/config/audiences") as resp:
|
||
data = await resp.json()
|
||
print(f" ✅ 获取到 {data['count']} 个人群")
|
||
|
||
# 3. 测试引擎列表
|
||
print("\n4. 获取引擎列表...")
|
||
async with session.get(f"{base_url}/api/v2/aigc/engines") as resp:
|
||
data = await resp.json()
|
||
print(f" ✅ 获取到 {data['count']} 个引擎:")
|
||
for e in data['engines']:
|
||
print(f" - {e['id']}: {e['name']}")
|
||
|
||
# 4. 测试执行 API
|
||
print("\n5. 测试选题生成...")
|
||
request_body = {
|
||
"engine": "topic_generate",
|
||
"params": {
|
||
"month": "2025-01",
|
||
"num_topics": 2,
|
||
"scenic_spot": SAMPLE_SCENIC_SPOT,
|
||
"product": SAMPLE_PRODUCT,
|
||
"style": SAMPLE_STYLE,
|
||
"audience": SAMPLE_AUDIENCE,
|
||
},
|
||
"async_mode": False
|
||
}
|
||
|
||
async with session.post(
|
||
f"{base_url}/api/v2/aigc/execute",
|
||
json=request_body,
|
||
timeout=aiohttp.ClientTimeout(total=120)
|
||
) as resp:
|
||
data = await resp.json()
|
||
if data.get('success'):
|
||
print(" ✅ 选题生成成功!")
|
||
print_json(data['data'], " 生成结果")
|
||
else:
|
||
print(f" ⚠️ 生成失败: {data.get('error')}")
|
||
|
||
return True
|
||
|
||
|
||
# ========== 生成 Java 端调用示例 ==========
|
||
|
||
def generate_java_example():
|
||
"""生成 Java 端调用示例代码"""
|
||
print_section("Java 端调用示例")
|
||
|
||
example = '''
|
||
// ========== Java 端调用 Python V2 API 示例 ==========
|
||
|
||
// 1. 获取风格/人群配置 (启动时缓存)
|
||
@Service
|
||
public class AIGCConfigService {
|
||
|
||
@Autowired
|
||
private ExternalServiceClient pythonClient;
|
||
|
||
public List<StyleConfig> getStyles() {
|
||
Map<String, Object> response = pythonClient.get(
|
||
"content-generate",
|
||
"/api/v2/aigc/config/styles",
|
||
Map.class
|
||
);
|
||
return parseStyles(response);
|
||
}
|
||
|
||
public List<AudienceConfig> getAudiences() {
|
||
Map<String, Object> response = pythonClient.get(
|
||
"content-generate",
|
||
"/api/v2/aigc/config/audiences",
|
||
Map.class
|
||
);
|
||
return parseAudiences(response);
|
||
}
|
||
}
|
||
|
||
// 2. 调用选题生成
|
||
public TopicGenerateResponse generateTopics(TopicGenerateRequest request) {
|
||
// 从数据库查询完整对象
|
||
ScenicSpot scenicSpot = scenicSpotService.getById(request.getScenicSpotId());
|
||
Product product = productService.getById(request.getProductId());
|
||
|
||
// 构建 V2 请求
|
||
Map<String, Object> requestBody = new HashMap<>();
|
||
requestBody.put("engine", "topic_generate");
|
||
requestBody.put("params", Map.of(
|
||
"month", request.getMonth(),
|
||
"num_topics", request.getNumTopics(),
|
||
"scenic_spot", Map.of(
|
||
"id", scenicSpot.getId(),
|
||
"name", scenicSpot.getName(),
|
||
"description", scenicSpot.getDescription(),
|
||
"address", scenicSpot.getAddress(),
|
||
"highlights", scenicSpot.getHighlights()
|
||
),
|
||
"product", Map.of(
|
||
"id", product.getId(),
|
||
"name", product.getName(),
|
||
"price", product.getPrice(),
|
||
"description", product.getDescription()
|
||
),
|
||
"style", Map.of(
|
||
"id", request.getStyleId(),
|
||
"name", styleConfigService.getById(request.getStyleId()).getName()
|
||
),
|
||
"audience", Map.of(
|
||
"id", request.getAudienceId(),
|
||
"name", audienceConfigService.getById(request.getAudienceId()).getName()
|
||
)
|
||
));
|
||
requestBody.put("async_mode", false);
|
||
|
||
// 调用 Python V2 API
|
||
Map<String, Object> response = pythonClient.post(
|
||
"content-generate",
|
||
"/api/v2/aigc/execute",
|
||
requestBody,
|
||
Map.class
|
||
);
|
||
|
||
// 解析响应
|
||
if ((Boolean) response.get("success")) {
|
||
Map<String, Object> data = (Map<String, Object>) response.get("data");
|
||
List<Map<String, Object>> topics = (List<Map<String, Object>>) data.get("topics");
|
||
return TopicGenerateResponse.fromTopics(topics);
|
||
} else {
|
||
throw new BusinessException(ErrorCode.OPERATION_ERROR,
|
||
(String) response.get("error"));
|
||
}
|
||
}
|
||
|
||
// 3. 调用内容生成
|
||
public ContentGenerateResponse generateContent(ContentGenerateRequest request) {
|
||
// 从数据库查询完整对象
|
||
ScenicSpot scenicSpot = scenicSpotService.getById(request.getScenicSpotId());
|
||
Product product = productService.getById(request.getProductId());
|
||
|
||
Map<String, Object> requestBody = new HashMap<>();
|
||
requestBody.put("engine", "content_generate");
|
||
requestBody.put("params", Map.of(
|
||
"topic", request.getTopic(), // 选题信息
|
||
"scenic_spot", convertToMap(scenicSpot),
|
||
"product", convertToMap(product),
|
||
"style", Map.of("id", request.getStyleId(), "name", "攻略风"),
|
||
"audience", Map.of("id", request.getAudienceId(), "name", "亲子向"),
|
||
"enable_judge", true
|
||
));
|
||
|
||
return pythonClient.post(
|
||
"content-generate",
|
||
"/api/v2/aigc/execute",
|
||
requestBody,
|
||
ContentGenerateResponse.class
|
||
);
|
||
}
|
||
'''
|
||
print(example)
|
||
|
||
|
||
# ========== 主函数 ==========
|
||
|
||
async def main():
|
||
parser = argparse.ArgumentParser(description='V2 API 端到端测试')
|
||
parser.add_argument('--server', type=str, help='服务器地址 (如 http://localhost:8000)')
|
||
parser.add_argument('--direct', action='store_true', help='直接测试 (不启动服务)')
|
||
parser.add_argument('--java-example', action='store_true', help='生成 Java 调用示例')
|
||
args = parser.parse_args()
|
||
|
||
print("\n" + "=" * 60)
|
||
print(" TravelContentCreator V2 API 端到端测试")
|
||
print("=" * 60)
|
||
print(f" 时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||
|
||
results = []
|
||
|
||
if args.java_example:
|
||
generate_java_example()
|
||
return
|
||
|
||
if args.server:
|
||
# HTTP 测试
|
||
success = await test_http_api(args.server)
|
||
results.append(("HTTP API 测试", success))
|
||
else:
|
||
# 直接测试
|
||
print("\n模式: 直接测试 (不启动服务)")
|
||
|
||
try:
|
||
success = await test_direct_config_api()
|
||
results.append(("配置查询 API", success))
|
||
except Exception as e:
|
||
print(f"❌ 配置查询 API 测试失败: {e}")
|
||
results.append(("配置查询 API", False))
|
||
|
||
try:
|
||
success = await test_direct_topic_generate()
|
||
results.append(("选题生成引擎", success))
|
||
except Exception as e:
|
||
print(f"❌ 选题生成引擎测试失败: {e}")
|
||
results.append(("选题生成引擎", False))
|
||
|
||
try:
|
||
success = await test_direct_content_generate()
|
||
results.append(("内容生成引擎", success))
|
||
except Exception as e:
|
||
print(f"❌ 内容生成引擎测试失败: {e}")
|
||
results.append(("内容生成引擎", False))
|
||
|
||
# 打印结果汇总
|
||
print_section("测试结果汇总")
|
||
|
||
passed = 0
|
||
for name, success in results:
|
||
status = "✅" if success else "❌"
|
||
print(f" {status} {name}")
|
||
if success:
|
||
passed += 1
|
||
|
||
print(f"\n总计: {passed}/{len(results)} 通过")
|
||
|
||
# 生成 Java 示例
|
||
if args.direct or args.server:
|
||
print("\n" + "-" * 60)
|
||
print("提示: 运行 --java-example 查看 Java 端调用示例")
|
||
|
||
return passed == len(results)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
success = asyncio.run(main())
|
||
sys.exit(0 if success else 1)
|