重写了流式输出
This commit is contained in:
parent
1919c2e7a8
commit
15176e2eaf
14
README.md
14
README.md
@ -11,6 +11,20 @@
|
||||
- **模块化设计**:核心功能(配置加载、提示词管理、AI交互、选题、内容生成、海报制作)分离,方便维护和扩展
|
||||
- **配置驱动**:通过配置文件集中管理所有运行参数
|
||||
|
||||
## 新功能: 流式输出处理
|
||||
|
||||
TravelContentCreator 现已支持三种流式输出处理方法,提供了更灵活的 AI 文本生成体验:
|
||||
|
||||
- **同步流式响应**: 使用流式 API 但返回完整响应
|
||||
- **回调式流式响应**: 通过回调函数处理每个文本块
|
||||
- **异步流式响应**: 使用异步生成器返回文本流
|
||||
|
||||
这些功能大大提升了长文本生成的用户体验和系统响应性。
|
||||
|
||||
详细文档请参阅:
|
||||
- [流式处理文档](docs/streaming.md)
|
||||
- [流式处理演示](examples/test_stream.py)
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 1. 环境准备
|
||||
|
||||
Binary file not shown.
Binary file not shown.
270
core/ai_agent.py
270
core/ai_agent.py
@ -5,6 +5,8 @@ import random
|
||||
import traceback
|
||||
import logging
|
||||
import tiktoken
|
||||
import asyncio
|
||||
from asyncio import TimeoutError as AsyncTimeoutError
|
||||
|
||||
# Configure basic logging for this module (or rely on root logger config)
|
||||
# logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
@ -16,6 +18,16 @@ INITIAL_BACKOFF = 1 # Initial backoff time in seconds
|
||||
MAX_BACKOFF = 16 # Maximum backoff time in seconds
|
||||
STREAM_CHUNK_TIMEOUT = 10 # Timeout in seconds for receiving a chunk in stream
|
||||
|
||||
# 自定义超时异常
|
||||
class Timeout(Exception):
|
||||
"""Raised when a stream chunk timeout occurs."""
|
||||
def __init__(self, message="Stream chunk timeout occurred."):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
|
||||
class AI_Agent():
|
||||
"""AI代理类,负责与AI模型交互生成文本内容"""
|
||||
|
||||
@ -198,7 +210,7 @@ class AI_Agent():
|
||||
|
||||
time_end = time.time()
|
||||
time_cost = time_end - time_start
|
||||
tokens = None
|
||||
tokens = len(result) * 1.3
|
||||
logging.info(f"'work' completed in {time_cost:.2f}s. Estimated tokens: {tokens}")
|
||||
|
||||
return result, tokens, time_cost
|
||||
@ -220,9 +232,10 @@ class AI_Agent():
|
||||
user_prompt: The user prompt for the AI.
|
||||
temperature: Sampling temperature.
|
||||
top_p: Nucleus sampling parameter.
|
||||
presence_penalty: Presence penalty parameter.
|
||||
|
||||
Yields:
|
||||
str: Chunks of the generated text.
|
||||
Returns:
|
||||
str: The complete generated text.
|
||||
|
||||
Raises:
|
||||
Exception: If the API call fails after all retries.
|
||||
@ -236,6 +249,7 @@ class AI_Agent():
|
||||
retries = 0
|
||||
backoff_time = INITIAL_BACKOFF
|
||||
last_exception = None
|
||||
full_response = ""
|
||||
|
||||
while retries < self.max_retries:
|
||||
try:
|
||||
@ -245,6 +259,7 @@ class AI_Agent():
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
presence_penalty=presence_penalty,
|
||||
stream=True,
|
||||
timeout=self.timeout # Overall request timeout
|
||||
)
|
||||
@ -263,16 +278,16 @@ class AI_Agent():
|
||||
|
||||
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
|
||||
content = chunk.choices[0].delta.content
|
||||
full_response += content
|
||||
# logging.debug(f"Received chunk: {content}") # Potentially very verbose
|
||||
yield content
|
||||
elif chunk.choices and chunk.choices[0].finish_reason == 'stop':
|
||||
logging.info("Stream finished.")
|
||||
return # Successful completion
|
||||
return full_response # Return complete response
|
||||
# Handle other finish reasons if needed, e.g., 'length'
|
||||
|
||||
except StopIteration:
|
||||
logging.info("Stream iterator exhausted.")
|
||||
return # End of stream normally
|
||||
return full_response # Return complete response
|
||||
except Timeout as e:
|
||||
logging.warning(f"Stream chunk timeout: {e}. Retrying if possible ({retries + 1}/{self.max_retries}).")
|
||||
last_exception = e
|
||||
@ -300,7 +315,6 @@ class AI_Agent():
|
||||
logging.error(f"Stream generation failed after {self.max_retries} retries.")
|
||||
raise last_exception or Exception("Stream generation failed after max retries.")
|
||||
|
||||
|
||||
except (Timeout, APITimeoutError, APIConnectionError, RateLimitError) as e:
|
||||
retries += 1
|
||||
last_exception = e
|
||||
@ -311,16 +325,246 @@ class AI_Agent():
|
||||
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
|
||||
else:
|
||||
logging.error(f"API call failed after {self.max_retries} retries.")
|
||||
raise last_exception
|
||||
return f"[生成内容失败: {str(last_exception)}]" # Return error message as string
|
||||
except Exception as e:
|
||||
# Catch unexpected errors during stream setup
|
||||
logging.error(f"Unexpected error setting up stream: {traceback.format_exc()}")
|
||||
raise e # Re-raise unexpected errors immediately
|
||||
return f"[生成内容失败: {str(e)}]" # Return error message as string
|
||||
|
||||
# Should not be reached if logic is correct, but as a safeguard:
|
||||
logging.error("Exited stream generation loop unexpectedly.")
|
||||
raise last_exception or Exception("Stream generation failed.")
|
||||
return f"[生成内容失败: 超过最大重试次数]" # Return error message as string
|
||||
|
||||
def generate_text_stream_with_callback(self, system_prompt, user_prompt,
|
||||
callback_fn, temperature=0.7, top_p=0.9,
|
||||
presence_penalty=0.0):
|
||||
"""生成文本流并通过回调函数处理每个块
|
||||
|
||||
Args:
|
||||
system_prompt: 系统提示词
|
||||
user_prompt: 用户提示词
|
||||
callback_fn: 处理每个文本块的回调函数,接收(content, is_last, is_timeout, is_error, error)参数
|
||||
temperature: 温度参数
|
||||
top_p: 核采样参数
|
||||
presence_penalty: 存在惩罚参数
|
||||
|
||||
Returns:
|
||||
str: 完整的响应文本
|
||||
"""
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
logging.info(f"Generating text stream with callback using model: {self.model_name}")
|
||||
|
||||
retries = 0
|
||||
backoff_time = INITIAL_BACKOFF
|
||||
last_exception = None
|
||||
full_response = ""
|
||||
|
||||
while retries < self.max_retries:
|
||||
try:
|
||||
logging.debug(f"Attempt {retries + 1}/{self.max_retries} to generate text stream with callback.")
|
||||
stream = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
presence_penalty=presence_penalty,
|
||||
stream=True,
|
||||
timeout=self.timeout
|
||||
)
|
||||
|
||||
chunk_iterator = iter(stream)
|
||||
last_chunk_time = time.time()
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
# 检查上次接收块以来的超时
|
||||
if time.time() - last_chunk_time > self.stream_chunk_timeout:
|
||||
callback_fn("", is_last=True, is_timeout=True, is_error=False, error=None)
|
||||
raise Timeout(f"No chunk received for {self.stream_chunk_timeout} seconds.")
|
||||
|
||||
chunk = next(chunk_iterator)
|
||||
last_chunk_time = time.time() # 成功接收块后重置计时器
|
||||
|
||||
content = ""
|
||||
is_last = False
|
||||
|
||||
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
|
||||
content = chunk.choices[0].delta.content
|
||||
full_response += content
|
||||
|
||||
if chunk.choices and chunk.choices[0].finish_reason == 'stop':
|
||||
is_last = True
|
||||
|
||||
# 调用回调函数处理块
|
||||
callback_fn(content, is_last=is_last, is_timeout=False, is_error=False, error=None)
|
||||
|
||||
if is_last:
|
||||
logging.info("Stream with callback finished normally.")
|
||||
return full_response # 成功完成,返回完整响应
|
||||
|
||||
except StopIteration:
|
||||
# 正常结束流
|
||||
callback_fn("", is_last=True, is_timeout=False, is_error=False, error=None)
|
||||
logging.info("Stream iterator exhausted normally.")
|
||||
return full_response
|
||||
|
||||
except Timeout as e:
|
||||
logging.warning(f"Stream chunk timeout: {e}")
|
||||
last_exception = e
|
||||
# 超时信息已通过回调传递,此处不需要再次调用回调
|
||||
except (APITimeoutError, APIConnectionError, RateLimitError) as e:
|
||||
logging.warning(f"API error during streaming with callback: {type(e).__name__} - {e}")
|
||||
callback_fn("", is_last=True, is_timeout=False, is_error=True, error=str(e))
|
||||
last_exception = e
|
||||
except Exception as e:
|
||||
logging.error(f"Unexpected error during streaming with callback: {traceback.format_exc()}")
|
||||
callback_fn("", is_last=True, is_timeout=False, is_error=True, error=str(e))
|
||||
last_exception = e
|
||||
|
||||
# 执行重试逻辑
|
||||
retries += 1
|
||||
if retries < self.max_retries:
|
||||
retry_msg = f"将在 {backoff_time:.2f} 秒后重试..."
|
||||
logging.info(f"Retrying stream in {backoff_time} seconds...")
|
||||
callback_fn(f"\n[{retry_msg}]\n", is_last=False, is_timeout=False, is_error=False, error=None)
|
||||
time.sleep(backoff_time + random.uniform(0, 1)) # 添加随机抖动
|
||||
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
|
||||
else:
|
||||
error_msg = f"API call failed after {self.max_retries} retries: {str(last_exception)}"
|
||||
logging.error(error_msg)
|
||||
callback_fn(f"\n[{error_msg}]\n", is_last=True, is_timeout=False, is_error=True, error=str(last_exception))
|
||||
return full_response # 返回已收集的部分响应
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error setting up stream with callback: {traceback.format_exc()}")
|
||||
callback_fn("", is_last=True, is_timeout=False, is_error=True, error=str(e))
|
||||
return f"[生成内容失败: {str(e)}]"
|
||||
|
||||
# 作为安全措施
|
||||
error_msg = "超出最大重试次数"
|
||||
logging.error(error_msg)
|
||||
callback_fn(f"\n[{error_msg}]\n", is_last=True, is_timeout=False, is_error=True, error=error_msg)
|
||||
return full_response
|
||||
|
||||
async def async_generate_text_stream(self, system_prompt, user_prompt, temperature=0.7, top_p=0.9, presence_penalty=0.0):
|
||||
"""异步生成文本流
|
||||
|
||||
Args:
|
||||
system_prompt: 系统提示词
|
||||
user_prompt: 用户提示词
|
||||
temperature: 温度参数
|
||||
top_p: 核采样参数
|
||||
presence_penalty: 存在惩罚参数
|
||||
|
||||
Yields:
|
||||
str: 生成的文本块
|
||||
|
||||
Raises:
|
||||
Exception: 如果API调用在所有重试后失败
|
||||
"""
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
logging.info(f"Asynchronously generating text stream with model: {self.model_name}")
|
||||
|
||||
retries = 0
|
||||
backoff_time = INITIAL_BACKOFF
|
||||
last_exception = None
|
||||
|
||||
while retries < self.max_retries:
|
||||
try:
|
||||
logging.debug(f"Async attempt {retries + 1}/{self.max_retries} to generate text stream.")
|
||||
# 创建新的客户端用于异步操作
|
||||
async_client = OpenAI(
|
||||
api_key=self.api,
|
||||
base_url=self.base_url,
|
||||
timeout=self.timeout
|
||||
)
|
||||
|
||||
stream = await async_client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
presence_penalty=presence_penalty,
|
||||
stream=True,
|
||||
timeout=self.timeout
|
||||
)
|
||||
|
||||
last_chunk_time = time.time()
|
||||
try:
|
||||
async for chunk in stream:
|
||||
# 检查上次接收块以来的超时
|
||||
current_time = time.time()
|
||||
if current_time - last_chunk_time > self.stream_chunk_timeout:
|
||||
raise Timeout(f"No chunk received for {self.stream_chunk_timeout} seconds.")
|
||||
|
||||
last_chunk_time = current_time # 成功接收块后重置计时器
|
||||
|
||||
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
|
||||
content = chunk.choices[0].delta.content
|
||||
yield content
|
||||
|
||||
logging.info("Async stream finished normally.")
|
||||
return # 成功完成
|
||||
|
||||
except AsyncTimeoutError as e:
|
||||
logging.warning(f"Async stream timeout: {e}")
|
||||
last_exception = e
|
||||
except Timeout as e:
|
||||
logging.warning(f"Async stream chunk timeout: {e}")
|
||||
last_exception = e
|
||||
except Exception as e:
|
||||
logging.error(f"Error during async streaming: {traceback.format_exc()}")
|
||||
last_exception = e
|
||||
|
||||
# 执行重试逻辑
|
||||
retries += 1
|
||||
if retries < self.max_retries:
|
||||
logging.info(f"Retrying async stream in {backoff_time} seconds...")
|
||||
await asyncio.sleep(backoff_time + random.uniform(0, 1)) # 使用异步睡眠
|
||||
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
|
||||
else:
|
||||
logging.error(f"Async API call failed after {self.max_retries} retries.")
|
||||
raise last_exception or Exception("Async stream generation failed after max retries.")
|
||||
|
||||
except (Timeout, AsyncTimeoutError, APITimeoutError, APIConnectionError, RateLimitError) as e:
|
||||
retries += 1
|
||||
last_exception = e
|
||||
logging.warning(f"Async attempt {retries}/{self.max_retries} failed: {type(e).__name__} - {e}")
|
||||
if retries < self.max_retries:
|
||||
logging.info(f"Retrying async stream in {backoff_time} seconds...")
|
||||
await asyncio.sleep(backoff_time + random.uniform(0, 1)) # 使用异步睡眠
|
||||
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
|
||||
else:
|
||||
logging.error(f"Async API call failed after {self.max_retries} retries.")
|
||||
raise last_exception
|
||||
except Exception as e:
|
||||
logging.error(f"Unexpected error setting up async stream: {traceback.format_exc()}")
|
||||
raise e # 立即重新引发意外错误
|
||||
|
||||
# 作为安全措施
|
||||
logging.error("Exited async stream generation loop unexpectedly.")
|
||||
raise last_exception or Exception("Async stream generation failed.")
|
||||
|
||||
async def async_work_stream(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty):
|
||||
"""异步完整工作流程:读取文件夹(如果提供),然后生成文本流"""
|
||||
logging.info(f"Starting async 'work_stream' process. File folder: {file_folder}")
|
||||
if file_folder:
|
||||
logging.info(f"Reading context from folder: {file_folder}")
|
||||
context = self.read_folder(file_folder)
|
||||
if context:
|
||||
user_prompt = f"{user_prompt.strip()}\n\n--- 参考资料 ---\n{context.strip()}"
|
||||
else:
|
||||
logging.warning(f"Folder {file_folder} provided but no content read.")
|
||||
|
||||
logging.info("Calling async_generate_text_stream...")
|
||||
return self.async_generate_text_stream(system_prompt, user_prompt, temperature, top_p, presence_penalty)
|
||||
|
||||
def work_stream(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty):
|
||||
"""完整的工作流程(流式):读取文件夹(如果提供),然后生成文本流"""
|
||||
@ -333,6 +577,6 @@ class AI_Agent():
|
||||
else:
|
||||
logging.warning(f"Folder {file_folder} provided but no content read.")
|
||||
|
||||
logging.info("Calling generate_text_stream...")
|
||||
return self.generate_text_stream(system_prompt, user_prompt, temperature, top_p, presence_penalty)
|
||||
# --- End Added Streaming Methods ---
|
||||
full_response = self.generate_text_stream(system_prompt, user_prompt, temperature, top_p, presence_penalty)
|
||||
return full_response
|
||||
# --- End Streaming Methods ---
|
||||
@ -45,6 +45,11 @@ class ContentGenerator:
|
||||
self.top_p = 0.8
|
||||
self.presence_penalty = 1.2
|
||||
|
||||
# 设置日志
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_infomation(self, info_directory_path):
|
||||
"""
|
||||
@ -193,142 +198,149 @@ class ContentGenerator:
|
||||
返回:
|
||||
生成的海报内容
|
||||
"""
|
||||
if system_prompt is None:
|
||||
system_prompt = """
|
||||
你是一名资深海报设计师,有丰富的爆款海报设计经验,你现在要为旅游景点做宣传,在小红书上发布大量宣传海报。你的主要工作目标有2个:
|
||||
1、你要根据我给你的图片描述和笔记推文内容,设计图文匹配的海报。
|
||||
2、为海报设计文案,文案的<第一个小标题>和<第二个小标题>之间你需要检查是否逻辑关系合理,你将通过先去生成<第二个小标题>关于景区亮点的部分,再去综合判断<第一个小标题>应该如何搭配组合更符合两个小标题的逻辑再生成<第一个小标题>。
|
||||
full_response = ""
|
||||
timeout = 60 # 请求超时时间(秒)
|
||||
|
||||
其中,生成三类标题文案的通用性要求如下:
|
||||
1、生成的<大标题>字数必须小于8个字符
|
||||
2、生成的<第一个小标题>字数和<第二个小标题>字数,两者都必须小8个字符
|
||||
3、标题和文案都应符合中国社会主义核心价值观
|
||||
if not system_prompt:
|
||||
# 使用默认系统提示词
|
||||
system_prompt = f"""
|
||||
你是一个专业的文案处理专家,擅长从文章中提取关键信息并生成吸引人的标题和简短描述。
|
||||
现在,我需要你根据提供的文章内容,生成{poster_num}个海报的文案配置。
|
||||
|
||||
接下来先开始生成<大标题>部分,由于海报是用来宣传旅游景点,生成的海报<大标题>必须使用以下8种格式之一:
|
||||
①地名+景点名(例如福建厦门鼓浪屿/厦门鼓浪屿);
|
||||
②地名+景点名+plog;
|
||||
③拿捏+地名+景点名;
|
||||
④地名+景点名+攻略;
|
||||
⑤速通+地名+景点名
|
||||
⑥推荐!+地名+景点名
|
||||
⑦勇闯!+地名+景点名
|
||||
⑧收藏!+地名+景点名
|
||||
你需要随机挑选一种格式生成对应景点的文案,但是格式除了上面8种不可以有其他任何格式;同时尽量保证每一种格式出现的频率均衡。
|
||||
接下来先去生成<第二个小标题>,<第二个小标题>文案的创作必须遵循以下原则:
|
||||
请根据笔记内容和图片识别,用极简的文字概括这篇笔记和图片中景点的特色亮点,其中你可以参考以下词汇进行创作,这段文案字数控制6-8字符以内;
|
||||
每个配置包含:
|
||||
1. main_title:主标题,简短有力,突出景点特点
|
||||
2. texts:两句简短文本,每句不超过15字,描述景点特色或游玩体验
|
||||
|
||||
特色亮点可能会出现的词汇不完全举例:非遗、古建、绝佳山水、祈福圣地、研学圣地、解压天堂、中国小瑞士、秘境竹筏游等等类型词汇
|
||||
|
||||
接下来再去生成<第一个小标题>,<第一个小标题>文案的创作必须遵循以下原则:
|
||||
这部分文案创作公式有5种,分别为:
|
||||
①<受众人群画像>+<痛点词>
|
||||
②<受众人群画像>
|
||||
③<痛点词>
|
||||
④<受众人群画像>+ | +<痛点词>
|
||||
⑤<痛点词>+ | +<受众人群画像>
|
||||
请你根据实际笔记内容,结合这部分文案创作公式,需要结合<受众人群画像>和<痛点词>时,必须根据<第二个小标题>的景点特征和所对应的完整笔记推文内容主旨,特征挑选对应<受众人群画像>和<痛点词>。
|
||||
|
||||
我给你提供受众人群画像库和痛点词库如下:
|
||||
1、受众人群画像库:情侣党、亲子游、合家游、银发族、亲子研学、学生党、打工人、周边游、本地人、穷游党、性价比、户外人、美食党、出片
|
||||
2、痛点词库:3天2夜、必去、看了都哭了、不能错过、一定要来、问爆了、超全攻略、必打卡、强推、懒人攻略、必游榜、小众打卡、狂喜等等。
|
||||
|
||||
你需要为每个请求至少生成{poster_num}个海报设计。请使用JSON格式输出结果,结构如下:
|
||||
|
||||
```json
|
||||
以JSON数组格式返回配置,示例:
|
||||
[
|
||||
{
|
||||
"index": 1,
|
||||
"main_title": "主标题内容",
|
||||
"texts": ["第一个小标题", "第二个小标题"]
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"main_title": "主标题内容",
|
||||
"texts": ["第一个小标题", "第二个小标题"]
|
||||
}
|
||||
// ... 更多海报
|
||||
{{
|
||||
"main_title": "泰宁古城",
|
||||
"texts": ["千年古韵","匠心独运"]
|
||||
}},
|
||||
...
|
||||
]
|
||||
```
|
||||
|
||||
确保生成的数量与用户要求的数量一致。只生成上述JSON格式内容,不要有其他任何额外内容。
|
||||
仅返回JSON数据,不需要任何额外解释。确保生成的标题和文本能够准确反映文章提到的景点特色。
|
||||
"""
|
||||
|
||||
user_content = f"""
|
||||
海报数量:{poster_num};
|
||||
景区介绍:{self.add_description};
|
||||
推文内容:{tweet_content};
|
||||
"""
|
||||
if self.add_description:
|
||||
# 创建用户内容,包括info信息和tweet_content
|
||||
user_content = f"""
|
||||
以下是需要你处理的信息:
|
||||
|
||||
# 最终响应内容
|
||||
full_response = ""
|
||||
关于景点的描述:
|
||||
{self.add_description}
|
||||
|
||||
推文内容:
|
||||
{tweet_content}
|
||||
|
||||
请根据这些信息,生成{poster_num}个海报文案配置,以JSON数组格式返回。
|
||||
确保主标题(main_title)简短有力,每个text不超过15字,并能准确反映景点特色。
|
||||
"""
|
||||
else:
|
||||
# 仅使用tweet_content
|
||||
user_content = f"""
|
||||
以下是需要你处理的推文内容:
|
||||
{tweet_content}
|
||||
|
||||
请根据这些信息,生成{poster_num}个海报文案配置,以JSON数组格式返回。
|
||||
确保主标题(main_title)简短有力,每个text不超过15字,并能准确反映景点特色。
|
||||
"""
|
||||
|
||||
self.logger.info(f"正在生成{poster_num}个海报文案配置")
|
||||
|
||||
# 创建临时客户端
|
||||
temp_client = self._create_temp_client()
|
||||
|
||||
# 如果创建客户端失败,直接使用备用方案
|
||||
if temp_client is None:
|
||||
print("创建OpenAI客户端失败,使用备用方案生成内容")
|
||||
return 404
|
||||
else:
|
||||
# 添加重试机制
|
||||
if temp_client:
|
||||
# 重试逻辑
|
||||
for retry in range(max_retries):
|
||||
try:
|
||||
print(f"尝试连接API (尝试 {retry+1}/{max_retries})...")
|
||||
self.logger.info(f"尝试生成内容 (尝试 {retry+1}/{max_retries})")
|
||||
|
||||
# 计算退避时间(指数退避策略):0, 2, 4, 8, 16...秒
|
||||
if retry > 0:
|
||||
backoff_time = min(2 ** (retry - 1) * 2, 30) # 最大等待30秒
|
||||
print(f"等待 {backoff_time} 秒后重试...")
|
||||
time.sleep(backoff_time)
|
||||
# 定义流式响应处理回调函数
|
||||
def handle_stream_chunk(chunk, is_last=False, is_timeout=False, is_error=False, error=None):
|
||||
nonlocal full_response
|
||||
|
||||
# 设置超时时间随重试次数递增
|
||||
timeout = 30 + (retry * 30) # 30, 60, 90, ...秒
|
||||
if chunk:
|
||||
full_response += chunk
|
||||
# 实时输出到控制台
|
||||
print(chunk, end="", flush=True)
|
||||
|
||||
chat_response = temp_client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_content}
|
||||
],
|
||||
stream=True,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
presence_penalty=self.presence_penalty,
|
||||
timeout=timeout # 设置请求超时时间
|
||||
if is_last:
|
||||
print("\n") # 输出完成后换行
|
||||
if is_timeout:
|
||||
print("警告: 响应流超时")
|
||||
if is_error:
|
||||
print(f"错误: {error}")
|
||||
|
||||
# 使用AI_Agent的新回调方式
|
||||
from core.ai_agent import AI_Agent
|
||||
ai_agent = AI_Agent(
|
||||
self.api_base_url,
|
||||
self.model_name,
|
||||
self.api_key,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
stream_chunk_timeout=30 # 流式块超时时间
|
||||
)
|
||||
|
||||
# 获取响应内容
|
||||
for chunk in chat_response:
|
||||
if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None:
|
||||
content = chunk.choices[0].delta.content
|
||||
full_response += content
|
||||
print(content, end="", flush=True)
|
||||
if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].finish_reason == "stop":
|
||||
break
|
||||
# 使用回调方式处理流式响应
|
||||
try:
|
||||
full_response = ai_agent.generate_text_stream_with_callback(
|
||||
system_prompt,
|
||||
user_content,
|
||||
handle_stream_chunk,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
presence_penalty=self.presence_penalty
|
||||
)
|
||||
|
||||
print("\n") # 输出完成后换行
|
||||
# 如果成功生成内容,跳出重试循环
|
||||
ai_agent.close()
|
||||
break
|
||||
|
||||
# 成功获取响应,跳出重试循环
|
||||
break
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
self.logger.error(f"AI生成错误: {error_msg}")
|
||||
ai_agent.close()
|
||||
|
||||
# 继续重试逻辑
|
||||
if retry + 1 >= max_retries:
|
||||
self.logger.warning("已达到最大重试次数,使用备用方案...")
|
||||
# 生成备用内容
|
||||
full_response = self._generate_fallback_content(poster_num)
|
||||
else:
|
||||
self.logger.info(f"将在稍后重试,还剩 {max_retries - retry - 1} 次重试机会")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
print(f"API连接错误 (尝试 {retry+1}/{max_retries}): {error_msg}")
|
||||
self.logger.error(f"API连接错误 (尝试 {retry+1}/{max_retries}): {error_msg}")
|
||||
|
||||
# 如果已经达到最大重试次数
|
||||
if retry + 1 >= max_retries:
|
||||
print("已达到最大重试次数,使用备用方案...")
|
||||
self.logger.warning("已达到最大重试次数,使用备用方案...")
|
||||
# 生成备用内容(简单模板)
|
||||
full_response = self._generate_fallback_content(poster_num)
|
||||
else:
|
||||
print(f"将在稍后重试,还剩 {max_retries - retry - 1} 次重试机会")
|
||||
self.logger.info(f"将在稍后重试,还剩 {max_retries - retry - 1} 次重试机会")
|
||||
|
||||
# 关闭临时客户端
|
||||
self._close_client(temp_client)
|
||||
|
||||
# 生成时间戳
|
||||
return full_response
|
||||
|
||||
def _generate_fallback_content(self, poster_num):
|
||||
"""生成备用内容,当API调用失败时使用"""
|
||||
self.logger.info("生成备用内容")
|
||||
default_configs = []
|
||||
for i in range(poster_num):
|
||||
default_configs.append({
|
||||
"main_title": f"景点风光 {i+1}",
|
||||
"texts": ["自然美景", "人文体验"]
|
||||
})
|
||||
return json.dumps(default_configs, ensure_ascii=False)
|
||||
|
||||
def save_result(self, full_response):
|
||||
"""
|
||||
保存生成结果到文件
|
||||
|
||||
209
docs/streaming.md
Normal file
209
docs/streaming.md
Normal file
@ -0,0 +1,209 @@
|
||||
# 流式输出处理功能
|
||||
|
||||
TravelContentCreator 现在支持三种不同的流式输出处理方法,让您能够更灵活地处理 AI 模型生成的文本内容。这些方法都在 `AI_Agent` 类中实现,可以根据不同的使用场景进行选择。
|
||||
|
||||
## 为什么需要流式处理?
|
||||
|
||||
流式处理(Streaming)相比于传统的一次性返回完整响应的方式有以下优势:
|
||||
|
||||
1. **实时性**:内容生成的同时即可开始处理,无需等待完整响应
|
||||
2. **用户体验更好**:可以实现"打字机效果",让用户看到文本逐步生成的过程
|
||||
3. **更早检测错误**:可以在响应生成过程中及早发现问题
|
||||
4. **长文本处理更高效**:特别适合生成较长的内容,避免长时间等待
|
||||
|
||||
## 流式处理方法
|
||||
|
||||
`AI_Agent` 类提供了三种不同模式的流式处理方法:
|
||||
|
||||
### 1. 同步流式响应 (generate_text_stream)
|
||||
|
||||
这种方法虽然使用了流式 API 连接,但会将所有的输出整合后一次性返回,适合简单的 API 调用。
|
||||
|
||||
```python
|
||||
def generate_text_stream(self, system_prompt, user_prompt, temperature, top_p, presence_penalty):
|
||||
"""
|
||||
生成文本内容(使用流式API但返回完整响应)
|
||||
|
||||
Args:
|
||||
system_prompt: 系统提示词
|
||||
user_prompt: 用户提示词
|
||||
temperature: 温度参数
|
||||
top_p: 核采样参数
|
||||
presence_penalty: 存在惩罚参数
|
||||
|
||||
Returns:
|
||||
str: 完整的生成文本
|
||||
"""
|
||||
```
|
||||
|
||||
使用示例:
|
||||
|
||||
```python
|
||||
agent = AI_Agent(base_url, model_name, api_key, timeout=30, stream_chunk_timeout=10)
|
||||
result = agent.generate_text_stream(system_prompt, user_prompt, 0.7, 0.9, 0.0)
|
||||
print(result) # 输出完整的生成结果
|
||||
```
|
||||
|
||||
### 2. 回调式流式响应 (generate_text_stream_with_callback)
|
||||
|
||||
这种方法使用回调函数来处理流中的每个文本块,非常适合实时显示、分析或保存过程数据,更加灵活。
|
||||
|
||||
```python
|
||||
def generate_text_stream_with_callback(self, system_prompt, user_prompt,
|
||||
callback_fn, temperature=0.7, top_p=0.9,
|
||||
presence_penalty=0.0):
|
||||
"""
|
||||
生成文本流并通过回调函数处理每个块
|
||||
|
||||
Args:
|
||||
system_prompt: 系统提示词
|
||||
user_prompt: 用户提示词
|
||||
callback_fn: 处理每个文本块的回调函数,接收(content, is_last, is_timeout, is_error, error)参数
|
||||
temperature: 温度参数
|
||||
top_p: 核采样参数
|
||||
presence_penalty: 存在惩罚参数
|
||||
|
||||
Returns:
|
||||
str: 完整的响应文本
|
||||
"""
|
||||
```
|
||||
|
||||
回调函数应符合以下格式:
|
||||
|
||||
```python
|
||||
def my_callback(content, is_last=False, is_timeout=False, is_error=False, error=None):
|
||||
"""
|
||||
处理流式响应的回调函数
|
||||
|
||||
Args:
|
||||
content: 文本块内容
|
||||
is_last: 是否为最后一个块
|
||||
is_timeout: 是否发生超时
|
||||
is_error: 是否发生错误
|
||||
error: 错误信息
|
||||
"""
|
||||
if content:
|
||||
print(content, end="", flush=True) # 实时打印
|
||||
|
||||
# 处理特殊情况
|
||||
if is_last:
|
||||
print("\n完成生成")
|
||||
if is_timeout:
|
||||
print("警告: 响应流超时")
|
||||
if is_error:
|
||||
print(f"错误: {error}")
|
||||
```
|
||||
|
||||
使用示例:
|
||||
|
||||
```python
|
||||
agent = AI_Agent(base_url, model_name, api_key, timeout=30, stream_chunk_timeout=10)
|
||||
result = agent.generate_text_stream_with_callback(
|
||||
system_prompt,
|
||||
user_prompt,
|
||||
my_callback, # 传入回调函数
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
presence_penalty=0.0
|
||||
)
|
||||
```
|
||||
|
||||
### 3. 异步流式响应 (async_generate_text_stream)
|
||||
|
||||
这种方法基于 `asyncio`,返回一个异步生成器,非常适合与其他异步操作集成,例如在异步网络应用中使用。
|
||||
|
||||
```python
|
||||
async def async_generate_text_stream(self, system_prompt, user_prompt, temperature=0.7, top_p=0.9, presence_penalty=0.0):
|
||||
"""
|
||||
异步生成文本流
|
||||
|
||||
Args:
|
||||
system_prompt: 系统提示词
|
||||
user_prompt: 用户提示词
|
||||
temperature: 温度参数
|
||||
top_p: 核采样参数
|
||||
presence_penalty: 存在惩罚参数
|
||||
|
||||
Yields:
|
||||
str: 生成的文本块
|
||||
|
||||
Raises:
|
||||
Exception: 如果API调用在所有重试后失败
|
||||
"""
|
||||
```
|
||||
|
||||
使用示例:
|
||||
|
||||
```python
|
||||
async def demo_async_stream():
|
||||
agent = AI_Agent(base_url, model_name, api_key, timeout=30, stream_chunk_timeout=10)
|
||||
|
||||
full_response = ""
|
||||
try:
|
||||
# 使用异步生成器
|
||||
async for chunk in agent.async_generate_text_stream(
|
||||
system_prompt,
|
||||
user_prompt,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
presence_penalty=0.0
|
||||
):
|
||||
print(chunk, end="", flush=True) # 实时显示
|
||||
full_response += chunk
|
||||
|
||||
# 这里可以同时执行其他异步操作
|
||||
# await some_other_async_operation()
|
||||
|
||||
except Exception as e:
|
||||
print(f"错误: {e}")
|
||||
finally:
|
||||
agent.close()
|
||||
|
||||
# 在异步环境中运行
|
||||
asyncio.run(demo_async_stream())
|
||||
```
|
||||
|
||||
## 超时处理
|
||||
|
||||
所有流式处理方法都支持两种超时设置:
|
||||
|
||||
1. **全局请求超时**:控制整个API请求的最大持续时间
|
||||
2. **流块超时**:控制接收连续两个数据块之间的最大等待时间
|
||||
|
||||
在创建 `AI_Agent` 实例时设置:
|
||||
|
||||
```python
|
||||
agent = AI_Agent(
|
||||
base_url="your_api_base_url",
|
||||
model_name="your_model_name",
|
||||
api="your_api_key",
|
||||
timeout=60, # 整体请求超时(秒)
|
||||
max_retries=3, # 最大重试次数
|
||||
stream_chunk_timeout=10 # 流块超时(秒)
|
||||
)
|
||||
```
|
||||
|
||||
## 完整示例
|
||||
|
||||
我们提供了一个完整的示例脚本,演示所有三种流式处理方法的使用:
|
||||
|
||||
```
|
||||
examples/test_stream.py
|
||||
```
|
||||
|
||||
运行方式:
|
||||
|
||||
```bash
|
||||
cd TravelContentCreator
|
||||
python examples/test_stream.py
|
||||
```
|
||||
|
||||
这将依次演示三种流式处理方法,并展示它们的输出和性能差异。
|
||||
|
||||
## 在WebUI中的应用
|
||||
|
||||
TravelContentCreator的WebUI已经集成了基于回调的流式处理,实现了生成内容的实时显示,大大提升了用户体验,特别是在生成长篇内容时。
|
||||
|
||||
## 在自定义项目中使用
|
||||
|
||||
如果您想在自己的项目中使用这些流式处理功能,只需导入 `AI_Agent` 类并按照上述示例使用相应的方法即可。所有流式处理方法都内置了完善的错误处理和重试机制,提高了生产环境中的稳定性。
|
||||
@ -156,6 +156,31 @@ python examples/test_stream.py
|
||||
- 演示实时输出文本块的处理方式
|
||||
- 说明如何收集完整响应和估计token数
|
||||
|
||||
## 流式处理演示 (test_stream.py)
|
||||
|
||||
`test_stream.py` 是一个演示脚本,展示了AI_Agent类中新增的三种不同流式输出处理方法:
|
||||
|
||||
1. **同步流式响应** (generate_text_stream) - 使用流式API但返回完整响应
|
||||
2. **回调式流式响应** (generate_text_stream_with_callback) - 通过回调函数处理每个文本块
|
||||
3. **异步流式响应** (async_generate_text_stream) - 使用异步生成器返回文本流
|
||||
|
||||
### 运行演示
|
||||
|
||||
```bash
|
||||
cd TravelContentCreator
|
||||
python examples/test_stream.py
|
||||
```
|
||||
|
||||
演示脚本会依次执行这三种方法,并展示它们的输出和性能差异。
|
||||
|
||||
### 更多信息
|
||||
|
||||
关于流式处理功能的详细文档,请参阅:
|
||||
|
||||
```
|
||||
docs/streaming.md
|
||||
```
|
||||
|
||||
## 其他示例
|
||||
|
||||
`generate_poster.py` 是一个简化的海报生成示例,主要用于快速测试特定图片的海报效果。
|
||||
|
||||
103
examples/README_streaming.md
Normal file
103
examples/README_streaming.md
Normal file
@ -0,0 +1,103 @@
|
||||
# 流式处理功能说明
|
||||
|
||||
本文档介绍了Travel Content Creator中新增的三种流式输出处理方式,以及如何使用它们来优化内容生成体验。
|
||||
|
||||
## 新增的流式处理方式
|
||||
|
||||
我们在`AI_Agent`类中实现了三种不同的流式处理方式:
|
||||
|
||||
1. **同步流式响应** (`generate_text_stream`)
|
||||
- 已修改为返回完整响应,不再是生成器
|
||||
- 内部仍使用流式请求以获得更好的超时控制
|
||||
- 适用于简单的API调用场景
|
||||
|
||||
2. **基于回调的流式响应** (`generate_text_stream_with_callback`)
|
||||
- 通过回调函数处理每个文本块
|
||||
- 可在回调中实现实时显示、分析或保存
|
||||
- 最灵活的选项,适合需要自定义处理流程的场景
|
||||
|
||||
3. **异步流式响应** (`async_generate_text_stream`)
|
||||
- 基于`asyncio`的异步生成器
|
||||
- 适用于需要与其他异步操作集成的场景
|
||||
- 可在保持响应性的同时处理长时间运行的请求
|
||||
|
||||
## 演示脚本
|
||||
|
||||
我们提供了一个示例脚本`test_stream.py`,演示了如何使用这三种流式处理方式:
|
||||
|
||||
```bash
|
||||
# 从项目根目录运行
|
||||
python examples/test_stream.py
|
||||
```
|
||||
|
||||
## 回调函数示例
|
||||
|
||||
以下是使用回调函数处理流式响应的例子:
|
||||
|
||||
```python
|
||||
def handle_chunk(chunk, is_last=False, is_timeout=False, is_error=False, error=None):
|
||||
# 处理文本块
|
||||
if chunk:
|
||||
print(chunk, end="", flush=True) # 实时显示
|
||||
|
||||
# 处理结束状态
|
||||
if is_last:
|
||||
if is_timeout:
|
||||
print("\n流式响应超时")
|
||||
if is_error:
|
||||
print(f"\n发生错误: {error}")
|
||||
|
||||
# 使用回调函数进行流式处理
|
||||
response = agent.generate_text_stream_with_callback(
|
||||
system_prompt,
|
||||
user_prompt,
|
||||
handle_chunk, # 传入回调函数
|
||||
temperature=0.7
|
||||
)
|
||||
```
|
||||
|
||||
## 异步流式处理示例
|
||||
|
||||
以下是使用异步方式处理流式响应的例子:
|
||||
|
||||
```python
|
||||
async def process_stream():
|
||||
async for chunk in agent.async_generate_text_stream(
|
||||
system_prompt,
|
||||
user_prompt,
|
||||
temperature=0.7
|
||||
):
|
||||
print(chunk, end="", flush=True) # 实时显示
|
||||
|
||||
# 可以同时执行其他异步操作
|
||||
await other_async_task()
|
||||
|
||||
# 在异步环境中运行
|
||||
asyncio.run(process_stream())
|
||||
```
|
||||
|
||||
## 超时处理
|
||||
|
||||
所有流式处理方法都支持超时控制:
|
||||
|
||||
1. **全局请求超时**:控制整个请求的最长等待时间
|
||||
2. **流式块超时**:控制两个连续文本块之间的最长等待时间
|
||||
|
||||
这些参数可以在创建`AI_Agent`实例时设置:
|
||||
|
||||
```python
|
||||
agent = AI_Agent(
|
||||
api_url="http://localhost:8000/v1/",
|
||||
model="qwen",
|
||||
api_key="EMPTY",
|
||||
timeout=30, # 全局请求超时(秒)
|
||||
stream_chunk_timeout=10 # 流式块超时(秒)
|
||||
)
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. 新的流式处理方法内置了重试机制和错误处理
|
||||
2. 回调方式提供了最佳的灵活性和控制
|
||||
3. 异步方式最适合需要保持UI响应性的应用
|
||||
4. 所有方法都会在完成时自动关闭流式请求
|
||||
@ -1,129 +1,175 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
测试AI_Agent的流式处理方法
|
||||
|
||||
此脚本演示TravelContentCreator中AI_Agent类的三种流式输出处理方法:
|
||||
1. 同步流式响应 (generate_text_stream)
|
||||
2. 回调式流式响应 (generate_text_stream_with_callback)
|
||||
3. 异步流式响应 (async_generate_text_stream)
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import asyncio
|
||||
import time
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
# Determine the project root directory (assuming examples/ is one level down)
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
if PROJECT_ROOT not in sys.path:
|
||||
sys.path.append(PROJECT_ROOT)
|
||||
# 添加项目根目录到Python路径
|
||||
project_root = str(Path(__file__).parent.parent)
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
# Now import from core
|
||||
try:
|
||||
from core.ai_agent import AI_Agent
|
||||
except ImportError as e:
|
||||
logging.critical(f"Failed to import AI_Agent. Ensure '{PROJECT_ROOT}' is in sys.path and core/ai_agent.py exists. Error: {e}")
|
||||
sys.exit(1)
|
||||
from core.ai_agent import AI_Agent
|
||||
|
||||
def load_config(config_path):
|
||||
"""Loads configuration from a JSON file."""
|
||||
try:
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
logging.info(f"Config loaded successfully from {config_path}")
|
||||
return config
|
||||
except FileNotFoundError:
|
||||
logging.error(f"Error: Configuration file not found at {config_path}")
|
||||
return None
|
||||
except json.JSONDecodeError:
|
||||
logging.error(f"Error: Could not decode JSON from {config_path}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.exception(f"An unexpected error occurred loading config {config_path}:")
|
||||
return None
|
||||
# 示例提示词
|
||||
SYSTEM_PROMPT = """你是一个专业的旅游内容创作助手,请根据用户的提示生成相关内容。"""
|
||||
USER_PROMPT = """请为我生成一篇关于福建泰宁古城的旅游攻略,包括著名景点、美食推荐和最佳游玩季节。字数控制在300字以内。"""
|
||||
|
||||
def main():
|
||||
# --- Basic Logging Setup ---
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
def print_separator(title):
|
||||
"""打印分隔线和标题"""
|
||||
print("\n" + "="*50)
|
||||
print(f" {title} ".center(50, "="))
|
||||
print("="*50 + "\n")
|
||||
|
||||
def demo_sync_stream():
|
||||
"""演示同步流式响应方法"""
|
||||
print_separator("同步流式响应 (generate_text_stream)")
|
||||
|
||||
# 创建AI_Agent实例
|
||||
agent = AI_Agent(
|
||||
base_url="vllm", # 使用本地vLLM服务
|
||||
model_name="qwen2-7b-instruct", # 或其他您配置的模型名称
|
||||
api="EMPTY", # vLLM不需要API key
|
||||
timeout=60, # 整体请求超时时间(秒)
|
||||
stream_chunk_timeout=10 # 流式块超时时间(秒)
|
||||
)
|
||||
# --- End Logging Setup ---
|
||||
|
||||
logging.info("Starting AI Agent Stream Test...")
|
||||
print("开始生成内容...")
|
||||
start_time = time.time()
|
||||
|
||||
# Load configuration (adjust path relative to this script)
|
||||
config_path = os.path.join(PROJECT_ROOT, "poster_gen_config.json")
|
||||
config = load_config(config_path)
|
||||
if config is None:
|
||||
logging.critical("Failed to load configuration. Exiting test.")
|
||||
sys.exit(1)
|
||||
# 使用同步流式方法
|
||||
result = agent.generate_text_stream(
|
||||
SYSTEM_PROMPT,
|
||||
USER_PROMPT,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
presence_penalty=0.0
|
||||
)
|
||||
|
||||
# Example Prompts
|
||||
system_prompt = "你是一个乐于助人的AI助手,擅长写短篇故事。"
|
||||
user_prompt = "请写一个关于旅行机器人的短篇故事,它在一个充满异国情调的星球上发现了新的生命形式。"
|
||||
end_time = time.time()
|
||||
|
||||
ai_agent = None
|
||||
print(f"\n\n完整生成内容:\n{result}")
|
||||
print(f"\n生成完成! 耗时: {end_time - start_time:.2f}秒")
|
||||
|
||||
# 关闭agent
|
||||
agent.close()
|
||||
|
||||
def demo_callback_stream():
|
||||
"""演示回调式流式响应方法"""
|
||||
print_separator("回调式流式响应 (generate_text_stream_with_callback)")
|
||||
|
||||
# 创建AI_Agent实例
|
||||
agent = AI_Agent(
|
||||
base_url="vllm",
|
||||
model_name="qwen2-7b-instruct",
|
||||
api="EMPTY",
|
||||
timeout=60,
|
||||
stream_chunk_timeout=10
|
||||
)
|
||||
|
||||
# 定义回调函数
|
||||
def my_callback(content, is_last=False, is_timeout=False, is_error=False, error=None):
|
||||
"""处理流式响应的回调函数"""
|
||||
if content:
|
||||
# 实时打印内容,不换行
|
||||
print(content, end="", flush=True)
|
||||
|
||||
if is_last:
|
||||
print("\n")
|
||||
if is_timeout:
|
||||
print("警告: 响应流超时")
|
||||
if is_error:
|
||||
print(f"错误: {error}")
|
||||
|
||||
print("开始生成内容...")
|
||||
start_time = time.time()
|
||||
|
||||
# 使用回调式流式方法
|
||||
result = agent.generate_text_stream_with_callback(
|
||||
SYSTEM_PROMPT,
|
||||
USER_PROMPT,
|
||||
my_callback,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
presence_penalty=0.0
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
print(f"\n生成完成! 耗时: {end_time - start_time:.2f}秒")
|
||||
|
||||
# 关闭agent
|
||||
agent.close()
|
||||
|
||||
async def demo_async_stream():
|
||||
"""演示异步流式响应方法"""
|
||||
print_separator("异步流式响应 (async_generate_text_stream)")
|
||||
|
||||
# 创建AI_Agent实例
|
||||
agent = AI_Agent(
|
||||
base_url="vllm",
|
||||
model_name="qwen2-7b-instruct",
|
||||
api="EMPTY",
|
||||
timeout=60,
|
||||
stream_chunk_timeout=10
|
||||
)
|
||||
|
||||
print("开始生成内容...")
|
||||
start_time = time.time()
|
||||
full_response = ""
|
||||
|
||||
# 使用异步流式方法
|
||||
try:
|
||||
# --- Extract AI Agent parameters from config ---
|
||||
ai_api_url = config.get("api_url")
|
||||
ai_model = config.get("model")
|
||||
ai_api_key = config.get("api_key")
|
||||
request_timeout = config.get("request_timeout", 30)
|
||||
max_retries = config.get("max_retries", 3)
|
||||
stream_chunk_timeout = config.get("stream_chunk_timeout", 60) # Get stream chunk timeout
|
||||
|
||||
# Check for required AI params
|
||||
if not all([ai_api_url, ai_model, ai_api_key]):
|
||||
logging.critical("Missing required AI configuration (api_url, model, api_key) in config. Exiting test.")
|
||||
sys.exit(1)
|
||||
# --- End Extract AI Agent params ---
|
||||
|
||||
logging.info("Initializing AI Agent for stream test...")
|
||||
# Initialize AI_Agent using extracted parameters
|
||||
ai_agent = AI_Agent(
|
||||
api_url=ai_api_url, # Use extracted var
|
||||
model=ai_model, # Use extracted var
|
||||
api_key=ai_api_key, # Use extracted var
|
||||
timeout=request_timeout,
|
||||
max_retries=max_retries,
|
||||
stream_chunk_timeout=stream_chunk_timeout # Pass it here
|
||||
async_stream = agent.async_generate_text_stream(
|
||||
SYSTEM_PROMPT,
|
||||
USER_PROMPT,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
presence_penalty=0.0
|
||||
)
|
||||
|
||||
# Example call to work_stream
|
||||
logging.info("Calling ai_agent.work_stream...")
|
||||
# Extract generation parameters from config
|
||||
temperature = config.get("content_temperature", 0.7) # Use a relevant temperature setting
|
||||
top_p = config.get("content_top_p", 0.9)
|
||||
presence_penalty = config.get("content_presence_penalty", 0.0)
|
||||
# 异步迭代流
|
||||
async for content in async_stream:
|
||||
# 累积完整响应
|
||||
full_response += content
|
||||
# 实时打印内容
|
||||
print(content, end="", flush=True)
|
||||
|
||||
start_time = time.time()
|
||||
stream_generator = ai_agent.work_stream(
|
||||
system_prompt=system_prompt,
|
||||
user_prompt=user_prompt,
|
||||
info_directory=None, # No extra context folder for this test
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
presence_penalty=presence_penalty
|
||||
)
|
||||
|
||||
# Process the stream
|
||||
logging.info("Processing stream response:")
|
||||
full_response = ""
|
||||
for chunk in stream_generator:
|
||||
print(chunk, end="", flush=True) # Keep print for stream output
|
||||
full_response += chunk
|
||||
|
||||
end_time = time.time()
|
||||
logging.info(f"\n--- Stream Finished ---")
|
||||
logging.info(f"Total time: {end_time - start_time:.2f} seconds")
|
||||
logging.info(f"Total characters received: {len(full_response)}")
|
||||
|
||||
except KeyError as e:
|
||||
logging.error(f"Configuration error: Missing key '{e}'. Please check '{config_path}'.")
|
||||
except Exception as e:
|
||||
logging.exception("An error occurred during the stream test:")
|
||||
finally:
|
||||
# Ensure the agent is closed
|
||||
if ai_agent:
|
||||
logging.info("Closing AI Agent...")
|
||||
ai_agent.close()
|
||||
logging.info("AI Agent closed.")
|
||||
print(f"\n生成过程中出错: {e}")
|
||||
|
||||
end_time = time.time()
|
||||
print(f"\n\n生成完成! 耗时: {end_time - start_time:.2f}秒")
|
||||
|
||||
# 关闭agent
|
||||
agent.close()
|
||||
|
||||
async def main():
|
||||
"""主函数"""
|
||||
print("Testing AI_Agent streaming methods...")
|
||||
|
||||
# 1. 测试同步流式响应
|
||||
demo_sync_stream()
|
||||
|
||||
# 2. 测试回调式流式响应
|
||||
demo_callback_stream()
|
||||
|
||||
# 3. 测试异步流式响应
|
||||
await demo_async_stream()
|
||||
|
||||
print("\n所有测试完成!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# 运行异步主函数
|
||||
asyncio.run(main())
|
||||
Loading…
x
Reference in New Issue
Block a user