TravelContentCreator/tests/test_topic_to_content.py

295 lines
9.8 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
完整流程测试: 选题生成 → 正文生成
测试 V2 API 的完整业务流程:
1. 生成选题
2. 基于选题生成正文 (原创模式)
3. 基于选题生成正文 (reference 模式)
4. 基于选题生成正文 (rewrite 模式)
使用方法:
python tests/test_topic_to_content.py [--base-url http://localhost:8001]
"""
import argparse
import json
import requests
import time
from typing import Dict, Any, Optional
class Colors:
"""终端颜色"""
GREEN = '\033[92m'
RED = '\033[91m'
YELLOW = '\033[93m'
BLUE = '\033[94m'
CYAN = '\033[96m'
RESET = '\033[0m'
BOLD = '\033[1m'
def print_header(text: str):
print(f"\n{Colors.BOLD}{Colors.BLUE}{'='*60}{Colors.RESET}")
print(f"{Colors.BOLD}{Colors.BLUE}{text}{Colors.RESET}")
print(f"{Colors.BOLD}{Colors.BLUE}{'='*60}{Colors.RESET}")
def print_step(step: int, text: str):
print(f"\n{Colors.CYAN}[Step {step}] {text}{Colors.RESET}")
def print_success(text: str):
print(f"{Colors.GREEN}{text}{Colors.RESET}")
def print_error(text: str):
print(f"{Colors.RED}{text}{Colors.RESET}")
def print_info(text: str):
print(f"{Colors.YELLOW} {text}{Colors.RESET}")
def call_api(base_url: str, endpoint: str, method: str = "GET", data: Dict = None) -> Dict:
"""调用 API"""
url = f"{base_url}{endpoint}"
try:
if method == "GET":
resp = requests.get(url, timeout=120)
else:
resp = requests.post(url, json=data, timeout=120)
return resp.json()
except requests.exceptions.Timeout:
return {"success": False, "error": "请求超时"}
except Exception as e:
return {"success": False, "error": str(e)}
def test_topic_generate(base_url: str, subject: Dict, style: Dict, audience: Dict) -> Optional[Dict]:
"""测试选题生成"""
print_step(1, "生成选题")
params = {
"engine": "topic_generate",
"params": {
"month": "2025-02",
"num_topics": 2,
"prompt_version": "v2.0.0",
"subject": subject,
"style": style,
"audience": audience,
"hot_topics": {
"events": [],
"festivals": [
{"name": "春节", "date": "2025-01-29", "marketing_angle": "团圆、年味"}
],
"trending": []
}
}
}
print_info(f"请求参数: subject={subject['name']}, style={style['name']}, audience={audience['name']}")
start_time = time.time()
result = call_api(base_url, "/api/v2/aigc/execute", "POST", params)
elapsed = time.time() - start_time
if result.get("success"):
topics = result.get("data", {}).get("topics", [])
print_success(f"生成 {len(topics)} 个选题 (耗时 {elapsed:.1f}s)")
for i, topic in enumerate(topics):
print(f"\n {Colors.BOLD}选题 {i+1}:{Colors.RESET}")
print(f" 标题: {topic.get('title', 'N/A')}")
print(f" 日期: {topic.get('date', 'N/A')}")
print(f" 风格: {topic.get('style', 'N/A')}")
print(f" 人群: {topic.get('targetAudience', 'N/A')}")
print(f" 逻辑: {topic.get('logic', 'N/A')[:50]}...")
return topics[0] if topics else None
else:
print_error(f"选题生成失败: {result.get('error', 'Unknown error')}")
return None
def test_content_generate(base_url: str, topic: Dict, subject: Dict, style: Dict, audience: Dict,
reference: Optional[Dict] = None, mode_name: str = "原创") -> Optional[Dict]:
"""测试正文生成"""
print_step(2 if reference is None else (3 if reference.get('mode') == 'reference' else 4),
f"生成正文 ({mode_name}模式)")
params = {
"engine": "content_generate",
"params": {
"prompt_version": "v2.0.0",
"topic": topic,
"subject": subject,
"style": style,
"audience": audience,
"enable_judge": False
}
}
if reference:
params["params"]["reference"] = reference
print_info(f"参考模式: {reference.get('mode')}")
else:
print_info("原创模式,无参考内容")
start_time = time.time()
result = call_api(base_url, "/api/v2/aigc/execute", "POST", params)
elapsed = time.time() - start_time
if result.get("success"):
content = result.get("data", {}).get("content", {})
print_success(f"正文生成成功 (耗时 {elapsed:.1f}s)")
print(f"\n {Colors.BOLD}标题:{Colors.RESET} {content.get('title', 'N/A')}")
body = content.get('content', '')
# 截取前 200 字符显示
preview = body[:200] + "..." if len(body) > 200 else body
print(f"\n {Colors.BOLD}正文预览:{Colors.RESET}")
for line in preview.split('\n')[:8]:
print(f" {line}")
print(f"\n {Colors.BOLD}TAG:{Colors.RESET} {content.get('tag', 'N/A')}")
print(f"\n {Colors.BOLD}参考模式:{Colors.RESET} {content.get('referenceMode', 'none')}")
return content
else:
print_error(f"正文生成失败: {result.get('error', 'Unknown error')}")
return None
def run_full_test(base_url: str):
"""运行完整测试"""
print_header("TravelContentCreator 完整流程测试")
print(f"API 地址: {base_url}")
# 测试数据
subject = {
"id": 1,
"name": "北京环球影城",
"type": "scenic_spot",
"description": "亚洲最大的环球影城主题公园,拥有哈利波特魔法世界、变形金刚基地等七大主题景区",
"location": "北京市通州区",
"traffic_info": "地铁1号线/7号线环球度假区站",
"highlights": ["哈利波特魔法世界", "变形金刚基地", "小黄人乐园", "侏罗纪世界"],
"advantages": "全球顶级主题乐园,沉浸式体验",
"products": [
{
"id": 10,
"name": "单日票",
"price": 638,
"original_price": 738,
"sales_period": "2025-01-01 至 2025-03-31",
"package_info": "含一次入园",
"usage_rules": "入园当日有效,不可退改"
}
]
}
style = {"id": "gonglue", "name": "攻略风"}
audience = {"id": "qinzi", "name": "亲子向"}
# 参考内容 (用于 reference 和 rewrite 测试)
reference_content = {
"title": "上海迪士尼遛娃天花板!一日游保姆级攻略🏰",
"content": """带娃去迪士尼,这篇攻略你一定要收藏!
🎯 必玩项目TOP5
1. 创极速光轮 - 全球最快迪士尼过山车
2. 加勒比海盗 - 沉浸式体验超震撼
3. 飞越地平线 - 裸眼4D环游世界
4. 小飞侠天空奇遇 - 适合小朋友
5. 七个小矮人矿山车 - 刺激又不吓人
⏰ 时间规划
8:30 到达门口排队
9:00 开园直冲创极速光轮
12:00 午餐(建议自带)
14:00 花车巡游
20:00 烟花秀
💡 省钱tips
- 门票提前买便宜100+
- 自带水和零食
- 下载APP领快速通道"""
}
# Step 1: 生成选题
topic = test_topic_generate(base_url, subject, style, audience)
if not topic:
print_error("选题生成失败,终止测试")
return False
# Step 2: 原创模式生成正文
content1 = test_content_generate(base_url, topic, subject, style, audience,
reference=None, mode_name="原创")
# Step 3: reference 模式生成正文 (参考风格,原创内容)
content2 = test_content_generate(base_url, topic, subject, style, audience,
reference={"mode": "reference", **reference_content},
mode_name="reference 参考")
# Step 4: rewrite 模式生成正文 (保留框架,换主体)
content3 = test_content_generate(base_url, topic, subject, style, audience,
reference={"mode": "rewrite", **reference_content},
mode_name="rewrite 改写")
# 结果汇总
print_header("测试结果汇总")
results = [
("选题生成", topic is not None),
("原创模式正文", content1 is not None),
("reference 模式正文", content2 is not None),
("rewrite 模式正文", content3 is not None),
]
all_passed = True
for name, passed in results:
if passed:
print_success(f"{name}: 通过")
else:
print_error(f"{name}: 失败")
all_passed = False
if all_passed:
print(f"\n{Colors.GREEN}{Colors.BOLD}🎉 所有测试通过!{Colors.RESET}")
else:
print(f"\n{Colors.RED}{Colors.BOLD}⚠️ 部分测试失败{Colors.RESET}")
return all_passed
def main():
parser = argparse.ArgumentParser(description="TravelContentCreator 完整流程测试")
parser.add_argument("--base-url", default="http://localhost:8001", help="API 基础地址")
args = parser.parse_args()
# 检查服务是否可用
print_info(f"检查服务状态: {args.base_url}")
try:
resp = requests.get(f"{args.base_url}/", timeout=5)
if resp.status_code == 200:
print_success("服务运行正常")
else:
print_error(f"服务异常: HTTP {resp.status_code}")
return
except Exception as e:
print_error(f"无法连接服务: {e}")
print_info("请先启动服务: PYTHONPATH=. uvicorn api.main:app --port 8001")
return
# 运行测试
run_full_test(args.base_url)
if __name__ == "__main__":
main()