重写了流式输出

This commit is contained in:
jinye_huang 2025-04-23 16:18:02 +08:00
parent 1919c2e7a8
commit 15176e2eaf
9 changed files with 884 additions and 231 deletions

View File

@ -11,6 +11,20 @@
- **模块化设计**核心功能配置加载、提示词管理、AI交互、选题、内容生成、海报制作分离方便维护和扩展
- **配置驱动**:通过配置文件集中管理所有运行参数
## 新功能: 流式输出处理
TravelContentCreator 现已支持三种流式输出处理方法,提供了更灵活的 AI 文本生成体验:
- **同步流式响应**: 使用流式 API 但返回完整响应
- **回调式流式响应**: 通过回调函数处理每个文本块
- **异步流式响应**: 使用异步生成器返回文本流
这些功能大大提升了长文本生成的用户体验和系统响应性。
详细文档请参阅:
- [流式处理文档](docs/streaming.md)
- [流式处理演示](examples/test_stream.py)
## 快速开始
### 1. 环境准备

View File

@ -5,6 +5,8 @@ import random
import traceback
import logging
import tiktoken
import asyncio
from asyncio import TimeoutError as AsyncTimeoutError
# Configure basic logging for this module (or rely on root logger config)
# logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
@ -16,6 +18,16 @@ INITIAL_BACKOFF = 1 # Initial backoff time in seconds
MAX_BACKOFF = 16 # Maximum backoff time in seconds
STREAM_CHUNK_TIMEOUT = 10 # Timeout in seconds for receiving a chunk in stream
# 自定义超时异常
class Timeout(Exception):
"""Raised when a stream chunk timeout occurs."""
def __init__(self, message="Stream chunk timeout occurred."):
self.message = message
super().__init__(self.message)
def __str__(self):
return self.message
class AI_Agent():
"""AI代理类负责与AI模型交互生成文本内容"""
@ -198,7 +210,7 @@ class AI_Agent():
time_end = time.time()
time_cost = time_end - time_start
tokens = None
tokens = len(result) * 1.3
logging.info(f"'work' completed in {time_cost:.2f}s. Estimated tokens: {tokens}")
return result, tokens, time_cost
@ -220,9 +232,10 @@ class AI_Agent():
user_prompt: The user prompt for the AI.
temperature: Sampling temperature.
top_p: Nucleus sampling parameter.
presence_penalty: Presence penalty parameter.
Yields:
str: Chunks of the generated text.
Returns:
str: The complete generated text.
Raises:
Exception: If the API call fails after all retries.
@ -236,6 +249,7 @@ class AI_Agent():
retries = 0
backoff_time = INITIAL_BACKOFF
last_exception = None
full_response = ""
while retries < self.max_retries:
try:
@ -245,6 +259,7 @@ class AI_Agent():
messages=messages,
temperature=temperature,
top_p=top_p,
presence_penalty=presence_penalty,
stream=True,
timeout=self.timeout # Overall request timeout
)
@ -263,16 +278,16 @@ class AI_Agent():
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
full_response += content
# logging.debug(f"Received chunk: {content}") # Potentially very verbose
yield content
elif chunk.choices and chunk.choices[0].finish_reason == 'stop':
logging.info("Stream finished.")
return # Successful completion
return full_response # Return complete response
# Handle other finish reasons if needed, e.g., 'length'
except StopIteration:
logging.info("Stream iterator exhausted.")
return # End of stream normally
return full_response # Return complete response
except Timeout as e:
logging.warning(f"Stream chunk timeout: {e}. Retrying if possible ({retries + 1}/{self.max_retries}).")
last_exception = e
@ -300,7 +315,6 @@ class AI_Agent():
logging.error(f"Stream generation failed after {self.max_retries} retries.")
raise last_exception or Exception("Stream generation failed after max retries.")
except (Timeout, APITimeoutError, APIConnectionError, RateLimitError) as e:
retries += 1
last_exception = e
@ -311,16 +325,246 @@ class AI_Agent():
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
else:
logging.error(f"API call failed after {self.max_retries} retries.")
raise last_exception
return f"[生成内容失败: {str(last_exception)}]" # Return error message as string
except Exception as e:
# Catch unexpected errors during stream setup
logging.error(f"Unexpected error setting up stream: {traceback.format_exc()}")
raise e # Re-raise unexpected errors immediately
return f"[生成内容失败: {str(e)}]" # Return error message as string
# Should not be reached if logic is correct, but as a safeguard:
logging.error("Exited stream generation loop unexpectedly.")
raise last_exception or Exception("Stream generation failed.")
return f"[生成内容失败: 超过最大重试次数]" # Return error message as string
def generate_text_stream_with_callback(self, system_prompt, user_prompt,
callback_fn, temperature=0.7, top_p=0.9,
presence_penalty=0.0):
"""生成文本流并通过回调函数处理每个块
Args:
system_prompt: 系统提示词
user_prompt: 用户提示词
callback_fn: 处理每个文本块的回调函数接收(content, is_last, is_timeout, is_error, error)参数
temperature: 温度参数
top_p: 核采样参数
presence_penalty: 存在惩罚参数
Returns:
str: 完整的响应文本
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
logging.info(f"Generating text stream with callback using model: {self.model_name}")
retries = 0
backoff_time = INITIAL_BACKOFF
last_exception = None
full_response = ""
while retries < self.max_retries:
try:
logging.debug(f"Attempt {retries + 1}/{self.max_retries} to generate text stream with callback.")
stream = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
temperature=temperature,
top_p=top_p,
presence_penalty=presence_penalty,
stream=True,
timeout=self.timeout
)
chunk_iterator = iter(stream)
last_chunk_time = time.time()
try:
while True:
try:
# 检查上次接收块以来的超时
if time.time() - last_chunk_time > self.stream_chunk_timeout:
callback_fn("", is_last=True, is_timeout=True, is_error=False, error=None)
raise Timeout(f"No chunk received for {self.stream_chunk_timeout} seconds.")
chunk = next(chunk_iterator)
last_chunk_time = time.time() # 成功接收块后重置计时器
content = ""
is_last = False
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
full_response += content
if chunk.choices and chunk.choices[0].finish_reason == 'stop':
is_last = True
# 调用回调函数处理块
callback_fn(content, is_last=is_last, is_timeout=False, is_error=False, error=None)
if is_last:
logging.info("Stream with callback finished normally.")
return full_response # 成功完成,返回完整响应
except StopIteration:
# 正常结束流
callback_fn("", is_last=True, is_timeout=False, is_error=False, error=None)
logging.info("Stream iterator exhausted normally.")
return full_response
except Timeout as e:
logging.warning(f"Stream chunk timeout: {e}")
last_exception = e
# 超时信息已通过回调传递,此处不需要再次调用回调
except (APITimeoutError, APIConnectionError, RateLimitError) as e:
logging.warning(f"API error during streaming with callback: {type(e).__name__} - {e}")
callback_fn("", is_last=True, is_timeout=False, is_error=True, error=str(e))
last_exception = e
except Exception as e:
logging.error(f"Unexpected error during streaming with callback: {traceback.format_exc()}")
callback_fn("", is_last=True, is_timeout=False, is_error=True, error=str(e))
last_exception = e
# 执行重试逻辑
retries += 1
if retries < self.max_retries:
retry_msg = f"将在 {backoff_time:.2f} 秒后重试..."
logging.info(f"Retrying stream in {backoff_time} seconds...")
callback_fn(f"\n[{retry_msg}]\n", is_last=False, is_timeout=False, is_error=False, error=None)
time.sleep(backoff_time + random.uniform(0, 1)) # 添加随机抖动
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
else:
error_msg = f"API call failed after {self.max_retries} retries: {str(last_exception)}"
logging.error(error_msg)
callback_fn(f"\n[{error_msg}]\n", is_last=True, is_timeout=False, is_error=True, error=str(last_exception))
return full_response # 返回已收集的部分响应
except Exception as e:
logging.error(f"Error setting up stream with callback: {traceback.format_exc()}")
callback_fn("", is_last=True, is_timeout=False, is_error=True, error=str(e))
return f"[生成内容失败: {str(e)}]"
# 作为安全措施
error_msg = "超出最大重试次数"
logging.error(error_msg)
callback_fn(f"\n[{error_msg}]\n", is_last=True, is_timeout=False, is_error=True, error=error_msg)
return full_response
async def async_generate_text_stream(self, system_prompt, user_prompt, temperature=0.7, top_p=0.9, presence_penalty=0.0):
"""异步生成文本流
Args:
system_prompt: 系统提示词
user_prompt: 用户提示词
temperature: 温度参数
top_p: 核采样参数
presence_penalty: 存在惩罚参数
Yields:
str: 生成的文本块
Raises:
Exception: 如果API调用在所有重试后失败
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
logging.info(f"Asynchronously generating text stream with model: {self.model_name}")
retries = 0
backoff_time = INITIAL_BACKOFF
last_exception = None
while retries < self.max_retries:
try:
logging.debug(f"Async attempt {retries + 1}/{self.max_retries} to generate text stream.")
# 创建新的客户端用于异步操作
async_client = OpenAI(
api_key=self.api,
base_url=self.base_url,
timeout=self.timeout
)
stream = await async_client.chat.completions.create(
model=self.model_name,
messages=messages,
temperature=temperature,
top_p=top_p,
presence_penalty=presence_penalty,
stream=True,
timeout=self.timeout
)
last_chunk_time = time.time()
try:
async for chunk in stream:
# 检查上次接收块以来的超时
current_time = time.time()
if current_time - last_chunk_time > self.stream_chunk_timeout:
raise Timeout(f"No chunk received for {self.stream_chunk_timeout} seconds.")
last_chunk_time = current_time # 成功接收块后重置计时器
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
yield content
logging.info("Async stream finished normally.")
return # 成功完成
except AsyncTimeoutError as e:
logging.warning(f"Async stream timeout: {e}")
last_exception = e
except Timeout as e:
logging.warning(f"Async stream chunk timeout: {e}")
last_exception = e
except Exception as e:
logging.error(f"Error during async streaming: {traceback.format_exc()}")
last_exception = e
# 执行重试逻辑
retries += 1
if retries < self.max_retries:
logging.info(f"Retrying async stream in {backoff_time} seconds...")
await asyncio.sleep(backoff_time + random.uniform(0, 1)) # 使用异步睡眠
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
else:
logging.error(f"Async API call failed after {self.max_retries} retries.")
raise last_exception or Exception("Async stream generation failed after max retries.")
except (Timeout, AsyncTimeoutError, APITimeoutError, APIConnectionError, RateLimitError) as e:
retries += 1
last_exception = e
logging.warning(f"Async attempt {retries}/{self.max_retries} failed: {type(e).__name__} - {e}")
if retries < self.max_retries:
logging.info(f"Retrying async stream in {backoff_time} seconds...")
await asyncio.sleep(backoff_time + random.uniform(0, 1)) # 使用异步睡眠
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
else:
logging.error(f"Async API call failed after {self.max_retries} retries.")
raise last_exception
except Exception as e:
logging.error(f"Unexpected error setting up async stream: {traceback.format_exc()}")
raise e # 立即重新引发意外错误
# 作为安全措施
logging.error("Exited async stream generation loop unexpectedly.")
raise last_exception or Exception("Async stream generation failed.")
async def async_work_stream(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty):
"""异步完整工作流程:读取文件夹(如果提供),然后生成文本流"""
logging.info(f"Starting async 'work_stream' process. File folder: {file_folder}")
if file_folder:
logging.info(f"Reading context from folder: {file_folder}")
context = self.read_folder(file_folder)
if context:
user_prompt = f"{user_prompt.strip()}\n\n--- 参考资料 ---\n{context.strip()}"
else:
logging.warning(f"Folder {file_folder} provided but no content read.")
logging.info("Calling async_generate_text_stream...")
return self.async_generate_text_stream(system_prompt, user_prompt, temperature, top_p, presence_penalty)
def work_stream(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty):
"""完整的工作流程(流式):读取文件夹(如果提供),然后生成文本流"""
@ -333,6 +577,6 @@ class AI_Agent():
else:
logging.warning(f"Folder {file_folder} provided but no content read.")
logging.info("Calling generate_text_stream...")
return self.generate_text_stream(system_prompt, user_prompt, temperature, top_p, presence_penalty)
# --- End Added Streaming Methods ---
full_response = self.generate_text_stream(system_prompt, user_prompt, temperature, top_p, presence_penalty)
return full_response
# --- End Streaming Methods ---

