修改了标题refer文档
This commit is contained in:
parent
15176e2eaf
commit
3037984fbe
451
core/async_content_generator.py
Normal file
451
core/async_content_generator.py
Normal file
@ -0,0 +1,451 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
异步内容生成器
|
||||
|
||||
基于AI_Agent的async_generate_text_stream方法实现的异步内容生成器类,
|
||||
支持异步API调度和流式输出,适合集成到高并发应用中。
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator, Callable, Union
|
||||
|
||||
# 确保可以导入核心模块
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parent_dir = os.path.dirname(current_dir)
|
||||
if parent_dir not in sys.path:
|
||||
sys.path.insert(0, parent_dir)
|
||||
|
||||
from core.ai_agent import AI_Agent
|
||||
|
||||
class AsyncContentGenerator:
|
||||
"""异步内容生成器类,支持流式输出和异步调度"""
|
||||
|
||||
def __init__(self,
|
||||
model_name: str = "qwen2-7b-instruct",
|
||||
api_base_url: str = "vllm",
|
||||
api_key: str = "EMPTY",
|
||||
output_dir: str = "/root/autodl-tmp/poster_generate_result",
|
||||
timeout: int = 60,
|
||||
max_retries: int = 3,
|
||||
stream_chunk_timeout: int = 10):
|
||||
"""
|
||||
初始化异步内容生成器
|
||||
|
||||
参数:
|
||||
model_name: 使用的模型名称
|
||||
api_base_url: API基础URL或预设名称 ('deepseek', 'vllm')
|
||||
api_key: API密钥
|
||||
output_dir: 输出结果保存目录
|
||||
timeout: API请求超时时间(秒)
|
||||
max_retries: 最大重试次数
|
||||
stream_chunk_timeout: 流块超时时间(秒)
|
||||
"""
|
||||
self.output_dir = output_dir
|
||||
self.model_name = model_name
|
||||
self.api_base_url = api_base_url
|
||||
self.api_key = api_key
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
self.stream_chunk_timeout = stream_chunk_timeout
|
||||
|
||||
# 参数设置
|
||||
self.temperature = 0.7
|
||||
self.top_p = 0.9
|
||||
self.presence_penalty = 0.0
|
||||
|
||||
# 上下文管理
|
||||
self.add_description = ""
|
||||
|
||||
# 设置日志
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
self.logger.info(f"已初始化异步内容生成器: model={model_name}, api_url={api_base_url}")
|
||||
|
||||
async def create_agent(self) -> AI_Agent:
|
||||
"""
|
||||
创建AI_Agent实例
|
||||
|
||||
返回:
|
||||
AI_Agent: AI代理实例
|
||||
"""
|
||||
agent = AI_Agent(
|
||||
base_url=self.api_base_url,
|
||||
model_name=self.model_name,
|
||||
api=self.api_key,
|
||||
timeout=self.timeout,
|
||||
max_retries=self.max_retries,
|
||||
stream_chunk_timeout=self.stream_chunk_timeout
|
||||
)
|
||||
return agent
|
||||
|
||||
async def load_information(self, info_paths: List[str]) -> None:
|
||||
"""
|
||||
异步加载额外描述文件
|
||||
|
||||
参数:
|
||||
info_paths: 信息文件路径列表
|
||||
"""
|
||||
self.add_description = ""
|
||||
|
||||
for path in info_paths:
|
||||
try:
|
||||
# 使用异步文件读取
|
||||
async with asyncio.open_file(path, 'r') as f:
|
||||
content = await f.read()
|
||||
self.add_description += content
|
||||
self.logger.info(f"已加载信息文件: {path}")
|
||||
except Exception as e:
|
||||
self.logger.warning(f"无法加载信息文件 {path}: {e}")
|
||||
|
||||
async def load_information_sync(self, info_paths: List[str]) -> None:
|
||||
"""
|
||||
同步加载额外描述文件(兼容不支持异步文件操作的环境)
|
||||
|
||||
参数:
|
||||
info_paths: 信息文件路径列表
|
||||
"""
|
||||
self.add_description = ""
|
||||
|
||||
for path in info_paths:
|
||||
try:
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
self.add_description += f.read()
|
||||
self.logger.info(f"已加载信息文件: {path}")
|
||||
except Exception as e:
|
||||
self.logger.warning(f"无法加载信息文件 {path}: {e}")
|
||||
|
||||
def prepare_prompt(self, tweet_content: str) -> tuple:
|
||||
"""
|
||||
准备生成内容所需的提示词
|
||||
|
||||
参数:
|
||||
tweet_content: 推文内容
|
||||
|
||||
返回:
|
||||
tuple: (system_prompt, user_prompt)
|
||||
"""
|
||||
# 系统提示词
|
||||
system_prompt = """
|
||||
你是一个专业的文案处理专家,擅长从文章中提取关键信息并生成吸引人的标题和简短描述。
|
||||
请根据提供的文章内容,提取关键信息并整理成易于理解的格式。
|
||||
"""
|
||||
|
||||
# 用户提示词
|
||||
if self.add_description:
|
||||
user_prompt = f"""
|
||||
以下是需要你处理的信息:
|
||||
|
||||
关于景点的描述:
|
||||
{self.add_description}
|
||||
|
||||
推文内容:
|
||||
{tweet_content}
|
||||
|
||||
请根据这些信息,生成简洁明了的旅游内容,突出景点特色和旅行建议。
|
||||
"""
|
||||
else:
|
||||
user_prompt = f"""
|
||||
以下是需要你处理的推文内容:
|
||||
{tweet_content}
|
||||
|
||||
请根据这些信息,生成简洁明了的旅游内容,突出景点特色和旅行建议。
|
||||
"""
|
||||
|
||||
return system_prompt, user_prompt
|
||||
|
||||
async def generate_content_stream(self,
|
||||
tweet_content: str,
|
||||
system_prompt: Optional[str] = None) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
异步生成内容流
|
||||
|
||||
参数:
|
||||
tweet_content: 推文内容
|
||||
system_prompt: 自定义系统提示词(可选)
|
||||
|
||||
Yields:
|
||||
str: 生成的文本块
|
||||
"""
|
||||
self.logger.info("开始异步生成内容")
|
||||
|
||||
# 如果没有提供系统提示词,使用默认提示词
|
||||
if system_prompt is None:
|
||||
system_prompt, user_prompt = self.prepare_prompt(tweet_content)
|
||||
else:
|
||||
_, user_prompt = self.prepare_prompt(tweet_content)
|
||||
|
||||
# 创建AI_Agent实例
|
||||
agent = await self.create_agent()
|
||||
|
||||
try:
|
||||
# 使用异步流式方法生成内容
|
||||
async for chunk in agent.async_generate_text_stream(
|
||||
system_prompt,
|
||||
user_prompt,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
presence_penalty=self.presence_penalty
|
||||
):
|
||||
yield chunk
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"内容生成出错: {e}")
|
||||
yield f"\n[生成内容出错: {str(e)}]"
|
||||
finally:
|
||||
# 确保资源被释放
|
||||
agent.close()
|
||||
|
||||
async def generate_poster_stream(self,
|
||||
tweet_content: str,
|
||||
poster_num: int = 3,
|
||||
system_prompt: Optional[str] = None) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
异步生成海报配置流
|
||||
|
||||
参数:
|
||||
tweet_content: 推文内容
|
||||
poster_num: 需要生成的海报数量
|
||||
system_prompt: 自定义系统提示词(可选)
|
||||
|
||||
Yields:
|
||||
str: 生成的文本块
|
||||
"""
|
||||
self.logger.info(f"开始异步生成{poster_num}个海报配置")
|
||||
|
||||
# 如果没有提供系统提示词,创建默认系统提示词
|
||||
if system_prompt is None:
|
||||
system_prompt = f"""
|
||||
你是一个专业的文案处理专家,擅长从文章中提取关键信息并生成吸引人的标题和简短描述。
|
||||
现在,我需要你根据提供的文章内容,生成{poster_num}个海报的文案配置。
|
||||
|
||||
每个配置包含:
|
||||
1. main_title:主标题,简短有力,突出景点特点
|
||||
2. texts:两句简短文本,每句不超过15字,描述景点特色或游玩体验
|
||||
|
||||
以JSON数组格式返回配置,示例:
|
||||
[
|
||||
{{
|
||||
"main_title": "泰宁古城",
|
||||
"texts": ["千年古韵","匠心独运"]
|
||||
}},
|
||||
...
|
||||
]
|
||||
|
||||
仅返回JSON数据,不需要任何额外解释。确保生成的标题和文本能够准确反映文章提到的景点特色。
|
||||
"""
|
||||
|
||||
# 准备用户提示词
|
||||
if self.add_description:
|
||||
user_prompt = f"""
|
||||
以下是需要你处理的信息:
|
||||
|
||||
关于景点的描述:
|
||||
{self.add_description}
|
||||
|
||||
推文内容:
|
||||
{tweet_content}
|
||||
|
||||
请根据这些信息,生成{poster_num}个海报文案配置,以JSON数组格式返回。
|
||||
确保主标题(main_title)简短有力,每个text不超过15字,并能准确反映景点特色。
|
||||
"""
|
||||
else:
|
||||
user_prompt = f"""
|
||||
以下是需要你处理的推文内容:
|
||||
{tweet_content}
|
||||
|
||||
请根据这些信息,生成{poster_num}个海报文案配置,以JSON数组格式返回。
|
||||
确保主标题(main_title)简短有力,每个text不超过15字,并能准确反映景点特色。
|
||||
"""
|
||||
|
||||
# 创建AI_Agent实例
|
||||
agent = await self.create_agent()
|
||||
|
||||
try:
|
||||
# 使用异步流式方法生成内容
|
||||
async for chunk in agent.async_generate_text_stream(
|
||||
system_prompt,
|
||||
user_prompt,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
presence_penalty=self.presence_penalty
|
||||
):
|
||||
yield chunk
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"海报配置生成出错: {e}")
|
||||
yield f"\n[生成内容出错: {str(e)}]"
|
||||
finally:
|
||||
# 确保资源被释放
|
||||
agent.close()
|
||||
|
||||
async def parse_json_stream(self, content_stream: AsyncGenerator[str, None]) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""
|
||||
解析JSON流,尝试从流中提取JSON对象
|
||||
|
||||
参数:
|
||||
content_stream: 内容流生成器
|
||||
|
||||
Yields:
|
||||
Dict: 解析出的JSON对象
|
||||
"""
|
||||
full_response = ""
|
||||
|
||||
# 收集完整响应
|
||||
async for chunk in content_stream:
|
||||
full_response += chunk
|
||||
|
||||
# 尝试解析当前累积的内容
|
||||
try:
|
||||
# 查找可能的JSON部分
|
||||
if "[" in full_response and "]" in full_response:
|
||||
start_idx = full_response.find("[")
|
||||
end_idx = full_response.rfind("]") + 1
|
||||
json_str = full_response[start_idx:end_idx]
|
||||
|
||||
# 尝试解析JSON
|
||||
json_data = json.loads(json_str)
|
||||
if isinstance(json_data, list) and json_data:
|
||||
yield {"type": "partial_json", "data": json_data}
|
||||
except json.JSONDecodeError:
|
||||
pass # 继续收集更多内容
|
||||
except Exception as e:
|
||||
self.logger.warning(f"JSON解析尝试出错: {e}")
|
||||
|
||||
# 发送原始块
|
||||
yield {"type": "chunk", "content": chunk}
|
||||
|
||||
# 最终解析
|
||||
try:
|
||||
# 尝试直接解析
|
||||
result = json.loads(full_response)
|
||||
yield {"type": "final_json", "data": result}
|
||||
return
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 尝试从文本中提取JSON
|
||||
try:
|
||||
# 查找常见的格式模式
|
||||
if "```json" in full_response:
|
||||
json_str = full_response.split("```json")[1].split("```")[0].strip()
|
||||
result = json.loads(json_str)
|
||||
yield {"type": "final_json", "data": result}
|
||||
return
|
||||
|
||||
# 查找方括号包围的内容
|
||||
if "[" in full_response and "]" in full_response:
|
||||
start_idx = full_response.find("[")
|
||||
end_idx = full_response.rfind("]") + 1
|
||||
json_str = full_response[start_idx:end_idx]
|
||||
|
||||
result = json.loads(json_str)
|
||||
yield {"type": "final_json", "data": result}
|
||||
return
|
||||
except Exception as e:
|
||||
self.logger.error(f"最终JSON解析失败: {e}")
|
||||
yield {"type": "error", "message": f"JSON解析失败: {str(e)}"}
|
||||
|
||||
async def save_result(self, content: Union[str, Dict, List]) -> str:
|
||||
"""
|
||||
异步保存生成结果
|
||||
|
||||
参数:
|
||||
content: 要保存的内容(字符串或JSON对象)
|
||||
|
||||
返回:
|
||||
str: 保存的文件路径
|
||||
"""
|
||||
# 生成时间戳文件名
|
||||
date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
result_path = os.path.join(self.output_dir, f"{date_time}.json")
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(os.path.dirname(result_path), exist_ok=True)
|
||||
|
||||
try:
|
||||
# 如果是字符串,尝试解析为JSON
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
# 尝试解析JSON
|
||||
json_content = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
# 如果解析失败,将原始字符串封装为JSON
|
||||
json_content = {"content": content, "type": "raw_text"}
|
||||
else:
|
||||
# 已经是字典或列表
|
||||
json_content = content
|
||||
|
||||
# 使用同步方式写入文件(异步文件写入可能不是所有环境都支持)
|
||||
with open(result_path, "w", encoding="utf-8") as f:
|
||||
json.dump(json_content, f, ensure_ascii=False, indent=2)
|
||||
|
||||
self.logger.info(f"结果已保存到: {result_path}")
|
||||
return result_path
|
||||
except Exception as e:
|
||||
self.logger.error(f"保存结果失败: {e}")
|
||||
return ""
|
||||
|
||||
# 参数设置方法
|
||||
def set_temperature(self, temperature: float) -> None:
|
||||
"""设置温度参数"""
|
||||
self.temperature = temperature
|
||||
|
||||
def set_top_p(self, top_p: float) -> None:
|
||||
"""设置top_p参数"""
|
||||
self.top_p = top_p
|
||||
|
||||
def set_presence_penalty(self, presence_penalty: float) -> None:
|
||||
"""设置存在惩罚参数"""
|
||||
self.presence_penalty = presence_penalty
|
||||
|
||||
def set_model_params(self, temperature: float, top_p: float, presence_penalty: float) -> None:
|
||||
"""设置所有模型参数"""
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.presence_penalty = presence_penalty
|
||||
|
||||
|
||||
# 演示代码
|
||||
async def demo():
|
||||
"""演示异步内容生成器的用法"""
|
||||
# 创建异步内容生成器实例
|
||||
generator = AsyncContentGenerator(
|
||||
model_name="qwenQWQ",
|
||||
api_base_url="http://localhost:8000/v1",
|
||||
api_key="EMPTY"
|
||||
)
|
||||
|
||||
# 示例推文内容
|
||||
tweet_content = """
|
||||
清明假期带娃哪里玩?泰宁甘露寺藏着明代建筑奇迹!一柱擎天的悬空阁楼+状元祈福传说,让孩子边玩边涨知识✨
|
||||
|
||||
🎒行程亮点:
|
||||
✅ 安全科普第一站:讲解"一柱插地"千年不倒的秘密,用乐高积木模型让孩子理解力学原理
|
||||
✅ 文化沉浸体验:穿汉服听"叶状元还愿建寺"故事,触摸3.38米粗的"状元柱"许愿
|
||||
✅ 自然探索路线:连接金湖栈道徒步,观察丹霞地貌与古建筑的巧妙融合
|
||||
"""
|
||||
|
||||
print("\n===== 测试异步海报生成 =====\n")
|
||||
|
||||
# 生成海报配置
|
||||
full_response = ""
|
||||
async for chunk in generator.generate_poster_stream(tweet_content, poster_num=3):
|
||||
print(chunk, end="", flush=True)
|
||||
full_response += chunk
|
||||
|
||||
# 保存结果
|
||||
result_path = await generator.save_result(full_response)
|
||||
print(f"\n\n结果已保存到: {result_path}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行演示
|
||||
asyncio.run(demo())
|
||||
288
examples/async_content_api.py
Normal file
288
examples/async_content_api.py
Normal file
@ -0,0 +1,288 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
异步内容生成API示例
|
||||
|
||||
演示如何使用AI_Agent的async_generate_text_stream方法实现异步内容生成API,
|
||||
支持实时流式输出,适合集成到Web服务或其他异步应用中。
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import uvicorn
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
project_root = str(Path(__file__).parent.parent)
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from core.ai_agent import AI_Agent
|
||||
|
||||
# 创建FastAPI实例
|
||||
app = FastAPI(title="Travel Content Generator API",
|
||||
description="提供异步内容生成的API服务")
|
||||
|
||||
# 添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # 允许所有来源,生产环境中应限制
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 全局配置
|
||||
CONFIG = {
|
||||
"base_url": "vllm", # 使用本地vLLM服务
|
||||
"model_name": "qwen2-7b-instruct", # 或其他配置的模型名称
|
||||
"api_key": "EMPTY", # vLLM不需要API key
|
||||
"timeout": 60, # 整体请求超时时间(秒)
|
||||
"stream_chunk_timeout": 10, # 流式块超时时间(秒)
|
||||
"max_retries": 3 # 最大重试次数
|
||||
}
|
||||
|
||||
# API路由
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""API根路径,返回简单的欢迎信息"""
|
||||
return {
|
||||
"message": "Travel Content Generator API",
|
||||
"version": "1.0.0",
|
||||
"docs_url": "/docs"
|
||||
}
|
||||
|
||||
@app.post("/generate/text")
|
||||
async def generate_text(request: Request):
|
||||
"""
|
||||
生成内容的API端点,支持流式响应
|
||||
|
||||
请求体示例:
|
||||
{
|
||||
"system_prompt": "你是一个专业的旅游内容创作助手",
|
||||
"user_prompt": "请为我生成一篇关于福建泰宁古城的旅游攻略",
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
"presence_penalty": 0.0
|
||||
}
|
||||
"""
|
||||
# 解析请求体
|
||||
data = await request.json()
|
||||
|
||||
# 提取参数
|
||||
system_prompt = data.get("system_prompt", "你是一个专业的旅游内容创作助手")
|
||||
user_prompt = data.get("user_prompt", "")
|
||||
temperature = data.get("temperature", 0.7)
|
||||
top_p = data.get("top_p", 0.9)
|
||||
presence_penalty = data.get("presence_penalty", 0.0)
|
||||
|
||||
# 创建响应生成器
|
||||
async def response_generator():
|
||||
# 创建AI_Agent实例
|
||||
agent = AI_Agent(
|
||||
base_url=CONFIG["base_url"],
|
||||
model_name=CONFIG["model_name"],
|
||||
api=CONFIG["api_key"],
|
||||
timeout=CONFIG["timeout"],
|
||||
max_retries=CONFIG["max_retries"],
|
||||
stream_chunk_timeout=CONFIG["stream_chunk_timeout"]
|
||||
)
|
||||
|
||||
try:
|
||||
# 使用异步流式方法生成内容
|
||||
async for chunk in agent.async_generate_text_stream(
|
||||
system_prompt,
|
||||
user_prompt,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
presence_penalty=presence_penalty
|
||||
):
|
||||
# 每个块都作为SSE事件发送
|
||||
yield f"data: {json.dumps({'chunk': chunk})}\n\n"
|
||||
|
||||
# 流结束标记
|
||||
yield f"data: {json.dumps({'done': True})}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
# 发送错误信息
|
||||
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
||||
finally:
|
||||
# 确保资源被释放
|
||||
agent.close()
|
||||
|
||||
# 返回流式响应
|
||||
return StreamingResponse(
|
||||
response_generator(),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
|
||||
# WebSocket路由
|
||||
|
||||
@app.websocket("/ws/generate")
|
||||
async def websocket_generate(websocket: WebSocket):
|
||||
"""
|
||||
通过WebSocket生成内容,支持实时双向通信
|
||||
"""
|
||||
await websocket.accept()
|
||||
|
||||
try:
|
||||
# 接收客户端参数
|
||||
data = await websocket.receive_json()
|
||||
|
||||
# 提取参数
|
||||
system_prompt = data.get("system_prompt", "你是一个专业的旅游内容创作助手")
|
||||
user_prompt = data.get("user_prompt", "")
|
||||
temperature = data.get("temperature", 0.7)
|
||||
top_p = data.get("top_p", 0.9)
|
||||
presence_penalty = data.get("presence_penalty", 0.0)
|
||||
|
||||
# 创建AI_Agent实例
|
||||
agent = AI_Agent(
|
||||
base_url=CONFIG["base_url"],
|
||||
model_name=CONFIG["model_name"],
|
||||
api=CONFIG["api_key"],
|
||||
timeout=CONFIG["timeout"],
|
||||
max_retries=CONFIG["max_retries"],
|
||||
stream_chunk_timeout=CONFIG["stream_chunk_timeout"]
|
||||
)
|
||||
|
||||
try:
|
||||
# 发送开始生成的消息
|
||||
await websocket.send_json({"status": "generating"})
|
||||
|
||||
# 使用异步流式方法生成内容
|
||||
full_response = ""
|
||||
async for chunk in agent.async_generate_text_stream(
|
||||
system_prompt,
|
||||
user_prompt,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
presence_penalty=presence_penalty
|
||||
):
|
||||
# 累积完整响应
|
||||
full_response += chunk
|
||||
|
||||
# 发送每个文本块
|
||||
await websocket.send_json({
|
||||
"type": "chunk",
|
||||
"content": chunk
|
||||
})
|
||||
|
||||
# 模拟处理客户端中断请求
|
||||
# 这里可以添加处理客户端发送的控制命令,如暂停、停止等
|
||||
|
||||
# 发送完成消息
|
||||
await websocket.send_json({
|
||||
"type": "complete",
|
||||
"full_content": full_response
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
# 发送错误信息
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": str(e)
|
||||
})
|
||||
finally:
|
||||
# 确保资源被释放
|
||||
agent.close()
|
||||
|
||||
except WebSocketDisconnect:
|
||||
print("WebSocket客户端断开连接")
|
||||
except Exception as e:
|
||||
print(f"WebSocket错误: {e}")
|
||||
|
||||
# 命令行界面,用于测试
|
||||
|
||||
async def test_async_generation():
|
||||
"""测试异步内容生成功能"""
|
||||
print("\n===== 测试异步内容生成 =====\n")
|
||||
|
||||
# 创建AI_Agent实例
|
||||
agent = AI_Agent(
|
||||
base_url=CONFIG["base_url"],
|
||||
model_name=CONFIG["model_name"],
|
||||
api=CONFIG["api_key"],
|
||||
timeout=CONFIG["timeout"],
|
||||
max_retries=CONFIG["max_retries"],
|
||||
stream_chunk_timeout=CONFIG["stream_chunk_timeout"]
|
||||
)
|
||||
|
||||
# 示例提示词
|
||||
system_prompt = "你是一个专业的旅游内容创作助手,请根据用户的提示生成相关内容。"
|
||||
user_prompt = "请为我生成一篇关于福建泰宁古城的旅游攻略,包括著名景点、美食推荐和最佳游玩季节。字数控制在300字以内。"
|
||||
|
||||
print("开始生成内容...")
|
||||
start_time = time.time()
|
||||
full_response = ""
|
||||
|
||||
try:
|
||||
# 使用异步流式方法
|
||||
async for chunk in agent.async_generate_text_stream(
|
||||
system_prompt,
|
||||
user_prompt,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
presence_penalty=0.0
|
||||
):
|
||||
# 累积完整响应
|
||||
full_response += chunk
|
||||
# 实时打印内容
|
||||
print(chunk, end="", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n生成过程中出错: {e}")
|
||||
finally:
|
||||
# 确保资源被释放
|
||||
agent.close()
|
||||
|
||||
end_time = time.time()
|
||||
print(f"\n\n生成完成! 耗时: {end_time - start_time:.2f}秒")
|
||||
|
||||
def start_api_server():
|
||||
"""启动API服务器"""
|
||||
# 使用uvicorn启动FastAPI服务
|
||||
uvicorn.run(
|
||||
"async_content_api:app",
|
||||
host="0.0.0.0",
|
||||
port=8800,
|
||||
reload=True
|
||||
)
|
||||
|
||||
async def main():
|
||||
"""主函数,提供命令行界面"""
|
||||
if len(sys.argv) > 1:
|
||||
command = sys.argv[1]
|
||||
if command == "server":
|
||||
# 启动API服务器
|
||||
print("启动API服务器...")
|
||||
start_api_server()
|
||||
elif command == "test":
|
||||
# 测试异步生成
|
||||
await test_async_generation()
|
||||
else:
|
||||
print(f"未知命令: {command}")
|
||||
print("可用命令: server, test")
|
||||
else:
|
||||
# 默认测试异步生成
|
||||
await test_async_generation()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 检查依赖
|
||||
try:
|
||||
import fastapi
|
||||
import uvicorn
|
||||
except ImportError:
|
||||
print("请先安装依赖: pip install fastapi uvicorn")
|
||||
sys.exit(1)
|
||||
|
||||
# 运行主函数
|
||||
asyncio.run(main())
|
||||
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Loading…
x
Reference in New Issue
Block a user