175 lines
4.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
测试AI_Agent的流式处理方法
此脚本演示TravelContentCreator中AI_Agent类的三种流式输出处理方法:
1. 同步流式响应 (generate_text_stream)
2. 回调式流式响应 (generate_text_stream_with_callback)
3. 异步流式响应 (async_generate_text_stream)
"""
import os
import sys
import asyncio
import time
from pathlib import Path
# 添加项目根目录到Python路径
project_root = str(Path(__file__).parent.parent)
if project_root not in sys.path:
sys.path.insert(0, project_root)
from core.ai_agent import AI_Agent
# 示例提示词
SYSTEM_PROMPT = """你是一个专业的旅游内容创作助手,请根据用户的提示生成相关内容。"""
USER_PROMPT = """请为我生成一篇关于福建泰宁古城的旅游攻略包括著名景点、美食推荐和最佳游玩季节。字数控制在300字以内。"""
def print_separator(title):
"""打印分隔线和标题"""
print("\n" + "="*50)
print(f" {title} ".center(50, "="))
print("="*50 + "\n")
def demo_sync_stream():
"""演示同步流式响应方法"""
print_separator("同步流式响应 (generate_text_stream)")
# 创建AI_Agent实例
agent = AI_Agent(
base_url="vllm", # 使用本地vLLM服务
model_name="qwen2-7b-instruct", # 或其他您配置的模型名称
api="EMPTY", # vLLM不需要API key
timeout=60, # 整体请求超时时间(秒)
stream_chunk_timeout=10 # 流式块超时时间(秒)
)
print("开始生成内容...")
start_time = time.time()
# 使用同步流式方法
result = agent.generate_text_stream(
SYSTEM_PROMPT,
USER_PROMPT,
temperature=0.7,
top_p=0.9,
presence_penalty=0.0
)
end_time = time.time()
print(f"\n\n完整生成内容:\n{result}")
print(f"\n生成完成! 耗时: {end_time - start_time:.2f}")
# 关闭agent
agent.close()
def demo_callback_stream():
"""演示回调式流式响应方法"""
print_separator("回调式流式响应 (generate_text_stream_with_callback)")
# 创建AI_Agent实例
agent = AI_Agent(
base_url="vllm",
model_name="qwen2-7b-instruct",
api="EMPTY",
timeout=60,
stream_chunk_timeout=10
)
# 定义回调函数
def my_callback(content, is_last=False, is_timeout=False, is_error=False, error=None):
"""处理流式响应的回调函数"""
if content:
# 实时打印内容,不换行
print(content, end="", flush=True)
if is_last:
print("\n")
if is_timeout:
print("警告: 响应流超时")
if is_error:
print(f"错误: {error}")
print("开始生成内容...")
start_time = time.time()
# 使用回调式流式方法
result = agent.generate_text_stream_with_callback(
SYSTEM_PROMPT,
USER_PROMPT,
my_callback,
temperature=0.7,
top_p=0.9,
presence_penalty=0.0
)
end_time = time.time()
print(f"\n生成完成! 耗时: {end_time - start_time:.2f}")
# 关闭agent
agent.close()
async def demo_async_stream():
"""演示异步流式响应方法"""
print_separator("异步流式响应 (async_generate_text_stream)")
# 创建AI_Agent实例
agent = AI_Agent(
base_url="vllm",
model_name="qwen2-7b-instruct",
api="EMPTY",
timeout=60,
stream_chunk_timeout=10
)
print("开始生成内容...")
start_time = time.time()
full_response = ""
# 使用异步流式方法
try:
async_stream = agent.async_generate_text_stream(
SYSTEM_PROMPT,
USER_PROMPT,
temperature=0.7,
top_p=0.9,
presence_penalty=0.0
)
# 异步迭代流
async for content in async_stream:
# 累积完整响应
full_response += content
# 实时打印内容
print(content, end="", flush=True)
except Exception as e:
print(f"\n生成过程中出错: {e}")
end_time = time.time()
print(f"\n\n生成完成! 耗时: {end_time - start_time:.2f}")
# 关闭agent
agent.close()
async def main():
"""主函数"""
print("Testing AI_Agent streaming methods...")
# 1. 测试同步流式响应
demo_sync_stream()
# 2. 测试回调式流式响应
demo_callback_stream()
# 3. 测试异步流式响应
await demo_async_stream()
print("\n所有测试完成!")
if __name__ == "__main__":
# 运行异步主函数
asyncio.run(main())