View File

@ -45,6 +45,11 @@ class ContentGenerator:
self.top_p = 0.8
self.presence_penalty = 1.2
# 设置日志
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
self.logger = logging.getLogger(__name__)
def load_infomation(self, info_directory_path):
"""
@ -193,142 +198,149 @@ class ContentGenerator:
返回:
生成的海报内容
"""
if system_prompt is None:
system_prompt = """
你是一名资深海报设计师有丰富的爆款海报设计经验你现在要为旅游景点做宣传在小红书上发布大量宣传海报你的主要工作目标有2个
1你要根据我给你的图片描述和笔记推文内容设计图文匹配的海报
2为海报设计文案文案的<第一个小标题><第二个小标题>之间你需要检查是否逻辑关系合理你将通过先去生成<第二个小标题>关于景区亮点的部分再去综合判断<第一个小标题>应该如何搭配组合更符合两个小标题的逻辑再生成<第一个小标题>
full_response = ""
timeout = 60 # 请求超时时间(秒)
其中生成三类标题文案的通用性要求如下
1生成的<大标题>字数必须小于8个字符
2生成的<第一个小标题>字数和<第二个小标题>字数两者都必须小8个字符
3标题和文案都应符合中国社会主义核心价值观
if not system_prompt:
# 使用默认系统提示词
system_prompt = f"""
你是一个专业的文案处理专家擅长从文章中提取关键信息并生成吸引人的标题和简短描述
现在我需要你根据提供的文章内容生成{poster_num}个海报的文案配置
接下来先开始生成<大标题>部分由于海报是用来宣传旅游景点生成的海报<大标题>必须使用以下8种格式之一
地名景点名例如福建厦门鼓浪屿/厦门鼓浪屿
地名+景点名+plog
拿捏+地名+景点名
地名+景点名+攻略
速通+地名+景点名
推荐+地名+景点名
勇闯+地名+景点名
收藏+地名+景点名
你需要随机挑选一种格式生成对应景点的文案但是格式除了上面8种不可以有其他任何格式同时尽量保证每一种格式出现的频率均衡
接下来先去生成<第二个小标题><第二个小标题>文案的创作必须遵循以下原则
请根据笔记内容和图片识别用极简的文字概括这篇笔记和图片中景点的特色亮点其中你可以参考以下词汇进行创作这段文案字数控制6-8字符以内
每个配置包含
1. main_title主标题简短有力突出景点特点
2. texts两句简短文本每句不超过15字描述景点特色或游玩体验
特色亮点可能会出现的词汇不完全举例非遗古建绝佳山水祈福圣地研学圣地解压天堂中国小瑞士秘境竹筏游等等类型词汇
接下来再去生成<第一个小标题><第一个小标题>文案的创作必须遵循以下原则
这部分文案创作公式有5种分别为
<受众人群画像>+<痛点词>
<受众人群画像>
<痛点词>
<受众人群画像>+ | +<痛点词>
<痛点词>+ | +<受众人群画像>
请你根据实际笔记内容结合这部分文案创作公式需要结合<受众人群画像><痛点词>必须根据<第二个小标题>的景点特征和所对应的完整笔记推文内容主旨特征挑选对应<受众人群画像><痛点词>
我给你提供受众人群画像库和痛点词库如下
1受众人群画像库情侣党亲子游合家游银发族亲子研学学生党打工人周边游本地人穷游党性价比户外人美食党出片
2痛点词库3天2夜必去看了都哭了不能错过一定要来问爆了超全攻略必打卡强推懒人攻略必游榜小众打卡狂喜等等
你需要为每个请求至少生成{poster_num}个海报设计请使用JSON格式输出结果结构如下
```json
以JSON数组格式返回配置示例
[
{
"index": 1,
"main_title": "主标题内容",
"texts": ["第一个小标题", "第二个小标题"]
},
{
"index": 2,
"main_title": "主标题内容",
"texts": ["第一个小标题", "第二个小标题"]
}
// ... 更多海报
{{
"main_title": "泰宁古城",
"texts": ["千年古韵","匠心独运"]
}},
...
]
```
确保生成的数量与用户要求的数量一致只生成上述JSON格式内容不要有其他任何额外内容
仅返回JSON数据不需要任何额外解释确保生成的标题和文本能够准确反映文章提到的景点特色
"""
user_content = f"""
海报数量{poster_num};
景区介绍{self.add_description};
推文内容{tweet_content};
"""
if self.add_description:
# 创建用户内容包括info信息和tweet_content
user_content = f"""
以下是需要你处理的信息
# 最终响应内容
full_response = ""
关于景点的描述:
{self.add_description}
推文内容:
{tweet_content}
请根据这些信息生成{poster_num}个海报文案配置以JSON数组格式返回
确保主标题(main_title)简短有力每个text不超过15字并能准确反映景点特色
"""
else:
# 仅使用tweet_content
user_content = f"""
以下是需要你处理的推文内容:
{tweet_content}
请根据这些信息生成{poster_num}个海报文案配置以JSON数组格式返回
确保主标题(main_title)简短有力每个text不超过15字并能准确反映景点特色
"""
self.logger.info(f"正在生成{poster_num}个海报文案配置")
# 创建临时客户端
temp_client = self._create_temp_client()
# 如果创建客户端失败,直接使用备用方案
if temp_client is None:
print("创建OpenAI客户端失败使用备用方案生成内容")
return 404
else:
# 添加重试机制
if temp_client:
# 重试逻辑
for retry in range(max_retries):
try:
print(f"尝试连接API (尝试 {retry+1}/{max_retries})...")
self.logger.info(f"尝试生成内容 (尝试 {retry+1}/{max_retries})")
# 计算退避时间指数退避策略0, 2, 4, 8, 16...秒
if retry > 0:
backoff_time = min(2 ** (retry - 1) * 2, 30) # 最大等待30秒
print(f"等待 {backoff_time} 秒后重试...")
time.sleep(backoff_time)
# 定义流式响应处理回调函数
def handle_stream_chunk(chunk, is_last=False, is_timeout=False, is_error=False, error=None):
nonlocal full_response
# 设置超时时间随重试次数递增
timeout = 30 + (retry * 30) # 30, 60, 90, ...秒
if chunk:
full_response += chunk
# 实时输出到控制台
print(chunk, end="", flush=True)
chat_response = temp_client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_content}
],
stream=True,
temperature=self.temperature,
top_p=self.top_p,
presence_penalty=self.presence_penalty,
timeout=timeout # 设置请求超时时间
if is_last:
print("\n") # 输出完成后换行
if is_timeout:
print("警告: 响应流超时")
if is_error:
print(f"错误: {error}")
# 使用AI_Agent的新回调方式
from core.ai_agent import AI_Agent
ai_agent = AI_Agent(
self.api_base_url,
self.model_name,
self.api_key,
timeout=timeout,
max_retries=max_retries,
stream_chunk_timeout=30 # 流式块超时时间
)
# 获取响应内容
for chunk in chat_response:
if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None:
content = chunk.choices[0].delta.content
full_response += content
print(content, end="", flush=True)
if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].finish_reason == "stop":
break
# 使用回调方式处理流式响应
try:
full_response = ai_agent.generate_text_stream_with_callback(
system_prompt,
user_content,
handle_stream_chunk,
temperature=self.temperature,
top_p=self.top_p,
presence_penalty=self.presence_penalty
)
print("\n") # 输出完成后换行
# 如果成功生成内容,跳出重试循环
ai_agent.close()
break
# 成功获取响应,跳出重试循环
break
except Exception as e:
error_msg = str(e)
self.logger.error(f"AI生成错误: {error_msg}")
ai_agent.close()
# 继续重试逻辑
if retry + 1 >= max_retries:
self.logger.warning("已达到最大重试次数,使用备用方案...")
# 生成备用内容
full_response = self._generate_fallback_content(poster_num)
else:
self.logger.info(f"将在稍后重试,还剩 {max_retries - retry - 1} 次重试机会")
except Exception as e:
error_msg = str(e)
print(f"API连接错误 (尝试 {retry+1}/{max_retries}): {error_msg}")
self.logger.error(f"API连接错误 (尝试 {retry+1}/{max_retries}): {error_msg}")
# 如果已经达到最大重试次数
if retry + 1 >= max_retries:
print("已达到最大重试次数,使用备用方案...")
self.logger.warning("已达到最大重试次数,使用备用方案...")
# 生成备用内容(简单模板)
full_response = self._generate_fallback_content(poster_num)
else:
print(f"将在稍后重试,还剩 {max_retries - retry - 1} 次重试机会")
self.logger.info(f"将在稍后重试,还剩 {max_retries - retry - 1} 次重试机会")
# 关闭临时客户端
self._close_client(temp_client)
# 生成时间戳
return full_response
def _generate_fallback_content(self, poster_num):
"""生成备用内容当API调用失败时使用"""
self.logger.info("生成备用内容")
default_configs = []
for i in range(poster_num):
default_configs.append({
"main_title": f"景点风光 {i+1}",
"texts": ["自然美景", "人文体验"]
})
return json.dumps(default_configs, ensure_ascii=False)
def save_result(self, full_response):
"""
保存生成结果到文件

