288 lines
8.6 KiB
Python
288 lines
8.6 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
异步内容生成API示例
|
||
|
||
演示如何使用AI_Agent的async_generate_text_stream方法实现异步内容生成API,
|
||
支持实时流式输出,适合集成到Web服务或其他异步应用中。
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import json
|
||
import time
|
||
import asyncio
|
||
from pathlib import Path
|
||
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
|
||
from fastapi.responses import StreamingResponse
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
import uvicorn
|
||
|
||
# 添加项目根目录到Python路径
|
||
project_root = str(Path(__file__).parent.parent)
|
||
if project_root not in sys.path:
|
||
sys.path.insert(0, project_root)
|
||
|
||
from core.ai_agent import AI_Agent
|
||
|
||
# 创建FastAPI实例
|
||
app = FastAPI(title="Travel Content Generator API",
|
||
description="提供异步内容生成的API服务")
|
||
|
||
# 添加CORS中间件
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"], # 允许所有来源,生产环境中应限制
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
# 全局配置
|
||
CONFIG = {
|
||
"base_url": "vllm", # 使用本地vLLM服务
|
||
"model_name": "qwen2-7b-instruct", # 或其他配置的模型名称
|
||
"api_key": "EMPTY", # vLLM不需要API key
|
||
"timeout": 60, # 整体请求超时时间(秒)
|
||
"stream_chunk_timeout": 10, # 流式块超时时间(秒)
|
||
"max_retries": 3 # 最大重试次数
|
||
}
|
||
|
||
# API路由
|
||
|
||
@app.get("/")
|
||
async def root():
|
||
"""API根路径,返回简单的欢迎信息"""
|
||
return {
|
||
"message": "Travel Content Generator API",
|
||
"version": "1.0.0",
|
||
"docs_url": "/docs"
|
||
}
|
||
|
||
@app.post("/generate/text")
|
||
async def generate_text(request: Request):
|
||
"""
|
||
生成内容的API端点,支持流式响应
|
||
|
||
请求体示例:
|
||
{
|
||
"system_prompt": "你是一个专业的旅游内容创作助手",
|
||
"user_prompt": "请为我生成一篇关于福建泰宁古城的旅游攻略",
|
||
"temperature": 0.7,
|
||
"top_p": 0.9,
|
||
"presence_penalty": 0.0
|
||
}
|
||
"""
|
||
# 解析请求体
|
||
data = await request.json()
|
||
|
||
# 提取参数
|
||
system_prompt = data.get("system_prompt", "你是一个专业的旅游内容创作助手")
|
||
user_prompt = data.get("user_prompt", "")
|
||
temperature = data.get("temperature", 0.7)
|
||
top_p = data.get("top_p", 0.9)
|
||
presence_penalty = data.get("presence_penalty", 0.0)
|
||
|
||
# 创建响应生成器
|
||
async def response_generator():
|
||
# 创建AI_Agent实例
|
||
agent = AI_Agent(
|
||
base_url=CONFIG["base_url"],
|
||
model_name=CONFIG["model_name"],
|
||
api=CONFIG["api_key"],
|
||
timeout=CONFIG["timeout"],
|
||
max_retries=CONFIG["max_retries"],
|
||
stream_chunk_timeout=CONFIG["stream_chunk_timeout"]
|
||
)
|
||
|
||
try:
|
||
# 使用异步流式方法生成内容
|
||
async for chunk in agent.async_generate_text_stream(
|
||
system_prompt,
|
||
user_prompt,
|
||
temperature=temperature,
|
||
top_p=top_p,
|
||
presence_penalty=presence_penalty
|
||
):
|
||
# 每个块都作为SSE事件发送
|
||
yield f"data: {json.dumps({'chunk': chunk})}\n\n"
|
||
|
||
# 流结束标记
|
||
yield f"data: {json.dumps({'done': True})}\n\n"
|
||
|
||
except Exception as e:
|
||
# 发送错误信息
|
||
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
||
finally:
|
||
# 确保资源被释放
|
||
agent.close()
|
||
|
||
# 返回流式响应
|
||
return StreamingResponse(
|
||
response_generator(),
|
||
media_type="text/event-stream"
|
||
)
|
||
|
||
# WebSocket路由
|
||
|
||
@app.websocket("/ws/generate")
|
||
async def websocket_generate(websocket: WebSocket):
|
||
"""
|
||
通过WebSocket生成内容,支持实时双向通信
|
||
"""
|
||
await websocket.accept()
|
||
|
||
try:
|
||
# 接收客户端参数
|
||
data = await websocket.receive_json()
|
||
|
||
# 提取参数
|
||
system_prompt = data.get("system_prompt", "你是一个专业的旅游内容创作助手")
|
||
user_prompt = data.get("user_prompt", "")
|
||
temperature = data.get("temperature", 0.7)
|
||
top_p = data.get("top_p", 0.9)
|
||
presence_penalty = data.get("presence_penalty", 0.0)
|
||
|
||
# 创建AI_Agent实例
|
||
agent = AI_Agent(
|
||
base_url=CONFIG["base_url"],
|
||
model_name=CONFIG["model_name"],
|
||
api=CONFIG["api_key"],
|
||
timeout=CONFIG["timeout"],
|
||
max_retries=CONFIG["max_retries"],
|
||
stream_chunk_timeout=CONFIG["stream_chunk_timeout"]
|
||
)
|
||
|
||
try:
|
||
# 发送开始生成的消息
|
||
await websocket.send_json({"status": "generating"})
|
||
|
||
# 使用异步流式方法生成内容
|
||
full_response = ""
|
||
async for chunk in agent.async_generate_text_stream(
|
||
system_prompt,
|
||
user_prompt,
|
||
temperature=temperature,
|
||
top_p=top_p,
|
||
presence_penalty=presence_penalty
|
||
):
|
||
# 累积完整响应
|
||
full_response += chunk
|
||
|
||
# 发送每个文本块
|
||
await websocket.send_json({
|
||
"type": "chunk",
|
||
"content": chunk
|
||
})
|
||
|
||
# 模拟处理客户端中断请求
|
||
# 这里可以添加处理客户端发送的控制命令,如暂停、停止等
|
||
|
||
# 发送完成消息
|
||
await websocket.send_json({
|
||
"type": "complete",
|
||
"full_content": full_response
|
||
})
|
||
|
||
except Exception as e:
|
||
# 发送错误信息
|
||
await websocket.send_json({
|
||
"type": "error",
|
||
"message": str(e)
|
||
})
|
||
finally:
|
||
# 确保资源被释放
|
||
agent.close()
|
||
|
||
except WebSocketDisconnect:
|
||
print("WebSocket客户端断开连接")
|
||
except Exception as e:
|
||
print(f"WebSocket错误: {e}")
|
||
|
||
# 命令行界面,用于测试
|
||
|
||
async def test_async_generation():
|
||
"""测试异步内容生成功能"""
|
||
print("\n===== 测试异步内容生成 =====\n")
|
||
|
||
# 创建AI_Agent实例
|
||
agent = AI_Agent(
|
||
base_url=CONFIG["base_url"],
|
||
model_name=CONFIG["model_name"],
|
||
api=CONFIG["api_key"],
|
||
timeout=CONFIG["timeout"],
|
||
max_retries=CONFIG["max_retries"],
|
||
stream_chunk_timeout=CONFIG["stream_chunk_timeout"]
|
||
)
|
||
|
||
# 示例提示词
|
||
system_prompt = "你是一个专业的旅游内容创作助手,请根据用户的提示生成相关内容。"
|
||
user_prompt = "请为我生成一篇关于福建泰宁古城的旅游攻略,包括著名景点、美食推荐和最佳游玩季节。字数控制在300字以内。"
|
||
|
||
print("开始生成内容...")
|
||
start_time = time.time()
|
||
full_response = ""
|
||
|
||
try:
|
||
# 使用异步流式方法
|
||
async for chunk in agent.async_generate_text_stream(
|
||
system_prompt,
|
||
user_prompt,
|
||
temperature=0.7,
|
||
top_p=0.9,
|
||
presence_penalty=0.0
|
||
):
|
||
# 累积完整响应
|
||
full_response += chunk
|
||
# 实时打印内容
|
||
print(chunk, end="", flush=True)
|
||
|
||
except Exception as e:
|
||
print(f"\n生成过程中出错: {e}")
|
||
finally:
|
||
# 确保资源被释放
|
||
agent.close()
|
||
|
||
end_time = time.time()
|
||
print(f"\n\n生成完成! 耗时: {end_time - start_time:.2f}秒")
|
||
|
||
def start_api_server():
|
||
"""启动API服务器"""
|
||
# 使用uvicorn启动FastAPI服务
|
||
uvicorn.run(
|
||
"async_content_api:app",
|
||
host="0.0.0.0",
|
||
port=8800,
|
||
reload=True
|
||
)
|
||
|
||
async def main():
|
||
"""主函数,提供命令行界面"""
|
||
if len(sys.argv) > 1:
|
||
command = sys.argv[1]
|
||
if command == "server":
|
||
# 启动API服务器
|
||
print("启动API服务器...")
|
||
start_api_server()
|
||
elif command == "test":
|
||
# 测试异步生成
|
||
await test_async_generation()
|
||
else:
|
||
print(f"未知命令: {command}")
|
||
print("可用命令: server, test")
|
||
else:
|
||
# 默认测试异步生成
|
||
await test_async_generation()
|
||
|
||
if __name__ == "__main__":
|
||
# 检查依赖
|
||
try:
|
||
import fastapi
|
||
import uvicorn
|
||
except ImportError:
|
||
print("请先安装依赖: pip install fastapi uvicorn")
|
||
sys.exit(1)
|
||
|
||
# 运行主函数
|
||
asyncio.run(main()) |