175 lines
4.7 KiB
Python
175 lines
4.7 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
测试AI_Agent的流式处理方法
|
||
|
||
此脚本演示TravelContentCreator中AI_Agent类的三种流式输出处理方法:
|
||
1. 同步流式响应 (generate_text_stream)
|
||
2. 回调式流式响应 (generate_text_stream_with_callback)
|
||
3. 异步流式响应 (async_generate_text_stream)
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import asyncio
|
||
import time
|
||
from pathlib import Path
|
||
|
||
# 添加项目根目录到Python路径
|
||
project_root = str(Path(__file__).parent.parent)
|
||
if project_root not in sys.path:
|
||
sys.path.insert(0, project_root)
|
||
|
||
from core.ai_agent import AI_Agent
|
||
|
||
# 示例提示词
|
||
SYSTEM_PROMPT = """你是一个专业的旅游内容创作助手,请根据用户的提示生成相关内容。"""
|
||
USER_PROMPT = """请为我生成一篇关于福建泰宁古城的旅游攻略,包括著名景点、美食推荐和最佳游玩季节。字数控制在300字以内。"""
|
||
|
||
def print_separator(title):
|
||
"""打印分隔线和标题"""
|
||
print("\n" + "="*50)
|
||
print(f" {title} ".center(50, "="))
|
||
print("="*50 + "\n")
|
||
|
||
def demo_sync_stream():
|
||
"""演示同步流式响应方法"""
|
||
print_separator("同步流式响应 (generate_text_stream)")
|
||
|
||
# 创建AI_Agent实例
|
||
agent = AI_Agent(
|
||
base_url="vllm", # 使用本地vLLM服务
|
||
model_name="qwen2-7b-instruct", # 或其他您配置的模型名称
|
||
api="EMPTY", # vLLM不需要API key
|
||
timeout=60, # 整体请求超时时间(秒)
|
||
stream_chunk_timeout=10 # 流式块超时时间(秒)
|
||
)
|
||
|
||
print("开始生成内容...")
|
||
start_time = time.time()
|
||
|
||
# 使用同步流式方法
|
||
result = agent.generate_text_stream(
|
||
SYSTEM_PROMPT,
|
||
USER_PROMPT,
|
||
temperature=0.7,
|
||
top_p=0.9,
|
||
presence_penalty=0.0
|
||
)
|
||
|
||
end_time = time.time()
|
||
|
||
print(f"\n\n完整生成内容:\n{result}")
|
||
print(f"\n生成完成! 耗时: {end_time - start_time:.2f}秒")
|
||
|
||
# 关闭agent
|
||
agent.close()
|
||
|
||
def demo_callback_stream():
|
||
"""演示回调式流式响应方法"""
|
||
print_separator("回调式流式响应 (generate_text_stream_with_callback)")
|
||
|
||
# 创建AI_Agent实例
|
||
agent = AI_Agent(
|
||
base_url="vllm",
|
||
model_name="qwen2-7b-instruct",
|
||
api="EMPTY",
|
||
timeout=60,
|
||
stream_chunk_timeout=10
|
||
)
|
||
|
||
# 定义回调函数
|
||
def my_callback(content, is_last=False, is_timeout=False, is_error=False, error=None):
|
||
"""处理流式响应的回调函数"""
|
||
if content:
|
||
# 实时打印内容,不换行
|
||
print(content, end="", flush=True)
|
||
|
||
if is_last:
|
||
print("\n")
|
||
if is_timeout:
|
||
print("警告: 响应流超时")
|
||
if is_error:
|
||
print(f"错误: {error}")
|
||
|
||
print("开始生成内容...")
|
||
start_time = time.time()
|
||
|
||
# 使用回调式流式方法
|
||
result = agent.generate_text_stream_with_callback(
|
||
SYSTEM_PROMPT,
|
||
USER_PROMPT,
|
||
my_callback,
|
||
temperature=0.7,
|
||
top_p=0.9,
|
||
presence_penalty=0.0
|
||
)
|
||
|
||
end_time = time.time()
|
||
print(f"\n生成完成! 耗时: {end_time - start_time:.2f}秒")
|
||
|
||
# 关闭agent
|
||
agent.close()
|
||
|
||
async def demo_async_stream():
|
||
"""演示异步流式响应方法"""
|
||
print_separator("异步流式响应 (async_generate_text_stream)")
|
||
|
||
# 创建AI_Agent实例
|
||
agent = AI_Agent(
|
||
base_url="vllm",
|
||
model_name="qwen2-7b-instruct",
|
||
api="EMPTY",
|
||
timeout=60,
|
||
stream_chunk_timeout=10
|
||
)
|
||
|
||
print("开始生成内容...")
|
||
start_time = time.time()
|
||
full_response = ""
|
||
|
||
# 使用异步流式方法
|
||
try:
|
||
async_stream = agent.async_generate_text_stream(
|
||
SYSTEM_PROMPT,
|
||
USER_PROMPT,
|
||
temperature=0.7,
|
||
top_p=0.9,
|
||
presence_penalty=0.0
|
||
)
|
||
|
||
# 异步迭代流
|
||
async for content in async_stream:
|
||
# 累积完整响应
|
||
full_response += content
|
||
# 实时打印内容
|
||
print(content, end="", flush=True)
|
||
|
||
except Exception as e:
|
||
print(f"\n生成过程中出错: {e}")
|
||
|
||
end_time = time.time()
|
||
print(f"\n\n生成完成! 耗时: {end_time - start_time:.2f}秒")
|
||
|
||
# 关闭agent
|
||
agent.close()
|
||
|
||
async def main():
|
||
"""主函数"""
|
||
print("Testing AI_Agent streaming methods...")
|
||
|
||
# 1. 测试同步流式响应
|
||
demo_sync_stream()
|
||
|
||
# 2. 测试回调式流式响应
|
||
demo_callback_stream()
|
||
|
||
# 3. 测试异步流式响应
|
||
await demo_async_stream()
|
||
|
||
print("\n所有测试完成!")
|
||
|
||
if __name__ == "__main__":
|
||
# 运行异步主函数
|
||
asyncio.run(main()) |