209
docs/streaming.md Normal file
View File

@ -0,0 +1,209 @@
# 流式输出处理功能
TravelContentCreator 现在支持三种不同的流式输出处理方法,让您能够更灵活地处理 AI 模型生成的文本内容。这些方法都在 `AI_Agent` 类中实现,可以根据不同的使用场景进行选择。
## 为什么需要流式处理?
流式处理Streaming相比于传统的一次性返回完整响应的方式有以下优势
1. **实时性**:内容生成的同时即可开始处理,无需等待完整响应
2. **用户体验更好**:可以实现"打字机效果",让用户看到文本逐步生成的过程
3. **更早检测错误**:可以在响应生成过程中及早发现问题
4. **长文本处理更高效**:特别适合生成较长的内容,避免长时间等待
## 流式处理方法
`AI_Agent` 类提供了三种不同模式的流式处理方法:
### 1. 同步流式响应 (generate_text_stream)
这种方法虽然使用了流式 API 连接,但会将所有的输出整合后一次性返回,适合简单的 API 调用。
```python
def generate_text_stream(self, system_prompt, user_prompt, temperature, top_p, presence_penalty):
"""
生成文本内容使用流式API但返回完整响应
Args:
system_prompt: 系统提示词
user_prompt: 用户提示词
temperature: 温度参数
top_p: 核采样参数
presence_penalty: 存在惩罚参数
Returns:
str: 完整的生成文本
"""
```
使用示例:
```python
agent = AI_Agent(base_url, model_name, api_key, timeout=30, stream_chunk_timeout=10)
result = agent.generate_text_stream(system_prompt, user_prompt, 0.7, 0.9, 0.0)
print(result) # 输出完整的生成结果
```
### 2. 回调式流式响应 (generate_text_stream_with_callback)
这种方法使用回调函数来处理流中的每个文本块,非常适合实时显示、分析或保存过程数据,更加灵活。
```python
def generate_text_stream_with_callback(self, system_prompt, user_prompt,
callback_fn, temperature=0.7, top_p=0.9,
presence_penalty=0.0):
"""
生成文本流并通过回调函数处理每个块
Args:
system_prompt: 系统提示词
user_prompt: 用户提示词
callback_fn: 处理每个文本块的回调函数,接收(content, is_last, is_timeout, is_error, error)参数
temperature: 温度参数
top_p: 核采样参数
presence_penalty: 存在惩罚参数
Returns:
str: 完整的响应文本
"""
```
回调函数应符合以下格式:
```python
def my_callback(content, is_last=False, is_timeout=False, is_error=False, error=None):
"""
处理流式响应的回调函数
Args:
content: 文本块内容
is_last: 是否为最后一个块
is_timeout: 是否发生超时
is_error: 是否发生错误
error: 错误信息
"""
if content:
print(content, end="", flush=True) # 实时打印
# 处理特殊情况
if is_last:
print("\n完成生成")
if is_timeout:
print("警告: 响应流超时")
if is_error:
print(f"错误: {error}")
```
使用示例:
```python
agent = AI_Agent(base_url, model_name, api_key, timeout=30, stream_chunk_timeout=10)
result = agent.generate_text_stream_with_callback(
system_prompt,
user_prompt,
my_callback, # 传入回调函数
temperature=0.7,
top_p=0.9,
presence_penalty=0.0
)
```
### 3. 异步流式响应 (async_generate_text_stream)
这种方法基于 `asyncio`,返回一个异步生成器,非常适合与其他异步操作集成,例如在异步网络应用中使用。
```python
async def async_generate_text_stream(self, system_prompt, user_prompt, temperature=0.7, top_p=0.9, presence_penalty=0.0):
"""
异步生成文本流
Args:
system_prompt: 系统提示词
user_prompt: 用户提示词
temperature: 温度参数
top_p: 核采样参数
presence_penalty: 存在惩罚参数
Yields:
str: 生成的文本块
Raises:
Exception: 如果API调用在所有重试后失败
"""
```
使用示例:
```python
async def demo_async_stream():
agent = AI_Agent(base_url, model_name, api_key, timeout=30, stream_chunk_timeout=10)
full_response = ""
try:
# 使用异步生成器
async for chunk in agent.async_generate_text_stream(
system_prompt,
user_prompt,
temperature=0.7,
top_p=0.9,
presence_penalty=0.0
):
print(chunk, end="", flush=True) # 实时显示
full_response += chunk
# 这里可以同时执行其他异步操作
# await some_other_async_operation()
except Exception as e:
print(f"错误: {e}")
finally:
agent.close()
# 在异步环境中运行
asyncio.run(demo_async_stream())
```
## 超时处理
所有流式处理方法都支持两种超时设置:
1. **全局请求超时**控制整个API请求的最大持续时间
2. **流块超时**:控制接收连续两个数据块之间的最大等待时间
在创建 `AI_Agent` 实例时设置:
```python
agent = AI_Agent(
base_url="your_api_base_url",
model_name="your_model_name",
api="your_api_key",
timeout=60, # 整体请求超时(秒)
max_retries=3, # 最大重试次数
stream_chunk_timeout=10 # 流块超时(秒)
)
```
## 完整示例
我们提供了一个完整的示例脚本,演示所有三种流式处理方法的使用:
```
examples/test_stream.py
```
运行方式:
```bash
cd TravelContentCreator
python examples/test_stream.py
```
这将依次演示三种流式处理方法,并展示它们的输出和性能差异。
## 在WebUI中的应用
TravelContentCreator的WebUI已经集成了基于回调的流式处理实现了生成内容的实时显示大大提升了用户体验特别是在生成长篇内容时。
## 在自定义项目中使用
如果您想在自己的项目中使用这些流式处理功能,只需导入 `AI_Agent` 类并按照上述示例使用相应的方法即可。所有流式处理方法都内置了完善的错误处理和重试机制,提高了生产环境中的稳定性。

