282 lines
14 KiB
Python
Raw Normal View History

import os
from openai import OpenAI, APITimeoutError, APIConnectionError, RateLimitError, APIStatusError
import time
2025-04-18 11:08:54 +08:00
import random
import traceback
2025-04-22 17:36:29 +08:00
import logging
2025-04-22 17:36:29 +08:00
# Configure basic logging for this module (or rely on root logger config)
# logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# logger = logging.getLogger(__name__) # Alternative: use named logger
class AI_Agent():
"""AI代理类负责与AI模型交互生成文本内容"""
2025-04-22 16:21:09 +08:00
def __init__(self, base_url, model_name, api, timeout=30, max_retries=3):
self.url_list = {
"ali": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"kimi": "https://api.moonshot.cn/v1",
"doubao": "https://ark.cn-beijing.volces.com/api/v3/",
"deepseek": "https://api.deepseek.com",
"vllm": "http://localhost:8000/v1",
}
2025-04-22 16:21:09 +08:00
self.base_url = self.url_list.get(base_url, base_url)
self.api = api
self.model_name = model_name
2025-04-22 16:21:09 +08:00
self.timeout = timeout
self.max_retries = max_retries
2025-04-22 17:36:29 +08:00
print(f"Initializing AI Agent with base_url={self.base_url}, model={self.model_name}, timeout={self.timeout}s, max_retries={self.max_retries}")
self.client = OpenAI(
api_key=self.api,
base_url=self.base_url,
2025-04-22 16:21:09 +08:00
timeout=self.timeout
)
def generate_text(self, system_prompt, user_prompt, temperature, top_p, presence_penalty):
"""生成文本内容并返回完整响应和token估计值"""
2025-04-22 17:36:29 +08:00
logging.info(f"Generating text with model: {self.model_name}, temp={temperature}, top_p={top_p}, presence_penalty={presence_penalty}")
logging.debug(f"System Prompt (first 100 chars): {system_prompt[:100]}...")
logging.debug(f"User Prompt (first 100 chars): {user_prompt[:100]}...")
2025-04-18 15:52:31 +08:00
time.sleep(random.random())
2025-04-18 11:08:54 +08:00
retry_count = 0
2025-04-22 17:36:29 +08:00
max_retry_wait = 10
2025-04-18 11:08:54 +08:00
while retry_count <= self.max_retries:
try:
2025-04-22 17:36:29 +08:00
logging.info(f"Attempting API call (try {retry_count + 1}/{self.max_retries + 1})")
2025-04-18 11:08:54 +08:00
response = self.client.chat.completions.create(
model=self.model_name,
messages=[{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}],
temperature=temperature,
top_p=top_p,
presence_penalty=presence_penalty,
2025-04-18 11:08:54 +08:00
stream=True,
max_tokens=8192,
2025-04-22 17:36:29 +08:00
timeout=self.timeout,
2025-04-18 11:08:54 +08:00
extra_body={
"repetition_penalty": 1.05,
},
)
full_response = ""
2025-04-22 17:36:29 +08:00
stream_timed_out = False
2025-04-18 11:08:54 +08:00
try:
for chunk in response:
if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None:
content = chunk.choices[0].delta.content
full_response += content
if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].finish_reason == "stop":
break
2025-04-22 17:36:29 +08:00
# Successfully finished stream
2025-04-18 11:08:54 +08:00
break
2025-04-22 17:36:29 +08:00
except Exception as stream_err:
logging.warning(f"Exception during stream processing: {stream_err}")
stream_timed_out = True
if stream_timed_out:
if len(full_response) > 100:
logging.warning(f"Stream interrupted, but received {len(full_response)} characters. Using partial content.")
break
2025-04-18 11:08:54 +08:00
retry_count += 1
if retry_count <= self.max_retries:
2025-04-22 17:36:29 +08:00
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
logging.warning(f"Stream error/timeout. Waiting {wait_time:.2f}s before retry ({retry_count}/{self.max_retries})...")
2025-04-18 11:08:54 +08:00
time.sleep(wait_time)
continue
2025-04-22 17:36:29 +08:00
except (APITimeoutError, APIConnectionError, RateLimitError, APIStatusError) as e:
logging.warning(f"API Error occurred: {e}")
should_retry = False
if isinstance(e, (APITimeoutError, APIConnectionError, RateLimitError)):
should_retry = True
elif isinstance(e, APIStatusError) and e.status_code >= 500:
should_retry = True
2025-04-18 11:08:54 +08:00
2025-04-22 17:36:29 +08:00
if should_retry:
retry_count += 1
if retry_count <= self.max_retries:
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
logging.warning(f"Retrying API call ({retry_count}/{self.max_retries}) after error, waiting {wait_time:.2f}s...")
time.sleep(wait_time)
else:
logging.error(f"Max retries ({self.max_retries}) reached for API errors. Aborting.")
return "请求失败,无法生成内容。", 0
else:
logging.error(f"Non-retriable API error: {e}. Aborting.")
return "请求失败,发生不可重试错误。", 0
2025-04-18 11:08:54 +08:00
except Exception as e:
2025-04-22 17:36:29 +08:00
logging.exception(f"Unexpected error during API call setup/execution:")
2025-04-18 11:08:54 +08:00
retry_count += 1
if retry_count <= self.max_retries:
2025-04-22 17:36:29 +08:00
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
logging.warning(f"Retrying API call ({retry_count}/{self.max_retries}) after unexpected error, waiting {wait_time:.2f}s...")
2025-04-18 11:08:54 +08:00
time.sleep(wait_time)
else:
2025-04-22 17:36:29 +08:00
logging.error(f"Max retries ({self.max_retries}) reached after unexpected errors. Aborting.")
return "请求失败,发生未知错误。", 0
2025-04-22 17:36:29 +08:00
logging.info("Text generation completed.")
estimated_tokens = len(full_response.split()) * 1.3
return full_response, estimated_tokens
def read_folder(self, file_folder):
"""读取指定文件夹下的所有文件内容"""
if not os.path.exists(file_folder):
2025-04-22 17:36:29 +08:00
logging.warning(f"Referenced folder does not exist: {file_folder}")
return ""
context = ""
2025-04-22 17:36:29 +08:00
try:
for file in os.listdir(file_folder):
file_path = os.path.join(file_folder, file)
if os.path.isfile(file_path):
try:
with open(file_path, "r", encoding="utf-8") as f:
context += f"文件名: {file}\n"
context += f.read()
context += "\n\n"
except Exception as read_err:
logging.error(f"Failed to read file {file_path}: {read_err}")
except Exception as list_err:
logging.error(f"Failed to list directory {file_folder}: {list_err}")
return context
def work(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty):
"""完整的工作流程:生成文本并返回结果"""
2025-04-22 17:36:29 +08:00
logging.info(f"Starting 'work' process. File folder: {file_folder}")
if file_folder:
2025-04-22 17:36:29 +08:00
logging.info(f"Reading context from folder: {file_folder}")
context = self.read_folder(file_folder)
if context:
2025-04-22 17:36:29 +08:00
user_prompt = f"{user_prompt.strip()}\n\n--- 参考资料 ---\n{context.strip()}"
else:
logging.warning(f"Folder {file_folder} provided but no content read.")
time_start = time.time()
result, tokens = self.generate_text(system_prompt, user_prompt, temperature, top_p, presence_penalty)
time_end = time.time()
time_cost = time_end - time_start
2025-04-22 17:36:29 +08:00
logging.info(f"'work' completed in {time_cost:.2f}s. Estimated tokens: {tokens}")
2025-04-22 17:36:29 +08:00
return result, tokens, time_cost
2025-04-18 15:52:31 +08:00
def close(self):
2025-04-22 17:36:29 +08:00
try:
logging.info("Closing AI Agent (client resources will be garbage collected).")
self.client = None
except Exception as e:
logging.error(f"Error during AI Agent close: {e}")
2025-04-22 17:36:29 +08:00
# --- Streaming Methods ---
def generate_text_stream(self, system_prompt, user_prompt, temperature, top_p, presence_penalty):
"""生成文本内容,并以生成器方式 yield 文本块"""
2025-04-22 17:36:29 +08:00
logging.info("Streaming Generation Started...")
logging.debug(f"Streaming System Prompt (first 100 chars): {system_prompt[:100]}...")
logging.debug(f"Streaming User Prompt (first 100 chars): {user_prompt[:100]}...")
logging.info(f"Streaming Params: temp={temperature}, top_p={top_p}, presence_penalty={presence_penalty}")
retry_count = 0
2025-04-22 17:36:29 +08:00
max_retry_wait = 10
while retry_count <= self.max_retries:
try:
2025-04-22 17:36:29 +08:00
logging.info(f"Attempting API stream call (try {retry_count + 1}/{self.max_retries + 1})")
response = self.client.chat.completions.create(
model=self.model_name,
messages=[{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}],
temperature=temperature,
top_p=top_p,
presence_penalty=presence_penalty,
stream=True,
2025-04-22 17:36:29 +08:00
max_tokens=8192,
timeout=self.timeout,
extra_body={"repetition_penalty": 1.05},
)
try:
2025-04-22 17:36:29 +08:00
logging.info("Stream connected, receiving content...")
yielded_something = False
for chunk in response:
if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None:
content = chunk.choices[0].delta.content
2025-04-22 17:36:29 +08:00
yield content
yielded_something = True
2025-04-22 17:36:29 +08:00
if yielded_something:
logging.info("Stream finished successfully.")
else:
logging.warning("Stream finished, but no content was yielded.")
return
2025-04-22 17:36:29 +08:00
except APIConnectionError as stream_err:
logging.warning(f"Stream connection error occurred: {stream_err}")
retry_count += 1
if retry_count <= self.max_retries:
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
2025-04-22 17:36:29 +08:00
logging.warning(f"Retrying connection ({retry_count}/{self.max_retries}), waiting {wait_time:.2f}s...")
time.sleep(wait_time)
2025-04-22 17:36:29 +08:00
continue
else:
2025-04-22 17:36:29 +08:00
logging.error("Max retries reached after stream connection error.")
yield f"[STREAM_ERROR: Max retries reached after connection error: {stream_err}]"
return
2025-04-22 17:36:29 +08:00
except Exception as stream_err:
logging.exception("Error occurred during stream processing:")
yield f"[STREAM_ERROR: {stream_err}]"
return
2025-04-22 17:36:29 +08:00
except (APITimeoutError, APIConnectionError, RateLimitError, APIStatusError) as e:
logging.warning(f"API Error occurred: {e}")
should_retry = False
if isinstance(e, (APITimeoutError, APIConnectionError, RateLimitError)):
should_retry = True
elif isinstance(e, APIStatusError) and e.status_code >= 500:
should_retry = True
if should_retry:
retry_count += 1
if retry_count <= self.max_retries:
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
logging.warning(f"Retrying API call ({retry_count}/{self.max_retries}) after error, waiting {wait_time:.2f}s...")
time.sleep(wait_time)
continue
else:
logging.error(f"Max retries ({self.max_retries}) reached for API errors. Aborting stream.")
yield "[API_ERROR: Max retries reached]"
return
else:
2025-04-22 17:36:29 +08:00
logging.error(f"Non-retriable API error: {e}. Aborting stream.")
yield f"[API_ERROR: Non-retriable status {e.status_code if isinstance(e, APIStatusError) else 'Unknown'}]"
return
except Exception as e:
logging.exception("Non-retriable error occurred during API call setup:")
yield f"[FATAL_ERROR: {e}]"
return
2025-04-22 17:36:29 +08:00
logging.error("Stream generation failed after exhausting all retries.")
yield "[ERROR: Failed after all retries]"
def work_stream(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty):
"""工作流程的流式版本:返回文本生成器"""
2025-04-22 17:36:29 +08:00
logging.info(f"Starting 'work_stream' process. File folder: {file_folder}")
if file_folder:
2025-04-22 17:36:29 +08:00
logging.info(f"Reading context from folder: {file_folder}")
context = self.read_folder(file_folder)
if context:
user_prompt = f"{user_prompt.strip()}\n\n--- 参考资料 ---\n{context.strip()}"
else:
2025-04-22 17:36:29 +08:00
logging.warning(f"Folder {file_folder} provided but no content read.")
2025-04-22 17:36:29 +08:00
logging.info("Calling generate_text_stream...")
return self.generate_text_stream(system_prompt, user_prompt, temperature, top_p, presence_penalty)
# --- End Added Streaming Methods ---