#!/usr/bin/env python # -*- coding: utf-8 -*- import os import sys import logging import asyncio import time from datetime import datetime # 添加项目根目录到路径,确保可以导入核心模块 sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from core.ai_agent import AI_Agent, Timeout # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler()] ) # 从环境变量获取API密钥,或使用默认值 API_KEY = os.environ.get("OPENAI_API_KEY", "your_api_key_here") # 使用API的基础URL BASE_URL = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1") # 使用的模型名称 MODEL_NAME = os.environ.get("OPENAI_MODEL", "gpt-3.5-turbo") def print_with_timestamp(message, end='\n'): """打印带有时间戳的消息""" timestamp = datetime.now().strftime("%H:%M:%S") print(f"[{timestamp}] {message}", end=end, flush=True) def test_sync_stream_with_timeouts(): """测试同步流式响应模式下的超时处理""" print_with_timestamp("开始测试同步流式响应的超时处理...") # 创建 AI_Agent 实例,设置较短的超时时间以便测试 agent = AI_Agent( base_url=BASE_URL, model_name=MODEL_NAME, api=API_KEY, timeout=10, # API 请求整体超时时间 (秒) max_retries=2, # 最大重试次数 stream_chunk_timeout=5 # 流块超时时间 (秒) ) system_prompt = "你是一个有用的助手。" user_prompt = "请详细描述中国的长城,至少500字。" try: print_with_timestamp("正在生成内容...") start_time = time.time() # 使用同步流式响应方法 response = agent.generate_text_stream( system_prompt=system_prompt, user_prompt=user_prompt ) # 输出完整响应和耗时 print_with_timestamp(f"完成! 耗时: {time.time() - start_time:.2f}秒") # 检查响应中是否包含超时或错误提示 if "[注意:" in response: print_with_timestamp("检测到警告信息:") warning_start = response.find("[注意:") warning_end = response.find("]", warning_start) if warning_end != -1: print_with_timestamp(f"警告内容: {response[warning_start:warning_end+1]}") except Timeout as e: print_with_timestamp(f"捕获到超时异常: {e}") except Exception as e: print_with_timestamp(f"捕获到异常: {type(e).__name__} - {e}") finally: agent.close() def test_callback_stream_with_timeouts(): """测试回调流式响应模式下的超时处理""" print_with_timestamp("开始测试回调流式响应的超时处理...") # 创建 AI_Agent 实例,设置较短的超时时间以便测试 agent = AI_Agent( base_url=BASE_URL, model_name=MODEL_NAME, api=API_KEY, timeout=10, # API 请求整体超时时间 (秒) max_retries=2, # 最大重试次数 stream_chunk_timeout=5 # 流块超时时间 (秒) ) system_prompt = "你是一个有用的助手。" user_prompt = "请详细描述中国的长城,至少500字。" # 定义回调函数 def callback(chunk, accumulated=None): print_with_timestamp(f"收到块: 「{chunk}」", end="") try: print_with_timestamp("正在通过回调生成内容...") start_time = time.time() # 使用回调流式响应方法 response = agent.generate_text_stream_with_callback( system_prompt=system_prompt, user_prompt=user_prompt, temperature=0.7, top_p=0.9, presence_penalty=0.0, callback=callback, accumulate=True # 启用累积模式 ) # 输出完整响应和耗时 print_with_timestamp(f"\n完成! 耗时: {time.time() - start_time:.2f}秒") print_with_timestamp("回调累积的响应:") # 检查响应中是否包含超时或错误提示 if "[注意:" in response: print_with_timestamp("检测到警告信息:") warning_start = response.find("[注意:") warning_end = response.find("]", warning_start) if warning_end != -1: print_with_timestamp(f"警告内容: {response[warning_start:warning_end+1]}") except Timeout as e: print_with_timestamp(f"捕获到超时异常: {e}") except Exception as e: print_with_timestamp(f"捕获到异常: {type(e).__name__} - {e}") finally: agent.close() async def test_async_stream_with_timeouts(): """测试异步流式响应模式下的超时处理""" print_with_timestamp("开始测试异步流式响应的超时处理...") # 创建 AI_Agent 实例,设置较短的超时时间以便测试 agent = AI_Agent( base_url=BASE_URL, model_name=MODEL_NAME, api=API_KEY, timeout=10, # API 请求整体超时时间 (秒) max_retries=2, # 最大重试次数 stream_chunk_timeout=5 # 流块超时时间 (秒) ) system_prompt = "你是一个有用的助手。" user_prompt = "请详细描述中国的长城,至少500字。" try: print_with_timestamp("正在异步生成内容...") start_time = time.time() # 使用异步流式响应方法 full_response = "" async for chunk in agent.async_generate_text_stream( system_prompt=system_prompt, user_prompt=user_prompt ): full_response += chunk print_with_timestamp(f"收到块: 「{chunk}」", end="") # 输出完整响应和耗时 print_with_timestamp(f"\n完成! 耗时: {time.time() - start_time:.2f}秒") # 检查响应中是否包含超时或错误提示 if "[注意:" in full_response: print_with_timestamp("检测到警告信息:") warning_start = full_response.find("[注意:") warning_end = full_response.find("]", warning_start) if warning_end != -1: print_with_timestamp(f"警告内容: {full_response[warning_start:warning_end+1]}") except Timeout as e: print_with_timestamp(f"捕获到超时异常: {e}") except Exception as e: print_with_timestamp(f"捕获到异常: {type(e).__name__} - {e}") finally: agent.close() async def run_all_tests(): """运行所有测试""" # 测试同步模式 test_sync_stream_with_timeouts() print("\n" + "-"*50 + "\n") # 测试回调模式 test_callback_stream_with_timeouts() print("\n" + "-"*50 + "\n") # 测试异步模式 await test_async_stream_with_timeouts() if __name__ == "__main__": # 运行所有测试 asyncio.run(run_all_tests())