View File

@ -156,6 +156,31 @@ python examples/test_stream.py
- 演示实时输出文本块的处理方式
- 说明如何收集完整响应和估计token数
## 流式处理演示 (test_stream.py)
`test_stream.py` 是一个演示脚本展示了AI_Agent类中新增的三种不同流式输出处理方法
1. **同步流式响应** (generate_text_stream) - 使用流式API但返回完整响应
2. **回调式流式响应** (generate_text_stream_with_callback) - 通过回调函数处理每个文本块
3. **异步流式响应** (async_generate_text_stream) - 使用异步生成器返回文本流
### 运行演示
```bash
cd TravelContentCreator
python examples/test_stream.py
```
演示脚本会依次执行这三种方法,并展示它们的输出和性能差异。
### 更多信息
关于流式处理功能的详细文档,请参阅:
```
docs/streaming.md
```
## 其他示例
`generate_poster.py` 是一个简化的海报生成示例,主要用于快速测试特定图片的海报效果。

View File

@ -0,0 +1,103 @@
# 流式处理功能说明
本文档介绍了Travel Content Creator中新增的三种流式输出处理方式以及如何使用它们来优化内容生成体验。
## 新增的流式处理方式
我们在`AI_Agent`类中实现了三种不同的流式处理方式:
1. **同步流式响应** (`generate_text_stream`)
- 已修改为返回完整响应,不再是生成器
- 内部仍使用流式请求以获得更好的超时控制
- 适用于简单的API调用场景
2. **基于回调的流式响应** (`generate_text_stream_with_callback`)
- 通过回调函数处理每个文本块
- 可在回调中实现实时显示、分析或保存
- 最灵活的选项,适合需要自定义处理流程的场景
3. **异步流式响应** (`async_generate_text_stream`)
- 基于`asyncio`的异步生成器
- 适用于需要与其他异步操作集成的场景
- 可在保持响应性的同时处理长时间运行的请求
## 演示脚本
我们提供了一个示例脚本`test_stream.py`,演示了如何使用这三种流式处理方式:
```bash
# 从项目根目录运行
python examples/test_stream.py
```
## 回调函数示例
以下是使用回调函数处理流式响应的例子:
```python
def handle_chunk(chunk, is_last=False, is_timeout=False, is_error=False, error=None):
# 处理文本块
if chunk:
print(chunk, end="", flush=True) # 实时显示
# 处理结束状态
if is_last:
if is_timeout:
print("\n流式响应超时")
if is_error:
print(f"\n发生错误: {error}")
# 使用回调函数进行流式处理
response = agent.generate_text_stream_with_callback(
system_prompt,
user_prompt,
handle_chunk, # 传入回调函数
temperature=0.7
)
```
## 异步流式处理示例
以下是使用异步方式处理流式响应的例子:
```python
async def process_stream():
async for chunk in agent.async_generate_text_stream(
system_prompt,
user_prompt,
temperature=0.7
):
print(chunk, end="", flush=True) # 实时显示
# 可以同时执行其他异步操作
await other_async_task()
# 在异步环境中运行
asyncio.run(process_stream())
```
## 超时处理
所有流式处理方法都支持超时控制:
1. **全局请求超时**:控制整个请求的最长等待时间
2. **流式块超时**:控制两个连续文本块之间的最长等待时间
这些参数可以在创建`AI_Agent`实例时设置:
```python
agent = AI_Agent(
api_url="http://localhost:8000/v1/",
model="qwen",
api_key="EMPTY",
timeout=30, # 全局请求超时(秒)
stream_chunk_timeout=10 # 流式块超时(秒)
)
```
## 注意事项
1. 新的流式处理方法内置了重试机制和错误处理
2. 回调方式提供了最佳的灵活性和控制
3. 异步方式最适合需要保持UI响应性的应用
4. 所有方法都会在完成时自动关闭流式请求

