198 lines
7.0 KiB
Python
198 lines
7.0 KiB
Python
|
|
#!/usr/bin/env python
|
|||
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
import sys
|
|||
|
|
import logging
|
|||
|
|
import asyncio
|
|||
|
|
import time
|
|||
|
|
from datetime import datetime
|
|||
|
|
|
|||
|
|
# 添加项目根目录到路径,确保可以导入核心模块
|
|||
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|||
|
|
|
|||
|
|
from core.ai_agent import AI_Agent, Timeout
|
|||
|
|
|
|||
|
|
# 配置日志
|
|||
|
|
logging.basicConfig(
|
|||
|
|
level=logging.INFO,
|
|||
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|||
|
|
handlers=[logging.StreamHandler()]
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 从环境变量获取API密钥,或使用默认值
|
|||
|
|
API_KEY = os.environ.get("OPENAI_API_KEY", "your_api_key_here")
|
|||
|
|
# 使用API的基础URL
|
|||
|
|
BASE_URL = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1")
|
|||
|
|
# 使用的模型名称
|
|||
|
|
MODEL_NAME = os.environ.get("OPENAI_MODEL", "gpt-3.5-turbo")
|
|||
|
|
|
|||
|
|
def print_with_timestamp(message, end='\n'):
|
|||
|
|
"""打印带有时间戳的消息"""
|
|||
|
|
timestamp = datetime.now().strftime("%H:%M:%S")
|
|||
|
|
print(f"[{timestamp}] {message}", end=end, flush=True)
|
|||
|
|
|
|||
|
|
def test_sync_stream_with_timeouts():
|
|||
|
|
"""测试同步流式响应模式下的超时处理"""
|
|||
|
|
print_with_timestamp("开始测试同步流式响应的超时处理...")
|
|||
|
|
|
|||
|
|
# 创建 AI_Agent 实例,设置较短的超时时间以便测试
|
|||
|
|
agent = AI_Agent(
|
|||
|
|
base_url=BASE_URL,
|
|||
|
|
model_name=MODEL_NAME,
|
|||
|
|
api=API_KEY,
|
|||
|
|
timeout=10, # API 请求整体超时时间 (秒)
|
|||
|
|
max_retries=2, # 最大重试次数
|
|||
|
|
stream_chunk_timeout=5 # 流块超时时间 (秒)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
system_prompt = "你是一个有用的助手。"
|
|||
|
|
user_prompt = "请详细描述中国的长城,至少500字。"
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
print_with_timestamp("正在生成内容...")
|
|||
|
|
start_time = time.time()
|
|||
|
|
|
|||
|
|
# 使用同步流式响应方法
|
|||
|
|
response = agent.generate_text_stream(
|
|||
|
|
system_prompt=system_prompt,
|
|||
|
|
user_prompt=user_prompt
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 输出完整响应和耗时
|
|||
|
|
print_with_timestamp(f"完成! 耗时: {time.time() - start_time:.2f}秒")
|
|||
|
|
|
|||
|
|
# 检查响应中是否包含超时或错误提示
|
|||
|
|
if "[注意:" in response:
|
|||
|
|
print_with_timestamp("检测到警告信息:")
|
|||
|
|
warning_start = response.find("[注意:")
|
|||
|
|
warning_end = response.find("]", warning_start)
|
|||
|
|
if warning_end != -1:
|
|||
|
|
print_with_timestamp(f"警告内容: {response[warning_start:warning_end+1]}")
|
|||
|
|
|
|||
|
|
except Timeout as e:
|
|||
|
|
print_with_timestamp(f"捕获到超时异常: {e}")
|
|||
|
|
except Exception as e:
|
|||
|
|
print_with_timestamp(f"捕获到异常: {type(e).__name__} - {e}")
|
|||
|
|
finally:
|
|||
|
|
agent.close()
|
|||
|
|
|
|||
|
|
def test_callback_stream_with_timeouts():
|
|||
|
|
"""测试回调流式响应模式下的超时处理"""
|
|||
|
|
print_with_timestamp("开始测试回调流式响应的超时处理...")
|
|||
|
|
|
|||
|
|
# 创建 AI_Agent 实例,设置较短的超时时间以便测试
|
|||
|
|
agent = AI_Agent(
|
|||
|
|
base_url=BASE_URL,
|
|||
|
|
model_name=MODEL_NAME,
|
|||
|
|
api=API_KEY,
|
|||
|
|
timeout=10, # API 请求整体超时时间 (秒)
|
|||
|
|
max_retries=2, # 最大重试次数
|
|||
|
|
stream_chunk_timeout=5 # 流块超时时间 (秒)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
system_prompt = "你是一个有用的助手。"
|
|||
|
|
user_prompt = "请详细描述中国的长城,至少500字。"
|
|||
|
|
|
|||
|
|
# 定义回调函数
|
|||
|
|
def callback(chunk, accumulated=None):
|
|||
|
|
print_with_timestamp(f"收到块: 「{chunk}」", end="")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
print_with_timestamp("正在通过回调生成内容...")
|
|||
|
|
start_time = time.time()
|
|||
|
|
|
|||
|
|
# 使用回调流式响应方法
|
|||
|
|
response = agent.generate_text_stream_with_callback(
|
|||
|
|
system_prompt=system_prompt,
|
|||
|
|
user_prompt=user_prompt,
|
|||
|
|
temperature=0.7,
|
|||
|
|
top_p=0.9,
|
|||
|
|
presence_penalty=0.0,
|
|||
|
|
callback=callback,
|
|||
|
|
accumulate=True # 启用累积模式
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 输出完整响应和耗时
|
|||
|
|
print_with_timestamp(f"\n完成! 耗时: {time.time() - start_time:.2f}秒")
|
|||
|
|
print_with_timestamp("回调累积的响应:")
|
|||
|
|
|
|||
|
|
# 检查响应中是否包含超时或错误提示
|
|||
|
|
if "[注意:" in response:
|
|||
|
|
print_with_timestamp("检测到警告信息:")
|
|||
|
|
warning_start = response.find("[注意:")
|
|||
|
|
warning_end = response.find("]", warning_start)
|
|||
|
|
if warning_end != -1:
|
|||
|
|
print_with_timestamp(f"警告内容: {response[warning_start:warning_end+1]}")
|
|||
|
|
|
|||
|
|
except Timeout as e:
|
|||
|
|
print_with_timestamp(f"捕获到超时异常: {e}")
|
|||
|
|
except Exception as e:
|
|||
|
|
print_with_timestamp(f"捕获到异常: {type(e).__name__} - {e}")
|
|||
|
|
finally:
|
|||
|
|
agent.close()
|
|||
|
|
|
|||
|
|
async def test_async_stream_with_timeouts():
|
|||
|
|
"""测试异步流式响应模式下的超时处理"""
|
|||
|
|
print_with_timestamp("开始测试异步流式响应的超时处理...")
|
|||
|
|
|
|||
|
|
# 创建 AI_Agent 实例,设置较短的超时时间以便测试
|
|||
|
|
agent = AI_Agent(
|
|||
|
|
base_url=BASE_URL,
|
|||
|
|
model_name=MODEL_NAME,
|
|||
|
|
api=API_KEY,
|
|||
|
|
timeout=10, # API 请求整体超时时间 (秒)
|
|||
|
|
max_retries=2, # 最大重试次数
|
|||
|
|
stream_chunk_timeout=5 # 流块超时时间 (秒)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
system_prompt = "你是一个有用的助手。"
|
|||
|
|
user_prompt = "请详细描述中国的长城,至少500字。"
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
print_with_timestamp("正在异步生成内容...")
|
|||
|
|
start_time = time.time()
|
|||
|
|
|
|||
|
|
# 使用异步流式响应方法
|
|||
|
|
full_response = ""
|
|||
|
|
async for chunk in agent.async_generate_text_stream(
|
|||
|
|
system_prompt=system_prompt,
|
|||
|
|
user_prompt=user_prompt
|
|||
|
|
):
|
|||
|
|
full_response += chunk
|
|||
|
|
print_with_timestamp(f"收到块: 「{chunk}」", end="")
|
|||
|
|
|
|||
|
|
# 输出完整响应和耗时
|
|||
|
|
print_with_timestamp(f"\n完成! 耗时: {time.time() - start_time:.2f}秒")
|
|||
|
|
|
|||
|
|
# 检查响应中是否包含超时或错误提示
|
|||
|
|
if "[注意:" in full_response:
|
|||
|
|
print_with_timestamp("检测到警告信息:")
|
|||
|
|
warning_start = full_response.find("[注意:")
|
|||
|
|
warning_end = full_response.find("]", warning_start)
|
|||
|
|
if warning_end != -1:
|
|||
|
|
print_with_timestamp(f"警告内容: {full_response[warning_start:warning_end+1]}")
|
|||
|
|
|
|||
|
|
except Timeout as e:
|
|||
|
|
print_with_timestamp(f"捕获到超时异常: {e}")
|
|||
|
|
except Exception as e:
|
|||
|
|
print_with_timestamp(f"捕获到异常: {type(e).__name__} - {e}")
|
|||
|
|
finally:
|
|||
|
|
agent.close()
|
|||
|
|
|
|||
|
|
async def run_all_tests():
|
|||
|
|
"""运行所有测试"""
|
|||
|
|
# 测试同步模式
|
|||
|
|
test_sync_stream_with_timeouts()
|
|||
|
|
print("\n" + "-"*50 + "\n")
|
|||
|
|
|
|||
|
|
# 测试回调模式
|
|||
|
|
test_callback_stream_with_timeouts()
|
|||
|
|
print("\n" + "-"*50 + "\n")
|
|||
|
|
|
|||
|
|
# 测试异步模式
|
|||
|
|
await test_async_stream_with_timeouts()
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
# 运行所有测试
|
|||
|
|
asyncio.run(run_all_tests())
|