335 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from openai import OpenAI, APITimeoutError, APIConnectionError, RateLimitError, APIStatusError
import time
import random
import traceback
import logging
import tiktoken
# Configure basic logging for this module (or rely on root logger config)
# logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# logger = logging.getLogger(__name__) # Alternative: use named logger
# Constants
MAX_RETRIES = 3 # Maximum number of retries for API calls
INITIAL_BACKOFF = 1 # Initial backoff time in seconds
MAX_BACKOFF = 16 # Maximum backoff time in seconds
STREAM_CHUNK_TIMEOUT = 10 # Timeout in seconds for receiving a chunk in stream
class AI_Agent():
"""AI代理类负责与AI模型交互生成文本内容"""
def __init__(self, base_url, model_name, api, timeout=30, max_retries=3, stream_chunk_timeout=10):
"""
初始化 AI 代理。
Args:
base_url (str): 模型服务的基础 URL 或预设名称 ('deepseek', 'vllm')。
model_name (str): 要使用的模型名称。
api (str): API 密钥。
timeout (int, optional): 单次 API 请求的超时时间 (秒)。 Defaults to 30.
max_retries (int, optional): API 请求失败时的最大重试次数。 Defaults to 3.
stream_chunk_timeout (int, optional): 流式响应中两个数据块之间的最大等待时间 (秒)。 Defaults to 10.
"""
logging.info("Initializing AI Agent")
self.url_list = {
"ali": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"kimi": "https://api.moonshot.cn/v1",
"doubao": "https://ark.cn-beijing.volces.com/api/v3/",
"deepseek": "https://api.deepseek.com",
"vllm": "http://localhost:8000/v1",
}
self.base_url = self.url_list.get(base_url, base_url)
self.api = api
self.model_name = model_name
self.timeout = timeout
self.max_retries = max_retries
self.stream_chunk_timeout = stream_chunk_timeout
logging.info(f"AI Agent Settings: base_url={self.base_url}, model={self.model_name}, timeout={self.timeout}s, max_retries={self.max_retries}, stream_chunk_timeout={self.stream_chunk_timeout}s")
self.client = OpenAI(
api_key=self.api,
base_url=self.base_url,
timeout=self.timeout
)
# try:
# self.encoding = tiktoken.encoding_for_model(self.model_name)
# except KeyError:
# logging.warning(f"Encoding for model '{self.model_name}' not found. Using 'cl100k_base' encoding.")
# self.encoding = tiktoken.get_encoding("cl100k_base")
def generate_text(self, system_prompt, user_prompt, temperature, top_p, presence_penalty):
"""生成文本内容并返回完整响应和token估计值"""
logging.info("Starting text generation process...")
# logging.debug(f"System Prompt (first 100): {system_prompt[:100]}...")
# logging.debug(f"User Prompt (first 100): {user_prompt[:100]}...") # Avoid logging potentially huge prompts
logging.info(f"Generation Params: temp={temperature}, top_p={top_p}, presence_penalty={presence_penalty}")
retry_count = 0
max_retry_wait = 10 # Max wait time between retries
full_response = ""
while retry_count <= self.max_retries:
call_start_time = None # Initialize start time
try:
# --- Added Logging ---
user_prompt_size = len(user_prompt)
logging.info(f"Attempt {retry_count + 1}/{self.max_retries + 1}: Preparing API request. User prompt size: {user_prompt_size} chars.")
call_start_time = time.time()
# --- End Added Logging ---
response = self.client.chat.completions.create(
model=self.model_name,
messages=[{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}],
temperature=temperature,
top_p=top_p,
presence_penalty=presence_penalty,
stream=False, # Ensure this is False for non-streaming method
max_tokens=8192,
timeout=self.timeout,
extra_body={
"repetition_penalty": 1.05,
},
)
# --- Added Logging ---
call_end_time = time.time()
logging.info(f"Attempt {retry_count + 1}/{self.max_retries + 1}: API request function returned successfully after {call_end_time - call_start_time:.2f} seconds.")
# --- End Added Logging ---
if response.choices and response.choices[0].message:
full_response = response.choices[0].message.content
logging.info(f"Received successful response. Content length: {len(full_response)} chars.")
break # Success, exit retry loop
else:
logging.warning("API response structure unexpected or empty content.")
full_response = "[Error: Empty or invalid response structure]"
# Decide if this specific case should retry or fail immediately
retry_count += 1 # Example: Treat as retryable
if retry_count <= self.max_retries:
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
logging.warning(f"Retrying due to unexpected response structure ({retry_count}/{self.max_retries}), waiting {wait_time:.2f}s...")
time.sleep(wait_time)
continue
except (APITimeoutError, APIConnectionError, RateLimitError, APIStatusError) as e:
# --- Added Logging ---
if call_start_time:
call_fail_time = time.time()
logging.warning(f"Attempt {retry_count + 1}/{self.max_retries + 1}: API call failed/timed out after {call_fail_time - call_start_time:.2f} seconds.")
else:
logging.warning(f"Attempt {retry_count + 1}/{self.max_retries + 1}: API call failed before or during initiation.")
# --- End Added Logging ---
logging.warning(f"API Error occurred: {e}")
should_retry = False
if isinstance(e, (APITimeoutError, APIConnectionError, RateLimitError)):
should_retry = True
elif isinstance(e, APIStatusError) and e.status_code >= 500:
should_retry = True
if should_retry:
retry_count += 1
if retry_count <= self.max_retries:
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
logging.warning(f"Retrying API call ({retry_count}/{self.max_retries}) after error, waiting {wait_time:.2f}s...")
time.sleep(wait_time)
else:
logging.error(f"Max retries ({self.max_retries}) reached for API errors. Aborting.")
return "请求失败,无法生成内容。", 0
else:
logging.error(f"Non-retriable API error: {e}. Aborting.")
return "请求失败,发生不可重试错误。", 0
except Exception as e:
logging.exception(f"Unexpected error during API call setup/execution:")
retry_count += 1
if retry_count <= self.max_retries:
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
logging.warning(f"Retrying API call ({retry_count}/{self.max_retries}) after unexpected error, waiting {wait_time:.2f}s...")
time.sleep(wait_time)
else:
logging.error(f"Max retries ({self.max_retries}) reached after unexpected errors. Aborting.")
return "请求失败,发生未知错误。", 0
logging.info("Text generation completed.")
estimated_tokens = len(full_response.split()) * 1.3
return full_response, estimated_tokens
def read_folder(self, file_folder):
"""读取指定文件夹下的所有文件内容"""
if not os.path.exists(file_folder):
logging.warning(f"Referenced folder does not exist: {file_folder}")
return ""
context = ""
try:
for file in os.listdir(file_folder):
file_path = os.path.join(file_folder, file)
if os.path.isfile(file_path):
try:
with open(file_path, "r", encoding="utf-8") as f:
context += f"文件名: {file}\n"
context += f.read()
context += "\n\n"
except Exception as read_err:
logging.error(f"Failed to read file {file_path}: {read_err}")
except Exception as list_err:
logging.error(f"Failed to list directory {file_folder}: {list_err}")
return context
def work(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty):
"""完整的工作流程:生成文本并返回结果"""
logging.info(f"Starting 'work' process. File folder: {file_folder}")
if file_folder:
logging.info(f"Reading context from folder: {file_folder}")
context = self.read_folder(file_folder)
if context:
user_prompt = f"{user_prompt.strip()}\n\n--- 参考资料 ---\n{context.strip()}"
else:
logging.warning(f"Folder {file_folder} provided but no content read.")
time_start = time.time()
result, tokens = self.generate_text(system_prompt, user_prompt, temperature, top_p, presence_penalty)
time_end = time.time()
time_cost = time_end - time_start
logging.info(f"'work' completed in {time_cost:.2f}s. Estimated tokens: {tokens}")
return result, tokens, time_cost
def close(self):
try:
logging.info("Closing AI Agent (client resources will be garbage collected).")
self.client = None
except Exception as e:
logging.error(f"Error during AI Agent close: {e}")
# --- Streaming Methods ---
def generate_text_stream(self, system_prompt, user_prompt, temperature, top_p, presence_penalty):
"""
Generates text based on prompts using a streaming connection. Handles retries with exponential backoff.
Args:
system_prompt: The system prompt for the AI.
user_prompt: The user prompt for the AI.
temperature: Sampling temperature.
top_p: Nucleus sampling parameter.
Yields:
str: Chunks of the generated text.
Raises:
Exception: If the API call fails after all retries.
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
logging.info(f"Generating text stream with model: {self.model_name}")
retries = 0
backoff_time = INITIAL_BACKOFF
last_exception = None
while retries < self.max_retries:
try:
logging.debug(f"Attempt {retries + 1}/{self.max_retries} to generate text stream.")
stream = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
temperature=temperature,
top_p=top_p,
stream=True,
timeout=self.timeout # Overall request timeout
)
chunk_iterator = iter(stream)
last_chunk_time = time.time()
while True:
try:
# Check for timeout since last received chunk
if time.time() - last_chunk_time > self.stream_chunk_timeout:
raise Timeout(f"No chunk received for {self.stream_chunk_timeout} seconds.")
chunk = next(chunk_iterator)
last_chunk_time = time.time() # Reset timer on successful chunk receipt
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
# logging.debug(f"Received chunk: {content}") # Potentially very verbose
yield content
elif chunk.choices and chunk.choices[0].finish_reason == 'stop':
logging.info("Stream finished.")
return # Successful completion
# Handle other finish reasons if needed, e.g., 'length'
except StopIteration:
logging.info("Stream iterator exhausted.")
return # End of stream normally
except Timeout as e:
logging.warning(f"Stream chunk timeout: {e}. Retrying if possible ({retries + 1}/{self.max_retries}).")
last_exception = e
break # Break inner loop to retry the stream creation
except (APITimeoutError, APIConnectionError, RateLimitError) as e:
logging.warning(f"API error during streaming: {type(e).__name__} - {e}. Retrying if possible ({retries + 1}/{self.max_retries}).")
last_exception = e
break # Break inner loop to retry the stream creation
except Exception as e:
logging.error(f"Unexpected error during streaming: {traceback.format_exc()}")
# Decide if this unexpected error should be retried or raised immediately
last_exception = e
# Option 1: Raise immediately
# raise e
# Option 2: Treat as retryable (use with caution)
break # Break inner loop to retry
# If we broke from the inner loop due to an error that needs retry
retries += 1
if retries < self.max_retries:
logging.info(f"Retrying stream in {backoff_time} seconds...")
time.sleep(backoff_time + random.uniform(0, 1)) # Add jitter
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
else:
logging.error(f"Stream generation failed after {self.max_retries} retries.")
raise last_exception or Exception("Stream generation failed after max retries.")
except (Timeout, APITimeoutError, APIConnectionError, RateLimitError) as e:
retries += 1
last_exception = e
logging.warning(f"Attempt {retries}/{self.max_retries} failed: {type(e).__name__} - {e}")
if retries < self.max_retries:
logging.info(f"Retrying in {backoff_time} seconds...")
time.sleep(backoff_time + random.uniform(0, 1)) # Add jitter
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
else:
logging.error(f"API call failed after {self.max_retries} retries.")
raise last_exception
except Exception as e:
# Catch unexpected errors during stream setup
logging.error(f"Unexpected error setting up stream: {traceback.format_exc()}")
raise e # Re-raise unexpected errors immediately
# Should not be reached if logic is correct, but as a safeguard:
logging.error("Exited stream generation loop unexpectedly.")
raise last_exception or Exception("Stream generation failed.")
def work_stream(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty):
"""完整的工作流程(流式):读取文件夹(如果提供),然后生成文本流"""
logging.info(f"Starting 'work_stream' process. File folder: {file_folder}")
if file_folder:
logging.info(f"Reading context from folder: {file_folder}")
context = self.read_folder(file_folder)
if context:
user_prompt = f"{user_prompt.strip()}\n\n--- 参考资料 ---\n{context.strip()}"
else:
logging.warning(f"Folder {file_folder} provided but no content read.")
logging.info("Calling generate_text_stream...")
return self.generate_text_stream(system_prompt, user_prompt, temperature, top_p, presence_penalty)
# --- End Added Streaming Methods ---