315 lines
11 KiB
Python
315 lines
11 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
演示改进后的流式处理超时机制
|
||
|
||
此脚本展示TravelContentCreator中AI_Agent类中优化后的流式处理超时机制:
|
||
1. 更智能的流块超时检测 - 只在收到初始块后才检测流块超时
|
||
2. 首次响应超时检测 - 检测是否在合理时间内收到第一个响应块
|
||
3. 全局超时保护 - 防止整体处理时间过长
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import asyncio
|
||
import time
|
||
import logging
|
||
import random
|
||
from pathlib import Path
|
||
from datetime import datetime
|
||
|
||
# 添加项目根目录到Python路径
|
||
project_root = str(Path(__file__).parent.parent)
|
||
if project_root not in sys.path:
|
||
sys.path.insert(0, project_root)
|
||
|
||
from core.ai_agent import AI_Agent, Timeout
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||
datefmt='%H:%M:%S'
|
||
)
|
||
|
||
# 示例提示词
|
||
SYSTEM_PROMPT = "你是一个专业的旅游内容创作助手,请根据用户的提示生成相关内容。"
|
||
# 普通提示词
|
||
NORMAL_PROMPT = "请生成一段关于杭州西湖的简短介绍,不超过200字。"
|
||
# 长提示词,可能需要更长处理时间
|
||
LONG_PROMPT = """请按照以下格式为我创建一份详尽的云南七日旅游攻略:
|
||
1. 每日行程计划(早中晚三餐和活动)
|
||
2. 每个景点的历史背景和特色
|
||
3. 当地特色美食推荐和品尝地点
|
||
4. 交通建议(包括景点间如何移动)
|
||
5. 住宿推荐(每个地区的优质酒店或民宿)
|
||
6. 注意事项(天气、海拔、装备等)
|
||
7. 购物指南(值得购买的纪念品和特产)
|
||
|
||
请确保涵盖昆明、大理、丽江和香格里拉等主要旅游地点,并考虑季节特点和当地习俗。提供实用且详细的信息。
|
||
"""
|
||
|
||
def print_separator(title):
|
||
"""打印分隔线和标题"""
|
||
print("\n" + "="*60)
|
||
print(f" {title} ".center(60, "="))
|
||
print("="*60 + "\n")
|
||
|
||
def print_with_timestamp(message, end='\n'):
|
||
"""打印带有时间戳的消息"""
|
||
timestamp = datetime.now().strftime("%H:%M:%S.%f")[:-3]
|
||
print(f"[{timestamp}] {message}", end=end, flush=True)
|
||
|
||
def demo_sync_stream():
|
||
"""演示同步流式响应方法的改进超时机制"""
|
||
print_separator("同步流式响应的超时机制")
|
||
|
||
# 创建AI_Agent实例,设置较短的超时便于测试
|
||
agent = AI_Agent(
|
||
base_url=os.environ.get("API_BASE", "https://api.openai.com/v1"),
|
||
model_name=os.environ.get("MODEL_NAME", "gpt-3.5-turbo"),
|
||
api=os.environ.get("API_KEY", "your_api_key"),
|
||
timeout=15, # 15秒请求超时
|
||
stream_chunk_timeout=5, # 5秒流块超时
|
||
max_retries=1 # 最小重试减少测试时间
|
||
)
|
||
|
||
print_with_timestamp("开始生成内容(同步流式方法)...")
|
||
start_time = time.time()
|
||
|
||
try:
|
||
# 使用同步流式方法
|
||
result = agent.generate_text_stream(
|
||
SYSTEM_PROMPT,
|
||
NORMAL_PROMPT,
|
||
temperature=0.7,
|
||
top_p=0.9,
|
||
presence_penalty=0.0
|
||
)
|
||
|
||
end_time = time.time()
|
||
|
||
print_with_timestamp(f"生成完成! 耗时: {end_time - start_time:.2f}秒")
|
||
print_with_timestamp(f"内容长度: {len(result)} 字符")
|
||
|
||
# 检查是否包含超时或错误提示
|
||
if "[注意:" in result:
|
||
print_with_timestamp("检测到警告信息:")
|
||
warning_start = result.find("[注意:")
|
||
warning_end = result.find("]", warning_start)
|
||
if warning_end != -1:
|
||
print_with_timestamp(f"警告内容: {result[warning_start:warning_end+1]}")
|
||
|
||
except Exception as e:
|
||
print_with_timestamp(f"生成过程中出错: {type(e).__name__} - {e}")
|
||
|
||
# 关闭agent
|
||
agent.close()
|
||
|
||
def demo_callback_stream():
|
||
"""演示回调流式响应方法的改进超时机制"""
|
||
print_separator("回调流式响应的超时机制")
|
||
|
||
# 创建AI_Agent实例
|
||
agent = AI_Agent(
|
||
base_url=os.environ.get("API_BASE", "https://api.openai.com/v1"),
|
||
model_name=os.environ.get("MODEL_NAME", "gpt-3.5-turbo"),
|
||
api=os.environ.get("API_KEY", "your_api_key"),
|
||
timeout=15,
|
||
stream_chunk_timeout=5,
|
||
max_retries=1
|
||
)
|
||
|
||
# 定义回调函数
|
||
def callback(chunk, accumulated=None):
|
||
"""处理流式响应的回调函数"""
|
||
if chunk:
|
||
# 实时打印内容,不换行
|
||
print_with_timestamp(f"收到块({len(chunk)}字符): {chunk[:10]}...", end="\r")
|
||
|
||
print_with_timestamp("开始生成内容(回调流式方法)...")
|
||
start_time = time.time()
|
||
|
||
try:
|
||
# 使用回调流式方法
|
||
result = agent.generate_text_stream_with_callback(
|
||
SYSTEM_PROMPT,
|
||
LONG_PROMPT, # 使用更长的提示
|
||
temperature=0.7,
|
||
top_p=0.9,
|
||
presence_penalty=0.0,
|
||
callback=callback,
|
||
accumulate=True # 启用累积模式
|
||
)
|
||
|
||
end_time = time.time()
|
||
print_with_timestamp(f"\n生成完成! 耗时: {end_time - start_time:.2f}秒")
|
||
print_with_timestamp(f"内容长度: {len(result)} 字符")
|
||
|
||
# 检查是否包含超时或错误提示
|
||
if "[注意:" in result:
|
||
print_with_timestamp("检测到警告信息:")
|
||
warning_start = result.find("[注意:")
|
||
warning_end = result.find("]", warning_start)
|
||
if warning_end != -1:
|
||
print_with_timestamp(f"警告内容: {result[warning_start:warning_end+1]}")
|
||
|
||
except Exception as e:
|
||
print_with_timestamp(f"生成过程中出错: {type(e).__name__} - {e}")
|
||
|
||
# 关闭agent
|
||
agent.close()
|
||
|
||
async def demo_async_stream():
|
||
"""演示异步流式响应方法的改进超时机制"""
|
||
print_separator("异步流式响应的超时机制")
|
||
|
||
# 创建AI_Agent实例
|
||
agent = AI_Agent(
|
||
base_url=os.environ.get("API_BASE", "https://api.openai.com/v1"),
|
||
model_name=os.environ.get("MODEL_NAME", "gpt-3.5-turbo"),
|
||
api=os.environ.get("API_KEY", "your_api_key"),
|
||
timeout=15,
|
||
stream_chunk_timeout=5,
|
||
max_retries=1
|
||
)
|
||
|
||
print_with_timestamp("开始生成内容(异步流式方法)...")
|
||
start_time = time.time()
|
||
full_response = ""
|
||
chunk_count = 0
|
||
|
||
try:
|
||
# 使用异步流式方法
|
||
async_stream = agent.async_generate_text_stream(
|
||
SYSTEM_PROMPT,
|
||
NORMAL_PROMPT,
|
||
temperature=0.7,
|
||
top_p=0.9,
|
||
presence_penalty=0.0
|
||
)
|
||
|
||
# 异步迭代流
|
||
async for content in async_stream:
|
||
# 累积完整响应
|
||
full_response += content
|
||
chunk_count += 1
|
||
|
||
# 每收到5个块打印一次进度
|
||
if chunk_count % 5 == 0:
|
||
print_with_timestamp(f"已收到 {chunk_count} 个块,当前内容长度: {len(full_response)}字符")
|
||
|
||
except Exception as e:
|
||
print_with_timestamp(f"生成过程中出错: {type(e).__name__} - {e}")
|
||
|
||
end_time = time.time()
|
||
print_with_timestamp(f"生成完成! 耗时: {end_time - start_time:.2f}秒")
|
||
print_with_timestamp(f"共收到 {chunk_count} 个块,内容长度: {len(full_response)}字符")
|
||
|
||
# 检查是否包含超时或错误提示
|
||
if "[注意:" in full_response:
|
||
print_with_timestamp("检测到警告信息:")
|
||
warning_start = full_response.find("[注意:")
|
||
warning_end = full_response.find("]", warning_start)
|
||
if warning_end != -1:
|
||
print_with_timestamp(f"警告内容: {full_response[warning_start:warning_end+1]}")
|
||
|
||
# 关闭agent
|
||
agent.close()
|
||
|
||
async def simulate_delayed_response():
|
||
"""模拟延迟响应的场景"""
|
||
print_separator("模拟延迟响应")
|
||
|
||
# 创建带有非常短超时的AI_Agent实例
|
||
agent = AI_Agent(
|
||
base_url=os.environ.get("API_BASE", "https://api.openai.com/v1"),
|
||
model_name=os.environ.get("MODEL_NAME", "gpt-3.5-turbo"),
|
||
api=os.environ.get("API_KEY", "your_api_key"),
|
||
timeout=10,
|
||
stream_chunk_timeout=3, # 非常短的流块超时
|
||
max_retries=1
|
||
)
|
||
|
||
# 自定义处理器模拟延迟
|
||
class DelayedProcessor:
|
||
def __init__(self):
|
||
self.chunks = []
|
||
self.total_chars = 0
|
||
|
||
async def process_chunk(self, chunk):
|
||
self.chunks.append(chunk)
|
||
self.total_chars += len(chunk)
|
||
# 打印进度
|
||
print_with_timestamp(f"收到块 #{len(self.chunks)}, 内容: {chunk[:10]}...", end="\r")
|
||
|
||
# 随机延迟模拟处理时间
|
||
delay = random.uniform(0.1, 4.0) # 0.1-4秒,有时会超过超时设置
|
||
if len(self.chunks) % 5 == 0:
|
||
print_with_timestamp(f"\n模拟处理延迟: {delay:.2f}秒")
|
||
await asyncio.sleep(delay)
|
||
|
||
# 返回处理后的块
|
||
return f"[处理完成: {chunk}]"
|
||
|
||
processor = DelayedProcessor()
|
||
print_with_timestamp("开始生成内容(带处理延迟)...")
|
||
start_time = time.time()
|
||
full_response = ""
|
||
|
||
try:
|
||
# 使用异步流式方法
|
||
async_stream = agent.async_generate_text_stream(
|
||
SYSTEM_PROMPT,
|
||
NORMAL_PROMPT,
|
||
temperature=0.7,
|
||
top_p=0.9,
|
||
presence_penalty=0.0
|
||
)
|
||
|
||
# 异步迭代流,添加处理延迟
|
||
async for content in async_stream:
|
||
# 模拟处理延迟
|
||
processed = await processor.process_chunk(content)
|
||
full_response += processed
|
||
|
||
except Exception as e:
|
||
print_with_timestamp(f"\n生成过程中出错: {type(e).__name__} - {e}")
|
||
|
||
end_time = time.time()
|
||
print_with_timestamp(f"\n生成完成! 耗时: {end_time - start_time:.2f}秒")
|
||
print_with_timestamp(f"共处理 {len(processor.chunks)} 个块,总字符数: {processor.total_chars}")
|
||
|
||
# 检查是否包含超时或错误提示
|
||
if "[注意:" in full_response:
|
||
print_with_timestamp("检测到警告信息:")
|
||
warning_start = full_response.find("[注意:")
|
||
warning_end = full_response.find("]", warning_start)
|
||
if warning_end != -1:
|
||
print_with_timestamp(f"警告内容: {full_response[warning_start:warning_end+1]}")
|
||
|
||
# 关闭agent
|
||
agent.close()
|
||
|
||
async def main():
|
||
"""主函数"""
|
||
print_with_timestamp("测试改进后的流式处理超时机制...")
|
||
|
||
# 1. 测试同步流式响应
|
||
demo_sync_stream()
|
||
|
||
# 2. 测试回调流式响应
|
||
demo_callback_stream()
|
||
|
||
# 3. 测试异步流式响应
|
||
await demo_async_stream()
|
||
|
||
# 4. 模拟延迟响应
|
||
await simulate_delayed_response()
|
||
|
||
print_with_timestamp("\n所有测试完成!")
|
||
|
||
if __name__ == "__main__":
|
||
# 运行异步主函数
|
||
asyncio.run(main()) |