TravelContentCreator/examples/demo_robust_streaming.py

315 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
演示改进后的流式处理超时机制
此脚本展示TravelContentCreator中AI_Agent类中优化后的流式处理超时机制:
1. 更智能的流块超时检测 - 只在收到初始块后才检测流块超时
2. 首次响应超时检测 - 检测是否在合理时间内收到第一个响应块
3. 全局超时保护 - 防止整体处理时间过长
"""
import os
import sys
import asyncio
import time
import logging
import random
from pathlib import Path
from datetime import datetime
# 添加项目根目录到Python路径
project_root = str(Path(__file__).parent.parent)
if project_root not in sys.path:
sys.path.insert(0, project_root)
from core.ai_agent import AI_Agent, Timeout
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%H:%M:%S'
)
# 示例提示词
SYSTEM_PROMPT = "你是一个专业的旅游内容创作助手,请根据用户的提示生成相关内容。"
# 普通提示词
NORMAL_PROMPT = "请生成一段关于杭州西湖的简短介绍不超过200字。"
# 长提示词,可能需要更长处理时间
LONG_PROMPT = """请按照以下格式为我创建一份详尽的云南七日旅游攻略:
1. 每日行程计划(早中晚三餐和活动)
2. 每个景点的历史背景和特色
3. 当地特色美食推荐和品尝地点
4. 交通建议(包括景点间如何移动)
5. 住宿推荐(每个地区的优质酒店或民宿)
6. 注意事项(天气、海拔、装备等)
7. 购物指南(值得购买的纪念品和特产)
请确保涵盖昆明、大理、丽江和香格里拉等主要旅游地点,并考虑季节特点和当地习俗。提供实用且详细的信息。
"""
def print_separator(title):
"""打印分隔线和标题"""
print("\n" + "="*60)
print(f" {title} ".center(60, "="))
print("="*60 + "\n")
def print_with_timestamp(message, end='\n'):
"""打印带有时间戳的消息"""
timestamp = datetime.now().strftime("%H:%M:%S.%f")[:-3]
print(f"[{timestamp}] {message}", end=end, flush=True)
def demo_sync_stream():
"""演示同步流式响应方法的改进超时机制"""
print_separator("同步流式响应的超时机制")
# 创建AI_Agent实例设置较短的超时便于测试
agent = AI_Agent(
base_url=os.environ.get("API_BASE", "https://api.openai.com/v1"),
model_name=os.environ.get("MODEL_NAME", "gpt-3.5-turbo"),
api=os.environ.get("API_KEY", "your_api_key"),
timeout=15, # 15秒请求超时
stream_chunk_timeout=5, # 5秒流块超时
max_retries=1 # 最小重试减少测试时间
)
print_with_timestamp("开始生成内容(同步流式方法)...")
start_time = time.time()
try:
# 使用同步流式方法
result = agent.generate_text_stream(
SYSTEM_PROMPT,
NORMAL_PROMPT,
temperature=0.7,
top_p=0.9,
presence_penalty=0.0
)
end_time = time.time()
print_with_timestamp(f"生成完成! 耗时: {end_time - start_time:.2f}")
print_with_timestamp(f"内容长度: {len(result)} 字符")
# 检查是否包含超时或错误提示
if "[注意:" in result:
print_with_timestamp("检测到警告信息:")
warning_start = result.find("[注意:")
warning_end = result.find("]", warning_start)
if warning_end != -1:
print_with_timestamp(f"警告内容: {result[warning_start:warning_end+1]}")
except Exception as e:
print_with_timestamp(f"生成过程中出错: {type(e).__name__} - {e}")
# 关闭agent
agent.close()
def demo_callback_stream():
"""演示回调流式响应方法的改进超时机制"""
print_separator("回调流式响应的超时机制")
# 创建AI_Agent实例
agent = AI_Agent(
base_url=os.environ.get("API_BASE", "https://api.openai.com/v1"),
model_name=os.environ.get("MODEL_NAME", "gpt-3.5-turbo"),
api=os.environ.get("API_KEY", "your_api_key"),
timeout=15,
stream_chunk_timeout=5,
max_retries=1
)
# 定义回调函数
def callback(chunk, accumulated=None):
"""处理流式响应的回调函数"""
if chunk:
# 实时打印内容,不换行
print_with_timestamp(f"收到块({len(chunk)}字符): {chunk[:10]}...", end="\r")
print_with_timestamp("开始生成内容(回调流式方法)...")
start_time = time.time()
try:
# 使用回调流式方法
result = agent.generate_text_stream_with_callback(
SYSTEM_PROMPT,
LONG_PROMPT, # 使用更长的提示
temperature=0.7,
top_p=0.9,
presence_penalty=0.0,
callback=callback,
accumulate=True # 启用累积模式
)
end_time = time.time()
print_with_timestamp(f"\n生成完成! 耗时: {end_time - start_time:.2f}")
print_with_timestamp(f"内容长度: {len(result)} 字符")
# 检查是否包含超时或错误提示
if "[注意:" in result:
print_with_timestamp("检测到警告信息:")
warning_start = result.find("[注意:")
warning_end = result.find("]", warning_start)
if warning_end != -1:
print_with_timestamp(f"警告内容: {result[warning_start:warning_end+1]}")
except Exception as e:
print_with_timestamp(f"生成过程中出错: {type(e).__name__} - {e}")
# 关闭agent
agent.close()
async def demo_async_stream():
"""演示异步流式响应方法的改进超时机制"""
print_separator("异步流式响应的超时机制")
# 创建AI_Agent实例
agent = AI_Agent(
base_url=os.environ.get("API_BASE", "https://api.openai.com/v1"),
model_name=os.environ.get("MODEL_NAME", "gpt-3.5-turbo"),
api=os.environ.get("API_KEY", "your_api_key"),
timeout=15,
stream_chunk_timeout=5,
max_retries=1
)
print_with_timestamp("开始生成内容(异步流式方法)...")
start_time = time.time()
full_response = ""
chunk_count = 0
try:
# 使用异步流式方法
async_stream = agent.async_generate_text_stream(
SYSTEM_PROMPT,
NORMAL_PROMPT,
temperature=0.7,
top_p=0.9,
presence_penalty=0.0
)
# 异步迭代流
async for content in async_stream:
# 累积完整响应
full_response += content
chunk_count += 1
# 每收到5个块打印一次进度
if chunk_count % 5 == 0:
print_with_timestamp(f"已收到 {chunk_count} 个块,当前内容长度: {len(full_response)}字符")
except Exception as e:
print_with_timestamp(f"生成过程中出错: {type(e).__name__} - {e}")
end_time = time.time()
print_with_timestamp(f"生成完成! 耗时: {end_time - start_time:.2f}")
print_with_timestamp(f"共收到 {chunk_count} 个块,内容长度: {len(full_response)}字符")
# 检查是否包含超时或错误提示
if "[注意:" in full_response:
print_with_timestamp("检测到警告信息:")
warning_start = full_response.find("[注意:")
warning_end = full_response.find("]", warning_start)
if warning_end != -1:
print_with_timestamp(f"警告内容: {full_response[warning_start:warning_end+1]}")
# 关闭agent
agent.close()
async def simulate_delayed_response():
"""模拟延迟响应的场景"""
print_separator("模拟延迟响应")
# 创建带有非常短超时的AI_Agent实例
agent = AI_Agent(
base_url=os.environ.get("API_BASE", "https://api.openai.com/v1"),
model_name=os.environ.get("MODEL_NAME", "gpt-3.5-turbo"),
api=os.environ.get("API_KEY", "your_api_key"),
timeout=10,
stream_chunk_timeout=3, # 非常短的流块超时
max_retries=1
)
# 自定义处理器模拟延迟
class DelayedProcessor:
def __init__(self):
self.chunks = []
self.total_chars = 0
async def process_chunk(self, chunk):
self.chunks.append(chunk)
self.total_chars += len(chunk)
# 打印进度
print_with_timestamp(f"收到块 #{len(self.chunks)}, 内容: {chunk[:10]}...", end="\r")
# 随机延迟模拟处理时间
delay = random.uniform(0.1, 4.0) # 0.1-4秒有时会超过超时设置
if len(self.chunks) % 5 == 0:
print_with_timestamp(f"\n模拟处理延迟: {delay:.2f}")
await asyncio.sleep(delay)
# 返回处理后的块
return f"[处理完成: {chunk}]"
processor = DelayedProcessor()
print_with_timestamp("开始生成内容(带处理延迟)...")
start_time = time.time()
full_response = ""
try:
# 使用异步流式方法
async_stream = agent.async_generate_text_stream(
SYSTEM_PROMPT,
NORMAL_PROMPT,
temperature=0.7,
top_p=0.9,
presence_penalty=0.0
)
# 异步迭代流,添加处理延迟
async for content in async_stream:
# 模拟处理延迟
processed = await processor.process_chunk(content)
full_response += processed
except Exception as e:
print_with_timestamp(f"\n生成过程中出错: {type(e).__name__} - {e}")
end_time = time.time()
print_with_timestamp(f"\n生成完成! 耗时: {end_time - start_time:.2f}")
print_with_timestamp(f"共处理 {len(processor.chunks)} 个块,总字符数: {processor.total_chars}")
# 检查是否包含超时或错误提示
if "[注意:" in full_response:
print_with_timestamp("检测到警告信息:")
warning_start = full_response.find("[注意:")
warning_end = full_response.find("]", warning_start)
if warning_end != -1:
print_with_timestamp(f"警告内容: {full_response[warning_start:warning_end+1]}")
# 关闭agent
agent.close()
async def main():
"""主函数"""
print_with_timestamp("测试改进后的流式处理超时机制...")
# 1. 测试同步流式响应
demo_sync_stream()
# 2. 测试回调流式响应
demo_callback_stream()
# 3. 测试异步流式响应
await demo_async_stream()
# 4. 模拟延迟响应
await simulate_delayed_response()
print_with_timestamp("\n所有测试完成!")
if __name__ == "__main__":
# 运行异步主函数
asyncio.run(main())