修改了异步程序
This commit is contained in:
parent
202ca01316
commit
ebf715dfcb
568
core/ai_agent.py
568
core/ai_agent.py
@ -223,20 +223,169 @@ class AI_Agent():
|
||||
logging.error(f"Error during AI Agent close: {e}")
|
||||
|
||||
# --- Streaming Methods ---
|
||||
def generate_text_stream(self, system_prompt, user_prompt, temperature, top_p, presence_penalty):
|
||||
def generate_text_stream(self, system_prompt, user_prompt, temperature=0.7, top_p=0.9, presence_penalty=0.0):
|
||||
"""
|
||||
Generates text based on prompts using a streaming connection. Handles retries with exponential backoff.
|
||||
生成文本流,但返回的是完整的响应而非生成器。
|
||||
此方法适用于简单的API调用,不需要实时处理响应的场景。
|
||||
|
||||
Args:
|
||||
system_prompt: 系统提示词
|
||||
user_prompt: 用户提示词
|
||||
temperature: 温度参数
|
||||
top_p: 核采样参数
|
||||
presence_penalty: 存在惩罚参数
|
||||
|
||||
Returns:
|
||||
str: 生成的完整文本响应
|
||||
|
||||
Raises:
|
||||
Exception: 如果API调用在所有重试后仍然失败
|
||||
"""
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
logging.info(f"Generating text stream using model: {self.model_name}")
|
||||
|
||||
retries = 0
|
||||
backoff_time = INITIAL_BACKOFF
|
||||
last_exception = None
|
||||
global_timeout = self.timeout * 2 # 设置全局超时,是API超时的两倍
|
||||
partial_response = "" # 用于存储部分响应
|
||||
|
||||
while retries < self.max_retries:
|
||||
try:
|
||||
logging.debug(f"Attempt {retries + 1}/{self.max_retries} to generate text stream.")
|
||||
client = OpenAI(
|
||||
api_key=self.api,
|
||||
base_url=self.base_url,
|
||||
timeout=self.timeout,
|
||||
max_retries=0 # 禁用内置重试,使用我们自己的重试逻辑
|
||||
)
|
||||
|
||||
stream = client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
presence_penalty=presence_penalty,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
response_text = ""
|
||||
stream_start_time = time.time()
|
||||
last_chunk_time = time.time()
|
||||
|
||||
try:
|
||||
for chunk in stream:
|
||||
# 检查全局超时
|
||||
current_time = time.time()
|
||||
if current_time - stream_start_time > global_timeout:
|
||||
logging.warning(f"Global timeout reached after {global_timeout} seconds.")
|
||||
if response_text:
|
||||
response_text += "\n\n[注意: 由于全局超时,内容可能不完整]"
|
||||
partial_response = response_text
|
||||
return partial_response
|
||||
else:
|
||||
logging.error("Global timeout with no content received.")
|
||||
raise Timeout(f"Global timeout after {global_timeout} seconds with no content.")
|
||||
|
||||
# 检查流块超时
|
||||
if current_time - last_chunk_time > self.stream_chunk_timeout:
|
||||
logging.warning(f"Stream chunk timeout: No chunk received for {self.stream_chunk_timeout} seconds.")
|
||||
if response_text:
|
||||
response_text += "\n\n[注意: 由于流式传输超时,内容可能不完整]"
|
||||
partial_response = response_text
|
||||
return partial_response
|
||||
raise Timeout(f"No chunk received for {self.stream_chunk_timeout} seconds.")
|
||||
|
||||
last_chunk_time = current_time # 更新最后一次接收块的时间
|
||||
|
||||
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
|
||||
content = chunk.choices[0].delta.content
|
||||
response_text += content
|
||||
|
||||
logging.info("Stream completed successfully.")
|
||||
return response_text # 成功完成
|
||||
|
||||
except Timeout as e:
|
||||
logging.warning(f"Stream chunk timeout: {e}. Retrying if possible ({retries + 1}/{self.max_retries}).")
|
||||
if response_text:
|
||||
partial_response = response_text
|
||||
last_exception = e
|
||||
except (APITimeoutError, APIConnectionError, RateLimitError) as e:
|
||||
logging.warning(f"API error during streaming: {type(e).__name__} - {e}. Retrying if possible.")
|
||||
if response_text:
|
||||
partial_response = response_text
|
||||
last_exception = e
|
||||
except Exception as e:
|
||||
logging.error(f"Unexpected error during streaming: {traceback.format_exc()}")
|
||||
if response_text:
|
||||
partial_response = response_text
|
||||
response_text += f"\n\n[注意: 由于错误中断,内容可能不完整: {str(e)}]"
|
||||
return response_text
|
||||
last_exception = e
|
||||
|
||||
# 执行重试逻辑
|
||||
retries += 1
|
||||
if retries < self.max_retries:
|
||||
logging.info(f"Retrying stream in {backoff_time} seconds...")
|
||||
time.sleep(backoff_time + random.uniform(0, 1)) # 添加随机抖动
|
||||
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
|
||||
else:
|
||||
logging.error(f"Stream generation failed after {self.max_retries} retries.")
|
||||
# 如果已经获取了部分响应,返回并标记不完整
|
||||
if partial_response:
|
||||
logging.info(f"Returning partial response of length {len(partial_response)}")
|
||||
return partial_response + "\n\n[注意: 达到最大重试次数,内容可能不完整]"
|
||||
# 如果没有获取到任何内容,抛出异常
|
||||
raise last_exception or Exception("Stream generation failed after max retries with no content.")
|
||||
|
||||
except (Timeout, APITimeoutError, APIConnectionError, RateLimitError) as e:
|
||||
retries += 1
|
||||
last_exception = e
|
||||
logging.warning(f"Attempt {retries}/{self.max_retries} failed: {type(e).__name__} - {e}")
|
||||
if retries < self.max_retries:
|
||||
logging.info(f"Retrying in {backoff_time} seconds...")
|
||||
time.sleep(backoff_time + random.uniform(0, 1))
|
||||
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
|
||||
else:
|
||||
logging.error(f"API call failed after {self.max_retries} retries.")
|
||||
# 如果已经获取了部分响应,返回并标记不完整
|
||||
if partial_response:
|
||||
return partial_response + f"\n\n[注意: API调用失败,内容可能不完整: {str(e)}]"
|
||||
return f"[生成内容失败: {str(last_exception)}]"
|
||||
except Exception as e:
|
||||
logging.error(f"Unexpected error setting up stream: {traceback.format_exc()}")
|
||||
# 如果已经获取了部分响应,返回并标记不完整
|
||||
if partial_response:
|
||||
return partial_response + f"\n\n[注意: 意外错误,内容可能不完整: {str(e)}]"
|
||||
return f"[生成内容失败: {str(e)}]"
|
||||
|
||||
# 作为安全措施,虽然通常不应该到达此处
|
||||
logging.error("Exited stream generation loop unexpectedly.")
|
||||
# 如果已经获取了部分响应,返回并标记不完整
|
||||
if partial_response:
|
||||
return partial_response + "\n\n[注意: 意外退出流处理,内容可能不完整]"
|
||||
return "[生成内容失败: 超过最大重试次数]"
|
||||
|
||||
def generate_text_stream_with_callback(self, system_prompt, user_prompt, temperature, top_p, presence_penalty,
|
||||
callback, accumulate=False):
|
||||
"""
|
||||
Generates text based on prompts using a streaming connection with a callback.
|
||||
|
||||
Args:
|
||||
system_prompt: The system prompt for the AI.
|
||||
user_prompt: The user prompt for the AI.
|
||||
temperature: Sampling temperature.
|
||||
top_p: Nucleus sampling parameter.
|
||||
presence_penalty: Presence penalty parameter.
|
||||
|
||||
callback: Function to call with each chunk of text (callback(chunk, accumulated_response))
|
||||
accumulate: If True, pass the accumulated response to the callback function.
|
||||
|
||||
Returns:
|
||||
str: The complete generated text.
|
||||
|
||||
|
||||
Raises:
|
||||
Exception: If the API call fails after all retries.
|
||||
"""
|
||||
@ -244,123 +393,13 @@ class AI_Agent():
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
logging.info(f"Generating text stream with model: {self.model_name}")
|
||||
|
||||
retries = 0
|
||||
backoff_time = INITIAL_BACKOFF
|
||||
last_exception = None
|
||||
full_response = ""
|
||||
|
||||
while retries < self.max_retries:
|
||||
try:
|
||||
logging.debug(f"Attempt {retries + 1}/{self.max_retries} to generate text stream.")
|
||||
stream = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
presence_penalty=presence_penalty,
|
||||
stream=True,
|
||||
timeout=self.timeout # Overall request timeout
|
||||
)
|
||||
|
||||
chunk_iterator = iter(stream)
|
||||
last_chunk_time = time.time()
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Check for timeout since last received chunk
|
||||
if time.time() - last_chunk_time > self.stream_chunk_timeout:
|
||||
raise Timeout(f"No chunk received for {self.stream_chunk_timeout} seconds.")
|
||||
|
||||
chunk = next(chunk_iterator)
|
||||
last_chunk_time = time.time() # Reset timer on successful chunk receipt
|
||||
|
||||
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
|
||||
content = chunk.choices[0].delta.content
|
||||
full_response += content
|
||||
# logging.debug(f"Received chunk: {content}") # Potentially very verbose
|
||||
elif chunk.choices and chunk.choices[0].finish_reason == 'stop':
|
||||
logging.info("Stream finished.")
|
||||
return full_response # Return complete response
|
||||
# Handle other finish reasons if needed, e.g., 'length'
|
||||
|
||||
except StopIteration:
|
||||
logging.info("Stream iterator exhausted.")
|
||||
return full_response # Return complete response
|
||||
except Timeout as e:
|
||||
logging.warning(f"Stream chunk timeout: {e}. Retrying if possible ({retries + 1}/{self.max_retries}).")
|
||||
last_exception = e
|
||||
break # Break inner loop to retry the stream creation
|
||||
except (APITimeoutError, APIConnectionError, RateLimitError) as e:
|
||||
logging.warning(f"API error during streaming: {type(e).__name__} - {e}. Retrying if possible ({retries + 1}/{self.max_retries}).")
|
||||
last_exception = e
|
||||
break # Break inner loop to retry the stream creation
|
||||
except Exception as e:
|
||||
logging.error(f"Unexpected error during streaming: {traceback.format_exc()}")
|
||||
# Decide if this unexpected error should be retried or raised immediately
|
||||
last_exception = e
|
||||
# Option 1: Raise immediately
|
||||
# raise e
|
||||
# Option 2: Treat as retryable (use with caution)
|
||||
break # Break inner loop to retry
|
||||
|
||||
# If we broke from the inner loop due to an error that needs retry
|
||||
retries += 1
|
||||
if retries < self.max_retries:
|
||||
logging.info(f"Retrying stream in {backoff_time} seconds...")
|
||||
time.sleep(backoff_time + random.uniform(0, 1)) # Add jitter
|
||||
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
|
||||
else:
|
||||
logging.error(f"Stream generation failed after {self.max_retries} retries.")
|
||||
raise last_exception or Exception("Stream generation failed after max retries.")
|
||||
|
||||
except (Timeout, APITimeoutError, APIConnectionError, RateLimitError) as e:
|
||||
retries += 1
|
||||
last_exception = e
|
||||
logging.warning(f"Attempt {retries}/{self.max_retries} failed: {type(e).__name__} - {e}")
|
||||
if retries < self.max_retries:
|
||||
logging.info(f"Retrying in {backoff_time} seconds...")
|
||||
time.sleep(backoff_time + random.uniform(0, 1)) # Add jitter
|
||||
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
|
||||
else:
|
||||
logging.error(f"API call failed after {self.max_retries} retries.")
|
||||
return f"[生成内容失败: {str(last_exception)}]" # Return error message as string
|
||||
except Exception as e:
|
||||
# Catch unexpected errors during stream setup
|
||||
logging.error(f"Unexpected error setting up stream: {traceback.format_exc()}")
|
||||
return f"[生成内容失败: {str(e)}]" # Return error message as string
|
||||
|
||||
# Should not be reached if logic is correct, but as a safeguard:
|
||||
logging.error("Exited stream generation loop unexpectedly.")
|
||||
return f"[生成内容失败: 超过最大重试次数]" # Return error message as string
|
||||
|
||||
def generate_text_stream_with_callback(self, system_prompt, user_prompt,
|
||||
callback_fn, temperature=0.7, top_p=0.9,
|
||||
presence_penalty=0.0):
|
||||
"""生成文本流并通过回调函数处理每个块
|
||||
|
||||
Args:
|
||||
system_prompt: 系统提示词
|
||||
user_prompt: 用户提示词
|
||||
callback_fn: 处理每个文本块的回调函数,接收(content, is_last, is_timeout, is_error, error)参数
|
||||
temperature: 温度参数
|
||||
top_p: 核采样参数
|
||||
presence_penalty: 存在惩罚参数
|
||||
|
||||
Returns:
|
||||
str: 完整的响应文本
|
||||
"""
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
logging.info(f"Generating text stream with callback using model: {self.model_name}")
|
||||
logging.info(f"Generating text stream with callback, model: {self.model_name}")
|
||||
|
||||
retries = 0
|
||||
backoff_time = INITIAL_BACKOFF
|
||||
last_exception = None
|
||||
full_response = ""
|
||||
global_timeout = self.timeout * 2 # 设置全局超时,是API超时的两倍
|
||||
|
||||
while retries < self.max_retries:
|
||||
try:
|
||||
@ -377,81 +416,127 @@ class AI_Agent():
|
||||
|
||||
chunk_iterator = iter(stream)
|
||||
last_chunk_time = time.time()
|
||||
start_time = time.time() # 记录开始时间用于全局超时检查
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
# 检查上次接收块以来的超时
|
||||
if time.time() - last_chunk_time > self.stream_chunk_timeout:
|
||||
callback_fn("", is_last=True, is_timeout=True, is_error=False, error=None)
|
||||
raise Timeout(f"No chunk received for {self.stream_chunk_timeout} seconds.")
|
||||
|
||||
chunk = next(chunk_iterator)
|
||||
last_chunk_time = time.time() # 成功接收块后重置计时器
|
||||
|
||||
content = ""
|
||||
is_last = False
|
||||
|
||||
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
|
||||
content = chunk.choices[0].delta.content
|
||||
full_response += content
|
||||
|
||||
if chunk.choices and chunk.choices[0].finish_reason == 'stop':
|
||||
is_last = True
|
||||
while True:
|
||||
try:
|
||||
# 检查全局超时
|
||||
current_time = time.time()
|
||||
if current_time - start_time > global_timeout:
|
||||
logging.warning(f"Global timeout reached after {global_timeout} seconds.")
|
||||
if full_response:
|
||||
logging.info(f"Returning partial response of length {len(full_response)} due to global timeout.")
|
||||
# 通知回调函数超时情况
|
||||
timeout_msg = "\n\n[注意: 由于全局超时,内容可能不完整]"
|
||||
callback(timeout_msg, full_response + timeout_msg if accumulate else None)
|
||||
return full_response + timeout_msg
|
||||
else:
|
||||
logging.error("Global timeout with no content received.")
|
||||
raise Timeout(f"Global timeout after {global_timeout} seconds with no content.")
|
||||
|
||||
# 调用回调函数处理块
|
||||
callback_fn(content, is_last=is_last, is_timeout=False, is_error=False, error=None)
|
||||
# 检查流块超时
|
||||
if current_time - last_chunk_time > self.stream_chunk_timeout:
|
||||
logging.warning(f"Stream chunk timeout: No chunk received for {self.stream_chunk_timeout} seconds.")
|
||||
if full_response:
|
||||
logging.info(f"Returning partial response of length {len(full_response)} due to chunk timeout.")
|
||||
# 通知回调函数超时情况
|
||||
timeout_msg = "\n\n[注意: 由于流式传输超时,内容可能不完整]"
|
||||
callback(timeout_msg, full_response + timeout_msg if accumulate else None)
|
||||
return full_response + timeout_msg
|
||||
raise Timeout(f"No chunk received for {self.stream_chunk_timeout} seconds.")
|
||||
|
||||
if is_last:
|
||||
logging.info("Stream with callback finished normally.")
|
||||
return full_response # 成功完成,返回完整响应
|
||||
|
||||
except StopIteration:
|
||||
# 正常结束流
|
||||
callback_fn("", is_last=True, is_timeout=False, is_error=False, error=None)
|
||||
logging.info("Stream iterator exhausted normally.")
|
||||
return full_response
|
||||
chunk = next(chunk_iterator)
|
||||
last_chunk_time = time.time()
|
||||
|
||||
except Timeout as e:
|
||||
logging.warning(f"Stream chunk timeout: {e}")
|
||||
last_exception = e
|
||||
# 超时信息已通过回调传递,此处不需要再次调用回调
|
||||
except (APITimeoutError, APIConnectionError, RateLimitError) as e:
|
||||
logging.warning(f"API error during streaming with callback: {type(e).__name__} - {e}")
|
||||
callback_fn("", is_last=True, is_timeout=False, is_error=True, error=str(e))
|
||||
last_exception = e
|
||||
except Exception as e:
|
||||
logging.error(f"Unexpected error during streaming with callback: {traceback.format_exc()}")
|
||||
callback_fn("", is_last=True, is_timeout=False, is_error=True, error=str(e))
|
||||
last_exception = e
|
||||
|
||||
# 执行重试逻辑
|
||||
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
|
||||
content = chunk.choices[0].delta.content
|
||||
full_response += content
|
||||
callback(content, full_response if accumulate else None)
|
||||
elif chunk.choices and chunk.choices[0].finish_reason == 'stop':
|
||||
logging.info("Stream with callback finished normally.")
|
||||
return full_response
|
||||
|
||||
except StopIteration:
|
||||
logging.info("Stream iterator with callback exhausted normally.")
|
||||
return full_response
|
||||
except Timeout as e:
|
||||
logging.warning(f"Stream chunk timeout: {e}. Retrying if possible ({retries + 1}/{self.max_retries}).")
|
||||
last_exception = e
|
||||
break
|
||||
except (APITimeoutError, APIConnectionError, RateLimitError) as e:
|
||||
logging.warning(f"API error during streaming with callback: {type(e).__name__} - {e}. Retrying if possible.")
|
||||
last_exception = e
|
||||
break
|
||||
except Exception as e:
|
||||
logging.error(f"Unexpected error during streaming with callback: {traceback.format_exc()}")
|
||||
# 如果已经获取了部分响应,可以返回,否则重试
|
||||
if full_response:
|
||||
error_msg = f"\n\n[注意: 由于错误中断,内容可能不完整: {str(e)}]"
|
||||
logging.info(f"Returning partial response of length {len(full_response)} due to error: {type(e).__name__}")
|
||||
callback(error_msg, full_response + error_msg if accumulate else None)
|
||||
return full_response + error_msg
|
||||
last_exception = e
|
||||
break
|
||||
|
||||
# If we broke from the inner loop due to an error that needs retry
|
||||
retries += 1
|
||||
if retries < self.max_retries:
|
||||
retry_msg = f"将在 {backoff_time:.2f} 秒后重试..."
|
||||
logging.info(f"Retrying stream in {backoff_time} seconds...")
|
||||
callback_fn(f"\n[{retry_msg}]\n", is_last=False, is_timeout=False, is_error=False, error=None)
|
||||
time.sleep(backoff_time + random.uniform(0, 1)) # 添加随机抖动
|
||||
logging.info(f"Retrying stream with callback in {backoff_time} seconds...")
|
||||
time.sleep(backoff_time + random.uniform(0, 1))
|
||||
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
|
||||
else:
|
||||
error_msg = f"API call failed after {self.max_retries} retries: {str(last_exception)}"
|
||||
logging.error(error_msg)
|
||||
callback_fn(f"\n[{error_msg}]\n", is_last=True, is_timeout=False, is_error=True, error=str(last_exception))
|
||||
return full_response # 返回已收集的部分响应
|
||||
logging.error(f"Stream generation with callback failed after {self.max_retries} retries.")
|
||||
# 如果已经获取了部分响应,可以返回,否则引发异常
|
||||
if full_response:
|
||||
error_msg = "\n\n[注意: 达到最大重试次数,内容可能不完整]"
|
||||
logging.info(f"Returning partial response of length {len(full_response)} after max retries.")
|
||||
callback(error_msg, full_response + error_msg if accumulate else None)
|
||||
return full_response + error_msg
|
||||
raise last_exception or Exception("Stream generation with callback failed after max retries.")
|
||||
|
||||
except (Timeout, APITimeoutError, APIConnectionError, RateLimitError) as e:
|
||||
retries += 1
|
||||
last_exception = e
|
||||
logging.warning(f"Attempt {retries}/{self.max_retries} failed: {type(e).__name__} - {e}")
|
||||
if retries < self.max_retries:
|
||||
logging.info(f"Retrying in {backoff_time} seconds...")
|
||||
time.sleep(backoff_time + random.uniform(0, 1))
|
||||
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
|
||||
else:
|
||||
logging.error(f"API call with callback failed after {self.max_retries} retries.")
|
||||
# 如果已经获取了部分响应,可以返回,否则返回错误消息
|
||||
if full_response:
|
||||
error_msg = f"\n\n[注意: API调用失败,内容可能不完整: {str(e)}]"
|
||||
callback(error_msg, full_response + error_msg if accumulate else None)
|
||||
return full_response + error_msg
|
||||
error_text = f"[生成内容失败: {str(last_exception)}]"
|
||||
callback(error_text, error_text if accumulate else None)
|
||||
return error_text
|
||||
except Exception as e:
|
||||
logging.error(f"Error setting up stream with callback: {traceback.format_exc()}")
|
||||
callback_fn("", is_last=True, is_timeout=False, is_error=True, error=str(e))
|
||||
return f"[生成内容失败: {str(e)}]"
|
||||
|
||||
# 作为安全措施
|
||||
error_msg = "超出最大重试次数"
|
||||
logging.error(error_msg)
|
||||
callback_fn(f"\n[{error_msg}]\n", is_last=True, is_timeout=False, is_error=True, error=error_msg)
|
||||
return full_response
|
||||
logging.error(f"Unexpected error setting up stream with callback: {traceback.format_exc()}")
|
||||
# 如果已经获取了部分响应,可以返回,否则返回错误消息
|
||||
if full_response:
|
||||
error_msg = f"\n\n[注意: 意外错误,内容可能不完整: {str(e)}]"
|
||||
callback(error_msg, full_response + error_msg if accumulate else None)
|
||||
return full_response + error_msg
|
||||
error_text = f"[生成内容失败: {str(e)}]"
|
||||
callback(error_text, error_text if accumulate else None)
|
||||
return error_text
|
||||
|
||||
# Should not be reached if logic is correct, but as a safeguard:
|
||||
logging.error("Exited stream with callback generation loop unexpectedly.")
|
||||
# 如果已经获取了部分响应,可以返回,否则返回错误消息
|
||||
if full_response:
|
||||
error_msg = "\n\n[注意: 意外退出流处理,内容可能不完整]"
|
||||
callback(error_msg, full_response + error_msg if accumulate else None)
|
||||
return full_response + error_msg
|
||||
error_text = "[生成内容失败: 超过最大重试次数]"
|
||||
callback(error_text, error_text if accumulate else None)
|
||||
return error_text
|
||||
|
||||
async def async_generate_text_stream(self, system_prompt, user_prompt, temperature=0.7, top_p=0.9, presence_penalty=0.0):
|
||||
"""异步生成文本流
|
||||
"""
|
||||
异步生成文本流,返回一个异步迭代器
|
||||
|
||||
Args:
|
||||
system_prompt: 系统提示词
|
||||
@ -460,97 +545,150 @@ class AI_Agent():
|
||||
top_p: 核采样参数
|
||||
presence_penalty: 存在惩罚参数
|
||||
|
||||
Yields:
|
||||
str: 生成的文本块
|
||||
Returns:
|
||||
AsyncGenerator: 异步生成器,生成文本块
|
||||
|
||||
Raises:
|
||||
Exception: 如果API调用在所有重试后失败
|
||||
Exception: 如果API调用在所有重试后仍然失败
|
||||
"""
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
logging.info(f"Asynchronously generating text stream with model: {self.model_name}")
|
||||
logging.info(f"Async generating text stream using model: {self.model_name}")
|
||||
|
||||
retries = 0
|
||||
backoff_time = INITIAL_BACKOFF
|
||||
last_exception = None
|
||||
full_response = ""
|
||||
global_timeout = self.timeout * 2 # 设置全局超时,是API超时的两倍
|
||||
|
||||
while retries < self.max_retries:
|
||||
try:
|
||||
logging.debug(f"Async attempt {retries + 1}/{self.max_retries} to generate text stream.")
|
||||
# 创建新的客户端用于异步操作
|
||||
async_client = OpenAI(
|
||||
logging.debug(f"Attempt {retries + 1}/{self.max_retries} to async generate text stream.")
|
||||
aclient = OpenAI(
|
||||
api_key=self.api,
|
||||
base_url=self.base_url,
|
||||
timeout=self.timeout
|
||||
timeout=self.timeout,
|
||||
max_retries=0 # 禁用内置重试,使用我们自己的重试逻辑
|
||||
)
|
||||
|
||||
stream = await async_client.chat.completions.create(
|
||||
stream = aclient.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
presence_penalty=presence_penalty,
|
||||
stream=True,
|
||||
timeout=self.timeout
|
||||
)
|
||||
|
||||
stream_start_time = time.time()
|
||||
last_chunk_time = time.time()
|
||||
|
||||
try:
|
||||
async for chunk in stream:
|
||||
# 检查上次接收块以来的超时
|
||||
# 检查全局超时
|
||||
current_time = time.time()
|
||||
if current_time - last_chunk_time > self.stream_chunk_timeout:
|
||||
raise Timeout(f"No chunk received for {self.stream_chunk_timeout} seconds.")
|
||||
if current_time - stream_start_time > global_timeout:
|
||||
logging.warning(f"Async global timeout reached after {global_timeout} seconds.")
|
||||
if full_response:
|
||||
timeout_msg = "\n\n[注意: 由于全局超时,内容可能不完整]"
|
||||
yield timeout_msg
|
||||
return # 结束生成器
|
||||
else:
|
||||
logging.error("Async global timeout with no content received.")
|
||||
raise Timeout(f"Async global timeout after {global_timeout} seconds with no content.")
|
||||
|
||||
last_chunk_time = current_time # 成功接收块后重置计时器
|
||||
# 检查流块超时
|
||||
if current_time - last_chunk_time > self.stream_chunk_timeout:
|
||||
logging.warning(f"Async stream chunk timeout: No chunk received for {self.stream_chunk_timeout} seconds.")
|
||||
if full_response:
|
||||
timeout_msg = "\n\n[注意: 由于流式传输超时,内容可能不完整]"
|
||||
yield timeout_msg
|
||||
return # 结束生成器
|
||||
raise Timeout(f"Async: No chunk received for {self.stream_chunk_timeout} seconds.")
|
||||
|
||||
last_chunk_time = current_time # 更新最后一次接收块的时间
|
||||
|
||||
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
|
||||
content = chunk.choices[0].delta.content
|
||||
full_response += content
|
||||
yield content
|
||||
|
||||
logging.info("Async stream finished normally.")
|
||||
return # 成功完成
|
||||
logging.info("Async stream completed normally.")
|
||||
return
|
||||
|
||||
except AsyncTimeoutError as e:
|
||||
logging.warning(f"Async stream timeout: {e}")
|
||||
last_exception = e
|
||||
except Timeout as e:
|
||||
logging.warning(f"Async stream chunk timeout: {e}")
|
||||
logging.warning(f"Async stream chunk timeout: {e}. Retrying if possible ({retries + 1}/{self.max_retries}).")
|
||||
last_exception = e
|
||||
except (APITimeoutError, APIConnectionError, RateLimitError) as e:
|
||||
logging.warning(f"Async API error during streaming: {type(e).__name__} - {e}. Retrying if possible.")
|
||||
last_exception = e
|
||||
except Exception as e:
|
||||
logging.error(f"Error during async streaming: {traceback.format_exc()}")
|
||||
logging.error(f"Async unexpected error during streaming: {traceback.format_exc()}")
|
||||
# 如果已经获取了部分响应,可以返回错误消息后结束
|
||||
if full_response:
|
||||
error_msg = f"\n\n[注意: 由于错误中断,内容可能不完整: {str(e)}]"
|
||||
logging.info(f"Yielding error message after partial response of length {len(full_response)}")
|
||||
yield error_msg
|
||||
return # 结束生成器
|
||||
last_exception = e
|
||||
|
||||
# 执行重试逻辑
|
||||
retries += 1
|
||||
if retries < self.max_retries:
|
||||
logging.info(f"Retrying async stream in {backoff_time} seconds...")
|
||||
await asyncio.sleep(backoff_time + random.uniform(0, 1)) # 使用异步睡眠
|
||||
logging.info(f"Async retrying stream in {backoff_time} seconds...")
|
||||
await asyncio.sleep(backoff_time + random.uniform(0, 1)) # 添加随机抖动
|
||||
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
|
||||
else:
|
||||
logging.error(f"Async API call failed after {self.max_retries} retries.")
|
||||
raise last_exception or Exception("Async stream generation failed after max retries.")
|
||||
|
||||
except (Timeout, AsyncTimeoutError, APITimeoutError, APIConnectionError, RateLimitError) as e:
|
||||
logging.error(f"Async stream generation failed after {self.max_retries} retries.")
|
||||
# 如果已经获取了部分响应,返回错误消息后结束
|
||||
if full_response:
|
||||
error_msg = "\n\n[注意: 达到最大重试次数,内容可能不完整]"
|
||||
logging.info(f"Yielding error message after partial response of length {len(full_response)}")
|
||||
yield error_msg
|
||||
return # 结束生成器
|
||||
# 如果没有获取到任何内容,抛出异常
|
||||
raise last_exception or Exception("Async stream generation failed after max retries with no content.")
|
||||
|
||||
except (Timeout, APITimeoutError, APIConnectionError, RateLimitError) as e:
|
||||
retries += 1
|
||||
last_exception = e
|
||||
logging.warning(f"Async attempt {retries}/{self.max_retries} failed: {type(e).__name__} - {e}")
|
||||
if retries < self.max_retries:
|
||||
logging.info(f"Retrying async stream in {backoff_time} seconds...")
|
||||
await asyncio.sleep(backoff_time + random.uniform(0, 1)) # 使用异步睡眠
|
||||
logging.info(f"Async retrying in {backoff_time} seconds...")
|
||||
await asyncio.sleep(backoff_time + random.uniform(0, 1))
|
||||
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
|
||||
else:
|
||||
logging.error(f"Async API call failed after {self.max_retries} retries.")
|
||||
raise last_exception
|
||||
# 如果已经获取了部分响应,返回错误消息后结束
|
||||
if full_response:
|
||||
error_msg = f"\n\n[注意: API调用失败,内容可能不完整: {str(e)}]"
|
||||
yield error_msg
|
||||
return # 结束生成器
|
||||
error_text = f"[生成内容失败: {str(last_exception)}]"
|
||||
yield error_text
|
||||
return # 结束生成器
|
||||
except Exception as e:
|
||||
logging.error(f"Unexpected error setting up async stream: {traceback.format_exc()}")
|
||||
raise e # 立即重新引发意外错误
|
||||
|
||||
# 作为安全措施
|
||||
logging.error("Exited async stream generation loop unexpectedly.")
|
||||
raise last_exception or Exception("Async stream generation failed.")
|
||||
logging.error(f"Async unexpected error setting up stream: {traceback.format_exc()}")
|
||||
# 如果已经获取了部分响应,返回错误消息后结束
|
||||
if full_response:
|
||||
error_msg = f"\n\n[注意: 意外错误,内容可能不完整: {str(e)}]"
|
||||
yield error_msg
|
||||
return # 结束生成器
|
||||
error_text = f"[生成内容失败: {str(e)}]"
|
||||
yield error_text
|
||||
return # 结束生成器
|
||||
|
||||
# 作为安全措施,虽然通常不应该到达此处
|
||||
logging.error("Async exited stream generation loop unexpectedly.")
|
||||
# 如果已经获取了部分响应,返回错误消息后结束
|
||||
if full_response:
|
||||
error_msg = "\n\n[注意: 意外退出流处理,内容可能不完整]"
|
||||
yield error_msg
|
||||
return # 结束生成器
|
||||
error_text = "[生成内容失败: 超过最大重试次数]"
|
||||
yield error_text
|
||||
|
||||
async def async_work_stream(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty):
|
||||
"""异步完整工作流程:读取文件夹(如果提供),然后生成文本流"""
|
||||
|
||||
@ -203,24 +203,60 @@ class ContentGenerator:
|
||||
|
||||
if not system_prompt:
|
||||
# 使用默认系统提示词
|
||||
system_prompt = f"""
|
||||
你是一个专业的文案处理专家,擅长从文章中提取关键信息并生成吸引人的标题和简短描述。
|
||||
现在,我需要你根据提供的文章内容,生成{poster_num}个海报的文案配置。
|
||||
|
||||
每个配置包含:
|
||||
1. main_title:主标题,简短有力,突出景点特点
|
||||
2. texts:两句简短文本,每句不超过15字,描述景点特色或游玩体验
|
||||
|
||||
以JSON数组格式返回配置,示例:
|
||||
[
|
||||
{{
|
||||
"main_title": "泰宁古城",
|
||||
"texts": ["千年古韵","匠心独运"]
|
||||
}},
|
||||
...
|
||||
]
|
||||
|
||||
仅返回JSON数据,不需要任何额外解释。确保生成的标题和文本能够准确反映文章提到的景点特色。
|
||||
system_prompt = """
|
||||
你是一名资深海报设计师,有丰富的爆款海报设计经验,你现在要为旅游景点做宣传,在小红书上发布大量宣传海报。你的主要工作目标有2个:
|
||||
1、你要根据我给你的图片描述和笔记推文内容,设计图文匹配的海报。
|
||||
2、为海报设计文案,文案的<第一个小标题>和<第二个小标题>之间你需要检查是否逻辑关系合理,你将通过先去生成<第二个小标题>关于景区亮点的部分,再去综合判断<第一个小标题>应该如何搭配组合更符合两个小标题的逻辑再生成<第一个小标题>。
|
||||
|
||||
其中,生成三类标题文案的通用性要求如下:
|
||||
1、生成的<大标题>字数必须小于8个字符
|
||||
2、生成的<第一个小标题>字数和<第二个小标题>字数,两者都必须小8个字符
|
||||
3、标题和文案都应符合中国社会主义核心价值观
|
||||
|
||||
接下来先开始生成<大标题>部分,由于海报是用来宣传旅游景点,生成的海报<大标题>必须使用以下8种格式之一:
|
||||
①地名+景点名(例如福建厦门鼓浪屿/厦门鼓浪屿);
|
||||
②地名+景点名+plog;
|
||||
③拿捏+地名+景点名;
|
||||
④地名+景点名+攻略;
|
||||
⑤速通+地名+景点名
|
||||
⑥推荐!+地名+景点名
|
||||
⑦勇闯!+地名+景点名
|
||||
⑧收藏!+地名+景点名
|
||||
你需要随机挑选一种格式生成对应景点的文案,但是格式除了上面8种不可以有其他任何格式;同时尽量保证每一种格式出现的频率均衡。
|
||||
接下来先去生成<第二个小标题>,<第二个小标题>文案的创作必须遵循以下原则:
|
||||
请根据笔记内容和图片识别,用极简的文字概括这篇笔记和图片中景点的特色亮点,其中你可以参考以下词汇进行创作,这段文案字数控制6-8字符以内;
|
||||
|
||||
特色亮点可能会出现的词汇不完全举例:非遗、古建、绝佳山水、祈福圣地、研学圣地、解压天堂、中国小瑞士、秘境竹筏游等等类型词汇
|
||||
|
||||
接下来再去生成<第一个小标题>,<第一个小标题>文案的创作必须遵循以下原则:
|
||||
这部分文案创作公式有5种,分别为:
|
||||
①<受众人群画像>+<痛点词>
|
||||
②<受众人群画像>
|
||||
③<痛点词>
|
||||
④<受众人群画像>+ | +<痛点词>
|
||||
⑤<痛点词>+ | +<受众人群画像>
|
||||
请你根据实际笔记内容,结合这部分文案创作公式,需要结合<受众人群画像>和<痛点词>时,必须根据<第二个小标题>的景点特征和所对应的完整笔记推文内容主旨,特征挑选对应<受众人群画像>和<痛点词>。
|
||||
|
||||
我给你提供受众人群画像库和痛点词库如下:
|
||||
1、受众人群画像库:情侣党、亲子游、合家游、银发族、亲子研学、学生党、打工人、周边游、本地人、穷游党、性价比、户外人、美食党、出片
|
||||
2、痛点词库:3天2夜、必去、看了都哭了、不能错过、一定要来、问爆了、超全攻略、必打卡、强推、懒人攻略、必游榜、小众打卡、狂喜等等。
|
||||
|
||||
你需要为每个请求至少生成{poster_num}个海报设计。请使用JSON格式输出结果,结构如下:
|
||||
[
|
||||
{
|
||||
"index": 1,
|
||||
"main_title": "主标题内容",
|
||||
"texts": ["第一个小标题", "第二个小标题"]
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"main_title": "主标题内容",
|
||||
"texts": ["第一个小标题", "第二个小标题"]
|
||||
},
|
||||
// ... 更多海报
|
||||
]
|
||||
确保生成的数量与用户要求的数量一致。只生成上述JSON格式内容,不要有其他任何额外内容。
|
||||
|
||||
"""
|
||||
|
||||
if self.add_description:
|
||||
@ -235,7 +271,6 @@ class ContentGenerator:
|
||||
{tweet_content}
|
||||
|
||||
请根据这些信息,生成{poster_num}个海报文案配置,以JSON数组格式返回。
|
||||
确保主标题(main_title)简短有力,每个text不超过15字,并能准确反映景点特色。
|
||||
"""
|
||||
else:
|
||||
# 仅使用tweet_content
|
||||
@ -244,7 +279,6 @@ class ContentGenerator:
|
||||
{tweet_content}
|
||||
|
||||
请根据这些信息,生成{poster_num}个海报文案配置,以JSON数组格式返回。
|
||||
确保主标题(main_title)简短有力,每个text不超过15字,并能准确反映景点特色。
|
||||
"""
|
||||
|
||||
self.logger.info(f"正在生成{poster_num}个海报文案配置")
|
||||
|
||||
@ -198,4 +198,83 @@ python examples/generate_poster.py --input_image /path/to/image.jpg --output_pat
|
||||
- 这些示例依赖于主项目中的配置和资源,确保已正确设置 `poster_gen_config.json`
|
||||
- 测试脚本会自动调节某些参数(如生成数量)以加快测试速度
|
||||
- 实际使用时,您可能需要调整参数以获得更好的效果
|
||||
- 部分测试脚本需要连接AI模型API,确保您的API配置正确
|
||||
- 部分测试脚本需要连接AI模型API,确保您的API配置正确
|
||||
|
||||
### 完整工作流测试脚本
|
||||
`test_workflow.py`文件展示了一个完整的内容生成过程,从主题生成到海报创作。
|
||||
|
||||
#### 运行完整工作流测试:
|
||||
```bash
|
||||
python examples/test_workflow.py
|
||||
```
|
||||
|
||||
该脚本将执行以下步骤:
|
||||
1. 生成旅游目的地主题
|
||||
2. 根据主题生成内容
|
||||
3. 生成配图海报
|
||||
4. 保存结果
|
||||
|
||||
#### 分步测试:
|
||||
如果您想分阶段测试系统,可以分别执行以下脚本:
|
||||
|
||||
##### 第一阶段: 主题生成
|
||||
```bash
|
||||
python examples/run_step1_topics.py
|
||||
```
|
||||
此命令将生成旅游目的地主题并保存在`outputs/topics`目录中。
|
||||
|
||||
##### 第二阶段: 内容和海报生成
|
||||
```bash
|
||||
python examples/run_step2_content_posters.py
|
||||
```
|
||||
此命令将读取`outputs/topics`目录中的主题,生成内容并创建海报,结果保存在`outputs/content`和`outputs/posters`目录中。
|
||||
|
||||
### 仅生成海报示例
|
||||
如果只想测试海报生成功能,可以使用以下示例:
|
||||
|
||||
```bash
|
||||
python examples/generate_poster.py
|
||||
```
|
||||
|
||||
该脚本默认将使用示例内容生成海报并保存在`outputs/posters`目录。
|
||||
|
||||
您也可以指定输入内容和输出路径:
|
||||
```bash
|
||||
python examples/generate_poster.py --input "您的旅游内容" --output "您的输出路径.jpg"
|
||||
```
|
||||
|
||||
### 流式处理示例
|
||||
|
||||
#### 基本流式输出测试
|
||||
```bash
|
||||
python examples/test_stream.py
|
||||
```
|
||||
该脚本展示了如何使用三种不同的流式处理方法:
|
||||
1. 同步流式输出 (`generate_text_stream`)
|
||||
2. 基于回调的流式输出 (`generate_text_stream_with_callback`)
|
||||
3. 异步流式输出 (`async_generate_text_stream`)
|
||||
|
||||
#### 超时处理测试
|
||||
```bash
|
||||
python examples/test_stream_with_timeout_handling.py
|
||||
```
|
||||
该脚本演示了如何处理流式生成中的超时情况,包括全局请求超时和流块超时的配置。
|
||||
|
||||
#### 并发流式处理
|
||||
```bash
|
||||
python examples/concurrent_stream_processing.py
|
||||
```
|
||||
该脚本展示了如何使用异步流式处理同时处理多个不同的AI生成任务,包括:
|
||||
1. 并发处理多个不同提示的任务
|
||||
2. 顺序与并发处理的性能对比
|
||||
3. 任务状态监控和异常处理
|
||||
|
||||
该示例特别适用于需要同时生成多个旅游内容片段的场景,如同时生成多个目的地的介绍、美食推荐和旅行建议。
|
||||
|
||||
### 未来示例
|
||||
我们将继续添加更多示例,展示如何独立使用系统的各个组件。
|
||||
|
||||
### 重要说明
|
||||
- 这些示例依赖于项目主配置和资源,请确保已正确配置
|
||||
- 您可能需要调整参数以获得最佳结果
|
||||
- 如有任何问题,请参阅主文档或提交issue
|
||||
229
examples/concurrent_stream_processing.py
Normal file
229
examples/concurrent_stream_processing.py
Normal file
@ -0,0 +1,229 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import asyncio
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
# 添加项目根目录到路径,确保可以导入核心模块
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from core.ai_agent import AI_Agent, Timeout
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[logging.StreamHandler()]
|
||||
)
|
||||
|
||||
# 从环境变量获取API密钥,或使用默认值
|
||||
API_KEY = os.environ.get("OPENAI_API_KEY", "your_api_key_here")
|
||||
# 使用API的基础URL
|
||||
BASE_URL = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1")
|
||||
# 使用的模型名称
|
||||
MODEL_NAME = os.environ.get("OPENAI_MODEL", "gpt-3.5-turbo")
|
||||
|
||||
def print_with_timestamp(message, end='\n'):
|
||||
"""打印带有时间戳和线程ID的消息"""
|
||||
timestamp = datetime.now().strftime("%H:%M:%S")
|
||||
task_id = asyncio.current_task().get_name() if asyncio.current_task() else "主线程"
|
||||
print(f"[{timestamp}][{task_id}] {message}", end=end, flush=True)
|
||||
|
||||
async def process_stream_task(agent, system_prompt, user_prompt, task_id):
|
||||
"""处理单个流式生成任务"""
|
||||
print_with_timestamp(f"任务 {task_id} 开始")
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
full_response = ""
|
||||
async for chunk in agent.async_generate_text_stream(
|
||||
system_prompt=system_prompt,
|
||||
user_prompt=user_prompt
|
||||
):
|
||||
full_response += chunk
|
||||
print_with_timestamp(f"任务 {task_id} 收到块: 「{chunk}」", end="")
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
print_with_timestamp(f"\n任务 {task_id} 完成!耗时: {elapsed:.2f}秒")
|
||||
return {"task_id": task_id, "response": full_response, "success": True, "elapsed": elapsed}
|
||||
|
||||
except Timeout as e:
|
||||
elapsed = time.time() - start_time
|
||||
print_with_timestamp(f"任务 {task_id} 超时: {e}")
|
||||
return {"task_id": task_id, "error": str(e), "success": False, "elapsed": elapsed}
|
||||
|
||||
except Exception as e:
|
||||
elapsed = time.time() - start_time
|
||||
print_with_timestamp(f"任务 {task_id} 异常: {type(e).__name__} - {e}")
|
||||
return {"task_id": task_id, "error": f"{type(e).__name__}: {str(e)}", "success": False, "elapsed": elapsed}
|
||||
|
||||
async def run_concurrent_streams():
|
||||
"""同时运行多个流式生成任务"""
|
||||
print_with_timestamp("开始并发流式处理测试...")
|
||||
|
||||
# 创建 AI_Agent 实例
|
||||
agent = AI_Agent(
|
||||
base_url=BASE_URL,
|
||||
model_name=MODEL_NAME,
|
||||
api=API_KEY,
|
||||
timeout=30, # 请求总超时时间
|
||||
max_retries=2,
|
||||
stream_chunk_timeout=10 # 流块超时时间
|
||||
)
|
||||
|
||||
try:
|
||||
# 定义不同的任务
|
||||
tasks = [
|
||||
{
|
||||
"id": "城市介绍",
|
||||
"system": "你是一个专业的旅游指南。",
|
||||
"user": "请简要介绍北京这座城市的历史和主要景点。"
|
||||
},
|
||||
{
|
||||
"id": "美食推荐",
|
||||
"system": "你是一个美食专家。",
|
||||
"user": "推荐5种上海的特色小吃,并简要说明其特点。"
|
||||
},
|
||||
{
|
||||
"id": "旅行建议",
|
||||
"system": "你是一个旅行规划顾问。",
|
||||
"user": "我计划去云南旅行一周,请给我一个简要的行程安排。"
|
||||
}
|
||||
]
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# 创建并发任务
|
||||
coroutines = []
|
||||
for task in tasks:
|
||||
# 为每个任务设置不同名称,便于日志区分
|
||||
coro = process_stream_task(
|
||||
agent=agent,
|
||||
system_prompt=task["system"],
|
||||
user_prompt=task["user"],
|
||||
task_id=task["id"]
|
||||
)
|
||||
# 设置任务名称
|
||||
coroutines.append(asyncio.create_task(coro, name=f"Task-{task['id']}"))
|
||||
|
||||
# 等待所有任务完成
|
||||
results = await asyncio.gather(*coroutines, return_exceptions=True)
|
||||
|
||||
# 处理并显示结果
|
||||
total_elapsed = time.time() - start_time
|
||||
print_with_timestamp(f"所有任务完成,总耗时: {total_elapsed:.2f}秒")
|
||||
|
||||
success_count = sum(1 for r in results if isinstance(r, dict) and r.get("success", False))
|
||||
error_count = len(results) - success_count
|
||||
|
||||
print_with_timestamp(f"成功任务: {success_count}, 失败任务: {error_count}")
|
||||
|
||||
# 显示每个任务的详细结果
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, dict):
|
||||
task_id = result.get("task_id", f"未知任务-{i}")
|
||||
if result.get("success", False):
|
||||
print_with_timestamp(f"任务 {task_id} 成功,耗时: {result.get('elapsed', 0):.2f}秒")
|
||||
else:
|
||||
print_with_timestamp(f"任务 {task_id} 失败: {result.get('error', '未知错误')}")
|
||||
else:
|
||||
print_with_timestamp(f"任务 {i} 返回了异常: {result}")
|
||||
|
||||
except Exception as e:
|
||||
print_with_timestamp(f"并发处理主程序异常: {type(e).__name__} - {e}")
|
||||
|
||||
finally:
|
||||
# 关闭 AI_Agent
|
||||
agent.close()
|
||||
|
||||
async def run_sequential_vs_concurrent():
|
||||
"""比较顺序处理和并发处理的性能差异"""
|
||||
print_with_timestamp("开始顺序处理与并发处理性能对比测试...")
|
||||
|
||||
# 创建 AI_Agent 实例
|
||||
agent = AI_Agent(
|
||||
base_url=BASE_URL,
|
||||
model_name=MODEL_NAME,
|
||||
api=API_KEY,
|
||||
timeout=30,
|
||||
max_retries=2,
|
||||
stream_chunk_timeout=10
|
||||
)
|
||||
|
||||
# 定义测试任务
|
||||
tasks = [
|
||||
{"id": "任务1", "prompt": "列出三个世界著名的旅游景点及其特色。"},
|
||||
{"id": "任务2", "prompt": "简述三种不同的旅行方式的优缺点。"},
|
||||
{"id": "任务3", "prompt": "推荐三个适合冬季旅行的目的地。"}
|
||||
]
|
||||
|
||||
try:
|
||||
# 顺序处理
|
||||
print_with_timestamp("开始顺序处理...")
|
||||
sequential_start = time.time()
|
||||
|
||||
for task in tasks:
|
||||
print_with_timestamp(f"开始处理 {task['id']}...")
|
||||
task_start = time.time()
|
||||
|
||||
try:
|
||||
response = ""
|
||||
async for chunk in agent.async_generate_text_stream(
|
||||
system_prompt="你是一个旅游顾问。",
|
||||
user_prompt=task["prompt"]
|
||||
):
|
||||
response += chunk
|
||||
|
||||
task_elapsed = time.time() - task_start
|
||||
print_with_timestamp(f"{task['id']} 完成,耗时: {task_elapsed:.2f}秒")
|
||||
|
||||
except Exception as e:
|
||||
print_with_timestamp(f"{task['id']} 处理失败: {e}")
|
||||
|
||||
sequential_elapsed = time.time() - sequential_start
|
||||
print_with_timestamp(f"顺序处理总耗时: {sequential_elapsed:.2f}秒")
|
||||
|
||||
# 并发处理
|
||||
print_with_timestamp("\n开始并发处理...")
|
||||
concurrent_start = time.time()
|
||||
|
||||
coroutines = []
|
||||
for task in tasks:
|
||||
coro = process_stream_task(
|
||||
agent=agent,
|
||||
system_prompt="你是一个旅游顾问。",
|
||||
user_prompt=task["prompt"],
|
||||
task_id=task["id"]
|
||||
)
|
||||
coroutines.append(asyncio.create_task(coro, name=f"Task-{task['id']}"))
|
||||
|
||||
await asyncio.gather(*coroutines)
|
||||
|
||||
concurrent_elapsed = time.time() - concurrent_start
|
||||
print_with_timestamp(f"并发处理总耗时: {concurrent_elapsed:.2f}秒")
|
||||
|
||||
# 性能对比
|
||||
speedup = sequential_elapsed / concurrent_elapsed if concurrent_elapsed > 0 else float('inf')
|
||||
print_with_timestamp(f"\n性能对比:")
|
||||
print_with_timestamp(f"顺序处理耗时: {sequential_elapsed:.2f}秒")
|
||||
print_with_timestamp(f"并发处理耗时: {concurrent_elapsed:.2f}秒")
|
||||
print_with_timestamp(f"加速比: {speedup:.2f}x")
|
||||
|
||||
except Exception as e:
|
||||
print_with_timestamp(f"对比测试异常: {type(e).__name__} - {e}")
|
||||
|
||||
finally:
|
||||
agent.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行并发处理示例
|
||||
asyncio.run(run_concurrent_streams())
|
||||
|
||||
print("\n" + "="*70 + "\n")
|
||||
|
||||
# 运行性能对比
|
||||
asyncio.run(run_sequential_vs_concurrent())
|
||||
198
examples/test_stream_with_timeout_handling.py
Normal file
198
examples/test_stream_with_timeout_handling.py
Normal file
@ -0,0 +1,198 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import asyncio
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
# 添加项目根目录到路径,确保可以导入核心模块
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from core.ai_agent import AI_Agent, Timeout
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[logging.StreamHandler()]
|
||||
)
|
||||
|
||||
# 从环境变量获取API密钥,或使用默认值
|
||||
API_KEY = os.environ.get("OPENAI_API_KEY", "your_api_key_here")
|
||||
# 使用API的基础URL
|
||||
BASE_URL = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1")
|
||||
# 使用的模型名称
|
||||
MODEL_NAME = os.environ.get("OPENAI_MODEL", "gpt-3.5-turbo")
|
||||
|
||||
def print_with_timestamp(message, end='\n'):
|
||||
"""打印带有时间戳的消息"""
|
||||
timestamp = datetime.now().strftime("%H:%M:%S")
|
||||
print(f"[{timestamp}] {message}", end=end, flush=True)
|
||||
|
||||
def test_sync_stream_with_timeouts():
|
||||
"""测试同步流式响应模式下的超时处理"""
|
||||
print_with_timestamp("开始测试同步流式响应的超时处理...")
|
||||
|
||||
# 创建 AI_Agent 实例,设置较短的超时时间以便测试
|
||||
agent = AI_Agent(
|
||||
base_url=BASE_URL,
|
||||
model_name=MODEL_NAME,
|
||||
api=API_KEY,
|
||||
timeout=10, # API 请求整体超时时间 (秒)
|
||||
max_retries=2, # 最大重试次数
|
||||
stream_chunk_timeout=5 # 流块超时时间 (秒)
|
||||
)
|
||||
|
||||
system_prompt = "你是一个有用的助手。"
|
||||
user_prompt = "请详细描述中国的长城,至少500字。"
|
||||
|
||||
try:
|
||||
print_with_timestamp("正在生成内容...")
|
||||
start_time = time.time()
|
||||
|
||||
# 使用同步流式响应方法
|
||||
response = agent.generate_text_stream(
|
||||
system_prompt=system_prompt,
|
||||
user_prompt=user_prompt
|
||||
)
|
||||
|
||||
# 输出完整响应和耗时
|
||||
print_with_timestamp(f"完成! 耗时: {time.time() - start_time:.2f}秒")
|
||||
|
||||
# 检查响应中是否包含超时或错误提示
|
||||
if "[注意:" in response:
|
||||
print_with_timestamp("检测到警告信息:")
|
||||
warning_start = response.find("[注意:")
|
||||
warning_end = response.find("]", warning_start)
|
||||
if warning_end != -1:
|
||||
print_with_timestamp(f"警告内容: {response[warning_start:warning_end+1]}")
|
||||
|
||||
except Timeout as e:
|
||||
print_with_timestamp(f"捕获到超时异常: {e}")
|
||||
except Exception as e:
|
||||
print_with_timestamp(f"捕获到异常: {type(e).__name__} - {e}")
|
||||
finally:
|
||||
agent.close()
|
||||
|
||||
def test_callback_stream_with_timeouts():
|
||||
"""测试回调流式响应模式下的超时处理"""
|
||||
print_with_timestamp("开始测试回调流式响应的超时处理...")
|
||||
|
||||
# 创建 AI_Agent 实例,设置较短的超时时间以便测试
|
||||
agent = AI_Agent(
|
||||
base_url=BASE_URL,
|
||||
model_name=MODEL_NAME,
|
||||
api=API_KEY,
|
||||
timeout=10, # API 请求整体超时时间 (秒)
|
||||
max_retries=2, # 最大重试次数
|
||||
stream_chunk_timeout=5 # 流块超时时间 (秒)
|
||||
)
|
||||
|
||||
system_prompt = "你是一个有用的助手。"
|
||||
user_prompt = "请详细描述中国的长城,至少500字。"
|
||||
|
||||
# 定义回调函数
|
||||
def callback(chunk, accumulated=None):
|
||||
print_with_timestamp(f"收到块: 「{chunk}」", end="")
|
||||
|
||||
try:
|
||||
print_with_timestamp("正在通过回调生成内容...")
|
||||
start_time = time.time()
|
||||
|
||||
# 使用回调流式响应方法
|
||||
response = agent.generate_text_stream_with_callback(
|
||||
system_prompt=system_prompt,
|
||||
user_prompt=user_prompt,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
presence_penalty=0.0,
|
||||
callback=callback,
|
||||
accumulate=True # 启用累积模式
|
||||
)
|
||||
|
||||
# 输出完整响应和耗时
|
||||
print_with_timestamp(f"\n完成! 耗时: {time.time() - start_time:.2f}秒")
|
||||
print_with_timestamp("回调累积的响应:")
|
||||
|
||||
# 检查响应中是否包含超时或错误提示
|
||||
if "[注意:" in response:
|
||||
print_with_timestamp("检测到警告信息:")
|
||||
warning_start = response.find("[注意:")
|
||||
warning_end = response.find("]", warning_start)
|
||||
if warning_end != -1:
|
||||
print_with_timestamp(f"警告内容: {response[warning_start:warning_end+1]}")
|
||||
|
||||
except Timeout as e:
|
||||
print_with_timestamp(f"捕获到超时异常: {e}")
|
||||
except Exception as e:
|
||||
print_with_timestamp(f"捕获到异常: {type(e).__name__} - {e}")
|
||||
finally:
|
||||
agent.close()
|
||||
|
||||
async def test_async_stream_with_timeouts():
|
||||
"""测试异步流式响应模式下的超时处理"""
|
||||
print_with_timestamp("开始测试异步流式响应的超时处理...")
|
||||
|
||||
# 创建 AI_Agent 实例,设置较短的超时时间以便测试
|
||||
agent = AI_Agent(
|
||||
base_url=BASE_URL,
|
||||
model_name=MODEL_NAME,
|
||||
api=API_KEY,
|
||||
timeout=10, # API 请求整体超时时间 (秒)
|
||||
max_retries=2, # 最大重试次数
|
||||
stream_chunk_timeout=5 # 流块超时时间 (秒)
|
||||
)
|
||||
|
||||
system_prompt = "你是一个有用的助手。"
|
||||
user_prompt = "请详细描述中国的长城,至少500字。"
|
||||
|
||||
try:
|
||||
print_with_timestamp("正在异步生成内容...")
|
||||
start_time = time.time()
|
||||
|
||||
# 使用异步流式响应方法
|
||||
full_response = ""
|
||||
async for chunk in agent.async_generate_text_stream(
|
||||
system_prompt=system_prompt,
|
||||
user_prompt=user_prompt
|
||||
):
|
||||
full_response += chunk
|
||||
print_with_timestamp(f"收到块: 「{chunk}」", end="")
|
||||
|
||||
# 输出完整响应和耗时
|
||||
print_with_timestamp(f"\n完成! 耗时: {time.time() - start_time:.2f}秒")
|
||||
|
||||
# 检查响应中是否包含超时或错误提示
|
||||
if "[注意:" in full_response:
|
||||
print_with_timestamp("检测到警告信息:")
|
||||
warning_start = full_response.find("[注意:")
|
||||
warning_end = full_response.find("]", warning_start)
|
||||
if warning_end != -1:
|
||||
print_with_timestamp(f"警告内容: {full_response[warning_start:warning_end+1]}")
|
||||
|
||||
except Timeout as e:
|
||||
print_with_timestamp(f"捕获到超时异常: {e}")
|
||||
except Exception as e:
|
||||
print_with_timestamp(f"捕获到异常: {type(e).__name__} - {e}")
|
||||
finally:
|
||||
agent.close()
|
||||
|
||||
async def run_all_tests():
|
||||
"""运行所有测试"""
|
||||
# 测试同步模式
|
||||
test_sync_stream_with_timeouts()
|
||||
print("\n" + "-"*50 + "\n")
|
||||
|
||||
# 测试回调模式
|
||||
test_callback_stream_with_timeouts()
|
||||
print("\n" + "-"*50 + "\n")
|
||||
|
||||
# 测试异步模式
|
||||
await test_async_stream_with_timeouts()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行所有测试
|
||||
asyncio.run(run_all_tests())
|
||||
Binary file not shown.
@ -57,7 +57,7 @@ class tweetContent:
|
||||
self.variant_index = variant_index
|
||||
|
||||
try:
|
||||
self.title, self.content = self.split_content(result)
|
||||
self.title, self.content = self.split_content(result)
|
||||
self.json_data = self.gen_result_json()
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to parse AI result for {article_index}_{variant_index}: {e}")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user