TravelContentCreator/tests/test_topic_to_content.py

295 lines
9.8 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
完整流程测试: 选题生成 正文生成
测试 V2 API 的完整业务流程:
1. 生成选题
2. 基于选题生成正文 (原创模式)
3. 基于选题生成正文 (reference 模式)
4. 基于选题生成正文 (rewrite 模式)
使用方法:
python tests/test_topic_to_content.py [--base-url http://localhost:8001]
"""
import argparse
import json
import requests
import time
from typing import Dict, Any, Optional
class Colors:
"""终端颜色"""
GREEN = '\033[92m'
RED = '\033[91m'
YELLOW = '\033[93m'
BLUE = '\033[94m'
CYAN = '\033[96m'
RESET = '\033[0m'
BOLD = '\033[1m'
def print_header(text: str):
print(f"\n{Colors.BOLD}{Colors.BLUE}{'='*60}{Colors.RESET}")
print(f"{Colors.BOLD}{Colors.BLUE}{text}{Colors.RESET}")
print(f"{Colors.BOLD}{Colors.BLUE}{'='*60}{Colors.RESET}")
def print_step(step: int, text: str):
print(f"\n{Colors.CYAN}[Step {step}] {text}{Colors.RESET}")
def print_success(text: str):
print(f"{Colors.GREEN}{text}{Colors.RESET}")
def print_error(text: str):
print(f"{Colors.RED}{text}{Colors.RESET}")
def print_info(text: str):
print(f"{Colors.YELLOW} {text}{Colors.RESET}")
def call_api(base_url: str, endpoint: str, method: str = "GET", data: Dict = None) -> Dict:
"""调用 API"""
url = f"{base_url}{endpoint}"
try:
if method == "GET":
resp = requests.get(url, timeout=120)
else:
resp = requests.post(url, json=data, timeout=120)
return resp.json()
except requests.exceptions.Timeout:
return {"success": False, "error": "请求超时"}
except Exception as e:
return {"success": False, "error": str(e)}
def test_topic_generate(base_url: str, subject: Dict, style: Dict, audience: Dict) -> Optional[Dict]:
"""测试选题生成"""
print_step(1, "生成选题")
params = {
"engine": "topic_generate",
"params": {
"month": "2025-02",
"num_topics": 2,
"prompt_version": "v2.0.0",
"subject": subject,
"style": style,
"audience": audience,
"hot_topics": {
"events": [],
"festivals": [
{"name": "春节", "date": "2025-01-29", "marketing_angle": "团圆、年味"}
],
"trending": []
}
}
}
print_info(f"请求参数: subject={subject['name']}, style={style['name']}, audience={audience['name']}")
start_time = time.time()
result = call_api(base_url, "/api/v2/aigc/execute", "POST", params)
elapsed = time.time() - start_time
if result.get("success"):
topics = result.get("data", {}).get("topics", [])
print_success(f"生成 {len(topics)} 个选题 (耗时 {elapsed:.1f}s)")
for i, topic in enumerate(topics):
print(f"\n {Colors.BOLD}选题 {i+1}:{Colors.RESET}")
print(f" 标题: {topic.get('title', 'N/A')}")
print(f" 日期: {topic.get('date', 'N/A')}")
print(f" 风格: {topic.get('style', 'N/A')}")
print(f" 人群: {topic.get('targetAudience', 'N/A')}")
print(f" 逻辑: {topic.get('logic', 'N/A')[:50]}...")
return topics[0] if topics else None
else:
print_error(f"选题生成失败: {result.get('error', 'Unknown error')}")
return None
def test_content_generate(base_url: str, topic: Dict, subject: Dict, style: Dict, audience: Dict,
reference: Optional[Dict] = None, mode_name: str = "原创") -> Optional[Dict]:
"""测试正文生成"""
print_step(2 if reference is None else (3 if reference.get('mode') == 'reference' else 4),
f"生成正文 ({mode_name}模式)")
params = {
"engine": "content_generate",
"params": {
"prompt_version": "v2.0.0",
"topic": topic,
"subject": subject,
"style": style,
"audience": audience,
"enable_judge": False
}
}
if reference:
params["params"]["reference"] = reference
print_info(f"参考模式: {reference.get('mode')}")
else:
print_info("原创模式,无参考内容")
start_time = time.time()
result = call_api(base_url, "/api/v2/aigc/execute", "POST", params)
elapsed = time.time() - start_time
if result.get("success"):
content = result.get("data", {}).get("content", {})
print_success(f"正文生成成功 (耗时 {elapsed:.1f}s)")
print(f"\n {Colors.BOLD}标题:{Colors.RESET} {content.get('title', 'N/A')}")
body = content.get('content', '')
# 截取前 200 字符显示
preview = body[:200] + "..." if len(body) > 200 else body
print(f"\n {Colors.BOLD}正文预览:{Colors.RESET}")
for line in preview.split('\n')[:8]:
print(f" {line}")
print(f"\n {Colors.BOLD}TAG:{Colors.RESET} {content.get('tag', 'N/A')}")
print(f"\n {Colors.BOLD}参考模式:{Colors.RESET} {content.get('referenceMode', 'none')}")
return content
else:
print_error(f"正文生成失败: {result.get('error', 'Unknown error')}")
return None
def run_full_test(base_url: str):
"""运行完整测试"""
print_header("TravelContentCreator 完整流程测试")
print(f"API 地址: {base_url}")
# 测试数据
subject = {
"id": 1,
"name": "北京环球影城",
"type": "scenic_spot",
"description": "亚洲最大的环球影城主题公园,拥有哈利波特魔法世界、变形金刚基地等七大主题景区",
"location": "北京市通州区",
"traffic_info": "地铁1号线/7号线环球度假区站",
"highlights": ["哈利波特魔法世界", "变形金刚基地", "小黄人乐园", "侏罗纪世界"],
"advantages": "全球顶级主题乐园,沉浸式体验",
"products": [
{
"id": 10,
"name": "单日票",
"price": 638,
"original_price": 738,
"sales_period": "2025-01-01 至 2025-03-31",
"package_info": "含一次入园",
"usage_rules": "入园当日有效,不可退改"
}
]
}
style = {"id": "gonglue", "name": "攻略风"}
audience = {"id": "qinzi", "name": "亲子向"}
# 参考内容 (用于 reference 和 rewrite 测试)
reference_content = {
"title": "上海迪士尼遛娃天花板!一日游保姆级攻略🏰",
"content": """带娃去迪士尼,这篇攻略你一定要收藏!
🎯 必玩项目TOP5
1. 创极速光轮 - 全球最快迪士尼过山车
2. 加勒比海盗 - 沉浸式体验超震撼
3. 飞越地平线 - 裸眼4D环游世界
4. 小飞侠天空奇遇 - 适合小朋友
5. 七个小矮人矿山车 - 刺激又不吓人
时间规划
8:30 到达门口排队
9:00 开园直冲创极速光轮
12:00 午餐建议自带
14:00 花车巡游
20:00 烟花秀
💡 省钱tips
- 门票提前买便宜100+
- 自带水和零食
- 下载APP领快速通道"""
}
# Step 1: 生成选题
topic = test_topic_generate(base_url, subject, style, audience)
if not topic:
print_error("选题生成失败,终止测试")
return False
# Step 2: 原创模式生成正文
content1 = test_content_generate(base_url, topic, subject, style, audience,
reference=None, mode_name="原创")
# Step 3: reference 模式生成正文 (参考风格,原创内容)
content2 = test_content_generate(base_url, topic, subject, style, audience,
reference={"mode": "reference", **reference_content},
mode_name="reference 参考")
# Step 4: rewrite 模式生成正文 (保留框架,换主体)
content3 = test_content_generate(base_url, topic, subject, style, audience,
reference={"mode": "rewrite", **reference_content},
mode_name="rewrite 改写")
# 结果汇总
print_header("测试结果汇总")
results = [
("选题生成", topic is not None),
("原创模式正文", content1 is not None),
("reference 模式正文", content2 is not None),
("rewrite 模式正文", content3 is not None),
]
all_passed = True
for name, passed in results:
if passed:
print_success(f"{name}: 通过")
else:
print_error(f"{name}: 失败")
all_passed = False
if all_passed:
print(f"\n{Colors.GREEN}{Colors.BOLD}🎉 所有测试通过!{Colors.RESET}")
else:
print(f"\n{Colors.RED}{Colors.BOLD}⚠️ 部分测试失败{Colors.RESET}")
return all_passed
def main():
parser = argparse.ArgumentParser(description="TravelContentCreator 完整流程测试")
parser.add_argument("--base-url", default="http://localhost:8001", help="API 基础地址")
args = parser.parse_args()
# 检查服务是否可用
print_info(f"检查服务状态: {args.base_url}")
try:
resp = requests.get(f"{args.base_url}/", timeout=5)
if resp.status_code == 200:
print_success("服务运行正常")
else:
print_error(f"服务异常: HTTP {resp.status_code}")
return
except Exception as e:
print_error(f"无法连接服务: {e}")
print_info("请先启动服务: PYTHONPATH=. uvicorn api.main:app --port 8001")
return
# 运行测试
run_full_test(args.base_url)
if __name__ == "__main__":
main()