TravelContentCreator/tests/test_v2_api_e2e.py

566 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
V2 API 端到端测试
模拟 Java 端调用 Python V2 接口的完整流程
用于验证 Python 端改造完成后的功能正确性
使用方法:
# 启动服务后测试
python tests/test_v2_api_e2e.py --server http://localhost:8000
# 直接测试 (不启动服务)
python tests/test_v2_api_e2e.py --direct
"""
import asyncio
import json
import argparse
import sys
from typing import Dict, Any, Optional
from datetime import datetime
# ========== 测试数据 (模拟 Java 端传入的完整对象) ==========
SAMPLE_SCENIC_SPOT = {
"id": 1,
"name": "天津冒险湾",
"description": "天津最大的水上乐园,拥有多种刺激水上项目和亲子设施,包括水上过山车、漂流河、儿童戏水区等",
"address": "天津市滨海新区海滨大道",
"location": "天津市",
"traffic_info": "地铁9号线直达自驾可走津滨高速",
"highlights": ["水上过山车", "儿童戏水区", "漂流河", "海浪池"],
"opening_hours": "09:00-18:00 (夏季延长至21:00)",
"ticket_info": "成人票 199 元,儿童票 99 元",
"tips": "建议自带泳衣,园区内也有售卖;建议避开周末高峰"
}
SAMPLE_PRODUCT = {
"id": 10,
"name": "家庭套票",
"price": 299,
"original_price": 399,
"description": "含2大1小门票赠送储物柜一个",
"includes": ["2张成人票", "1张儿童票", "储物柜1个"],
"valid_period": "2025-01-01 至 2025-03-31",
"usage_rules": "需提前1天预约入园当日有效"
}
SAMPLE_STYLE = {
"id": "gonglue",
"name": "攻略风"
}
SAMPLE_AUDIENCE = {
"id": "qinzi",
"name": "亲子向"
}
def print_section(title: str):
"""打印分隔标题"""
print("\n" + "=" * 60)
print(f" {title}")
print("=" * 60)
def print_json(data: Any, title: str = None):
"""格式化打印 JSON"""
if title:
print(f"\n{title}:")
print(json.dumps(data, ensure_ascii=False, indent=2))
# ========== 直接测试 (不启动服务) ==========
async def test_direct_config_api():
"""测试配置查询 API (直接调用)"""
print_section("测试配置查询 API")
from api.routers.aigc import get_styles, get_audiences, get_all_config
# 测试风格列表
print("\n1. 获取风格列表...")
styles = await get_styles()
print(f" ✅ 获取到 {styles['count']} 个风格:")
for s in styles['styles']:
print(f" {s['icon']} {s['id']}: {s['name']}")
# 测试人群列表
print("\n2. 获取人群列表...")
audiences = await get_audiences()
print(f" ✅ 获取到 {audiences['count']} 个人群:")
for a in audiences['audiences']:
print(f" {a['icon']} {a['id']}: {a['name']}")
# 测试全部配置
print("\n3. 获取全部配置...")
all_config = await get_all_config()
print(f" ✅ 风格: {all_config['styles_count']}, 人群: {all_config['audiences_count']}")
return True
async def test_direct_topic_generate():
"""测试选题生成 (直接调用引擎)"""
print_section("测试选题生成引擎")
from domain.aigc import EngineRegistry
from domain.aigc.shared import ComponentFactory
# 初始化
print("\n1. 初始化引擎...")
registry = EngineRegistry()
factory = ComponentFactory()
# 创建共享组件 (使用 Mock LLM)
components = factory.create_components()
registry.set_shared_components(components)
registry.auto_discover()
# 获取引擎
engine = registry.get('topic_generate')
if not engine:
print(" ❌ 引擎不存在")
return False
print(f" ✅ 引擎已加载: {engine.engine_name} v{engine.version}")
# 构建请求参数 (模拟 Java 端传入)
params = {
"month": "2025-01",
"num_topics": 3,
"scenic_spot": SAMPLE_SCENIC_SPOT,
"product": SAMPLE_PRODUCT,
"style": SAMPLE_STYLE,
"audience": SAMPLE_AUDIENCE,
"prompt_version": "latest"
}
print("\n2. 请求参数:")
print(f" 月份: {params['month']}")
print(f" 数量: {params['num_topics']}")
print(f" 景区: {params['scenic_spot']['name']}")
print(f" 产品: {params['product']['name']}")
print(f" 风格: {params['style']['name']}")
print(f" 人群: {params['audience']['name']}")
# 验证参数
print("\n3. 验证参数...")
valid, error = engine.validate_params(params)
if not valid:
print(f" ❌ 参数验证失败: {error}")
return False
print(" ✅ 参数验证通过")
# 测试 Prompt 渲染
print("\n4. 测试 Prompt 渲染...")
from domain.prompt import PromptRegistry
prompt_registry = PromptRegistry('prompts')
context = {
'num_topics': params['num_topics'],
'month': params['month'],
'scenic_spot': params['scenic_spot'],
'product': params['product'],
'style': params['style'],
'audience': params['audience'],
'styles_list': '',
'audiences_list': '',
}
system_prompt, user_prompt = prompt_registry.render('topic_generate', context)
print(f" System Prompt 长度: {len(system_prompt)} 字符")
print(f" User Prompt 长度: {len(user_prompt)} 字符")
print(f" User Prompt 预览:\n {user_prompt[:300]}...")
# 验证关键信息是否在 prompt 中
assert SAMPLE_SCENIC_SPOT['name'] in user_prompt, "景区名称未出现在 prompt 中"
assert SAMPLE_PRODUCT['name'] in user_prompt, "产品名称未出现在 prompt 中"
print(" ✅ Prompt 渲染正确,包含所有关键信息")
return True
async def test_direct_content_generate():
"""测试内容生成 (直接调用引擎)"""
print_section("测试内容生成引擎")
from domain.aigc import EngineRegistry
from domain.aigc.shared import ComponentFactory
# 初始化
print("\n1. 初始化引擎...")
registry = EngineRegistry()
factory = ComponentFactory()
components = factory.create_components()
registry.set_shared_components(components)
registry.auto_discover()
engine = registry.get('content_generate')
if not engine:
print(" ❌ 引擎不存在")
return False
print(f" ✅ 引擎已加载: {engine.engine_name} v{engine.version}")
# 构建请求参数
params = {
"topic": {
"index": 1,
"date": "2025-01-15",
"title": "寒假遛娃好去处!天津冒险湾家庭套票超值体验",
"object": "天津冒险湾",
"product": "家庭套票",
"style": "攻略风",
"targetAudience": "亲子向",
"logic": "寒假期间家庭出游需求旺盛,水上乐园是热门选择"
},
"scenic_spot": SAMPLE_SCENIC_SPOT,
"product": SAMPLE_PRODUCT,
"style": SAMPLE_STYLE,
"audience": SAMPLE_AUDIENCE,
"enable_judge": False, # 跳过审核以加快测试
"prompt_version": "latest"
}
print("\n2. 请求参数:")
print(f" 选题: {params['topic']['title']}")
print(f" 景区: {params['scenic_spot']['name']}")
print(f" 产品: {params['product']['name']}")
# 验证参数
print("\n3. 验证参数...")
valid, error = engine.validate_params(params)
if not valid:
print(f" ❌ 参数验证失败: {error}")
return False
print(" ✅ 参数验证通过")
# 测试 Prompt 渲染
print("\n4. 测试 Prompt 渲染...")
from domain.prompt import PromptRegistry
prompt_registry = PromptRegistry('prompts')
context = {
'style_content': f"{params['style']['name']}",
'demand_content': f"{params['audience']['name']}",
'object_content': f"{params['scenic_spot']['name']}\n{params['scenic_spot']['description']}",
'product_content': f"{params['product']['name']}\n价格: {params['product']['price']}",
'refer_content': '',
}
system_prompt, user_prompt = prompt_registry.render('content_generate', context)
print(f" System Prompt 长度: {len(system_prompt)} 字符")
print(f" User Prompt 长度: {len(user_prompt)} 字符")
print(" ✅ Prompt 渲染成功")
return True
async def test_direct_execute_api():
"""测试执行 API (直接调用)"""
print_section("测试执行 API")
from api.routers.aigc import execute_engine, ExecuteRequest
# 构建请求
request = ExecuteRequest(
engine="topic_generate",
params={
"month": "2025-01",
"num_topics": 2,
"scenic_spot": SAMPLE_SCENIC_SPOT,
"product": SAMPLE_PRODUCT,
"style": SAMPLE_STYLE,
"audience": SAMPLE_AUDIENCE,
},
async_mode=False
)
print("\n请求体:")
print_json({
"engine": request.engine,
"params": request.params,
"async_mode": request.async_mode
})
print("\n执行中... (需要 LLM 调用,可能需要等待)")
print("(如果没有配置 LLM会返回错误这是预期行为)")
try:
response = await execute_engine(request)
print("\n响应:")
print_json({
"success": response.success,
"data": response.data,
"error": response.error
})
return response.success
except Exception as e:
print(f"\n⚠️ 执行失败 (预期行为,因为没有真实 LLM): {e}")
return True # 这是预期的,因为没有配置真实 LLM
# ========== HTTP 测试 (需要启动服务) ==========
async def test_http_api(base_url: str):
"""通过 HTTP 测试 API"""
import aiohttp
print_section(f"HTTP API 测试 ({base_url})")
async with aiohttp.ClientSession() as session:
# 1. 测试健康检查
print("\n1. 健康检查...")
try:
async with session.get(f"{base_url}/") as resp:
if resp.status == 200:
print(" ✅ 服务正常")
else:
print(f" ❌ 服务异常: {resp.status}")
return False
except Exception as e:
print(f" ❌ 连接失败: {e}")
return False
# 2. 测试配置 API
print("\n2. 获取风格配置...")
async with session.get(f"{base_url}/api/v2/aigc/config/styles") as resp:
data = await resp.json()
print(f" ✅ 获取到 {data['count']} 个风格")
print("\n3. 获取人群配置...")
async with session.get(f"{base_url}/api/v2/aigc/config/audiences") as resp:
data = await resp.json()
print(f" ✅ 获取到 {data['count']} 个人群")
# 3. 测试引擎列表
print("\n4. 获取引擎列表...")
async with session.get(f"{base_url}/api/v2/aigc/engines") as resp:
data = await resp.json()
print(f" ✅ 获取到 {data['count']} 个引擎:")
for e in data['engines']:
print(f" - {e['id']}: {e['name']}")
# 4. 测试执行 API
print("\n5. 测试选题生成...")
request_body = {
"engine": "topic_generate",
"params": {
"month": "2025-01",
"num_topics": 2,
"scenic_spot": SAMPLE_SCENIC_SPOT,
"product": SAMPLE_PRODUCT,
"style": SAMPLE_STYLE,
"audience": SAMPLE_AUDIENCE,
},
"async_mode": False
}
async with session.post(
f"{base_url}/api/v2/aigc/execute",
json=request_body,
timeout=aiohttp.ClientTimeout(total=120)
) as resp:
data = await resp.json()
if data.get('success'):
print(" ✅ 选题生成成功!")
print_json(data['data'], " 生成结果")
else:
print(f" ⚠️ 生成失败: {data.get('error')}")
return True
# ========== 生成 Java 端调用示例 ==========
def generate_java_example():
"""生成 Java 端调用示例代码"""
print_section("Java 端调用示例")
example = '''
// ========== Java 端调用 Python V2 API 示例 ==========
// 1. 获取风格/人群配置 (启动时缓存)
@Service
public class AIGCConfigService {
@Autowired
private ExternalServiceClient pythonClient;
public List<StyleConfig> getStyles() {
Map<String, Object> response = pythonClient.get(
"content-generate",
"/api/v2/aigc/config/styles",
Map.class
);
return parseStyles(response);
}
public List<AudienceConfig> getAudiences() {
Map<String, Object> response = pythonClient.get(
"content-generate",
"/api/v2/aigc/config/audiences",
Map.class
);
return parseAudiences(response);
}
}
// 2. 调用选题生成
public TopicGenerateResponse generateTopics(TopicGenerateRequest request) {
// 从数据库查询完整对象
ScenicSpot scenicSpot = scenicSpotService.getById(request.getScenicSpotId());
Product product = productService.getById(request.getProductId());
// 构建 V2 请求
Map<String, Object> requestBody = new HashMap<>();
requestBody.put("engine", "topic_generate");
requestBody.put("params", Map.of(
"month", request.getMonth(),
"num_topics", request.getNumTopics(),
"scenic_spot", Map.of(
"id", scenicSpot.getId(),
"name", scenicSpot.getName(),
"description", scenicSpot.getDescription(),
"address", scenicSpot.getAddress(),
"highlights", scenicSpot.getHighlights()
),
"product", Map.of(
"id", product.getId(),
"name", product.getName(),
"price", product.getPrice(),
"description", product.getDescription()
),
"style", Map.of(
"id", request.getStyleId(),
"name", styleConfigService.getById(request.getStyleId()).getName()
),
"audience", Map.of(
"id", request.getAudienceId(),
"name", audienceConfigService.getById(request.getAudienceId()).getName()
)
));
requestBody.put("async_mode", false);
// 调用 Python V2 API
Map<String, Object> response = pythonClient.post(
"content-generate",
"/api/v2/aigc/execute",
requestBody,
Map.class
);
// 解析响应
if ((Boolean) response.get("success")) {
Map<String, Object> data = (Map<String, Object>) response.get("data");
List<Map<String, Object>> topics = (List<Map<String, Object>>) data.get("topics");
return TopicGenerateResponse.fromTopics(topics);
} else {
throw new BusinessException(ErrorCode.OPERATION_ERROR,
(String) response.get("error"));
}
}
// 3. 调用内容生成
public ContentGenerateResponse generateContent(ContentGenerateRequest request) {
// 从数据库查询完整对象
ScenicSpot scenicSpot = scenicSpotService.getById(request.getScenicSpotId());
Product product = productService.getById(request.getProductId());
Map<String, Object> requestBody = new HashMap<>();
requestBody.put("engine", "content_generate");
requestBody.put("params", Map.of(
"topic", request.getTopic(), // 选题信息
"scenic_spot", convertToMap(scenicSpot),
"product", convertToMap(product),
"style", Map.of("id", request.getStyleId(), "name", "攻略风"),
"audience", Map.of("id", request.getAudienceId(), "name", "亲子向"),
"enable_judge", true
));
return pythonClient.post(
"content-generate",
"/api/v2/aigc/execute",
requestBody,
ContentGenerateResponse.class
);
}
'''
print(example)
# ========== 主函数 ==========
async def main():
parser = argparse.ArgumentParser(description='V2 API 端到端测试')
parser.add_argument('--server', type=str, help='服务器地址 (如 http://localhost:8000)')
parser.add_argument('--direct', action='store_true', help='直接测试 (不启动服务)')
parser.add_argument('--java-example', action='store_true', help='生成 Java 调用示例')
args = parser.parse_args()
print("\n" + "=" * 60)
print(" TravelContentCreator V2 API 端到端测试")
print("=" * 60)
print(f" 时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
results = []
if args.java_example:
generate_java_example()
return
if args.server:
# HTTP 测试
success = await test_http_api(args.server)
results.append(("HTTP API 测试", success))
else:
# 直接测试
print("\n模式: 直接测试 (不启动服务)")
try:
success = await test_direct_config_api()
results.append(("配置查询 API", success))
except Exception as e:
print(f"❌ 配置查询 API 测试失败: {e}")
results.append(("配置查询 API", False))
try:
success = await test_direct_topic_generate()
results.append(("选题生成引擎", success))
except Exception as e:
print(f"❌ 选题生成引擎测试失败: {e}")
results.append(("选题生成引擎", False))
try:
success = await test_direct_content_generate()
results.append(("内容生成引擎", success))
except Exception as e:
print(f"❌ 内容生成引擎测试失败: {e}")
results.append(("内容生成引擎", False))
# 打印结果汇总
print_section("测试结果汇总")
passed = 0
for name, success in results:
status = "" if success else ""
print(f" {status} {name}")
if success:
passed += 1
print(f"\n总计: {passed}/{len(results)} 通过")
# 生成 Java 示例
if args.direct or args.server:
print("\n" + "-" * 60)
print("提示: 运行 --java-example 查看 Java 端调用示例")
return passed == len(results)
if __name__ == "__main__":
success = asyncio.run(main())
sys.exit(0 if success else 1)