746 lines
41 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from openai import OpenAI, APITimeoutError, APIConnectionError, RateLimitError, APIStatusError
import time
import random
import traceback
import logging
import tiktoken
import asyncio
from asyncio import TimeoutError as AsyncTimeoutError
# Configure basic logging for this module (or rely on root logger config)
# logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# logger = logging.getLogger(__name__) # Alternative: use named logger
# Constants
MAX_RETRIES = 3 # Maximum number of retries for API calls
INITIAL_BACKOFF = 1 # Initial backoff time in seconds
MAX_BACKOFF = 16 # Maximum backoff time in seconds
STREAM_CHUNK_TIMEOUT = 10 # Timeout in seconds for receiving a chunk in stream
# 自定义超时异常
class Timeout(Exception):
"""Raised when a stream chunk timeout occurs."""
def __init__(self, message="Stream chunk timeout occurred."):
self.message = message
super().__init__(self.message)
def __str__(self):
return self.message
class AI_Agent():
"""AI代理类负责与AI模型交互生成文本内容"""
def __init__(self, base_url, model_name, api, timeout=30, max_retries=3, stream_chunk_timeout=10):
"""
初始化 AI 代理。
Args:
base_url (str): 模型服务的基础 URL 或预设名称 ('deepseek', 'vllm')。
model_name (str): 要使用的模型名称。
api (str): API 密钥。
timeout (int, optional): 单次 API 请求的超时时间 (秒)。 Defaults to 30.
max_retries (int, optional): API 请求失败时的最大重试次数。 Defaults to 3.
stream_chunk_timeout (int, optional): 流式响应中两个数据块之间的最大等待时间 (秒)。 Defaults to 10.
"""
logging.info("Initializing AI Agent")
self.url_list = {
"ali": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"kimi": "https://api.moonshot.cn/v1",
"doubao": "https://ark.cn-beijing.volces.com/api/v3/",
"deepseek": "https://api.deepseek.com",
"vllm": "http://localhost:8000/v1",
}
self.base_url = self.url_list.get(base_url, base_url)
self.api = api
self.model_name = model_name
self.timeout = timeout
self.max_retries = max_retries
self.stream_chunk_timeout = stream_chunk_timeout
logging.info(f"AI Agent Settings: base_url={self.base_url}, model={self.model_name}, timeout={self.timeout}s, max_retries={self.max_retries}, stream_chunk_timeout={self.stream_chunk_timeout}s")
self.client = OpenAI(
api_key=self.api,
base_url=self.base_url,
timeout=self.timeout
)
# try:
# self.encoding = tiktoken.encoding_for_model(self.model_name)
# except KeyError:
# logging.warning(f"Encoding for model '{self.model_name}' not found. Using 'cl100k_base' encoding.")
# self.encoding = tiktoken.get_encoding("cl100k_base")
def generate_text(self, system_prompt, user_prompt, temperature, top_p, presence_penalty):
"""生成文本内容并返回完整响应和token估计值"""
logging.info("Starting text generation process...")
# logging.debug(f"System Prompt (first 100): {system_prompt[:100]}...")
# logging.debug(f"User Prompt (first 100): {user_prompt[:100]}...") # Avoid logging potentially huge prompts
logging.info(f"Generation Params: temp={temperature}, top_p={top_p}, presence_penalty={presence_penalty}")
retry_count = 0
max_retry_wait = 10 # Max wait time between retries
full_response = ""
while retry_count <= self.max_retries:
call_start_time = None # Initialize start time
try:
# --- Added Logging ---
user_prompt_size = len(user_prompt)
logging.info(f"Attempt {retry_count + 1}/{self.max_retries + 1}: Preparing API request. User prompt size: {user_prompt_size} chars.")
call_start_time = time.time()
# --- End Added Logging ---
response = self.client.chat.completions.create(
model=self.model_name,
messages=[{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}],
temperature=temperature,
top_p=top_p,
presence_penalty=presence_penalty,
stream=False, # Ensure this is False for non-streaming method
max_tokens=8192,
timeout=self.timeout,
extra_body={
"repetition_penalty": 1.05,
},
)
# --- Added Logging ---
call_end_time = time.time()
logging.info(f"Attempt {retry_count + 1}/{self.max_retries + 1}: API request function returned successfully after {call_end_time - call_start_time:.2f} seconds.")
# --- End Added Logging ---
if response.choices and response.choices[0].message:
full_response = response.choices[0].message.content
logging.info(f"Received successful response. Content length: {len(full_response)} chars.")
break # Success, exit retry loop
else:
logging.warning("API response structure unexpected or empty content.")
full_response = "[Error: Empty or invalid response structure]"
# Decide if this specific case should retry or fail immediately
retry_count += 1 # Example: Treat as retryable
if retry_count <= self.max_retries:
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
logging.warning(f"Retrying due to unexpected response structure ({retry_count}/{self.max_retries}), waiting {wait_time:.2f}s...")
time.sleep(wait_time)
continue
except (APITimeoutError, APIConnectionError, RateLimitError, APIStatusError) as e:
# --- Added Logging ---
if call_start_time:
call_fail_time = time.time()
logging.warning(f"Attempt {retry_count + 1}/{self.max_retries + 1}: API call failed/timed out after {call_fail_time - call_start_time:.2f} seconds.")
else:
logging.warning(f"Attempt {retry_count + 1}/{self.max_retries + 1}: API call failed before or during initiation.")
# --- End Added Logging ---
logging.warning(f"API Error occurred: {e}")
should_retry = False
if isinstance(e, (APITimeoutError, APIConnectionError, RateLimitError)):
should_retry = True
elif isinstance(e, APIStatusError) and e.status_code >= 500:
should_retry = True
if should_retry:
retry_count += 1
if retry_count <= self.max_retries:
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
logging.warning(f"Retrying API call ({retry_count}/{self.max_retries}) after error, waiting {wait_time:.2f}s...")
time.sleep(wait_time)
else:
logging.error(f"Max retries ({self.max_retries}) reached for API errors. Aborting.")
return "请求失败,无法生成内容。", 0
else:
logging.error(f"Non-retriable API error: {e}. Aborting.")
return "请求失败,发生不可重试错误。", 0
except Exception as e:
logging.exception(f"Unexpected error during API call setup/execution:")
retry_count += 1
if retry_count <= self.max_retries:
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
logging.warning(f"Retrying API call ({retry_count}/{self.max_retries}) after unexpected error, waiting {wait_time:.2f}s...")
time.sleep(wait_time)
else:
logging.error(f"Max retries ({self.max_retries}) reached after unexpected errors. Aborting.")
return "请求失败,发生未知错误。", 0
logging.info("Text generation completed.")
estimated_tokens = len(full_response.split()) * 1.3
return full_response, estimated_tokens
def read_folder(self, file_folder):
"""读取指定文件夹下的所有文件内容"""
if not os.path.exists(file_folder):
logging.warning(f"Referenced folder does not exist: {file_folder}")
return ""
context = ""
try:
for file in os.listdir(file_folder):
file_path = os.path.join(file_folder, file)
if os.path.isfile(file_path):
try:
with open(file_path, "r", encoding="utf-8") as f:
context += f"文件名: {file}\n"
context += f.read()
context += "\n\n"
except Exception as read_err:
logging.error(f"Failed to read file {file_path}: {read_err}")
except Exception as list_err:
logging.error(f"Failed to list directory {file_folder}: {list_err}")
return context
def work(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty):
"""完整的工作流程:生成文本并返回结果"""
logging.info(f"Starting 'work' process. File folder: {file_folder}")
if file_folder:
logging.info(f"Reading context from folder: {file_folder}")
context = self.read_folder(file_folder)
if context:
user_prompt = f"{user_prompt.strip()}\n\n--- 参考资料 ---\n{context.strip()}"
else:
logging.warning(f"Folder {file_folder} provided but no content read.")
time_start = time.time()
result = self.generate_text_stream(system_prompt, user_prompt, temperature, top_p, presence_penalty)
# result, tokens = self.generate_text(system_prompt, user_prompt, temperature, top_p, presence_penalty)
time_end = time.time()
time_cost = time_end - time_start
tokens = len(result) * 1.3
logging.info(f"'work' completed in {time_cost:.2f}s. Estimated tokens: {tokens}")
return result, tokens, time_cost
def close(self):
try:
logging.info("Closing AI Agent (client resources will be garbage collected).")
self.client = None
except Exception as e:
logging.error(f"Error during AI Agent close: {e}")
# --- Streaming Methods ---
def generate_text_stream(self, system_prompt, user_prompt, temperature=0.7, top_p=0.9, presence_penalty=0.0):
"""
生成文本流,但返回的是完整的响应而非生成器。
此方法适用于简单的API调用不需要实时处理响应的场景。
Args:
system_prompt: 系统提示词
user_prompt: 用户提示词
temperature: 温度参数
top_p: 核采样参数
presence_penalty: 存在惩罚参数
Returns:
str: 生成的完整文本响应
Raises:
Exception: 如果API调用在所有重试后仍然失败
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
logging.info(f"Generating text stream using model: {self.model_name}")
retries = 0
backoff_time = INITIAL_BACKOFF
last_exception = None
global_timeout = self.timeout * 2 # 设置全局超时是API超时的两倍
partial_response = "" # 用于存储部分响应
while retries < self.max_retries:
try:
logging.debug(f"Attempt {retries + 1}/{self.max_retries} to generate text stream.")
client = OpenAI(
api_key=self.api,
base_url=self.base_url,
timeout=self.timeout,
max_retries=0 # 禁用内置重试,使用我们自己的重试逻辑
)
stream = client.chat.completions.create(
model=self.model_name,
messages=messages,
temperature=temperature,
top_p=top_p,
presence_penalty=presence_penalty,
stream=True,
)
response_text = ""
stream_start_time = time.time()
last_chunk_time = time.time()
received_any_chunk = False # 标记是否收到过任何块
try:
for chunk in stream:
current_time = time.time()
# 检查流块超时:只有在已经开始接收数据后才检查
if received_any_chunk and current_time - last_chunk_time > self.stream_chunk_timeout:
logging.warning(f"Stream chunk timeout: No new chunk received for {self.stream_chunk_timeout} seconds after previous chunk.")
if response_text:
response_text += "\n\n[注意: 流式传输中断,内容可能不完整]"
partial_response = response_text
return partial_response
raise Timeout(f"Stream stalled: No new chunk received for {self.stream_chunk_timeout} seconds after previous chunk.")
# 检查首次响应超时:如果很长时间没收到第一个块
if not received_any_chunk and current_time - stream_start_time > self.timeout:
logging.warning(f"Initial response timeout: No chunk received in {self.timeout} seconds since stream started.")
raise Timeout(f"No initial response received for {self.timeout} seconds.")
# 检查全局超时(保留作为安全措施)
if current_time - stream_start_time > global_timeout:
logging.warning(f"Global timeout reached after {global_timeout} seconds.")
if response_text:
response_text += "\n\n[注意: 由于全局超时,内容可能不完整]"
partial_response = response_text
return partial_response
else:
logging.error("Global timeout with no content received.")
raise Timeout(f"Global timeout after {global_timeout} seconds with no content.")
# 处理接收到的数据块
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
response_text += content
received_any_chunk = True # 标记已收到块
last_chunk_time = current_time # 更新最后块时间戳
logging.info("Stream completed successfully.")
return response_text # 成功完成
except Timeout as e:
logging.warning(f"Stream timeout: {e}. Retrying if possible ({retries + 1}/{self.max_retries}).")
if response_text:
partial_response = response_text
last_exception = e
except (APITimeoutError, APIConnectionError, RateLimitError) as e:
logging.warning(f"API error during streaming: {type(e).__name__} - {e}. Retrying if possible.")
if response_text:
partial_response = response_text
last_exception = e
except Exception as e:
logging.error(f"Unexpected error during streaming: {traceback.format_exc()}")
if response_text:
partial_response = response_text
response_text += f"\n\n[注意: 由于错误中断,内容可能不完整: {str(e)}]"
return response_text
last_exception = e
# 执行重试逻辑
retries += 1
if retries < self.max_retries:
logging.info(f"Retrying stream in {backoff_time} seconds...")
time.sleep(backoff_time + random.uniform(0, 1)) # 添加随机抖动
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
else:
logging.error(f"Stream generation failed after {self.max_retries} retries.")
# 如果已经获取了部分响应,返回并标记不完整
if partial_response:
logging.info(f"Returning partial response of length {len(partial_response)}")
return partial_response + "\n\n[注意: 达到最大重试次数,内容可能不完整]"
# 如果没有获取到任何内容,抛出异常
raise last_exception or Exception("Stream generation failed after max retries with no content.")
except (Timeout, APITimeoutError, APIConnectionError, RateLimitError) as e:
retries += 1
last_exception = e
logging.warning(f"Attempt {retries}/{self.max_retries} failed: {type(e).__name__} - {e}")
if retries < self.max_retries:
logging.info(f"Retrying in {backoff_time} seconds...")
time.sleep(backoff_time + random.uniform(0, 1))
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
else:
logging.error(f"API call failed after {self.max_retries} retries.")
# 如果已经获取了部分响应,返回并标记不完整
if partial_response:
return partial_response + f"\n\n[注意: API调用失败内容可能不完整: {str(e)}]"
return f"[生成内容失败: {str(last_exception)}]"
except Exception as e:
logging.error(f"Unexpected error setting up stream: {traceback.format_exc()}")
# 如果已经获取了部分响应,返回并标记不完整
if partial_response:
return partial_response + f"\n\n[注意: 意外错误,内容可能不完整: {str(e)}]"
return f"[生成内容失败: {str(e)}]"
# 作为安全措施,虽然通常不应该到达此处
logging.error("Exited stream generation loop unexpectedly.")
# 如果已经获取了部分响应,返回并标记不完整
if partial_response:
return partial_response + "\n\n[注意: 意外退出流处理,内容可能不完整]"
return "[生成内容失败: 超过最大重试次数]"
def generate_text_stream_with_callback(self, system_prompt, user_prompt, callback,
temperature=0.7, top_p=0.9, presence_penalty=0.0, accumulate=False):
"""
Generates text based on prompts using a streaming connection with a callback.
Args:
system_prompt: The system prompt for the AI.
user_prompt: The user prompt for the AI.
callback: Function to call with each chunk of text (callback(chunk, accumulated_response))
temperature: Sampling temperature, defaults to 0.7.
top_p: Nucleus sampling parameter, defaults to 0.9.
presence_penalty: Presence penalty parameter, defaults to 0.0.
accumulate: If True, pass the accumulated response to the callback function, defaults to False.
Returns:
str: The complete generated text.
Raises:
Exception: If the API call fails after all retries.
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
logging.info(f"Generating text stream with callback, model: {self.model_name}")
retries = 0
backoff_time = INITIAL_BACKOFF
last_exception = None
full_response = ""
global_timeout = self.timeout * 2 # 设置全局超时是API超时的两倍
while retries < self.max_retries:
try:
logging.debug(f"Attempt {retries + 1}/{self.max_retries} to generate text stream with callback.")
stream = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
temperature=temperature,
top_p=top_p,
presence_penalty=presence_penalty,
stream=True,
timeout=self.timeout
)
chunk_iterator = iter(stream)
last_chunk_time = time.time()
start_time = time.time() # 记录开始时间用于全局超时检查
received_any_chunk = False # 标记是否收到过任何块
while True:
try:
current_time = time.time()
# 检查流块超时:只有在已经开始接收数据后才检查
if received_any_chunk and current_time - last_chunk_time > self.stream_chunk_timeout:
logging.warning(f"Stream chunk timeout: No new chunk received for {self.stream_chunk_timeout} seconds after previous chunk.")
if full_response:
logging.info(f"Returning partial response of length {len(full_response)} due to chunk timeout.")
# 通知回调函数超时情况
timeout_msg = "\n\n[注意: 流式传输中断,内容可能不完整]"
callback(timeout_msg, full_response + timeout_msg if accumulate else None)
return full_response + timeout_msg
raise Timeout(f"Stream stalled: No new chunk received for {self.stream_chunk_timeout} seconds after previous chunk.")
# 检查首次响应超时:如果很长时间没收到第一个块
if not received_any_chunk and current_time - start_time > self.timeout:
logging.warning(f"Initial response timeout: No chunk received in {self.timeout} seconds since stream started.")
raise Timeout(f"No initial response received for {self.timeout} seconds.")
# 检查全局超时(保留作为安全措施)
if current_time - start_time > global_timeout:
logging.warning(f"Global timeout reached after {global_timeout} seconds.")
if full_response:
logging.info(f"Returning partial response of length {len(full_response)} due to global timeout.")
# 通知回调函数超时情况
timeout_msg = "\n\n[注意: 由于全局超时,内容可能不完整]"
callback(timeout_msg, full_response + timeout_msg if accumulate else None)
return full_response + timeout_msg
else:
logging.error("Global timeout with no content received.")
raise Timeout(f"Global timeout after {global_timeout} seconds with no content.")
# 获取下一个数据块
chunk = next(chunk_iterator)
# 处理接收到的数据块
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
full_response += content
callback(content, full_response if accumulate else None)
received_any_chunk = True # 标记已收到块
last_chunk_time = current_time # 更新最后块时间戳
elif chunk.choices and chunk.choices[0].finish_reason == 'stop':
logging.info("Stream with callback finished normally.")
return full_response
except StopIteration:
logging.info("Stream iterator with callback exhausted normally.")
return full_response
except Timeout as e:
logging.warning(f"Stream timeout: {e}. Retrying if possible ({retries + 1}/{self.max_retries}).")
last_exception = e
break
except (APITimeoutError, APIConnectionError, RateLimitError) as e:
logging.warning(f"API error during streaming with callback: {type(e).__name__} - {e}. Retrying if possible.")
last_exception = e
break
except Exception as e:
logging.error(f"Unexpected error during streaming with callback: {traceback.format_exc()}")
# 如果已经获取了部分响应,可以返回,否则重试
if full_response:
error_msg = f"\n\n[注意: 由于错误中断,内容可能不完整: {str(e)}]"
logging.info(f"Returning partial response of length {len(full_response)} due to error: {type(e).__name__}")
callback(error_msg, full_response + error_msg if accumulate else None)
return full_response + error_msg
last_exception = e
break
# If we broke from the inner loop due to an error that needs retry
retries += 1
if retries < self.max_retries:
logging.info(f"Retrying stream with callback in {backoff_time} seconds...")
time.sleep(backoff_time + random.uniform(0, 1))
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
else:
logging.error(f"Stream generation with callback failed after {self.max_retries} retries.")
# 如果已经获取了部分响应,可以返回,否则引发异常
if full_response:
error_msg = "\n\n[注意: 达到最大重试次数,内容可能不完整]"
logging.info(f"Returning partial response of length {len(full_response)} after max retries.")
callback(error_msg, full_response + error_msg if accumulate else None)
return full_response + error_msg
raise last_exception or Exception("Stream generation with callback failed after max retries.")
except (Timeout, APITimeoutError, APIConnectionError, RateLimitError) as e:
retries += 1
last_exception = e
logging.warning(f"Attempt {retries}/{self.max_retries} failed: {type(e).__name__} - {e}")
if retries < self.max_retries:
logging.info(f"Retrying in {backoff_time} seconds...")
time.sleep(backoff_time + random.uniform(0, 1))
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
else:
logging.error(f"API call with callback failed after {self.max_retries} retries.")
# 如果已经获取了部分响应,可以返回,否则返回错误消息
if full_response:
error_msg = f"\n\n[注意: API调用失败内容可能不完整: {str(e)}]"
callback(error_msg, full_response + error_msg if accumulate else None)
return full_response + error_msg
error_text = f"[生成内容失败: {str(last_exception)}]"
callback(error_text, error_text if accumulate else None)
return error_text
except Exception as e:
logging.error(f"Unexpected error setting up stream with callback: {traceback.format_exc()}")
# 如果已经获取了部分响应,可以返回,否则返回错误消息
if full_response:
error_msg = f"\n\n[注意: 意外错误,内容可能不完整: {str(e)}]"
callback(error_msg, full_response + error_msg if accumulate else None)
return full_response + error_msg
error_text = f"[生成内容失败: {str(e)}]"
callback(error_text, error_text if accumulate else None)
return error_text
# Should not be reached if logic is correct, but as a safeguard:
logging.error("Exited stream with callback generation loop unexpectedly.")
# 如果已经获取了部分响应,可以返回,否则返回错误消息
if full_response:
error_msg = "\n\n[注意: 意外退出流处理,内容可能不完整]"
callback(error_msg, full_response + error_msg if accumulate else None)
return full_response + error_msg
error_text = "[生成内容失败: 超过最大重试次数]"
callback(error_text, error_text if accumulate else None)
return error_text
async def async_generate_text_stream(self, system_prompt, user_prompt, temperature=0.7, top_p=0.9, presence_penalty=0.0):
"""
异步生成文本流,返回一个异步迭代器
Args:
system_prompt: 系统提示词
user_prompt: 用户提示词
temperature: 温度参数
top_p: 核采样参数
presence_penalty: 存在惩罚参数
Returns:
AsyncGenerator: 异步生成器,生成文本块
Raises:
Exception: 如果API调用在所有重试后仍然失败
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
logging.info(f"Async generating text stream using model: {self.model_name}")
retries = 0
backoff_time = INITIAL_BACKOFF
last_exception = None
full_response = ""
global_timeout = self.timeout * 2 # 设置全局超时是API超时的两倍
while retries < self.max_retries:
try:
logging.debug(f"Attempt {retries + 1}/{self.max_retries} to async generate text stream.")
aclient = OpenAI(
api_key=self.api,
base_url=self.base_url,
timeout=self.timeout,
max_retries=0 # 禁用内置重试,使用我们自己的重试逻辑
)
stream = aclient.chat.completions.create(
model=self.model_name,
messages=messages,
temperature=temperature,
top_p=top_p,
presence_penalty=presence_penalty,
stream=True,
)
stream_start_time = time.time()
last_chunk_time = time.time()
received_any_chunk = False # 标记是否收到过任何块
try:
async for chunk in stream:
current_time = time.time()
# 检查流块超时:只有在已经开始接收数据后才检查
if received_any_chunk and current_time - last_chunk_time > self.stream_chunk_timeout:
logging.warning(f"Async stream chunk timeout: No new chunk received for {self.stream_chunk_timeout} seconds after previous chunk.")
if full_response:
timeout_msg = "\n\n[注意: 流式传输中断,内容可能不完整]"
yield timeout_msg
return # 结束生成器
raise Timeout(f"Async stream stalled: No new chunk received for {self.stream_chunk_timeout} seconds after previous chunk.")
# 检查首次响应超时:如果很长时间没收到第一个块
if not received_any_chunk and current_time - stream_start_time > self.timeout:
logging.warning(f"Async initial response timeout: No chunk received in {self.timeout} seconds since stream started.")
raise Timeout(f"No initial response received for {self.timeout} seconds.")
# 检查全局超时(保留作为安全措施)
if current_time - stream_start_time > global_timeout:
logging.warning(f"Async global timeout reached after {global_timeout} seconds.")
if full_response:
timeout_msg = "\n\n[注意: 由于全局超时,内容可能不完整]"
yield timeout_msg
return # 结束生成器
else:
logging.error("Async global timeout with no content received.")
raise Timeout(f"Async global timeout after {global_timeout} seconds with no content.")
# 处理接收到的数据块
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
full_response += content
yield content
received_any_chunk = True # 标记已收到块
last_chunk_time = current_time # 更新最后块时间戳
logging.info("Async stream completed normally.")
return
except Timeout as e:
logging.warning(f"Async stream timeout: {e}. Retrying if possible ({retries + 1}/{self.max_retries}).")
last_exception = e
except (APITimeoutError, APIConnectionError, RateLimitError) as e:
logging.warning(f"Async API error during streaming: {type(e).__name__} - {e}. Retrying if possible.")
last_exception = e
except Exception as e:
logging.error(f"Async unexpected error during streaming: {traceback.format_exc()}")
# 如果已经获取了部分响应,可以返回错误消息后结束
if full_response:
error_msg = f"\n\n[注意: 由于错误中断,内容可能不完整: {str(e)}]"
logging.info(f"Yielding error message after partial response of length {len(full_response)}")
yield error_msg
return # 结束生成器
last_exception = e
# 执行重试逻辑
retries += 1
if retries < self.max_retries:
logging.info(f"Async retrying stream in {backoff_time} seconds...")
await asyncio.sleep(backoff_time + random.uniform(0, 1)) # 添加随机抖动
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
else:
logging.error(f"Async stream generation failed after {self.max_retries} retries.")
# 如果已经获取了部分响应,返回错误消息后结束
if full_response:
error_msg = "\n\n[注意: 达到最大重试次数,内容可能不完整]"
logging.info(f"Yielding error message after partial response of length {len(full_response)}")
yield error_msg
return # 结束生成器
# 如果没有获取到任何内容,抛出异常
raise last_exception or Exception("Async stream generation failed after max retries with no content.")
except (Timeout, APITimeoutError, APIConnectionError, RateLimitError) as e:
retries += 1
last_exception = e
logging.warning(f"Async attempt {retries}/{self.max_retries} failed: {type(e).__name__} - {e}")
if retries < self.max_retries:
logging.info(f"Async retrying in {backoff_time} seconds...")
await asyncio.sleep(backoff_time + random.uniform(0, 1))
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
else:
logging.error(f"Async API call failed after {self.max_retries} retries.")
# 如果已经获取了部分响应,返回错误消息后结束
if full_response:
error_msg = f"\n\n[注意: API调用失败内容可能不完整: {str(e)}]"
yield error_msg
return # 结束生成器
error_text = f"[生成内容失败: {str(last_exception)}]"
yield error_text
return # 结束生成器
except Exception as e:
logging.error(f"Async unexpected error setting up stream: {traceback.format_exc()}")
# 如果已经获取了部分响应,返回错误消息后结束
if full_response:
error_msg = f"\n\n[注意: 意外错误,内容可能不完整: {str(e)}]"
yield error_msg
return # 结束生成器
error_text = f"[生成内容失败: {str(e)}]"
yield error_text
return # 结束生成器
# 作为安全措施,虽然通常不应该到达此处
logging.error("Async exited stream generation loop unexpectedly.")
# 如果已经获取了部分响应,返回错误消息后结束
if full_response:
error_msg = "\n\n[注意: 意外退出流处理,内容可能不完整]"
yield error_msg
return # 结束生成器
error_text = "[生成内容失败: 超过最大重试次数]"
yield error_text
async def async_work_stream(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty):
"""异步完整工作流程:读取文件夹(如果提供),然后生成文本流"""
logging.info(f"Starting async 'work_stream' process. File folder: {file_folder}")
if file_folder:
logging.info(f"Reading context from folder: {file_folder}")
context = self.read_folder(file_folder)
if context:
user_prompt = f"{user_prompt.strip()}\n\n--- 参考资料 ---\n{context.strip()}"
else:
logging.warning(f"Folder {file_folder} provided but no content read.")
logging.info("Calling async_generate_text_stream...")
return self.async_generate_text_stream(system_prompt, user_prompt, temperature, top_p, presence_penalty)
def work_stream(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty):
"""完整的工作流程(流式):读取文件夹(如果提供),然后生成文本流"""
logging.info(f"Starting 'work_stream' process. File folder: {file_folder}")
if file_folder:
logging.info(f"Reading context from folder: {file_folder}")
context = self.read_folder(file_folder)
if context:
user_prompt = f"{user_prompt.strip()}\n\n--- 参考资料 ---\n{context.strip()}"
else:
logging.warning(f"Folder {file_folder} provided but no content read.")
full_response = self.generate_text_stream(system_prompt, user_prompt, temperature, top_p, presence_penalty)
return full_response
# --- End Streaming Methods ---