TravelContentCreator/examples/test_stream_with_timeout_handling.py

198 lines
7.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import sys
import logging
import asyncio
import time
from datetime import datetime
# 添加项目根目录到路径,确保可以导入核心模块
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from core.ai_agent import AI_Agent, Timeout
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()]
)
# 从环境变量获取API密钥或使用默认值
API_KEY = os.environ.get("OPENAI_API_KEY", "your_api_key_here")
# 使用API的基础URL
BASE_URL = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1")
# 使用的模型名称
MODEL_NAME = os.environ.get("OPENAI_MODEL", "gpt-3.5-turbo")
def print_with_timestamp(message, end='\n'):
"""打印带有时间戳的消息"""
timestamp = datetime.now().strftime("%H:%M:%S")
print(f"[{timestamp}] {message}", end=end, flush=True)
def test_sync_stream_with_timeouts():
"""测试同步流式响应模式下的超时处理"""
print_with_timestamp("开始测试同步流式响应的超时处理...")
# 创建 AI_Agent 实例,设置较短的超时时间以便测试
agent = AI_Agent(
base_url=BASE_URL,
model_name=MODEL_NAME,
api=API_KEY,
timeout=10, # API 请求整体超时时间 (秒)
max_retries=2, # 最大重试次数
stream_chunk_timeout=5 # 流块超时时间 (秒)
)
system_prompt = "你是一个有用的助手。"
user_prompt = "请详细描述中国的长城至少500字。"
try:
print_with_timestamp("正在生成内容...")
start_time = time.time()
# 使用同步流式响应方法
response = agent.generate_text_stream(
system_prompt=system_prompt,
user_prompt=user_prompt
)
# 输出完整响应和耗时
print_with_timestamp(f"完成! 耗时: {time.time() - start_time:.2f}")
# 检查响应中是否包含超时或错误提示
if "[注意:" in response:
print_with_timestamp("检测到警告信息:")
warning_start = response.find("[注意:")
warning_end = response.find("]", warning_start)
if warning_end != -1:
print_with_timestamp(f"警告内容: {response[warning_start:warning_end+1]}")
except Timeout as e:
print_with_timestamp(f"捕获到超时异常: {e}")
except Exception as e:
print_with_timestamp(f"捕获到异常: {type(e).__name__} - {e}")
finally:
agent.close()
def test_callback_stream_with_timeouts():
"""测试回调流式响应模式下的超时处理"""
print_with_timestamp("开始测试回调流式响应的超时处理...")
# 创建 AI_Agent 实例,设置较短的超时时间以便测试
agent = AI_Agent(
base_url=BASE_URL,
model_name=MODEL_NAME,
api=API_KEY,
timeout=10, # API 请求整体超时时间 (秒)
max_retries=2, # 最大重试次数
stream_chunk_timeout=5 # 流块超时时间 (秒)
)
system_prompt = "你是一个有用的助手。"
user_prompt = "请详细描述中国的长城至少500字。"
# 定义回调函数
def callback(chunk, accumulated=None):
print_with_timestamp(f"收到块: 「{chunk}", end="")
try:
print_with_timestamp("正在通过回调生成内容...")
start_time = time.time()
# 使用回调流式响应方法
response = agent.generate_text_stream_with_callback(
system_prompt=system_prompt,
user_prompt=user_prompt,
temperature=0.7,
top_p=0.9,
presence_penalty=0.0,
callback=callback,
accumulate=True # 启用累积模式
)
# 输出完整响应和耗时
print_with_timestamp(f"\n完成! 耗时: {time.time() - start_time:.2f}")
print_with_timestamp("回调累积的响应:")
# 检查响应中是否包含超时或错误提示
if "[注意:" in response:
print_with_timestamp("检测到警告信息:")
warning_start = response.find("[注意:")
warning_end = response.find("]", warning_start)
if warning_end != -1:
print_with_timestamp(f"警告内容: {response[warning_start:warning_end+1]}")
except Timeout as e:
print_with_timestamp(f"捕获到超时异常: {e}")
except Exception as e:
print_with_timestamp(f"捕获到异常: {type(e).__name__} - {e}")
finally:
agent.close()
async def test_async_stream_with_timeouts():
"""测试异步流式响应模式下的超时处理"""
print_with_timestamp("开始测试异步流式响应的超时处理...")
# 创建 AI_Agent 实例,设置较短的超时时间以便测试
agent = AI_Agent(
base_url=BASE_URL,
model_name=MODEL_NAME,
api=API_KEY,
timeout=10, # API 请求整体超时时间 (秒)
max_retries=2, # 最大重试次数
stream_chunk_timeout=5 # 流块超时时间 (秒)
)
system_prompt = "你是一个有用的助手。"
user_prompt = "请详细描述中国的长城至少500字。"
try:
print_with_timestamp("正在异步生成内容...")
start_time = time.time()
# 使用异步流式响应方法
full_response = ""
async for chunk in agent.async_generate_text_stream(
system_prompt=system_prompt,
user_prompt=user_prompt
):
full_response += chunk
print_with_timestamp(f"收到块: 「{chunk}", end="")
# 输出完整响应和耗时
print_with_timestamp(f"\n完成! 耗时: {time.time() - start_time:.2f}")
# 检查响应中是否包含超时或错误提示
if "[注意:" in full_response:
print_with_timestamp("检测到警告信息:")
warning_start = full_response.find("[注意:")
warning_end = full_response.find("]", warning_start)
if warning_end != -1:
print_with_timestamp(f"警告内容: {full_response[warning_start:warning_end+1]}")
except Timeout as e:
print_with_timestamp(f"捕获到超时异常: {e}")
except Exception as e:
print_with_timestamp(f"捕获到异常: {type(e).__name__} - {e}")
finally:
agent.close()
async def run_all_tests():
"""运行所有测试"""
# 测试同步模式
test_sync_stream_with_timeouts()
print("\n" + "-"*50 + "\n")
# 测试回调模式
test_callback_stream_with_timeouts()
print("\n" + "-"*50 + "\n")
# 测试异步模式
await test_async_stream_with_timeouts()
if __name__ == "__main__":
# 运行所有测试
asyncio.run(run_all_tests())