295 lines
9.8 KiB
Python
295 lines
9.8 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
完整流程测试: 选题生成 → 正文生成
|
||
|
||
测试 V2 API 的完整业务流程:
|
||
1. 生成选题
|
||
2. 基于选题生成正文 (原创模式)
|
||
3. 基于选题生成正文 (reference 模式)
|
||
4. 基于选题生成正文 (rewrite 模式)
|
||
|
||
使用方法:
|
||
python tests/test_topic_to_content.py [--base-url http://localhost:8001]
|
||
"""
|
||
|
||
import argparse
|
||
import json
|
||
import requests
|
||
import time
|
||
from typing import Dict, Any, Optional
|
||
|
||
|
||
class Colors:
|
||
"""终端颜色"""
|
||
GREEN = '\033[92m'
|
||
RED = '\033[91m'
|
||
YELLOW = '\033[93m'
|
||
BLUE = '\033[94m'
|
||
CYAN = '\033[96m'
|
||
RESET = '\033[0m'
|
||
BOLD = '\033[1m'
|
||
|
||
|
||
def print_header(text: str):
|
||
print(f"\n{Colors.BOLD}{Colors.BLUE}{'='*60}{Colors.RESET}")
|
||
print(f"{Colors.BOLD}{Colors.BLUE}{text}{Colors.RESET}")
|
||
print(f"{Colors.BOLD}{Colors.BLUE}{'='*60}{Colors.RESET}")
|
||
|
||
|
||
def print_step(step: int, text: str):
|
||
print(f"\n{Colors.CYAN}[Step {step}] {text}{Colors.RESET}")
|
||
|
||
|
||
def print_success(text: str):
|
||
print(f"{Colors.GREEN}✅ {text}{Colors.RESET}")
|
||
|
||
|
||
def print_error(text: str):
|
||
print(f"{Colors.RED}❌ {text}{Colors.RESET}")
|
||
|
||
|
||
def print_info(text: str):
|
||
print(f"{Colors.YELLOW}ℹ️ {text}{Colors.RESET}")
|
||
|
||
|
||
def call_api(base_url: str, endpoint: str, method: str = "GET", data: Dict = None) -> Dict:
|
||
"""调用 API"""
|
||
url = f"{base_url}{endpoint}"
|
||
try:
|
||
if method == "GET":
|
||
resp = requests.get(url, timeout=120)
|
||
else:
|
||
resp = requests.post(url, json=data, timeout=120)
|
||
return resp.json()
|
||
except requests.exceptions.Timeout:
|
||
return {"success": False, "error": "请求超时"}
|
||
except Exception as e:
|
||
return {"success": False, "error": str(e)}
|
||
|
||
|
||
def test_topic_generate(base_url: str, subject: Dict, style: Dict, audience: Dict) -> Optional[Dict]:
|
||
"""测试选题生成"""
|
||
print_step(1, "生成选题")
|
||
|
||
params = {
|
||
"engine": "topic_generate",
|
||
"params": {
|
||
"month": "2025-02",
|
||
"num_topics": 2,
|
||
"prompt_version": "v2.0.0",
|
||
"subject": subject,
|
||
"style": style,
|
||
"audience": audience,
|
||
"hot_topics": {
|
||
"events": [],
|
||
"festivals": [
|
||
{"name": "春节", "date": "2025-01-29", "marketing_angle": "团圆、年味"}
|
||
],
|
||
"trending": []
|
||
}
|
||
}
|
||
}
|
||
|
||
print_info(f"请求参数: subject={subject['name']}, style={style['name']}, audience={audience['name']}")
|
||
|
||
start_time = time.time()
|
||
result = call_api(base_url, "/api/v2/aigc/execute", "POST", params)
|
||
elapsed = time.time() - start_time
|
||
|
||
if result.get("success"):
|
||
topics = result.get("data", {}).get("topics", [])
|
||
print_success(f"生成 {len(topics)} 个选题 (耗时 {elapsed:.1f}s)")
|
||
|
||
for i, topic in enumerate(topics):
|
||
print(f"\n {Colors.BOLD}选题 {i+1}:{Colors.RESET}")
|
||
print(f" 标题: {topic.get('title', 'N/A')}")
|
||
print(f" 日期: {topic.get('date', 'N/A')}")
|
||
print(f" 风格: {topic.get('style', 'N/A')}")
|
||
print(f" 人群: {topic.get('targetAudience', 'N/A')}")
|
||
print(f" 逻辑: {topic.get('logic', 'N/A')[:50]}...")
|
||
|
||
return topics[0] if topics else None
|
||
else:
|
||
print_error(f"选题生成失败: {result.get('error', 'Unknown error')}")
|
||
return None
|
||
|
||
|
||
def test_content_generate(base_url: str, topic: Dict, subject: Dict, style: Dict, audience: Dict,
|
||
reference: Optional[Dict] = None, mode_name: str = "原创") -> Optional[Dict]:
|
||
"""测试正文生成"""
|
||
print_step(2 if reference is None else (3 if reference.get('mode') == 'reference' else 4),
|
||
f"生成正文 ({mode_name}模式)")
|
||
|
||
params = {
|
||
"engine": "content_generate",
|
||
"params": {
|
||
"prompt_version": "v2.0.0",
|
||
"topic": topic,
|
||
"subject": subject,
|
||
"style": style,
|
||
"audience": audience,
|
||
"enable_judge": False
|
||
}
|
||
}
|
||
|
||
if reference:
|
||
params["params"]["reference"] = reference
|
||
print_info(f"参考模式: {reference.get('mode')}")
|
||
else:
|
||
print_info("原创模式,无参考内容")
|
||
|
||
start_time = time.time()
|
||
result = call_api(base_url, "/api/v2/aigc/execute", "POST", params)
|
||
elapsed = time.time() - start_time
|
||
|
||
if result.get("success"):
|
||
content = result.get("data", {}).get("content", {})
|
||
print_success(f"正文生成成功 (耗时 {elapsed:.1f}s)")
|
||
|
||
print(f"\n {Colors.BOLD}标题:{Colors.RESET} {content.get('title', 'N/A')}")
|
||
|
||
body = content.get('content', '')
|
||
# 截取前 200 字符显示
|
||
preview = body[:200] + "..." if len(body) > 200 else body
|
||
print(f"\n {Colors.BOLD}正文预览:{Colors.RESET}")
|
||
for line in preview.split('\n')[:8]:
|
||
print(f" {line}")
|
||
|
||
print(f"\n {Colors.BOLD}TAG:{Colors.RESET} {content.get('tag', 'N/A')}")
|
||
print(f"\n {Colors.BOLD}参考模式:{Colors.RESET} {content.get('referenceMode', 'none')}")
|
||
|
||
return content
|
||
else:
|
||
print_error(f"正文生成失败: {result.get('error', 'Unknown error')}")
|
||
return None
|
||
|
||
|
||
def run_full_test(base_url: str):
|
||
"""运行完整测试"""
|
||
print_header("TravelContentCreator 完整流程测试")
|
||
print(f"API 地址: {base_url}")
|
||
|
||
# 测试数据
|
||
subject = {
|
||
"id": 1,
|
||
"name": "北京环球影城",
|
||
"type": "scenic_spot",
|
||
"description": "亚洲最大的环球影城主题公园,拥有哈利波特魔法世界、变形金刚基地等七大主题景区",
|
||
"location": "北京市通州区",
|
||
"traffic_info": "地铁1号线/7号线环球度假区站",
|
||
"highlights": ["哈利波特魔法世界", "变形金刚基地", "小黄人乐园", "侏罗纪世界"],
|
||
"advantages": "全球顶级主题乐园,沉浸式体验",
|
||
"products": [
|
||
{
|
||
"id": 10,
|
||
"name": "单日票",
|
||
"price": 638,
|
||
"original_price": 738,
|
||
"sales_period": "2025-01-01 至 2025-03-31",
|
||
"package_info": "含一次入园",
|
||
"usage_rules": "入园当日有效,不可退改"
|
||
}
|
||
]
|
||
}
|
||
|
||
style = {"id": "gonglue", "name": "攻略风"}
|
||
audience = {"id": "qinzi", "name": "亲子向"}
|
||
|
||
# 参考内容 (用于 reference 和 rewrite 测试)
|
||
reference_content = {
|
||
"title": "上海迪士尼遛娃天花板!一日游保姆级攻略🏰",
|
||
"content": """带娃去迪士尼,这篇攻略你一定要收藏!
|
||
|
||
🎯 必玩项目TOP5
|
||
1. 创极速光轮 - 全球最快迪士尼过山车
|
||
2. 加勒比海盗 - 沉浸式体验超震撼
|
||
3. 飞越地平线 - 裸眼4D环游世界
|
||
4. 小飞侠天空奇遇 - 适合小朋友
|
||
5. 七个小矮人矿山车 - 刺激又不吓人
|
||
|
||
⏰ 时间规划
|
||
8:30 到达门口排队
|
||
9:00 开园直冲创极速光轮
|
||
12:00 午餐(建议自带)
|
||
14:00 花车巡游
|
||
20:00 烟花秀
|
||
|
||
💡 省钱tips
|
||
- 门票提前买,便宜100+
|
||
- 自带水和零食
|
||
- 下载APP领快速通道"""
|
||
}
|
||
|
||
# Step 1: 生成选题
|
||
topic = test_topic_generate(base_url, subject, style, audience)
|
||
if not topic:
|
||
print_error("选题生成失败,终止测试")
|
||
return False
|
||
|
||
# Step 2: 原创模式生成正文
|
||
content1 = test_content_generate(base_url, topic, subject, style, audience,
|
||
reference=None, mode_name="原创")
|
||
|
||
# Step 3: reference 模式生成正文 (参考风格,原创内容)
|
||
content2 = test_content_generate(base_url, topic, subject, style, audience,
|
||
reference={"mode": "reference", **reference_content},
|
||
mode_name="reference 参考")
|
||
|
||
# Step 4: rewrite 模式生成正文 (保留框架,换主体)
|
||
content3 = test_content_generate(base_url, topic, subject, style, audience,
|
||
reference={"mode": "rewrite", **reference_content},
|
||
mode_name="rewrite 改写")
|
||
|
||
# 结果汇总
|
||
print_header("测试结果汇总")
|
||
|
||
results = [
|
||
("选题生成", topic is not None),
|
||
("原创模式正文", content1 is not None),
|
||
("reference 模式正文", content2 is not None),
|
||
("rewrite 模式正文", content3 is not None),
|
||
]
|
||
|
||
all_passed = True
|
||
for name, passed in results:
|
||
if passed:
|
||
print_success(f"{name}: 通过")
|
||
else:
|
||
print_error(f"{name}: 失败")
|
||
all_passed = False
|
||
|
||
if all_passed:
|
||
print(f"\n{Colors.GREEN}{Colors.BOLD}🎉 所有测试通过!{Colors.RESET}")
|
||
else:
|
||
print(f"\n{Colors.RED}{Colors.BOLD}⚠️ 部分测试失败{Colors.RESET}")
|
||
|
||
return all_passed
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="TravelContentCreator 完整流程测试")
|
||
parser.add_argument("--base-url", default="http://localhost:8001", help="API 基础地址")
|
||
args = parser.parse_args()
|
||
|
||
# 检查服务是否可用
|
||
print_info(f"检查服务状态: {args.base_url}")
|
||
try:
|
||
resp = requests.get(f"{args.base_url}/", timeout=5)
|
||
if resp.status_code == 200:
|
||
print_success("服务运行正常")
|
||
else:
|
||
print_error(f"服务异常: HTTP {resp.status_code}")
|
||
return
|
||
except Exception as e:
|
||
print_error(f"无法连接服务: {e}")
|
||
print_info("请先启动服务: PYTHONPATH=. uvicorn api.main:app --port 8001")
|
||
return
|
||
|
||
# 运行测试
|
||
run_full_test(args.base_url)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|