582 lines
29 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from openai import OpenAI, APITimeoutError, APIConnectionError, RateLimitError, APIStatusError
import time
import random
import traceback
import logging
import tiktoken
import asyncio
from asyncio import TimeoutError as AsyncTimeoutError
# Configure basic logging for this module (or rely on root logger config)
# logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# logger = logging.getLogger(__name__) # Alternative: use named logger
# Constants
MAX_RETRIES = 3 # Maximum number of retries for API calls
INITIAL_BACKOFF = 1 # Initial backoff time in seconds
MAX_BACKOFF = 16 # Maximum backoff time in seconds
STREAM_CHUNK_TIMEOUT = 10 # Timeout in seconds for receiving a chunk in stream
# 自定义超时异常
class Timeout(Exception):
"""Raised when a stream chunk timeout occurs."""
def __init__(self, message="Stream chunk timeout occurred."):
self.message = message
super().__init__(self.message)
def __str__(self):
return self.message
class AI_Agent():
"""AI代理类负责与AI模型交互生成文本内容"""
def __init__(self, base_url, model_name, api, timeout=30, max_retries=3, stream_chunk_timeout=10):
"""
初始化 AI 代理。
Args:
base_url (str): 模型服务的基础 URL 或预设名称 ('deepseek', 'vllm')。
model_name (str): 要使用的模型名称。
api (str): API 密钥。
timeout (int, optional): 单次 API 请求的超时时间 (秒)。 Defaults to 30.
max_retries (int, optional): API 请求失败时的最大重试次数。 Defaults to 3.
stream_chunk_timeout (int, optional): 流式响应中两个数据块之间的最大等待时间 (秒)。 Defaults to 10.
"""
logging.info("Initializing AI Agent")
self.url_list = {
"ali": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"kimi": "https://api.moonshot.cn/v1",
"doubao": "https://ark.cn-beijing.volces.com/api/v3/",
"deepseek": "https://api.deepseek.com",
"vllm": "http://localhost:8000/v1",
}
self.base_url = self.url_list.get(base_url, base_url)
self.api = api
self.model_name = model_name
self.timeout = timeout
self.max_retries = max_retries
self.stream_chunk_timeout = stream_chunk_timeout
logging.info(f"AI Agent Settings: base_url={self.base_url}, model={self.model_name}, timeout={self.timeout}s, max_retries={self.max_retries}, stream_chunk_timeout={self.stream_chunk_timeout}s")
self.client = OpenAI(
api_key=self.api,
base_url=self.base_url,
timeout=self.timeout
)
# try:
# self.encoding = tiktoken.encoding_for_model(self.model_name)
# except KeyError:
# logging.warning(f"Encoding for model '{self.model_name}' not found. Using 'cl100k_base' encoding.")
# self.encoding = tiktoken.get_encoding("cl100k_base")
def generate_text(self, system_prompt, user_prompt, temperature, top_p, presence_penalty):
"""生成文本内容并返回完整响应和token估计值"""
logging.info("Starting text generation process...")
# logging.debug(f"System Prompt (first 100): {system_prompt[:100]}...")
# logging.debug(f"User Prompt (first 100): {user_prompt[:100]}...") # Avoid logging potentially huge prompts
logging.info(f"Generation Params: temp={temperature}, top_p={top_p}, presence_penalty={presence_penalty}")
retry_count = 0
max_retry_wait = 10 # Max wait time between retries
full_response = ""
while retry_count <= self.max_retries:
call_start_time = None # Initialize start time
try:
# --- Added Logging ---
user_prompt_size = len(user_prompt)
logging.info(f"Attempt {retry_count + 1}/{self.max_retries + 1}: Preparing API request. User prompt size: {user_prompt_size} chars.")
call_start_time = time.time()
# --- End Added Logging ---
response = self.client.chat.completions.create(
model=self.model_name,
messages=[{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}],
temperature=temperature,
top_p=top_p,
presence_penalty=presence_penalty,
stream=False, # Ensure this is False for non-streaming method
max_tokens=8192,
timeout=self.timeout,
extra_body={
"repetition_penalty": 1.05,
},
)
# --- Added Logging ---
call_end_time = time.time()
logging.info(f"Attempt {retry_count + 1}/{self.max_retries + 1}: API request function returned successfully after {call_end_time - call_start_time:.2f} seconds.")
# --- End Added Logging ---
if response.choices and response.choices[0].message:
full_response = response.choices[0].message.content
logging.info(f"Received successful response. Content length: {len(full_response)} chars.")
break # Success, exit retry loop
else:
logging.warning("API response structure unexpected or empty content.")
full_response = "[Error: Empty or invalid response structure]"
# Decide if this specific case should retry or fail immediately
retry_count += 1 # Example: Treat as retryable
if retry_count <= self.max_retries:
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
logging.warning(f"Retrying due to unexpected response structure ({retry_count}/{self.max_retries}), waiting {wait_time:.2f}s...")
time.sleep(wait_time)
continue
except (APITimeoutError, APIConnectionError, RateLimitError, APIStatusError) as e:
# --- Added Logging ---
if call_start_time:
call_fail_time = time.time()
logging.warning(f"Attempt {retry_count + 1}/{self.max_retries + 1}: API call failed/timed out after {call_fail_time - call_start_time:.2f} seconds.")
else:
logging.warning(f"Attempt {retry_count + 1}/{self.max_retries + 1}: API call failed before or during initiation.")
# --- End Added Logging ---
logging.warning(f"API Error occurred: {e}")
should_retry = False
if isinstance(e, (APITimeoutError, APIConnectionError, RateLimitError)):
should_retry = True
elif isinstance(e, APIStatusError) and e.status_code >= 500:
should_retry = True
if should_retry:
retry_count += 1
if retry_count <= self.max_retries:
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
logging.warning(f"Retrying API call ({retry_count}/{self.max_retries}) after error, waiting {wait_time:.2f}s...")
time.sleep(wait_time)
else:
logging.error(f"Max retries ({self.max_retries}) reached for API errors. Aborting.")
return "请求失败,无法生成内容。", 0
else:
logging.error(f"Non-retriable API error: {e}. Aborting.")
return "请求失败,发生不可重试错误。", 0
except Exception as e:
logging.exception(f"Unexpected error during API call setup/execution:")
retry_count += 1
if retry_count <= self.max_retries:
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
logging.warning(f"Retrying API call ({retry_count}/{self.max_retries}) after unexpected error, waiting {wait_time:.2f}s...")
time.sleep(wait_time)
else:
logging.error(f"Max retries ({self.max_retries}) reached after unexpected errors. Aborting.")
return "请求失败,发生未知错误。", 0
logging.info("Text generation completed.")
estimated_tokens = len(full_response.split()) * 1.3
return full_response, estimated_tokens
def read_folder(self, file_folder):
"""读取指定文件夹下的所有文件内容"""
if not os.path.exists(file_folder):
logging.warning(f"Referenced folder does not exist: {file_folder}")
return ""
context = ""
try:
for file in os.listdir(file_folder):
file_path = os.path.join(file_folder, file)
if os.path.isfile(file_path):
try:
with open(file_path, "r", encoding="utf-8") as f:
context += f"文件名: {file}\n"
context += f.read()
context += "\n\n"
except Exception as read_err:
logging.error(f"Failed to read file {file_path}: {read_err}")
except Exception as list_err:
logging.error(f"Failed to list directory {file_folder}: {list_err}")
return context
def work(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty):
"""完整的工作流程:生成文本并返回结果"""
logging.info(f"Starting 'work' process. File folder: {file_folder}")
if file_folder:
logging.info(f"Reading context from folder: {file_folder}")
context = self.read_folder(file_folder)
if context:
user_prompt = f"{user_prompt.strip()}\n\n--- 参考资料 ---\n{context.strip()}"
else:
logging.warning(f"Folder {file_folder} provided but no content read.")
time_start = time.time()
result = self.generate_text_stream(system_prompt, user_prompt, temperature, top_p, presence_penalty)
# result, tokens = self.generate_text(system_prompt, user_prompt, temperature, top_p, presence_penalty)
time_end = time.time()
time_cost = time_end - time_start
tokens = len(result) * 1.3
logging.info(f"'work' completed in {time_cost:.2f}s. Estimated tokens: {tokens}")
return result, tokens, time_cost
def close(self):
try:
logging.info("Closing AI Agent (client resources will be garbage collected).")
self.client = None
except Exception as e:
logging.error(f"Error during AI Agent close: {e}")
# --- Streaming Methods ---
def generate_text_stream(self, system_prompt, user_prompt, temperature, top_p, presence_penalty):
"""
Generates text based on prompts using a streaming connection. Handles retries with exponential backoff.
Args:
system_prompt: The system prompt for the AI.
user_prompt: The user prompt for the AI.
temperature: Sampling temperature.
top_p: Nucleus sampling parameter.
presence_penalty: Presence penalty parameter.
Returns:
str: The complete generated text.
Raises:
Exception: If the API call fails after all retries.
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
logging.info(f"Generating text stream with model: {self.model_name}")
retries = 0
backoff_time = INITIAL_BACKOFF
last_exception = None
full_response = ""
while retries < self.max_retries:
try:
logging.debug(f"Attempt {retries + 1}/{self.max_retries} to generate text stream.")
stream = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
temperature=temperature,
top_p=top_p,
presence_penalty=presence_penalty,
stream=True,
timeout=self.timeout # Overall request timeout
)
chunk_iterator = iter(stream)
last_chunk_time = time.time()
while True:
try:
# Check for timeout since last received chunk
if time.time() - last_chunk_time > self.stream_chunk_timeout:
raise Timeout(f"No chunk received for {self.stream_chunk_timeout} seconds.")
chunk = next(chunk_iterator)
last_chunk_time = time.time() # Reset timer on successful chunk receipt
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
full_response += content
# logging.debug(f"Received chunk: {content}") # Potentially very verbose
elif chunk.choices and chunk.choices[0].finish_reason == 'stop':
logging.info("Stream finished.")
return full_response # Return complete response
# Handle other finish reasons if needed, e.g., 'length'
except StopIteration:
logging.info("Stream iterator exhausted.")
return full_response # Return complete response
except Timeout as e:
logging.warning(f"Stream chunk timeout: {e}. Retrying if possible ({retries + 1}/{self.max_retries}).")
last_exception = e
break # Break inner loop to retry the stream creation
except (APITimeoutError, APIConnectionError, RateLimitError) as e:
logging.warning(f"API error during streaming: {type(e).__name__} - {e}. Retrying if possible ({retries + 1}/{self.max_retries}).")
last_exception = e
break # Break inner loop to retry the stream creation
except Exception as e:
logging.error(f"Unexpected error during streaming: {traceback.format_exc()}")
# Decide if this unexpected error should be retried or raised immediately
last_exception = e
# Option 1: Raise immediately
# raise e
# Option 2: Treat as retryable (use with caution)
break # Break inner loop to retry
# If we broke from the inner loop due to an error that needs retry
retries += 1
if retries < self.max_retries:
logging.info(f"Retrying stream in {backoff_time} seconds...")
time.sleep(backoff_time + random.uniform(0, 1)) # Add jitter
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
else:
logging.error(f"Stream generation failed after {self.max_retries} retries.")
raise last_exception or Exception("Stream generation failed after max retries.")
except (Timeout, APITimeoutError, APIConnectionError, RateLimitError) as e:
retries += 1
last_exception = e
logging.warning(f"Attempt {retries}/{self.max_retries} failed: {type(e).__name__} - {e}")
if retries < self.max_retries:
logging.info(f"Retrying in {backoff_time} seconds...")
time.sleep(backoff_time + random.uniform(0, 1)) # Add jitter
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
else:
logging.error(f"API call failed after {self.max_retries} retries.")
return f"[生成内容失败: {str(last_exception)}]" # Return error message as string
except Exception as e:
# Catch unexpected errors during stream setup
logging.error(f"Unexpected error setting up stream: {traceback.format_exc()}")
return f"[生成内容失败: {str(e)}]" # Return error message as string
# Should not be reached if logic is correct, but as a safeguard:
logging.error("Exited stream generation loop unexpectedly.")
return f"[生成内容失败: 超过最大重试次数]" # Return error message as string
def generate_text_stream_with_callback(self, system_prompt, user_prompt,
callback_fn, temperature=0.7, top_p=0.9,
presence_penalty=0.0):
"""生成文本流并通过回调函数处理每个块
Args:
system_prompt: 系统提示词
user_prompt: 用户提示词
callback_fn: 处理每个文本块的回调函数,接收(content, is_last, is_timeout, is_error, error)参数
temperature: 温度参数
top_p: 核采样参数
presence_penalty: 存在惩罚参数
Returns:
str: 完整的响应文本
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
logging.info(f"Generating text stream with callback using model: {self.model_name}")
retries = 0
backoff_time = INITIAL_BACKOFF
last_exception = None
full_response = ""
while retries < self.max_retries:
try:
logging.debug(f"Attempt {retries + 1}/{self.max_retries} to generate text stream with callback.")
stream = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
temperature=temperature,
top_p=top_p,
presence_penalty=presence_penalty,
stream=True,
timeout=self.timeout
)
chunk_iterator = iter(stream)
last_chunk_time = time.time()
try:
while True:
try:
# 检查上次接收块以来的超时
if time.time() - last_chunk_time > self.stream_chunk_timeout:
callback_fn("", is_last=True, is_timeout=True, is_error=False, error=None)
raise Timeout(f"No chunk received for {self.stream_chunk_timeout} seconds.")
chunk = next(chunk_iterator)
last_chunk_time = time.time() # 成功接收块后重置计时器
content = ""
is_last = False
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
full_response += content
if chunk.choices and chunk.choices[0].finish_reason == 'stop':
is_last = True
# 调用回调函数处理块
callback_fn(content, is_last=is_last, is_timeout=False, is_error=False, error=None)
if is_last:
logging.info("Stream with callback finished normally.")
return full_response # 成功完成,返回完整响应
except StopIteration:
# 正常结束流
callback_fn("", is_last=True, is_timeout=False, is_error=False, error=None)
logging.info("Stream iterator exhausted normally.")
return full_response
except Timeout as e:
logging.warning(f"Stream chunk timeout: {e}")
last_exception = e
# 超时信息已通过回调传递,此处不需要再次调用回调
except (APITimeoutError, APIConnectionError, RateLimitError) as e:
logging.warning(f"API error during streaming with callback: {type(e).__name__} - {e}")
callback_fn("", is_last=True, is_timeout=False, is_error=True, error=str(e))
last_exception = e
except Exception as e:
logging.error(f"Unexpected error during streaming with callback: {traceback.format_exc()}")
callback_fn("", is_last=True, is_timeout=False, is_error=True, error=str(e))
last_exception = e
# 执行重试逻辑
retries += 1
if retries < self.max_retries:
retry_msg = f"将在 {backoff_time:.2f} 秒后重试..."
logging.info(f"Retrying stream in {backoff_time} seconds...")
callback_fn(f"\n[{retry_msg}]\n", is_last=False, is_timeout=False, is_error=False, error=None)
time.sleep(backoff_time + random.uniform(0, 1)) # 添加随机抖动
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
else:
error_msg = f"API call failed after {self.max_retries} retries: {str(last_exception)}"
logging.error(error_msg)
callback_fn(f"\n[{error_msg}]\n", is_last=True, is_timeout=False, is_error=True, error=str(last_exception))
return full_response # 返回已收集的部分响应
except Exception as e:
logging.error(f"Error setting up stream with callback: {traceback.format_exc()}")
callback_fn("", is_last=True, is_timeout=False, is_error=True, error=str(e))
return f"[生成内容失败: {str(e)}]"
# 作为安全措施
error_msg = "超出最大重试次数"
logging.error(error_msg)
callback_fn(f"\n[{error_msg}]\n", is_last=True, is_timeout=False, is_error=True, error=error_msg)
return full_response
async def async_generate_text_stream(self, system_prompt, user_prompt, temperature=0.7, top_p=0.9, presence_penalty=0.0):
"""异步生成文本流
Args:
system_prompt: 系统提示词
user_prompt: 用户提示词
temperature: 温度参数
top_p: 核采样参数
presence_penalty: 存在惩罚参数
Yields:
str: 生成的文本块
Raises:
Exception: 如果API调用在所有重试后失败
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
logging.info(f"Asynchronously generating text stream with model: {self.model_name}")
retries = 0
backoff_time = INITIAL_BACKOFF
last_exception = None
while retries < self.max_retries:
try:
logging.debug(f"Async attempt {retries + 1}/{self.max_retries} to generate text stream.")
# 创建新的客户端用于异步操作
async_client = OpenAI(
api_key=self.api,
base_url=self.base_url,
timeout=self.timeout
)
stream = await async_client.chat.completions.create(
model=self.model_name,
messages=messages,
temperature=temperature,
top_p=top_p,
presence_penalty=presence_penalty,
stream=True,
timeout=self.timeout
)
last_chunk_time = time.time()
try:
async for chunk in stream:
# 检查上次接收块以来的超时
current_time = time.time()
if current_time - last_chunk_time > self.stream_chunk_timeout:
raise Timeout(f"No chunk received for {self.stream_chunk_timeout} seconds.")
last_chunk_time = current_time # 成功接收块后重置计时器
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
yield content
logging.info("Async stream finished normally.")
return # 成功完成
except AsyncTimeoutError as e:
logging.warning(f"Async stream timeout: {e}")
last_exception = e
except Timeout as e:
logging.warning(f"Async stream chunk timeout: {e}")
last_exception = e
except Exception as e:
logging.error(f"Error during async streaming: {traceback.format_exc()}")
last_exception = e
# 执行重试逻辑
retries += 1
if retries < self.max_retries:
logging.info(f"Retrying async stream in {backoff_time} seconds...")
await asyncio.sleep(backoff_time + random.uniform(0, 1)) # 使用异步睡眠
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
else:
logging.error(f"Async API call failed after {self.max_retries} retries.")
raise last_exception or Exception("Async stream generation failed after max retries.")
except (Timeout, AsyncTimeoutError, APITimeoutError, APIConnectionError, RateLimitError) as e:
retries += 1
last_exception = e
logging.warning(f"Async attempt {retries}/{self.max_retries} failed: {type(e).__name__} - {e}")
if retries < self.max_retries:
logging.info(f"Retrying async stream in {backoff_time} seconds...")
await asyncio.sleep(backoff_time + random.uniform(0, 1)) # 使用异步睡眠
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
else:
logging.error(f"Async API call failed after {self.max_retries} retries.")
raise last_exception
except Exception as e:
logging.error(f"Unexpected error setting up async stream: {traceback.format_exc()}")
raise e # 立即重新引发意外错误
# 作为安全措施
logging.error("Exited async stream generation loop unexpectedly.")
raise last_exception or Exception("Async stream generation failed.")
async def async_work_stream(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty):
"""异步完整工作流程:读取文件夹(如果提供),然后生成文本流"""
logging.info(f"Starting async 'work_stream' process. File folder: {file_folder}")
if file_folder:
logging.info(f"Reading context from folder: {file_folder}")
context = self.read_folder(file_folder)
if context:
user_prompt = f"{user_prompt.strip()}\n\n--- 参考资料 ---\n{context.strip()}"
else:
logging.warning(f"Folder {file_folder} provided but no content read.")
logging.info("Calling async_generate_text_stream...")
return self.async_generate_text_stream(system_prompt, user_prompt, temperature, top_p, presence_penalty)
def work_stream(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty):
"""完整的工作流程(流式):读取文件夹(如果提供),然后生成文本流"""
logging.info(f"Starting 'work_stream' process. File folder: {file_folder}")
if file_folder:
logging.info(f"Reading context from folder: {file_folder}")
context = self.read_folder(file_folder)
if context:
user_prompt = f"{user_prompt.strip()}\n\n--- 参考资料 ---\n{context.strip()}"
else:
logging.warning(f"Folder {file_folder} provided but no content read.")
full_response = self.generate_text_stream(system_prompt, user_prompt, temperature, top_p, presence_penalty)
return full_response
# --- End Streaming Methods ---