View File

@ -1,129 +1,175 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
测试AI_Agent的流式处理方法
此脚本演示TravelContentCreator中AI_Agent类的三种流式输出处理方法:
1. 同步流式响应 (generate_text_stream)
2. 回调式流式响应 (generate_text_stream_with_callback)
3. 异步流式响应 (async_generate_text_stream)
"""
import os
import sys
import json
import asyncio
import time
import logging
from pathlib import Path
# Determine the project root directory (assuming examples/ is one level down)
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if PROJECT_ROOT not in sys.path:
sys.path.append(PROJECT_ROOT)
# 添加项目根目录到Python路径
project_root = str(Path(__file__).parent.parent)
if project_root not in sys.path:
sys.path.insert(0, project_root)
# Now import from core
try:
from core.ai_agent import AI_Agent
except ImportError as e:
logging.critical(f"Failed to import AI_Agent. Ensure '{PROJECT_ROOT}' is in sys.path and core/ai_agent.py exists. Error: {e}")
sys.exit(1)
from core.ai_agent import AI_Agent
def load_config(config_path):
"""Loads configuration from a JSON file."""
try:
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
logging.info(f"Config loaded successfully from {config_path}")
return config
except FileNotFoundError:
logging.error(f"Error: Configuration file not found at {config_path}")
return None
except json.JSONDecodeError:
logging.error(f"Error: Could not decode JSON from {config_path}")
return None
except Exception as e:
logging.exception(f"An unexpected error occurred loading config {config_path}:")
return None
# 示例提示词
SYSTEM_PROMPT = """你是一个专业的旅游内容创作助手,请根据用户的提示生成相关内容。"""
USER_PROMPT = """请为我生成一篇关于福建泰宁古城的旅游攻略包括著名景点、美食推荐和最佳游玩季节。字数控制在300字以内。"""
def main():
# --- Basic Logging Setup ---
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
def print_separator(title):
"""打印分隔线和标题"""
print("\n" + "="*50)
print(f" {title} ".center(50, "="))
print("="*50 + "\n")
def demo_sync_stream():
"""演示同步流式响应方法"""
print_separator("同步流式响应 (generate_text_stream)")
# 创建AI_Agent实例
agent = AI_Agent(
base_url="vllm", # 使用本地vLLM服务
model_name="qwen2-7b-instruct", # 或其他您配置的模型名称
api="EMPTY", # vLLM不需要API key
timeout=60, # 整体请求超时时间(秒)
stream_chunk_timeout=10 # 流式块超时时间(秒)
)
# --- End Logging Setup ---
logging.info("Starting AI Agent Stream Test...")
print("开始生成内容...")
start_time = time.time()
# Load configuration (adjust path relative to this script)
config_path = os.path.join(PROJECT_ROOT, "poster_gen_config.json")
config = load_config(config_path)
if config is None:
logging.critical("Failed to load configuration. Exiting test.")
sys.exit(1)
# 使用同步流式方法
result = agent.generate_text_stream(
SYSTEM_PROMPT,
USER_PROMPT,
temperature=0.7,
top_p=0.9,
presence_penalty=0.0
)
# Example Prompts
system_prompt = "你是一个乐于助人的AI助手擅长写短篇故事。"
user_prompt = "请写一个关于旅行机器人的短篇故事,它在一个充满异国情调的星球上发现了新的生命形式。"
end_time = time.time()
ai_agent = None
print(f"\n\n完整生成内容:\n{result}")
print(f"\n生成完成! 耗时: {end_time - start_time:.2f}")
# 关闭agent
agent.close()
def demo_callback_stream():
"""演示回调式流式响应方法"""
print_separator("回调式流式响应 (generate_text_stream_with_callback)")
# 创建AI_Agent实例
agent = AI_Agent(
base_url="vllm",
model_name="qwen2-7b-instruct",
api="EMPTY",
timeout=60,
stream_chunk_timeout=10
)
# 定义回调函数
def my_callback(content, is_last=False, is_timeout=False, is_error=False, error=None):
"""处理流式响应的回调函数"""
if content:
# 实时打印内容,不换行
print(content, end="", flush=True)
if is_last:
print("\n")
if is_timeout:
print("警告: 响应流超时")
if is_error:
print(f"错误: {error}")
print("开始生成内容...")
start_time = time.time()
# 使用回调式流式方法
result = agent.generate_text_stream_with_callback(
SYSTEM_PROMPT,
USER_PROMPT,
my_callback,
temperature=0.7,
top_p=0.9,
presence_penalty=0.0
)
end_time = time.time()
print(f"\n生成完成! 耗时: {end_time - start_time:.2f}")
# 关闭agent
agent.close()
async def demo_async_stream():
"""演示异步流式响应方法"""
print_separator("异步流式响应 (async_generate_text_stream)")
# 创建AI_Agent实例
agent = AI_Agent(
base_url="vllm",
model_name="qwen2-7b-instruct",
api="EMPTY",
timeout=60,
stream_chunk_timeout=10
)
print("开始生成内容...")
start_time = time.time()
full_response = ""
# 使用异步流式方法
try:
# --- Extract AI Agent parameters from config ---
ai_api_url = config.get("api_url")
ai_model = config.get("model")
ai_api_key = config.get("api_key")
request_timeout = config.get("request_timeout", 30)
max_retries = config.get("max_retries", 3)
stream_chunk_timeout = config.get("stream_chunk_timeout", 60) # Get stream chunk timeout
# Check for required AI params
if not all([ai_api_url, ai_model, ai_api_key]):
logging.critical("Missing required AI configuration (api_url, model, api_key) in config. Exiting test.")
sys.exit(1)
# --- End Extract AI Agent params ---
logging.info("Initializing AI Agent for stream test...")
# Initialize AI_Agent using extracted parameters
ai_agent = AI_Agent(
api_url=ai_api_url, # Use extracted var
model=ai_model, # Use extracted var
api_key=ai_api_key, # Use extracted var
timeout=request_timeout,
max_retries=max_retries,
stream_chunk_timeout=stream_chunk_timeout # Pass it here
async_stream = agent.async_generate_text_stream(
SYSTEM_PROMPT,
USER_PROMPT,
temperature=0.7,
top_p=0.9,
presence_penalty=0.0
)
# Example call to work_stream
logging.info("Calling ai_agent.work_stream...")
# Extract generation parameters from config
temperature = config.get("content_temperature", 0.7) # Use a relevant temperature setting
top_p = config.get("content_top_p", 0.9)
presence_penalty = config.get("content_presence_penalty", 0.0)
# 异步迭代流
async for content in async_stream:
# 累积完整响应
full_response += content
# 实时打印内容
print(content, end="", flush=True)
start_time = time.time()
stream_generator = ai_agent.work_stream(
system_prompt=system_prompt,
user_prompt=user_prompt,
info_directory=None, # No extra context folder for this test
temperature=temperature,
top_p=top_p,
presence_penalty=presence_penalty
)
# Process the stream
logging.info("Processing stream response:")
full_response = ""
for chunk in stream_generator:
print(chunk, end="", flush=True) # Keep print for stream output
full_response += chunk
end_time = time.time()
logging.info(f"\n--- Stream Finished ---")
logging.info(f"Total time: {end_time - start_time:.2f} seconds")
logging.info(f"Total characters received: {len(full_response)}")
except KeyError as e:
logging.error(f"Configuration error: Missing key '{e}'. Please check '{config_path}'.")
except Exception as e:
logging.exception("An error occurred during the stream test:")
finally:
# Ensure the agent is closed
if ai_agent:
logging.info("Closing AI Agent...")
ai_agent.close()
logging.info("AI Agent closed.")
print(f"\n生成过程中出错: {e}")
end_time = time.time()
print(f"\n\n生成完成! 耗时: {end_time - start_time:.2f}")
# 关闭agent
agent.close()
async def main():
"""主函数"""
print("Testing AI_Agent streaming methods...")
# 1. 测试同步流式响应
demo_sync_stream()
# 2. 测试回调式流式响应
demo_callback_stream()
# 3. 测试异步流式响应
await demo_async_stream()
print("\n所有测试完成!")
if __name__ == "__main__":
main()
# 运行异步主函数
asyncio.run(main())