746 lines
41 KiB
Python
Raw Normal View History

import os
from openai import OpenAI, APITimeoutError, APIConnectionError, RateLimitError, APIStatusError
import time
2025-04-18 11:08:54 +08:00
import random
import traceback
2025-04-22 17:36:29 +08:00
import logging
2025-04-22 20:12:27 +08:00
import tiktoken
2025-04-23 16:18:02 +08:00
import asyncio
from asyncio import TimeoutError as AsyncTimeoutError
2025-04-22 17:36:29 +08:00
# Configure basic logging for this module (or rely on root logger config)
# logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# logger = logging.getLogger(__name__) # Alternative: use named logger
2025-04-22 20:12:27 +08:00
# Constants
MAX_RETRIES = 3 # Maximum number of retries for API calls
INITIAL_BACKOFF = 1 # Initial backoff time in seconds
MAX_BACKOFF = 16 # Maximum backoff time in seconds
STREAM_CHUNK_TIMEOUT = 10 # Timeout in seconds for receiving a chunk in stream
2025-04-23 16:18:02 +08:00
# 自定义超时异常
class Timeout(Exception):
"""Raised when a stream chunk timeout occurs."""
def __init__(self, message="Stream chunk timeout occurred."):
self.message = message
super().__init__(self.message)
def __str__(self):
return self.message
class AI_Agent():
"""AI代理类负责与AI模型交互生成文本内容"""
2025-04-22 20:12:27 +08:00
def __init__(self, base_url, model_name, api, timeout=30, max_retries=3, stream_chunk_timeout=10):
"""
初始化 AI 代理
Args:
base_url (str): 模型服务的基础 URL 或预设名称 ('deepseek', 'vllm')
model_name (str): 要使用的模型名称
api (str): API 密钥
timeout (int, optional): 单次 API 请求的超时时间 () Defaults to 30.
max_retries (int, optional): API 请求失败时的最大重试次数 Defaults to 3.
stream_chunk_timeout (int, optional): 流式响应中两个数据块之间的最大等待时间 () Defaults to 10.
"""
logging.info("Initializing AI Agent")
self.url_list = {
"ali": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"kimi": "https://api.moonshot.cn/v1",
"doubao": "https://ark.cn-beijing.volces.com/api/v3/",
"deepseek": "https://api.deepseek.com",
"vllm": "http://localhost:8000/v1",
}
2025-04-22 16:21:09 +08:00
self.base_url = self.url_list.get(base_url, base_url)
self.api = api
self.model_name = model_name
2025-04-22 16:21:09 +08:00
self.timeout = timeout
self.max_retries = max_retries
2025-04-22 20:12:27 +08:00
self.stream_chunk_timeout = stream_chunk_timeout
2025-04-22 16:21:09 +08:00
2025-04-22 20:12:27 +08:00
logging.info(f"AI Agent Settings: base_url={self.base_url}, model={self.model_name}, timeout={self.timeout}s, max_retries={self.max_retries}, stream_chunk_timeout={self.stream_chunk_timeout}s")
self.client = OpenAI(
api_key=self.api,
base_url=self.base_url,
2025-04-22 16:21:09 +08:00
timeout=self.timeout
)
2025-04-22 21:26:56 +08:00
# try:
# self.encoding = tiktoken.encoding_for_model(self.model_name)
# except KeyError:
# logging.warning(f"Encoding for model '{self.model_name}' not found. Using 'cl100k_base' encoding.")
# self.encoding = tiktoken.get_encoding("cl100k_base")
2025-04-22 20:12:27 +08:00
def generate_text(self, system_prompt, user_prompt, temperature, top_p, presence_penalty):
"""生成文本内容并返回完整响应和token估计值"""
2025-04-22 21:26:56 +08:00
logging.info("Starting text generation process...")
# logging.debug(f"System Prompt (first 100): {system_prompt[:100]}...")
# logging.debug(f"User Prompt (first 100): {user_prompt[:100]}...") # Avoid logging potentially huge prompts
logging.info(f"Generation Params: temp={temperature}, top_p={top_p}, presence_penalty={presence_penalty}")
2025-04-22 17:36:29 +08:00
2025-04-18 11:08:54 +08:00
retry_count = 0
2025-04-22 21:26:56 +08:00
max_retry_wait = 10 # Max wait time between retries
full_response = ""
2025-04-22 17:36:29 +08:00
2025-04-18 11:08:54 +08:00
while retry_count <= self.max_retries:
2025-04-22 21:26:56 +08:00
call_start_time = None # Initialize start time
2025-04-18 11:08:54 +08:00
try:
2025-04-22 21:26:56 +08:00
# --- Added Logging ---
user_prompt_size = len(user_prompt)
logging.info(f"Attempt {retry_count + 1}/{self.max_retries + 1}: Preparing API request. User prompt size: {user_prompt_size} chars.")
call_start_time = time.time()
# --- End Added Logging ---
2025-04-18 11:08:54 +08:00
response = self.client.chat.completions.create(
model=self.model_name,
2025-04-22 21:26:56 +08:00
messages=[{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}],
temperature=temperature,
2025-04-18 11:08:54 +08:00
top_p=top_p,
presence_penalty=presence_penalty,
2025-04-22 21:26:56 +08:00
stream=False, # Ensure this is False for non-streaming method
2025-04-18 11:08:54 +08:00
max_tokens=8192,
2025-04-22 17:36:29 +08:00
timeout=self.timeout,
2025-04-18 11:08:54 +08:00
extra_body={
"repetition_penalty": 1.05,
},
)
2025-04-22 21:26:56 +08:00
# --- Added Logging ---
call_end_time = time.time()
logging.info(f"Attempt {retry_count + 1}/{self.max_retries + 1}: API request function returned successfully after {call_end_time - call_start_time:.2f} seconds.")
# --- End Added Logging ---
2025-04-22 17:36:29 +08:00
2025-04-22 21:26:56 +08:00
if response.choices and response.choices[0].message:
full_response = response.choices[0].message.content
logging.info(f"Received successful response. Content length: {len(full_response)} chars.")
break # Success, exit retry loop
else:
logging.warning("API response structure unexpected or empty content.")
full_response = "[Error: Empty or invalid response structure]"
# Decide if this specific case should retry or fail immediately
retry_count += 1 # Example: Treat as retryable
2025-04-18 11:08:54 +08:00
if retry_count <= self.max_retries:
2025-04-22 21:26:56 +08:00
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
logging.warning(f"Retrying due to unexpected response structure ({retry_count}/{self.max_retries}), waiting {wait_time:.2f}s...")
time.sleep(wait_time)
2025-04-18 11:08:54 +08:00
continue
2025-04-22 21:26:56 +08:00
2025-04-22 17:36:29 +08:00
except (APITimeoutError, APIConnectionError, RateLimitError, APIStatusError) as e:
2025-04-22 21:26:56 +08:00
# --- Added Logging ---
if call_start_time:
call_fail_time = time.time()
logging.warning(f"Attempt {retry_count + 1}/{self.max_retries + 1}: API call failed/timed out after {call_fail_time - call_start_time:.2f} seconds.")
else:
logging.warning(f"Attempt {retry_count + 1}/{self.max_retries + 1}: API call failed before or during initiation.")
# --- End Added Logging ---
2025-04-22 17:36:29 +08:00
logging.warning(f"API Error occurred: {e}")
should_retry = False
if isinstance(e, (APITimeoutError, APIConnectionError, RateLimitError)):
should_retry = True
elif isinstance(e, APIStatusError) and e.status_code >= 500:
should_retry = True
2025-04-18 11:08:54 +08:00
2025-04-22 17:36:29 +08:00
if should_retry:
retry_count += 1
if retry_count <= self.max_retries:
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
logging.warning(f"Retrying API call ({retry_count}/{self.max_retries}) after error, waiting {wait_time:.2f}s...")
time.sleep(wait_time)
else:
logging.error(f"Max retries ({self.max_retries}) reached for API errors. Aborting.")
return "请求失败,无法生成内容。", 0
else:
logging.error(f"Non-retriable API error: {e}. Aborting.")
return "请求失败,发生不可重试错误。", 0
2025-04-18 11:08:54 +08:00
except Exception as e:
2025-04-22 17:36:29 +08:00
logging.exception(f"Unexpected error during API call setup/execution:")
2025-04-18 11:08:54 +08:00
retry_count += 1
if retry_count <= self.max_retries:
2025-04-22 17:36:29 +08:00
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
logging.warning(f"Retrying API call ({retry_count}/{self.max_retries}) after unexpected error, waiting {wait_time:.2f}s...")
2025-04-18 11:08:54 +08:00
time.sleep(wait_time)
else:
2025-04-22 17:36:29 +08:00
logging.error(f"Max retries ({self.max_retries}) reached after unexpected errors. Aborting.")
return "请求失败,发生未知错误。", 0
2025-04-22 17:36:29 +08:00
logging.info("Text generation completed.")
estimated_tokens = len(full_response.split()) * 1.3
return full_response, estimated_tokens
def read_folder(self, file_folder):
"""读取指定文件夹下的所有文件内容"""
if not os.path.exists(file_folder):
2025-04-22 17:36:29 +08:00
logging.warning(f"Referenced folder does not exist: {file_folder}")
return ""
context = ""
2025-04-22 17:36:29 +08:00
try:
for file in os.listdir(file_folder):
file_path = os.path.join(file_folder, file)
if os.path.isfile(file_path):
try:
with open(file_path, "r", encoding="utf-8") as f:
context += f"文件名: {file}\n"
context += f.read()
context += "\n\n"
except Exception as read_err:
logging.error(f"Failed to read file {file_path}: {read_err}")
except Exception as list_err:
logging.error(f"Failed to list directory {file_folder}: {list_err}")
return context
def work(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty):
"""完整的工作流程:生成文本并返回结果"""
2025-04-22 17:36:29 +08:00
logging.info(f"Starting 'work' process. File folder: {file_folder}")
if file_folder:
2025-04-22 17:36:29 +08:00
logging.info(f"Reading context from folder: {file_folder}")
context = self.read_folder(file_folder)
if context:
2025-04-22 17:36:29 +08:00
user_prompt = f"{user_prompt.strip()}\n\n--- 参考资料 ---\n{context.strip()}"
else:
logging.warning(f"Folder {file_folder} provided but no content read.")
time_start = time.time()
2025-04-23 16:01:18 +08:00
result = self.generate_text_stream(system_prompt, user_prompt, temperature, top_p, presence_penalty)
# result, tokens = self.generate_text(system_prompt, user_prompt, temperature, top_p, presence_penalty)
time_end = time.time()
time_cost = time_end - time_start
2025-04-23 16:18:02 +08:00
tokens = len(result) * 1.3
2025-04-22 17:36:29 +08:00
logging.info(f"'work' completed in {time_cost:.2f}s. Estimated tokens: {tokens}")
2025-04-22 17:36:29 +08:00
return result, tokens, time_cost
2025-04-18 15:52:31 +08:00
def close(self):
2025-04-22 17:36:29 +08:00
try:
logging.info("Closing AI Agent (client resources will be garbage collected).")
self.client = None
except Exception as e:
logging.error(f"Error during AI Agent close: {e}")
2025-04-22 17:36:29 +08:00
# --- Streaming Methods ---
2025-04-23 20:03:00 +08:00
def generate_text_stream(self, system_prompt, user_prompt, temperature=0.7, top_p=0.9, presence_penalty=0.0):
2025-04-22 20:12:27 +08:00
"""
2025-04-23 20:03:00 +08:00
生成文本流但返回的是完整的响应而非生成器
此方法适用于简单的API调用不需要实时处理响应的场景
2025-04-22 20:12:27 +08:00
Args:
2025-04-23 20:03:00 +08:00
system_prompt: 系统提示词
user_prompt: 用户提示词
temperature: 温度参数
top_p: 核采样参数
presence_penalty: 存在惩罚参数
2025-04-23 16:18:02 +08:00
Returns:
2025-04-23 20:03:00 +08:00
str: 生成的完整文本响应
2025-04-22 20:12:27 +08:00
Raises:
2025-04-23 20:03:00 +08:00
Exception: 如果API调用在所有重试后仍然失败
2025-04-22 20:12:27 +08:00
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
2025-04-23 20:03:00 +08:00
logging.info(f"Generating text stream using model: {self.model_name}")
2025-04-22 20:12:27 +08:00
retries = 0
backoff_time = INITIAL_BACKOFF
last_exception = None
2025-04-23 20:03:00 +08:00
global_timeout = self.timeout * 2 # 设置全局超时是API超时的两倍
partial_response = "" # 用于存储部分响应
2025-04-22 20:12:27 +08:00
while retries < self.max_retries:
try:
2025-04-22 20:12:27 +08:00
logging.debug(f"Attempt {retries + 1}/{self.max_retries} to generate text stream.")
2025-04-23 20:03:00 +08:00
client = OpenAI(
api_key=self.api,
base_url=self.base_url,
timeout=self.timeout,
max_retries=0 # 禁用内置重试,使用我们自己的重试逻辑
)
stream = client.chat.completions.create(
model=self.model_name,
2025-04-22 20:12:27 +08:00
messages=messages,
temperature=temperature,
top_p=top_p,
2025-04-23 16:18:02 +08:00
presence_penalty=presence_penalty,
stream=True,
)
2025-04-23 20:03:00 +08:00
response_text = ""
stream_start_time = time.time()
2025-04-22 20:12:27 +08:00
last_chunk_time = time.time()
received_any_chunk = False # 标记是否收到过任何块
2025-04-23 20:03:00 +08:00
try:
for chunk in stream:
current_time = time.time()
# 检查流块超时:只有在已经开始接收数据后才检查
if received_any_chunk and current_time - last_chunk_time > self.stream_chunk_timeout:
logging.warning(f"Stream chunk timeout: No new chunk received for {self.stream_chunk_timeout} seconds after previous chunk.")
if response_text:
response_text += "\n\n[注意: 流式传输中断,内容可能不完整]"
partial_response = response_text
return partial_response
raise Timeout(f"Stream stalled: No new chunk received for {self.stream_chunk_timeout} seconds after previous chunk.")
# 检查首次响应超时:如果很长时间没收到第一个块
if not received_any_chunk and current_time - stream_start_time > self.timeout:
logging.warning(f"Initial response timeout: No chunk received in {self.timeout} seconds since stream started.")
raise Timeout(f"No initial response received for {self.timeout} seconds.")
# 检查全局超时(保留作为安全措施)
2025-04-23 20:03:00 +08:00
if current_time - stream_start_time > global_timeout:
logging.warning(f"Global timeout reached after {global_timeout} seconds.")
if response_text:
response_text += "\n\n[注意: 由于全局超时,内容可能不完整]"
partial_response = response_text
return partial_response
else:
logging.error("Global timeout with no content received.")
raise Timeout(f"Global timeout after {global_timeout} seconds with no content.")
# 处理接收到的数据块
2025-04-22 20:12:27 +08:00
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
2025-04-23 20:03:00 +08:00
response_text += content
received_any_chunk = True # 标记已收到块
last_chunk_time = current_time # 更新最后块时间戳
2025-04-23 20:03:00 +08:00
logging.info("Stream completed successfully.")
return response_text # 成功完成
except Timeout as e:
logging.warning(f"Stream timeout: {e}. Retrying if possible ({retries + 1}/{self.max_retries}).")
2025-04-23 20:03:00 +08:00
if response_text:
partial_response = response_text
last_exception = e
except (APITimeoutError, APIConnectionError, RateLimitError) as e:
logging.warning(f"API error during streaming: {type(e).__name__} - {e}. Retrying if possible.")
if response_text:
partial_response = response_text
last_exception = e
except Exception as e:
logging.error(f"Unexpected error during streaming: {traceback.format_exc()}")
if response_text:
partial_response = response_text
response_text += f"\n\n[注意: 由于错误中断,内容可能不完整: {str(e)}]"
return response_text
last_exception = e
# 执行重试逻辑
2025-04-22 20:12:27 +08:00
retries += 1
if retries < self.max_retries:
logging.info(f"Retrying stream in {backoff_time} seconds...")
2025-04-23 20:03:00 +08:00
time.sleep(backoff_time + random.uniform(0, 1)) # 添加随机抖动
2025-04-22 20:12:27 +08:00
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
else:
logging.error(f"Stream generation failed after {self.max_retries} retries.")
2025-04-23 20:03:00 +08:00
# 如果已经获取了部分响应,返回并标记不完整
if partial_response:
logging.info(f"Returning partial response of length {len(partial_response)}")
return partial_response + "\n\n[注意: 达到最大重试次数,内容可能不完整]"
# 如果没有获取到任何内容,抛出异常
raise last_exception or Exception("Stream generation failed after max retries with no content.")
2025-04-22 20:12:27 +08:00
except (Timeout, APITimeoutError, APIConnectionError, RateLimitError) as e:
retries += 1
last_exception = e
logging.warning(f"Attempt {retries}/{self.max_retries} failed: {type(e).__name__} - {e}")
if retries < self.max_retries:
logging.info(f"Retrying in {backoff_time} seconds...")
2025-04-23 20:03:00 +08:00
time.sleep(backoff_time + random.uniform(0, 1))
2025-04-22 20:12:27 +08:00
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
else:
2025-04-22 20:12:27 +08:00
logging.error(f"API call failed after {self.max_retries} retries.")
2025-04-23 20:03:00 +08:00
# 如果已经获取了部分响应,返回并标记不完整
if partial_response:
return partial_response + f"\n\n[注意: API调用失败内容可能不完整: {str(e)}]"
return f"[生成内容失败: {str(last_exception)}]"
2025-04-22 20:12:27 +08:00
except Exception as e:
2025-04-23 20:03:00 +08:00
logging.error(f"Unexpected error setting up stream: {traceback.format_exc()}")
# 如果已经获取了部分响应,返回并标记不完整
if partial_response:
return partial_response + f"\n\n[注意: 意外错误,内容可能不完整: {str(e)}]"
return f"[生成内容失败: {str(e)}]"
# 作为安全措施,虽然通常不应该到达此处
2025-04-22 20:12:27 +08:00
logging.error("Exited stream generation loop unexpectedly.")
2025-04-23 20:03:00 +08:00
# 如果已经获取了部分响应,返回并标记不完整
if partial_response:
return partial_response + "\n\n[注意: 意外退出流处理,内容可能不完整]"
return "[生成内容失败: 超过最大重试次数]"
2025-04-23 16:18:02 +08:00
2025-04-23 20:53:27 +08:00
def generate_text_stream_with_callback(self, system_prompt, user_prompt, callback,
temperature=0.7, top_p=0.9, presence_penalty=0.0, accumulate=False):
2025-04-23 20:03:00 +08:00
"""
Generates text based on prompts using a streaming connection with a callback.
2025-04-23 16:18:02 +08:00
Args:
2025-04-23 20:03:00 +08:00
system_prompt: The system prompt for the AI.
user_prompt: The user prompt for the AI.
callback: Function to call with each chunk of text (callback(chunk, accumulated_response))
2025-04-23 20:53:27 +08:00
temperature: Sampling temperature, defaults to 0.7.
top_p: Nucleus sampling parameter, defaults to 0.9.
presence_penalty: Presence penalty parameter, defaults to 0.0.
accumulate: If True, pass the accumulated response to the callback function, defaults to False.
2025-04-23 16:18:02 +08:00
Returns:
2025-04-23 20:03:00 +08:00
str: The complete generated text.
Raises:
Exception: If the API call fails after all retries.
2025-04-23 16:18:02 +08:00
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
2025-04-23 20:03:00 +08:00
logging.info(f"Generating text stream with callback, model: {self.model_name}")
2025-04-23 16:18:02 +08:00
retries = 0
backoff_time = INITIAL_BACKOFF
last_exception = None
full_response = ""
2025-04-23 20:03:00 +08:00
global_timeout = self.timeout * 2 # 设置全局超时是API超时的两倍
2025-04-23 16:18:02 +08:00
while retries < self.max_retries:
try:
logging.debug(f"Attempt {retries + 1}/{self.max_retries} to generate text stream with callback.")
stream = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
temperature=temperature,
top_p=top_p,
presence_penalty=presence_penalty,
stream=True,
timeout=self.timeout
)
chunk_iterator = iter(stream)
last_chunk_time = time.time()
2025-04-23 20:03:00 +08:00
start_time = time.time() # 记录开始时间用于全局超时检查
received_any_chunk = False # 标记是否收到过任何块
2025-04-23 16:18:02 +08:00
2025-04-23 20:03:00 +08:00
while True:
try:
current_time = time.time()
# 检查流块超时:只有在已经开始接收数据后才检查
if received_any_chunk and current_time - last_chunk_time > self.stream_chunk_timeout:
logging.warning(f"Stream chunk timeout: No new chunk received for {self.stream_chunk_timeout} seconds after previous chunk.")
if full_response:
logging.info(f"Returning partial response of length {len(full_response)} due to chunk timeout.")
# 通知回调函数超时情况
timeout_msg = "\n\n[注意: 流式传输中断,内容可能不完整]"
callback(timeout_msg, full_response + timeout_msg if accumulate else None)
return full_response + timeout_msg
raise Timeout(f"Stream stalled: No new chunk received for {self.stream_chunk_timeout} seconds after previous chunk.")
# 检查首次响应超时:如果很长时间没收到第一个块
if not received_any_chunk and current_time - start_time > self.timeout:
logging.warning(f"Initial response timeout: No chunk received in {self.timeout} seconds since stream started.")
raise Timeout(f"No initial response received for {self.timeout} seconds.")
# 检查全局超时(保留作为安全措施)
2025-04-23 20:03:00 +08:00
if current_time - start_time > global_timeout:
logging.warning(f"Global timeout reached after {global_timeout} seconds.")
if full_response:
logging.info(f"Returning partial response of length {len(full_response)} due to global timeout.")
# 通知回调函数超时情况
timeout_msg = "\n\n[注意: 由于全局超时,内容可能不完整]"
callback(timeout_msg, full_response + timeout_msg if accumulate else None)
return full_response + timeout_msg
else:
logging.error("Global timeout with no content received.")
raise Timeout(f"Global timeout after {global_timeout} seconds with no content.")
2025-04-23 16:18:02 +08:00
# 获取下一个数据块
2025-04-23 20:03:00 +08:00
chunk = next(chunk_iterator)
# 处理接收到的数据块
2025-04-23 20:03:00 +08:00
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
full_response += content
callback(content, full_response if accumulate else None)
received_any_chunk = True # 标记已收到块
last_chunk_time = current_time # 更新最后块时间戳
2025-04-23 20:03:00 +08:00
elif chunk.choices and chunk.choices[0].finish_reason == 'stop':
logging.info("Stream with callback finished normally.")
2025-04-23 16:18:02 +08:00
return full_response
2025-04-23 20:03:00 +08:00
except StopIteration:
logging.info("Stream iterator with callback exhausted normally.")
return full_response
except Timeout as e:
logging.warning(f"Stream timeout: {e}. Retrying if possible ({retries + 1}/{self.max_retries}).")
2025-04-23 20:03:00 +08:00
last_exception = e
break
except (APITimeoutError, APIConnectionError, RateLimitError) as e:
logging.warning(f"API error during streaming with callback: {type(e).__name__} - {e}. Retrying if possible.")
last_exception = e
break
except Exception as e:
logging.error(f"Unexpected error during streaming with callback: {traceback.format_exc()}")
# 如果已经获取了部分响应,可以返回,否则重试
if full_response:
error_msg = f"\n\n[注意: 由于错误中断,内容可能不完整: {str(e)}]"
logging.info(f"Returning partial response of length {len(full_response)} due to error: {type(e).__name__}")
callback(error_msg, full_response + error_msg if accumulate else None)
return full_response + error_msg
last_exception = e
break
2025-04-23 16:18:02 +08:00
2025-04-23 20:03:00 +08:00
# If we broke from the inner loop due to an error that needs retry
2025-04-23 16:18:02 +08:00
retries += 1
if retries < self.max_retries:
2025-04-23 20:03:00 +08:00
logging.info(f"Retrying stream with callback in {backoff_time} seconds...")
time.sleep(backoff_time + random.uniform(0, 1))
2025-04-23 16:18:02 +08:00
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
else:
2025-04-23 20:03:00 +08:00
logging.error(f"Stream generation with callback failed after {self.max_retries} retries.")
# 如果已经获取了部分响应,可以返回,否则引发异常
if full_response:
error_msg = "\n\n[注意: 达到最大重试次数,内容可能不完整]"
logging.info(f"Returning partial response of length {len(full_response)} after max retries.")
callback(error_msg, full_response + error_msg if accumulate else None)
return full_response + error_msg
raise last_exception or Exception("Stream generation with callback failed after max retries.")
2025-04-23 16:18:02 +08:00
2025-04-23 20:03:00 +08:00
except (Timeout, APITimeoutError, APIConnectionError, RateLimitError) as e:
retries += 1
last_exception = e
logging.warning(f"Attempt {retries}/{self.max_retries} failed: {type(e).__name__} - {e}")
if retries < self.max_retries:
logging.info(f"Retrying in {backoff_time} seconds...")
time.sleep(backoff_time + random.uniform(0, 1))
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
else:
logging.error(f"API call with callback failed after {self.max_retries} retries.")
# 如果已经获取了部分响应,可以返回,否则返回错误消息
if full_response:
error_msg = f"\n\n[注意: API调用失败内容可能不完整: {str(e)}]"
callback(error_msg, full_response + error_msg if accumulate else None)
return full_response + error_msg
error_text = f"[生成内容失败: {str(last_exception)}]"
callback(error_text, error_text if accumulate else None)
return error_text
2025-04-23 16:18:02 +08:00
except Exception as e:
2025-04-23 20:03:00 +08:00
logging.error(f"Unexpected error setting up stream with callback: {traceback.format_exc()}")
# 如果已经获取了部分响应,可以返回,否则返回错误消息
if full_response:
error_msg = f"\n\n[注意: 意外错误,内容可能不完整: {str(e)}]"
callback(error_msg, full_response + error_msg if accumulate else None)
return full_response + error_msg
error_text = f"[生成内容失败: {str(e)}]"
callback(error_text, error_text if accumulate else None)
return error_text
# Should not be reached if logic is correct, but as a safeguard:
logging.error("Exited stream with callback generation loop unexpectedly.")
# 如果已经获取了部分响应,可以返回,否则返回错误消息
if full_response:
error_msg = "\n\n[注意: 意外退出流处理,内容可能不完整]"
callback(error_msg, full_response + error_msg if accumulate else None)
return full_response + error_msg
error_text = "[生成内容失败: 超过最大重试次数]"
callback(error_text, error_text if accumulate else None)
return error_text
2025-04-23 16:18:02 +08:00
async def async_generate_text_stream(self, system_prompt, user_prompt, temperature=0.7, top_p=0.9, presence_penalty=0.0):
2025-04-23 20:03:00 +08:00
"""
异步生成文本流返回一个异步迭代器
2025-04-23 16:18:02 +08:00
Args:
system_prompt: 系统提示词
user_prompt: 用户提示词
temperature: 温度参数
top_p: 核采样参数
presence_penalty: 存在惩罚参数
2025-04-23 20:03:00 +08:00
Returns:
AsyncGenerator: 异步生成器生成文本块
2025-04-23 16:18:02 +08:00
Raises:
2025-04-23 20:03:00 +08:00
Exception: 如果API调用在所有重试后仍然失败
2025-04-23 16:18:02 +08:00
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
2025-04-23 20:03:00 +08:00
logging.info(f"Async generating text stream using model: {self.model_name}")
2025-04-23 16:18:02 +08:00
retries = 0
backoff_time = INITIAL_BACKOFF
last_exception = None
2025-04-23 20:03:00 +08:00
full_response = ""
global_timeout = self.timeout * 2 # 设置全局超时是API超时的两倍
2025-04-23 16:18:02 +08:00
while retries < self.max_retries:
try:
2025-04-23 20:03:00 +08:00
logging.debug(f"Attempt {retries + 1}/{self.max_retries} to async generate text stream.")
aclient = OpenAI(
2025-04-23 16:18:02 +08:00
api_key=self.api,
base_url=self.base_url,
2025-04-23 20:03:00 +08:00
timeout=self.timeout,
max_retries=0 # 禁用内置重试,使用我们自己的重试逻辑
2025-04-23 16:18:02 +08:00
)
2025-04-23 20:03:00 +08:00
stream = aclient.chat.completions.create(
2025-04-23 16:18:02 +08:00
model=self.model_name,
messages=messages,
temperature=temperature,
top_p=top_p,
presence_penalty=presence_penalty,
stream=True,
)
2025-04-23 20:03:00 +08:00
stream_start_time = time.time()
2025-04-23 16:18:02 +08:00
last_chunk_time = time.time()
received_any_chunk = False # 标记是否收到过任何块
2025-04-23 20:03:00 +08:00
2025-04-23 16:18:02 +08:00
try:
async for chunk in stream:
current_time = time.time()
# 检查流块超时:只有在已经开始接收数据后才检查
if received_any_chunk and current_time - last_chunk_time > self.stream_chunk_timeout:
logging.warning(f"Async stream chunk timeout: No new chunk received for {self.stream_chunk_timeout} seconds after previous chunk.")
if full_response:
timeout_msg = "\n\n[注意: 流式传输中断,内容可能不完整]"
yield timeout_msg
return # 结束生成器
raise Timeout(f"Async stream stalled: No new chunk received for {self.stream_chunk_timeout} seconds after previous chunk.")
# 检查首次响应超时:如果很长时间没收到第一个块
if not received_any_chunk and current_time - stream_start_time > self.timeout:
logging.warning(f"Async initial response timeout: No chunk received in {self.timeout} seconds since stream started.")
raise Timeout(f"No initial response received for {self.timeout} seconds.")
# 检查全局超时(保留作为安全措施)
2025-04-23 20:03:00 +08:00
if current_time - stream_start_time > global_timeout:
logging.warning(f"Async global timeout reached after {global_timeout} seconds.")
if full_response:
timeout_msg = "\n\n[注意: 由于全局超时,内容可能不完整]"
yield timeout_msg
return # 结束生成器
else:
logging.error("Async global timeout with no content received.")
raise Timeout(f"Async global timeout after {global_timeout} seconds with no content.")
2025-04-23 16:18:02 +08:00
# 处理接收到的数据块
2025-04-23 16:18:02 +08:00
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
2025-04-23 20:03:00 +08:00
full_response += content
2025-04-23 16:18:02 +08:00
yield content
received_any_chunk = True # 标记已收到块
last_chunk_time = current_time # 更新最后块时间戳
2025-04-23 16:18:02 +08:00
2025-04-23 20:03:00 +08:00
logging.info("Async stream completed normally.")
return
2025-04-23 16:18:02 +08:00
except Timeout as e:
logging.warning(f"Async stream timeout: {e}. Retrying if possible ({retries + 1}/{self.max_retries}).")
2025-04-23 20:03:00 +08:00
last_exception = e
except (APITimeoutError, APIConnectionError, RateLimitError) as e:
logging.warning(f"Async API error during streaming: {type(e).__name__} - {e}. Retrying if possible.")
2025-04-23 16:18:02 +08:00
last_exception = e
except Exception as e:
2025-04-23 20:03:00 +08:00
logging.error(f"Async unexpected error during streaming: {traceback.format_exc()}")
# 如果已经获取了部分响应,可以返回错误消息后结束
if full_response:
error_msg = f"\n\n[注意: 由于错误中断,内容可能不完整: {str(e)}]"
logging.info(f"Yielding error message after partial response of length {len(full_response)}")
yield error_msg
return # 结束生成器
2025-04-23 16:18:02 +08:00
last_exception = e
# 执行重试逻辑
retries += 1
if retries < self.max_retries:
2025-04-23 20:03:00 +08:00
logging.info(f"Async retrying stream in {backoff_time} seconds...")
await asyncio.sleep(backoff_time + random.uniform(0, 1)) # 添加随机抖动
2025-04-23 16:18:02 +08:00
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
else:
2025-04-23 20:03:00 +08:00
logging.error(f"Async stream generation failed after {self.max_retries} retries.")
# 如果已经获取了部分响应,返回错误消息后结束
if full_response:
error_msg = "\n\n[注意: 达到最大重试次数,内容可能不完整]"
logging.info(f"Yielding error message after partial response of length {len(full_response)}")
yield error_msg
return # 结束生成器
# 如果没有获取到任何内容,抛出异常
raise last_exception or Exception("Async stream generation failed after max retries with no content.")
except (Timeout, APITimeoutError, APIConnectionError, RateLimitError) as e:
2025-04-23 16:18:02 +08:00
retries += 1
last_exception = e
logging.warning(f"Async attempt {retries}/{self.max_retries} failed: {type(e).__name__} - {e}")
if retries < self.max_retries:
2025-04-23 20:03:00 +08:00
logging.info(f"Async retrying in {backoff_time} seconds...")
await asyncio.sleep(backoff_time + random.uniform(0, 1))
2025-04-23 16:18:02 +08:00
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
else:
logging.error(f"Async API call failed after {self.max_retries} retries.")
2025-04-23 20:03:00 +08:00
# 如果已经获取了部分响应,返回错误消息后结束
if full_response:
error_msg = f"\n\n[注意: API调用失败内容可能不完整: {str(e)}]"
yield error_msg
return # 结束生成器
error_text = f"[生成内容失败: {str(last_exception)}]"
yield error_text
return # 结束生成器
2025-04-23 16:18:02 +08:00
except Exception as e:
2025-04-23 20:03:00 +08:00
logging.error(f"Async unexpected error setting up stream: {traceback.format_exc()}")
# 如果已经获取了部分响应,返回错误消息后结束
if full_response:
error_msg = f"\n\n[注意: 意外错误,内容可能不完整: {str(e)}]"
yield error_msg
return # 结束生成器
error_text = f"[生成内容失败: {str(e)}]"
yield error_text
return # 结束生成器
# 作为安全措施,虽然通常不应该到达此处
logging.error("Async exited stream generation loop unexpectedly.")
# 如果已经获取了部分响应,返回错误消息后结束
if full_response:
error_msg = "\n\n[注意: 意外退出流处理,内容可能不完整]"
yield error_msg
return # 结束生成器
error_text = "[生成内容失败: 超过最大重试次数]"
yield error_text
2025-04-23 16:18:02 +08:00
async def async_work_stream(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty):
"""异步完整工作流程:读取文件夹(如果提供),然后生成文本流"""
logging.info(f"Starting async 'work_stream' process. File folder: {file_folder}")
if file_folder:
logging.info(f"Reading context from folder: {file_folder}")
context = self.read_folder(file_folder)
if context:
user_prompt = f"{user_prompt.strip()}\n\n--- 参考资料 ---\n{context.strip()}"
else:
logging.warning(f"Folder {file_folder} provided but no content read.")
2025-04-23 16:18:02 +08:00
logging.info("Calling async_generate_text_stream...")
return self.async_generate_text_stream(system_prompt, user_prompt, temperature, top_p, presence_penalty)
def work_stream(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty):
2025-04-22 20:12:27 +08:00
"""完整的工作流程(流式):读取文件夹(如果提供),然后生成文本流"""
2025-04-22 17:36:29 +08:00
logging.info(f"Starting 'work_stream' process. File folder: {file_folder}")
if file_folder:
2025-04-22 17:36:29 +08:00
logging.info(f"Reading context from folder: {file_folder}")
context = self.read_folder(file_folder)
if context:
user_prompt = f"{user_prompt.strip()}\n\n--- 参考资料 ---\n{context.strip()}"
else:
2025-04-22 17:36:29 +08:00
logging.warning(f"Folder {file_folder} provided but no content read.")
2025-04-23 16:18:02 +08:00
full_response = self.generate_text_stream(system_prompt, user_prompt, temperature, top_p, presence_penalty)
return full_response
# --- End Streaming Methods ---