282 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from openai import OpenAI, APITimeoutError, APIConnectionError, RateLimitError, APIStatusError
import time
import random
import traceback
import logging
# Configure basic logging for this module (or rely on root logger config)
# logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# logger = logging.getLogger(__name__) # Alternative: use named logger
class AI_Agent():
"""AI代理类负责与AI模型交互生成文本内容"""
def __init__(self, base_url, model_name, api, timeout=30, max_retries=3):
self.url_list = {
"ali": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"kimi": "https://api.moonshot.cn/v1",
"doubao": "https://ark.cn-beijing.volces.com/api/v3/",
"deepseek": "https://api.deepseek.com",
"vllm": "http://localhost:8000/v1",
}
self.base_url = self.url_list.get(base_url, base_url)
self.api = api
self.model_name = model_name
self.timeout = timeout
self.max_retries = max_retries
print(f"Initializing AI Agent with base_url={self.base_url}, model={self.model_name}, timeout={self.timeout}s, max_retries={self.max_retries}")
self.client = OpenAI(
api_key=self.api,
base_url=self.base_url,
timeout=self.timeout
)
def generate_text(self, system_prompt, user_prompt, temperature, top_p, presence_penalty):
"""生成文本内容并返回完整响应和token估计值"""
logging.info(f"Generating text with model: {self.model_name}, temp={temperature}, top_p={top_p}, presence_penalty={presence_penalty}")
logging.debug(f"System Prompt (first 100 chars): {system_prompt[:100]}...")
logging.debug(f"User Prompt (first 100 chars): {user_prompt[:100]}...")
time.sleep(random.random())
retry_count = 0
max_retry_wait = 10
while retry_count <= self.max_retries:
try:
logging.info(f"Attempting API call (try {retry_count + 1}/{self.max_retries + 1})")
response = self.client.chat.completions.create(
model=self.model_name,
messages=[{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}],
temperature=temperature,
top_p=top_p,
presence_penalty=presence_penalty,
stream=True,
max_tokens=8192,
timeout=self.timeout,
extra_body={
"repetition_penalty": 1.05,
},
)
full_response = ""
stream_timed_out = False
try:
for chunk in response:
if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None:
content = chunk.choices[0].delta.content
full_response += content
if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].finish_reason == "stop":
break
# Successfully finished stream
break
except Exception as stream_err:
logging.warning(f"Exception during stream processing: {stream_err}")
stream_timed_out = True
if stream_timed_out:
if len(full_response) > 100:
logging.warning(f"Stream interrupted, but received {len(full_response)} characters. Using partial content.")
break
retry_count += 1
if retry_count <= self.max_retries:
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
logging.warning(f"Stream error/timeout. Waiting {wait_time:.2f}s before retry ({retry_count}/{self.max_retries})...")
time.sleep(wait_time)
continue
except (APITimeoutError, APIConnectionError, RateLimitError, APIStatusError) as e:
logging.warning(f"API Error occurred: {e}")
should_retry = False
if isinstance(e, (APITimeoutError, APIConnectionError, RateLimitError)):
should_retry = True
elif isinstance(e, APIStatusError) and e.status_code >= 500:
should_retry = True
if should_retry:
retry_count += 1
if retry_count <= self.max_retries:
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
logging.warning(f"Retrying API call ({retry_count}/{self.max_retries}) after error, waiting {wait_time:.2f}s...")
time.sleep(wait_time)
else:
logging.error(f"Max retries ({self.max_retries}) reached for API errors. Aborting.")
return "请求失败,无法生成内容。", 0
else:
logging.error(f"Non-retriable API error: {e}. Aborting.")
return "请求失败,发生不可重试错误。", 0
except Exception as e:
logging.exception(f"Unexpected error during API call setup/execution:")
retry_count += 1
if retry_count <= self.max_retries:
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
logging.warning(f"Retrying API call ({retry_count}/{self.max_retries}) after unexpected error, waiting {wait_time:.2f}s...")
time.sleep(wait_time)
else:
logging.error(f"Max retries ({self.max_retries}) reached after unexpected errors. Aborting.")
return "请求失败,发生未知错误。", 0
logging.info("Text generation completed.")
estimated_tokens = len(full_response.split()) * 1.3
return full_response, estimated_tokens
def read_folder(self, file_folder):
"""读取指定文件夹下的所有文件内容"""
if not os.path.exists(file_folder):
logging.warning(f"Referenced folder does not exist: {file_folder}")
return ""
context = ""
try:
for file in os.listdir(file_folder):
file_path = os.path.join(file_folder, file)
if os.path.isfile(file_path):
try:
with open(file_path, "r", encoding="utf-8") as f:
context += f"文件名: {file}\n"
context += f.read()
context += "\n\n"
except Exception as read_err:
logging.error(f"Failed to read file {file_path}: {read_err}")
except Exception as list_err:
logging.error(f"Failed to list directory {file_folder}: {list_err}")
return context
def work(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty):
"""完整的工作流程:生成文本并返回结果"""
logging.info(f"Starting 'work' process. File folder: {file_folder}")
if file_folder:
logging.info(f"Reading context from folder: {file_folder}")
context = self.read_folder(file_folder)
if context:
user_prompt = f"{user_prompt.strip()}\n\n--- 参考资料 ---\n{context.strip()}"
else:
logging.warning(f"Folder {file_folder} provided but no content read.")
time_start = time.time()
result, tokens = self.generate_text(system_prompt, user_prompt, temperature, top_p, presence_penalty)
time_end = time.time()
time_cost = time_end - time_start
logging.info(f"'work' completed in {time_cost:.2f}s. Estimated tokens: {tokens}")
return result, tokens, time_cost
def close(self):
try:
logging.info("Closing AI Agent (client resources will be garbage collected).")
self.client = None
except Exception as e:
logging.error(f"Error during AI Agent close: {e}")
# --- Streaming Methods ---
def generate_text_stream(self, system_prompt, user_prompt, temperature, top_p, presence_penalty):
"""生成文本内容,并以生成器方式 yield 文本块"""
logging.info("Streaming Generation Started...")
logging.debug(f"Streaming System Prompt (first 100 chars): {system_prompt[:100]}...")
logging.debug(f"Streaming User Prompt (first 100 chars): {user_prompt[:100]}...")
logging.info(f"Streaming Params: temp={temperature}, top_p={top_p}, presence_penalty={presence_penalty}")
retry_count = 0
max_retry_wait = 10
while retry_count <= self.max_retries:
try:
logging.info(f"Attempting API stream call (try {retry_count + 1}/{self.max_retries + 1})")
response = self.client.chat.completions.create(
model=self.model_name,
messages=[{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}],
temperature=temperature,
top_p=top_p,
presence_penalty=presence_penalty,
stream=True,
max_tokens=8192,
timeout=self.timeout,
extra_body={"repetition_penalty": 1.05},
)
try:
logging.info("Stream connected, receiving content...")
yielded_something = False
for chunk in response:
if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None:
content = chunk.choices[0].delta.content
yield content
yielded_something = True
if yielded_something:
logging.info("Stream finished successfully.")
else:
logging.warning("Stream finished, but no content was yielded.")
return
except APIConnectionError as stream_err:
logging.warning(f"Stream connection error occurred: {stream_err}")
retry_count += 1
if retry_count <= self.max_retries:
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
logging.warning(f"Retrying connection ({retry_count}/{self.max_retries}), waiting {wait_time:.2f}s...")
time.sleep(wait_time)
continue
else:
logging.error("Max retries reached after stream connection error.")
yield f"[STREAM_ERROR: Max retries reached after connection error: {stream_err}]"
return
except Exception as stream_err:
logging.exception("Error occurred during stream processing:")
yield f"[STREAM_ERROR: {stream_err}]"
return
except (APITimeoutError, APIConnectionError, RateLimitError, APIStatusError) as e:
logging.warning(f"API Error occurred: {e}")
should_retry = False
if isinstance(e, (APITimeoutError, APIConnectionError, RateLimitError)):
should_retry = True
elif isinstance(e, APIStatusError) and e.status_code >= 500:
should_retry = True
if should_retry:
retry_count += 1
if retry_count <= self.max_retries:
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
logging.warning(f"Retrying API call ({retry_count}/{self.max_retries}) after error, waiting {wait_time:.2f}s...")
time.sleep(wait_time)
continue
else:
logging.error(f"Max retries ({self.max_retries}) reached for API errors. Aborting stream.")
yield "[API_ERROR: Max retries reached]"
return
else:
logging.error(f"Non-retriable API error: {e}. Aborting stream.")
yield f"[API_ERROR: Non-retriable status {e.status_code if isinstance(e, APIStatusError) else 'Unknown'}]"
return
except Exception as e:
logging.exception("Non-retriable error occurred during API call setup:")
yield f"[FATAL_ERROR: {e}]"
return
logging.error("Stream generation failed after exhausting all retries.")
yield "[ERROR: Failed after all retries]"
def work_stream(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty):
"""工作流程的流式版本:返回文本生成器"""
logging.info(f"Starting 'work_stream' process. File folder: {file_folder}")
if file_folder:
logging.info(f"Reading context from folder: {file_folder}")
context = self.read_folder(file_folder)
if context:
user_prompt = f"{user_prompt.strip()}\n\n--- 参考资料 ---\n{context.strip()}"
else:
logging.warning(f"Folder {file_folder} provided but no content read.")
logging.info("Calling generate_text_stream...")
return self.generate_text_stream(system_prompt, user_prompt, temperature, top_p, presence_penalty)
# --- End Added Streaming Methods ---