TravelContentCreator/examples/test_stream_with_timeout_handling.py

198 lines
7.0 KiB
Python
Raw Normal View History

2025-04-23 20:03:00 +08:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import sys
import logging
import asyncio
import time
from datetime import datetime
# 添加项目根目录到路径,确保可以导入核心模块
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from core.ai_agent import AI_Agent, Timeout
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()]
)
# 从环境变量获取API密钥或使用默认值
API_KEY = os.environ.get("OPENAI_API_KEY", "your_api_key_here")
# 使用API的基础URL
BASE_URL = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1")
# 使用的模型名称
MODEL_NAME = os.environ.get("OPENAI_MODEL", "gpt-3.5-turbo")
def print_with_timestamp(message, end='\n'):
"""打印带有时间戳的消息"""
timestamp = datetime.now().strftime("%H:%M:%S")
print(f"[{timestamp}] {message}", end=end, flush=True)
def test_sync_stream_with_timeouts():
"""测试同步流式响应模式下的超时处理"""
print_with_timestamp("开始测试同步流式响应的超时处理...")
# 创建 AI_Agent 实例,设置较短的超时时间以便测试
agent = AI_Agent(
base_url=BASE_URL,
model_name=MODEL_NAME,
api=API_KEY,
timeout=10, # API 请求整体超时时间 (秒)
max_retries=2, # 最大重试次数
stream_chunk_timeout=5 # 流块超时时间 (秒)
)
system_prompt = "你是一个有用的助手。"
user_prompt = "请详细描述中国的长城至少500字。"
try:
print_with_timestamp("正在生成内容...")
start_time = time.time()
# 使用同步流式响应方法
response = agent.generate_text_stream(
system_prompt=system_prompt,
user_prompt=user_prompt
)
# 输出完整响应和耗时
print_with_timestamp(f"完成! 耗时: {time.time() - start_time:.2f}")
# 检查响应中是否包含超时或错误提示
if "[注意:" in response:
print_with_timestamp("检测到警告信息:")
warning_start = response.find("[注意:")
warning_end = response.find("]", warning_start)
if warning_end != -1:
print_with_timestamp(f"警告内容: {response[warning_start:warning_end+1]}")
except Timeout as e:
print_with_timestamp(f"捕获到超时异常: {e}")
except Exception as e:
print_with_timestamp(f"捕获到异常: {type(e).__name__} - {e}")
finally:
agent.close()
def test_callback_stream_with_timeouts():
"""测试回调流式响应模式下的超时处理"""
print_with_timestamp("开始测试回调流式响应的超时处理...")
# 创建 AI_Agent 实例,设置较短的超时时间以便测试
agent = AI_Agent(
base_url=BASE_URL,
model_name=MODEL_NAME,
api=API_KEY,
timeout=10, # API 请求整体超时时间 (秒)
max_retries=2, # 最大重试次数
stream_chunk_timeout=5 # 流块超时时间 (秒)
)
system_prompt = "你是一个有用的助手。"
user_prompt = "请详细描述中国的长城至少500字。"
# 定义回调函数
def callback(chunk, accumulated=None):
print_with_timestamp(f"收到块: 「{chunk}", end="")
try:
print_with_timestamp("正在通过回调生成内容...")
start_time = time.time()
# 使用回调流式响应方法
response = agent.generate_text_stream_with_callback(
system_prompt=system_prompt,
user_prompt=user_prompt,
temperature=0.7,
top_p=0.9,
presence_penalty=0.0,
callback=callback,
accumulate=True # 启用累积模式
)
# 输出完整响应和耗时
print_with_timestamp(f"\n完成! 耗时: {time.time() - start_time:.2f}")
print_with_timestamp("回调累积的响应:")
# 检查响应中是否包含超时或错误提示
if "[注意:" in response:
print_with_timestamp("检测到警告信息:")
warning_start = response.find("[注意:")
warning_end = response.find("]", warning_start)
if warning_end != -1:
print_with_timestamp(f"警告内容: {response[warning_start:warning_end+1]}")
except Timeout as e:
print_with_timestamp(f"捕获到超时异常: {e}")
except Exception as e:
print_with_timestamp(f"捕获到异常: {type(e).__name__} - {e}")
finally:
agent.close()
async def test_async_stream_with_timeouts():
"""测试异步流式响应模式下的超时处理"""
print_with_timestamp("开始测试异步流式响应的超时处理...")
# 创建 AI_Agent 实例,设置较短的超时时间以便测试
agent = AI_Agent(
base_url=BASE_URL,
model_name=MODEL_NAME,
api=API_KEY,
timeout=10, # API 请求整体超时时间 (秒)
max_retries=2, # 最大重试次数
stream_chunk_timeout=5 # 流块超时时间 (秒)
)
system_prompt = "你是一个有用的助手。"
user_prompt = "请详细描述中国的长城至少500字。"
try:
print_with_timestamp("正在异步生成内容...")
start_time = time.time()
# 使用异步流式响应方法
full_response = ""
async for chunk in agent.async_generate_text_stream(
system_prompt=system_prompt,
user_prompt=user_prompt
):
full_response += chunk
print_with_timestamp(f"收到块: 「{chunk}", end="")
# 输出完整响应和耗时
print_with_timestamp(f"\n完成! 耗时: {time.time() - start_time:.2f}")
# 检查响应中是否包含超时或错误提示
if "[注意:" in full_response:
print_with_timestamp("检测到警告信息:")
warning_start = full_response.find("[注意:")
warning_end = full_response.find("]", warning_start)
if warning_end != -1:
print_with_timestamp(f"警告内容: {full_response[warning_start:warning_end+1]}")
except Timeout as e:
print_with_timestamp(f"捕获到超时异常: {e}")
except Exception as e:
print_with_timestamp(f"捕获到异常: {type(e).__name__} - {e}")
finally:
agent.close()
async def run_all_tests():
"""运行所有测试"""
# 测试同步模式
test_sync_stream_with_timeouts()
print("\n" + "-"*50 + "\n")
# 测试回调模式
test_callback_stream_with_timeouts()
print("\n" + "-"*50 + "\n")
# 测试异步模式
await test_async_stream_with_timeouts()
if __name__ == "__main__":
# 运行所有测试
asyncio.run(run_all_tests())