198 lines
7.0 KiB
Python
198 lines
7.0 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
import os
|
||
import sys
|
||
import logging
|
||
import asyncio
|
||
import time
|
||
from datetime import datetime
|
||
|
||
# 添加项目根目录到路径,确保可以导入核心模块
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from core.ai_agent import AI_Agent, Timeout
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||
handlers=[logging.StreamHandler()]
|
||
)
|
||
|
||
# 从环境变量获取API密钥,或使用默认值
|
||
API_KEY = os.environ.get("OPENAI_API_KEY", "your_api_key_here")
|
||
# 使用API的基础URL
|
||
BASE_URL = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1")
|
||
# 使用的模型名称
|
||
MODEL_NAME = os.environ.get("OPENAI_MODEL", "gpt-3.5-turbo")
|
||
|
||
def print_with_timestamp(message, end='\n'):
|
||
"""打印带有时间戳的消息"""
|
||
timestamp = datetime.now().strftime("%H:%M:%S")
|
||
print(f"[{timestamp}] {message}", end=end, flush=True)
|
||
|
||
def test_sync_stream_with_timeouts():
|
||
"""测试同步流式响应模式下的超时处理"""
|
||
print_with_timestamp("开始测试同步流式响应的超时处理...")
|
||
|
||
# 创建 AI_Agent 实例,设置较短的超时时间以便测试
|
||
agent = AI_Agent(
|
||
base_url=BASE_URL,
|
||
model_name=MODEL_NAME,
|
||
api=API_KEY,
|
||
timeout=10, # API 请求整体超时时间 (秒)
|
||
max_retries=2, # 最大重试次数
|
||
stream_chunk_timeout=5 # 流块超时时间 (秒)
|
||
)
|
||
|
||
system_prompt = "你是一个有用的助手。"
|
||
user_prompt = "请详细描述中国的长城,至少500字。"
|
||
|
||
try:
|
||
print_with_timestamp("正在生成内容...")
|
||
start_time = time.time()
|
||
|
||
# 使用同步流式响应方法
|
||
response = agent.generate_text_stream(
|
||
system_prompt=system_prompt,
|
||
user_prompt=user_prompt
|
||
)
|
||
|
||
# 输出完整响应和耗时
|
||
print_with_timestamp(f"完成! 耗时: {time.time() - start_time:.2f}秒")
|
||
|
||
# 检查响应中是否包含超时或错误提示
|
||
if "[注意:" in response:
|
||
print_with_timestamp("检测到警告信息:")
|
||
warning_start = response.find("[注意:")
|
||
warning_end = response.find("]", warning_start)
|
||
if warning_end != -1:
|
||
print_with_timestamp(f"警告内容: {response[warning_start:warning_end+1]}")
|
||
|
||
except Timeout as e:
|
||
print_with_timestamp(f"捕获到超时异常: {e}")
|
||
except Exception as e:
|
||
print_with_timestamp(f"捕获到异常: {type(e).__name__} - {e}")
|
||
finally:
|
||
agent.close()
|
||
|
||
def test_callback_stream_with_timeouts():
|
||
"""测试回调流式响应模式下的超时处理"""
|
||
print_with_timestamp("开始测试回调流式响应的超时处理...")
|
||
|
||
# 创建 AI_Agent 实例,设置较短的超时时间以便测试
|
||
agent = AI_Agent(
|
||
base_url=BASE_URL,
|
||
model_name=MODEL_NAME,
|
||
api=API_KEY,
|
||
timeout=10, # API 请求整体超时时间 (秒)
|
||
max_retries=2, # 最大重试次数
|
||
stream_chunk_timeout=5 # 流块超时时间 (秒)
|
||
)
|
||
|
||
system_prompt = "你是一个有用的助手。"
|
||
user_prompt = "请详细描述中国的长城,至少500字。"
|
||
|
||
# 定义回调函数
|
||
def callback(chunk, accumulated=None):
|
||
print_with_timestamp(f"收到块: 「{chunk}」", end="")
|
||
|
||
try:
|
||
print_with_timestamp("正在通过回调生成内容...")
|
||
start_time = time.time()
|
||
|
||
# 使用回调流式响应方法
|
||
response = agent.generate_text_stream_with_callback(
|
||
system_prompt=system_prompt,
|
||
user_prompt=user_prompt,
|
||
temperature=0.7,
|
||
top_p=0.9,
|
||
presence_penalty=0.0,
|
||
callback=callback,
|
||
accumulate=True # 启用累积模式
|
||
)
|
||
|
||
# 输出完整响应和耗时
|
||
print_with_timestamp(f"\n完成! 耗时: {time.time() - start_time:.2f}秒")
|
||
print_with_timestamp("回调累积的响应:")
|
||
|
||
# 检查响应中是否包含超时或错误提示
|
||
if "[注意:" in response:
|
||
print_with_timestamp("检测到警告信息:")
|
||
warning_start = response.find("[注意:")
|
||
warning_end = response.find("]", warning_start)
|
||
if warning_end != -1:
|
||
print_with_timestamp(f"警告内容: {response[warning_start:warning_end+1]}")
|
||
|
||
except Timeout as e:
|
||
print_with_timestamp(f"捕获到超时异常: {e}")
|
||
except Exception as e:
|
||
print_with_timestamp(f"捕获到异常: {type(e).__name__} - {e}")
|
||
finally:
|
||
agent.close()
|
||
|
||
async def test_async_stream_with_timeouts():
|
||
"""测试异步流式响应模式下的超时处理"""
|
||
print_with_timestamp("开始测试异步流式响应的超时处理...")
|
||
|
||
# 创建 AI_Agent 实例,设置较短的超时时间以便测试
|
||
agent = AI_Agent(
|
||
base_url=BASE_URL,
|
||
model_name=MODEL_NAME,
|
||
api=API_KEY,
|
||
timeout=10, # API 请求整体超时时间 (秒)
|
||
max_retries=2, # 最大重试次数
|
||
stream_chunk_timeout=5 # 流块超时时间 (秒)
|
||
)
|
||
|
||
system_prompt = "你是一个有用的助手。"
|
||
user_prompt = "请详细描述中国的长城,至少500字。"
|
||
|
||
try:
|
||
print_with_timestamp("正在异步生成内容...")
|
||
start_time = time.time()
|
||
|
||
# 使用异步流式响应方法
|
||
full_response = ""
|
||
async for chunk in agent.async_generate_text_stream(
|
||
system_prompt=system_prompt,
|
||
user_prompt=user_prompt
|
||
):
|
||
full_response += chunk
|
||
print_with_timestamp(f"收到块: 「{chunk}」", end="")
|
||
|
||
# 输出完整响应和耗时
|
||
print_with_timestamp(f"\n完成! 耗时: {time.time() - start_time:.2f}秒")
|
||
|
||
# 检查响应中是否包含超时或错误提示
|
||
if "[注意:" in full_response:
|
||
print_with_timestamp("检测到警告信息:")
|
||
warning_start = full_response.find("[注意:")
|
||
warning_end = full_response.find("]", warning_start)
|
||
if warning_end != -1:
|
||
print_with_timestamp(f"警告内容: {full_response[warning_start:warning_end+1]}")
|
||
|
||
except Timeout as e:
|
||
print_with_timestamp(f"捕获到超时异常: {e}")
|
||
except Exception as e:
|
||
print_with_timestamp(f"捕获到异常: {type(e).__name__} - {e}")
|
||
finally:
|
||
agent.close()
|
||
|
||
async def run_all_tests():
|
||
"""运行所有测试"""
|
||
# 测试同步模式
|
||
test_sync_stream_with_timeouts()
|
||
print("\n" + "-"*50 + "\n")
|
||
|
||
# 测试回调模式
|
||
test_callback_stream_with_timeouts()
|
||
print("\n" + "-"*50 + "\n")
|
||
|
||
# 测试异步模式
|
||
await test_async_stream_with_timeouts()
|
||
|
||
if __name__ == "__main__":
|
||
# 运行所有测试
|
||
asyncio.run(run_all_tests()) |