修改了标题refer文档
This commit is contained in:
parent
15176e2eaf
commit
3037984fbe
451
core/async_content_generator.py
Normal file
451
core/async_content_generator.py
Normal file
@ -0,0 +1,451 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
"""
|
||||||
|
异步内容生成器
|
||||||
|
|
||||||
|
基于AI_Agent的async_generate_text_stream方法实现的异步内容生成器类,
|
||||||
|
支持异步API调度和流式输出,适合集成到高并发应用中。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Dict, Any, Optional, AsyncGenerator, Callable, Union
|
||||||
|
|
||||||
|
# 确保可以导入核心模块
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
parent_dir = os.path.dirname(current_dir)
|
||||||
|
if parent_dir not in sys.path:
|
||||||
|
sys.path.insert(0, parent_dir)
|
||||||
|
|
||||||
|
from core.ai_agent import AI_Agent
|
||||||
|
|
||||||
|
class AsyncContentGenerator:
|
||||||
|
"""异步内容生成器类,支持流式输出和异步调度"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model_name: str = "qwen2-7b-instruct",
|
||||||
|
api_base_url: str = "vllm",
|
||||||
|
api_key: str = "EMPTY",
|
||||||
|
output_dir: str = "/root/autodl-tmp/poster_generate_result",
|
||||||
|
timeout: int = 60,
|
||||||
|
max_retries: int = 3,
|
||||||
|
stream_chunk_timeout: int = 10):
|
||||||
|
"""
|
||||||
|
初始化异步内容生成器
|
||||||
|
|
||||||
|
参数:
|
||||||
|
model_name: 使用的模型名称
|
||||||
|
api_base_url: API基础URL或预设名称 ('deepseek', 'vllm')
|
||||||
|
api_key: API密钥
|
||||||
|
output_dir: 输出结果保存目录
|
||||||
|
timeout: API请求超时时间(秒)
|
||||||
|
max_retries: 最大重试次数
|
||||||
|
stream_chunk_timeout: 流块超时时间(秒)
|
||||||
|
"""
|
||||||
|
self.output_dir = output_dir
|
||||||
|
self.model_name = model_name
|
||||||
|
self.api_base_url = api_base_url
|
||||||
|
self.api_key = api_key
|
||||||
|
self.timeout = timeout
|
||||||
|
self.max_retries = max_retries
|
||||||
|
self.stream_chunk_timeout = stream_chunk_timeout
|
||||||
|
|
||||||
|
# 参数设置
|
||||||
|
self.temperature = 0.7
|
||||||
|
self.top_p = 0.9
|
||||||
|
self.presence_penalty = 0.0
|
||||||
|
|
||||||
|
# 上下文管理
|
||||||
|
self.add_description = ""
|
||||||
|
|
||||||
|
# 设置日志
|
||||||
|
logging.basicConfig(level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
self.logger.info(f"已初始化异步内容生成器: model={model_name}, api_url={api_base_url}")
|
||||||
|
|
||||||
|
async def create_agent(self) -> AI_Agent:
|
||||||
|
"""
|
||||||
|
创建AI_Agent实例
|
||||||
|
|
||||||
|
返回:
|
||||||
|
AI_Agent: AI代理实例
|
||||||
|
"""
|
||||||
|
agent = AI_Agent(
|
||||||
|
base_url=self.api_base_url,
|
||||||
|
model_name=self.model_name,
|
||||||
|
api=self.api_key,
|
||||||
|
timeout=self.timeout,
|
||||||
|
max_retries=self.max_retries,
|
||||||
|
stream_chunk_timeout=self.stream_chunk_timeout
|
||||||
|
)
|
||||||
|
return agent
|
||||||
|
|
||||||
|
async def load_information(self, info_paths: List[str]) -> None:
|
||||||
|
"""
|
||||||
|
异步加载额外描述文件
|
||||||
|
|
||||||
|
参数:
|
||||||
|
info_paths: 信息文件路径列表
|
||||||
|
"""
|
||||||
|
self.add_description = ""
|
||||||
|
|
||||||
|
for path in info_paths:
|
||||||
|
try:
|
||||||
|
# 使用异步文件读取
|
||||||
|
async with asyncio.open_file(path, 'r') as f:
|
||||||
|
content = await f.read()
|
||||||
|
self.add_description += content
|
||||||
|
self.logger.info(f"已加载信息文件: {path}")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"无法加载信息文件 {path}: {e}")
|
||||||
|
|
||||||
|
async def load_information_sync(self, info_paths: List[str]) -> None:
|
||||||
|
"""
|
||||||
|
同步加载额外描述文件(兼容不支持异步文件操作的环境)
|
||||||
|
|
||||||
|
参数:
|
||||||
|
info_paths: 信息文件路径列表
|
||||||
|
"""
|
||||||
|
self.add_description = ""
|
||||||
|
|
||||||
|
for path in info_paths:
|
||||||
|
try:
|
||||||
|
with open(path, 'r', encoding='utf-8') as f:
|
||||||
|
self.add_description += f.read()
|
||||||
|
self.logger.info(f"已加载信息文件: {path}")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"无法加载信息文件 {path}: {e}")
|
||||||
|
|
||||||
|
def prepare_prompt(self, tweet_content: str) -> tuple:
|
||||||
|
"""
|
||||||
|
准备生成内容所需的提示词
|
||||||
|
|
||||||
|
参数:
|
||||||
|
tweet_content: 推文内容
|
||||||
|
|
||||||
|
返回:
|
||||||
|
tuple: (system_prompt, user_prompt)
|
||||||
|
"""
|
||||||
|
# 系统提示词
|
||||||
|
system_prompt = """
|
||||||
|
你是一个专业的文案处理专家,擅长从文章中提取关键信息并生成吸引人的标题和简短描述。
|
||||||
|
请根据提供的文章内容,提取关键信息并整理成易于理解的格式。
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 用户提示词
|
||||||
|
if self.add_description:
|
||||||
|
user_prompt = f"""
|
||||||
|
以下是需要你处理的信息:
|
||||||
|
|
||||||
|
关于景点的描述:
|
||||||
|
{self.add_description}
|
||||||
|
|
||||||
|
推文内容:
|
||||||
|
{tweet_content}
|
||||||
|
|
||||||
|
请根据这些信息,生成简洁明了的旅游内容,突出景点特色和旅行建议。
|
||||||
|
"""
|
||||||
|
else:
|
||||||
|
user_prompt = f"""
|
||||||
|
以下是需要你处理的推文内容:
|
||||||
|
{tweet_content}
|
||||||
|
|
||||||
|
请根据这些信息,生成简洁明了的旅游内容,突出景点特色和旅行建议。
|
||||||
|
"""
|
||||||
|
|
||||||
|
return system_prompt, user_prompt
|
||||||
|
|
||||||
|
async def generate_content_stream(self,
|
||||||
|
tweet_content: str,
|
||||||
|
system_prompt: Optional[str] = None) -> AsyncGenerator[str, None]:
|
||||||
|
"""
|
||||||
|
异步生成内容流
|
||||||
|
|
||||||
|
参数:
|
||||||
|
tweet_content: 推文内容
|
||||||
|
system_prompt: 自定义系统提示词(可选)
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
str: 生成的文本块
|
||||||
|
"""
|
||||||
|
self.logger.info("开始异步生成内容")
|
||||||
|
|
||||||
|
# 如果没有提供系统提示词,使用默认提示词
|
||||||
|
if system_prompt is None:
|
||||||
|
system_prompt, user_prompt = self.prepare_prompt(tweet_content)
|
||||||
|
else:
|
||||||
|
_, user_prompt = self.prepare_prompt(tweet_content)
|
||||||
|
|
||||||
|
# 创建AI_Agent实例
|
||||||
|
agent = await self.create_agent()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 使用异步流式方法生成内容
|
||||||
|
async for chunk in agent.async_generate_text_stream(
|
||||||
|
system_prompt,
|
||||||
|
user_prompt,
|
||||||
|
temperature=self.temperature,
|
||||||
|
top_p=self.top_p,
|
||||||
|
presence_penalty=self.presence_penalty
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"内容生成出错: {e}")
|
||||||
|
yield f"\n[生成内容出错: {str(e)}]"
|
||||||
|
finally:
|
||||||
|
# 确保资源被释放
|
||||||
|
agent.close()
|
||||||
|
|
||||||
|
async def generate_poster_stream(self,
|
||||||
|
tweet_content: str,
|
||||||
|
poster_num: int = 3,
|
||||||
|
system_prompt: Optional[str] = None) -> AsyncGenerator[str, None]:
|
||||||
|
"""
|
||||||
|
异步生成海报配置流
|
||||||
|
|
||||||
|
参数:
|
||||||
|
tweet_content: 推文内容
|
||||||
|
poster_num: 需要生成的海报数量
|
||||||
|
system_prompt: 自定义系统提示词(可选)
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
str: 生成的文本块
|
||||||
|
"""
|
||||||
|
self.logger.info(f"开始异步生成{poster_num}个海报配置")
|
||||||
|
|
||||||
|
# 如果没有提供系统提示词,创建默认系统提示词
|
||||||
|
if system_prompt is None:
|
||||||
|
system_prompt = f"""
|
||||||
|
你是一个专业的文案处理专家,擅长从文章中提取关键信息并生成吸引人的标题和简短描述。
|
||||||
|
现在,我需要你根据提供的文章内容,生成{poster_num}个海报的文案配置。
|
||||||
|
|
||||||
|
每个配置包含:
|
||||||
|
1. main_title:主标题,简短有力,突出景点特点
|
||||||
|
2. texts:两句简短文本,每句不超过15字,描述景点特色或游玩体验
|
||||||
|
|
||||||
|
以JSON数组格式返回配置,示例:
|
||||||
|
[
|
||||||
|
{{
|
||||||
|
"main_title": "泰宁古城",
|
||||||
|
"texts": ["千年古韵","匠心独运"]
|
||||||
|
}},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
|
||||||
|
仅返回JSON数据,不需要任何额外解释。确保生成的标题和文本能够准确反映文章提到的景点特色。
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 准备用户提示词
|
||||||
|
if self.add_description:
|
||||||
|
user_prompt = f"""
|
||||||
|
以下是需要你处理的信息:
|
||||||
|
|
||||||
|
关于景点的描述:
|
||||||
|
{self.add_description}
|
||||||
|
|
||||||
|
推文内容:
|
||||||
|
{tweet_content}
|
||||||
|
|
||||||
|
请根据这些信息,生成{poster_num}个海报文案配置,以JSON数组格式返回。
|
||||||
|
确保主标题(main_title)简短有力,每个text不超过15字,并能准确反映景点特色。
|
||||||
|
"""
|
||||||
|
else:
|
||||||
|
user_prompt = f"""
|
||||||
|
以下是需要你处理的推文内容:
|
||||||
|
{tweet_content}
|
||||||
|
|
||||||
|
请根据这些信息,生成{poster_num}个海报文案配置,以JSON数组格式返回。
|
||||||
|
确保主标题(main_title)简短有力,每个text不超过15字,并能准确反映景点特色。
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 创建AI_Agent实例
|
||||||
|
agent = await self.create_agent()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 使用异步流式方法生成内容
|
||||||
|
async for chunk in agent.async_generate_text_stream(
|
||||||
|
system_prompt,
|
||||||
|
user_prompt,
|
||||||
|
temperature=self.temperature,
|
||||||
|
top_p=self.top_p,
|
||||||
|
presence_penalty=self.presence_penalty
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"海报配置生成出错: {e}")
|
||||||
|
yield f"\n[生成内容出错: {str(e)}]"
|
||||||
|
finally:
|
||||||
|
# 确保资源被释放
|
||||||
|
agent.close()
|
||||||
|
|
||||||
|
async def parse_json_stream(self, content_stream: AsyncGenerator[str, None]) -> AsyncGenerator[Dict[str, Any], None]:
|
||||||
|
"""
|
||||||
|
解析JSON流,尝试从流中提取JSON对象
|
||||||
|
|
||||||
|
参数:
|
||||||
|
content_stream: 内容流生成器
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Dict: 解析出的JSON对象
|
||||||
|
"""
|
||||||
|
full_response = ""
|
||||||
|
|
||||||
|
# 收集完整响应
|
||||||
|
async for chunk in content_stream:
|
||||||
|
full_response += chunk
|
||||||
|
|
||||||
|
# 尝试解析当前累积的内容
|
||||||
|
try:
|
||||||
|
# 查找可能的JSON部分
|
||||||
|
if "[" in full_response and "]" in full_response:
|
||||||
|
start_idx = full_response.find("[")
|
||||||
|
end_idx = full_response.rfind("]") + 1
|
||||||
|
json_str = full_response[start_idx:end_idx]
|
||||||
|
|
||||||
|
# 尝试解析JSON
|
||||||
|
json_data = json.loads(json_str)
|
||||||
|
if isinstance(json_data, list) and json_data:
|
||||||
|
yield {"type": "partial_json", "data": json_data}
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass # 继续收集更多内容
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"JSON解析尝试出错: {e}")
|
||||||
|
|
||||||
|
# 发送原始块
|
||||||
|
yield {"type": "chunk", "content": chunk}
|
||||||
|
|
||||||
|
# 最终解析
|
||||||
|
try:
|
||||||
|
# 尝试直接解析
|
||||||
|
result = json.loads(full_response)
|
||||||
|
yield {"type": "final_json", "data": result}
|
||||||
|
return
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 尝试从文本中提取JSON
|
||||||
|
try:
|
||||||
|
# 查找常见的格式模式
|
||||||
|
if "```json" in full_response:
|
||||||
|
json_str = full_response.split("```json")[1].split("```")[0].strip()
|
||||||
|
result = json.loads(json_str)
|
||||||
|
yield {"type": "final_json", "data": result}
|
||||||
|
return
|
||||||
|
|
||||||
|
# 查找方括号包围的内容
|
||||||
|
if "[" in full_response and "]" in full_response:
|
||||||
|
start_idx = full_response.find("[")
|
||||||
|
end_idx = full_response.rfind("]") + 1
|
||||||
|
json_str = full_response[start_idx:end_idx]
|
||||||
|
|
||||||
|
result = json.loads(json_str)
|
||||||
|
yield {"type": "final_json", "data": result}
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"最终JSON解析失败: {e}")
|
||||||
|
yield {"type": "error", "message": f"JSON解析失败: {str(e)}"}
|
||||||
|
|
||||||
|
async def save_result(self, content: Union[str, Dict, List]) -> str:
|
||||||
|
"""
|
||||||
|
异步保存生成结果
|
||||||
|
|
||||||
|
参数:
|
||||||
|
content: 要保存的内容(字符串或JSON对象)
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 保存的文件路径
|
||||||
|
"""
|
||||||
|
# 生成时间戳文件名
|
||||||
|
date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||||
|
result_path = os.path.join(self.output_dir, f"{date_time}.json")
|
||||||
|
|
||||||
|
# 确保目录存在
|
||||||
|
os.makedirs(os.path.dirname(result_path), exist_ok=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 如果是字符串,尝试解析为JSON
|
||||||
|
if isinstance(content, str):
|
||||||
|
try:
|
||||||
|
# 尝试解析JSON
|
||||||
|
json_content = json.loads(content)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# 如果解析失败,将原始字符串封装为JSON
|
||||||
|
json_content = {"content": content, "type": "raw_text"}
|
||||||
|
else:
|
||||||
|
# 已经是字典或列表
|
||||||
|
json_content = content
|
||||||
|
|
||||||
|
# 使用同步方式写入文件(异步文件写入可能不是所有环境都支持)
|
||||||
|
with open(result_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(json_content, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
self.logger.info(f"结果已保存到: {result_path}")
|
||||||
|
return result_path
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"保存结果失败: {e}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 参数设置方法
|
||||||
|
def set_temperature(self, temperature: float) -> None:
|
||||||
|
"""设置温度参数"""
|
||||||
|
self.temperature = temperature
|
||||||
|
|
||||||
|
def set_top_p(self, top_p: float) -> None:
|
||||||
|
"""设置top_p参数"""
|
||||||
|
self.top_p = top_p
|
||||||
|
|
||||||
|
def set_presence_penalty(self, presence_penalty: float) -> None:
|
||||||
|
"""设置存在惩罚参数"""
|
||||||
|
self.presence_penalty = presence_penalty
|
||||||
|
|
||||||
|
def set_model_params(self, temperature: float, top_p: float, presence_penalty: float) -> None:
|
||||||
|
"""设置所有模型参数"""
|
||||||
|
self.temperature = temperature
|
||||||
|
self.top_p = top_p
|
||||||
|
self.presence_penalty = presence_penalty
|
||||||
|
|
||||||
|
|
||||||
|
# 演示代码
|
||||||
|
async def demo():
|
||||||
|
"""演示异步内容生成器的用法"""
|
||||||
|
# 创建异步内容生成器实例
|
||||||
|
generator = AsyncContentGenerator(
|
||||||
|
model_name="qwenQWQ",
|
||||||
|
api_base_url="http://localhost:8000/v1",
|
||||||
|
api_key="EMPTY"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 示例推文内容
|
||||||
|
tweet_content = """
|
||||||
|
清明假期带娃哪里玩?泰宁甘露寺藏着明代建筑奇迹!一柱擎天的悬空阁楼+状元祈福传说,让孩子边玩边涨知识✨
|
||||||
|
|
||||||
|
🎒行程亮点:
|
||||||
|
✅ 安全科普第一站:讲解"一柱插地"千年不倒的秘密,用乐高积木模型让孩子理解力学原理
|
||||||
|
✅ 文化沉浸体验:穿汉服听"叶状元还愿建寺"故事,触摸3.38米粗的"状元柱"许愿
|
||||||
|
✅ 自然探索路线:连接金湖栈道徒步,观察丹霞地貌与古建筑的巧妙融合
|
||||||
|
"""
|
||||||
|
|
||||||
|
print("\n===== 测试异步海报生成 =====\n")
|
||||||
|
|
||||||
|
# 生成海报配置
|
||||||
|
full_response = ""
|
||||||
|
async for chunk in generator.generate_poster_stream(tweet_content, poster_num=3):
|
||||||
|
print(chunk, end="", flush=True)
|
||||||
|
full_response += chunk
|
||||||
|
|
||||||
|
# 保存结果
|
||||||
|
result_path = await generator.save_result(full_response)
|
||||||
|
print(f"\n\n结果已保存到: {result_path}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 运行演示
|
||||||
|
asyncio.run(demo())
|
||||||
288
examples/async_content_api.py
Normal file
288
examples/async_content_api.py
Normal file
@ -0,0 +1,288 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
"""
|
||||||
|
异步内容生成API示例
|
||||||
|
|
||||||
|
演示如何使用AI_Agent的async_generate_text_stream方法实现异步内容生成API,
|
||||||
|
支持实时流式输出,适合集成到Web服务或其他异步应用中。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
|
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
# 添加项目根目录到Python路径
|
||||||
|
project_root = str(Path(__file__).parent.parent)
|
||||||
|
if project_root not in sys.path:
|
||||||
|
sys.path.insert(0, project_root)
|
||||||
|
|
||||||
|
from core.ai_agent import AI_Agent
|
||||||
|
|
||||||
|
# 创建FastAPI实例
|
||||||
|
app = FastAPI(title="Travel Content Generator API",
|
||||||
|
description="提供异步内容生成的API服务")
|
||||||
|
|
||||||
|
# 添加CORS中间件
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"], # 允许所有来源,生产环境中应限制
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# 全局配置
|
||||||
|
CONFIG = {
|
||||||
|
"base_url": "vllm", # 使用本地vLLM服务
|
||||||
|
"model_name": "qwen2-7b-instruct", # 或其他配置的模型名称
|
||||||
|
"api_key": "EMPTY", # vLLM不需要API key
|
||||||
|
"timeout": 60, # 整体请求超时时间(秒)
|
||||||
|
"stream_chunk_timeout": 10, # 流式块超时时间(秒)
|
||||||
|
"max_retries": 3 # 最大重试次数
|
||||||
|
}
|
||||||
|
|
||||||
|
# API路由
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def root():
|
||||||
|
"""API根路径,返回简单的欢迎信息"""
|
||||||
|
return {
|
||||||
|
"message": "Travel Content Generator API",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"docs_url": "/docs"
|
||||||
|
}
|
||||||
|
|
||||||
|
@app.post("/generate/text")
|
||||||
|
async def generate_text(request: Request):
|
||||||
|
"""
|
||||||
|
生成内容的API端点,支持流式响应
|
||||||
|
|
||||||
|
请求体示例:
|
||||||
|
{
|
||||||
|
"system_prompt": "你是一个专业的旅游内容创作助手",
|
||||||
|
"user_prompt": "请为我生成一篇关于福建泰宁古城的旅游攻略",
|
||||||
|
"temperature": 0.7,
|
||||||
|
"top_p": 0.9,
|
||||||
|
"presence_penalty": 0.0
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
# 解析请求体
|
||||||
|
data = await request.json()
|
||||||
|
|
||||||
|
# 提取参数
|
||||||
|
system_prompt = data.get("system_prompt", "你是一个专业的旅游内容创作助手")
|
||||||
|
user_prompt = data.get("user_prompt", "")
|
||||||
|
temperature = data.get("temperature", 0.7)
|
||||||
|
top_p = data.get("top_p", 0.9)
|
||||||
|
presence_penalty = data.get("presence_penalty", 0.0)
|
||||||
|
|
||||||
|
# 创建响应生成器
|
||||||
|
async def response_generator():
|
||||||
|
# 创建AI_Agent实例
|
||||||
|
agent = AI_Agent(
|
||||||
|
base_url=CONFIG["base_url"],
|
||||||
|
model_name=CONFIG["model_name"],
|
||||||
|
api=CONFIG["api_key"],
|
||||||
|
timeout=CONFIG["timeout"],
|
||||||
|
max_retries=CONFIG["max_retries"],
|
||||||
|
stream_chunk_timeout=CONFIG["stream_chunk_timeout"]
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 使用异步流式方法生成内容
|
||||||
|
async for chunk in agent.async_generate_text_stream(
|
||||||
|
system_prompt,
|
||||||
|
user_prompt,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
presence_penalty=presence_penalty
|
||||||
|
):
|
||||||
|
# 每个块都作为SSE事件发送
|
||||||
|
yield f"data: {json.dumps({'chunk': chunk})}\n\n"
|
||||||
|
|
||||||
|
# 流结束标记
|
||||||
|
yield f"data: {json.dumps({'done': True})}\n\n"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# 发送错误信息
|
||||||
|
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
||||||
|
finally:
|
||||||
|
# 确保资源被释放
|
||||||
|
agent.close()
|
||||||
|
|
||||||
|
# 返回流式响应
|
||||||
|
return StreamingResponse(
|
||||||
|
response_generator(),
|
||||||
|
media_type="text/event-stream"
|
||||||
|
)
|
||||||
|
|
||||||
|
# WebSocket路由
|
||||||
|
|
||||||
|
@app.websocket("/ws/generate")
|
||||||
|
async def websocket_generate(websocket: WebSocket):
|
||||||
|
"""
|
||||||
|
通过WebSocket生成内容,支持实时双向通信
|
||||||
|
"""
|
||||||
|
await websocket.accept()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 接收客户端参数
|
||||||
|
data = await websocket.receive_json()
|
||||||
|
|
||||||
|
# 提取参数
|
||||||
|
system_prompt = data.get("system_prompt", "你是一个专业的旅游内容创作助手")
|
||||||
|
user_prompt = data.get("user_prompt", "")
|
||||||
|
temperature = data.get("temperature", 0.7)
|
||||||
|
top_p = data.get("top_p", 0.9)
|
||||||
|
presence_penalty = data.get("presence_penalty", 0.0)
|
||||||
|
|
||||||
|
# 创建AI_Agent实例
|
||||||
|
agent = AI_Agent(
|
||||||
|
base_url=CONFIG["base_url"],
|
||||||
|
model_name=CONFIG["model_name"],
|
||||||
|
api=CONFIG["api_key"],
|
||||||
|
timeout=CONFIG["timeout"],
|
||||||
|
max_retries=CONFIG["max_retries"],
|
||||||
|
stream_chunk_timeout=CONFIG["stream_chunk_timeout"]
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 发送开始生成的消息
|
||||||
|
await websocket.send_json({"status": "generating"})
|
||||||
|
|
||||||
|
# 使用异步流式方法生成内容
|
||||||
|
full_response = ""
|
||||||
|
async for chunk in agent.async_generate_text_stream(
|
||||||
|
system_prompt,
|
||||||
|
user_prompt,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
presence_penalty=presence_penalty
|
||||||
|
):
|
||||||
|
# 累积完整响应
|
||||||
|
full_response += chunk
|
||||||
|
|
||||||
|
# 发送每个文本块
|
||||||
|
await websocket.send_json({
|
||||||
|
"type": "chunk",
|
||||||
|
"content": chunk
|
||||||
|
})
|
||||||
|
|
||||||
|
# 模拟处理客户端中断请求
|
||||||
|
# 这里可以添加处理客户端发送的控制命令,如暂停、停止等
|
||||||
|
|
||||||
|
# 发送完成消息
|
||||||
|
await websocket.send_json({
|
||||||
|
"type": "complete",
|
||||||
|
"full_content": full_response
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# 发送错误信息
|
||||||
|
await websocket.send_json({
|
||||||
|
"type": "error",
|
||||||
|
"message": str(e)
|
||||||
|
})
|
||||||
|
finally:
|
||||||
|
# 确保资源被释放
|
||||||
|
agent.close()
|
||||||
|
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
print("WebSocket客户端断开连接")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"WebSocket错误: {e}")
|
||||||
|
|
||||||
|
# 命令行界面,用于测试
|
||||||
|
|
||||||
|
async def test_async_generation():
|
||||||
|
"""测试异步内容生成功能"""
|
||||||
|
print("\n===== 测试异步内容生成 =====\n")
|
||||||
|
|
||||||
|
# 创建AI_Agent实例
|
||||||
|
agent = AI_Agent(
|
||||||
|
base_url=CONFIG["base_url"],
|
||||||
|
model_name=CONFIG["model_name"],
|
||||||
|
api=CONFIG["api_key"],
|
||||||
|
timeout=CONFIG["timeout"],
|
||||||
|
max_retries=CONFIG["max_retries"],
|
||||||
|
stream_chunk_timeout=CONFIG["stream_chunk_timeout"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 示例提示词
|
||||||
|
system_prompt = "你是一个专业的旅游内容创作助手,请根据用户的提示生成相关内容。"
|
||||||
|
user_prompt = "请为我生成一篇关于福建泰宁古城的旅游攻略,包括著名景点、美食推荐和最佳游玩季节。字数控制在300字以内。"
|
||||||
|
|
||||||
|
print("开始生成内容...")
|
||||||
|
start_time = time.time()
|
||||||
|
full_response = ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 使用异步流式方法
|
||||||
|
async for chunk in agent.async_generate_text_stream(
|
||||||
|
system_prompt,
|
||||||
|
user_prompt,
|
||||||
|
temperature=0.7,
|
||||||
|
top_p=0.9,
|
||||||
|
presence_penalty=0.0
|
||||||
|
):
|
||||||
|
# 累积完整响应
|
||||||
|
full_response += chunk
|
||||||
|
# 实时打印内容
|
||||||
|
print(chunk, end="", flush=True)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n生成过程中出错: {e}")
|
||||||
|
finally:
|
||||||
|
# 确保资源被释放
|
||||||
|
agent.close()
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
print(f"\n\n生成完成! 耗时: {end_time - start_time:.2f}秒")
|
||||||
|
|
||||||
|
def start_api_server():
|
||||||
|
"""启动API服务器"""
|
||||||
|
# 使用uvicorn启动FastAPI服务
|
||||||
|
uvicorn.run(
|
||||||
|
"async_content_api:app",
|
||||||
|
host="0.0.0.0",
|
||||||
|
port=8800,
|
||||||
|
reload=True
|
||||||
|
)
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""主函数,提供命令行界面"""
|
||||||
|
if len(sys.argv) > 1:
|
||||||
|
command = sys.argv[1]
|
||||||
|
if command == "server":
|
||||||
|
# 启动API服务器
|
||||||
|
print("启动API服务器...")
|
||||||
|
start_api_server()
|
||||||
|
elif command == "test":
|
||||||
|
# 测试异步生成
|
||||||
|
await test_async_generation()
|
||||||
|
else:
|
||||||
|
print(f"未知命令: {command}")
|
||||||
|
print("可用命令: server, test")
|
||||||
|
else:
|
||||||
|
# 默认测试异步生成
|
||||||
|
await test_async_generation()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 检查依赖
|
||||||
|
try:
|
||||||
|
import fastapi
|
||||||
|
import uvicorn
|
||||||
|
except ImportError:
|
||||||
|
print("请先安装依赖: pip install fastapi uvicorn")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# 运行主函数
|
||||||
|
asyncio.run(main())
|
||||||
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Loading…
x
Reference in New Issue
Block a user