import os from openai import OpenAI, APITimeoutError, APIConnectionError, RateLimitError, APIStatusError import time import random import traceback import logging # Configure basic logging for this module (or rely on root logger config) # logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') # logger = logging.getLogger(__name__) # Alternative: use named logger class AI_Agent(): """AI代理类,负责与AI模型交互生成文本内容""" def __init__(self, base_url, model_name, api, timeout=30, max_retries=3): self.url_list = { "ali": "https://dashscope.aliyuncs.com/compatible-mode/v1", "kimi": "https://api.moonshot.cn/v1", "doubao": "https://ark.cn-beijing.volces.com/api/v3/", "deepseek": "https://api.deepseek.com", "vllm": "http://localhost:8000/v1", } self.base_url = self.url_list.get(base_url, base_url) self.api = api self.model_name = model_name self.timeout = timeout self.max_retries = max_retries print(f"Initializing AI Agent with base_url={self.base_url}, model={self.model_name}, timeout={self.timeout}s, max_retries={self.max_retries}") self.client = OpenAI( api_key=self.api, base_url=self.base_url, timeout=self.timeout ) def generate_text(self, system_prompt, user_prompt, temperature, top_p, presence_penalty): """生成文本内容,并返回完整响应和token估计值""" logging.info(f"Generating text with model: {self.model_name}, temp={temperature}, top_p={top_p}, presence_penalty={presence_penalty}") logging.debug(f"System Prompt (first 100 chars): {system_prompt[:100]}...") logging.debug(f"User Prompt (first 100 chars): {user_prompt[:100]}...") time.sleep(random.random()) retry_count = 0 max_retry_wait = 10 while retry_count <= self.max_retries: try: logging.info(f"Attempting API call (try {retry_count + 1}/{self.max_retries + 1})") response = self.client.chat.completions.create( model=self.model_name, messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}], temperature=temperature, top_p=top_p, presence_penalty=presence_penalty, stream=True, max_tokens=8192, timeout=self.timeout, extra_body={ "repetition_penalty": 1.05, }, ) full_response = "" stream_timed_out = False try: for chunk in response: if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None: content = chunk.choices[0].delta.content full_response += content if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].finish_reason == "stop": break # Successfully finished stream break except Exception as stream_err: logging.warning(f"Exception during stream processing: {stream_err}") stream_timed_out = True if stream_timed_out: if len(full_response) > 100: logging.warning(f"Stream interrupted, but received {len(full_response)} characters. Using partial content.") break retry_count += 1 if retry_count <= self.max_retries: wait_time = min(2 ** retry_count + random.random(), max_retry_wait) logging.warning(f"Stream error/timeout. Waiting {wait_time:.2f}s before retry ({retry_count}/{self.max_retries})...") time.sleep(wait_time) continue except (APITimeoutError, APIConnectionError, RateLimitError, APIStatusError) as e: logging.warning(f"API Error occurred: {e}") should_retry = False if isinstance(e, (APITimeoutError, APIConnectionError, RateLimitError)): should_retry = True elif isinstance(e, APIStatusError) and e.status_code >= 500: should_retry = True if should_retry: retry_count += 1 if retry_count <= self.max_retries: wait_time = min(2 ** retry_count + random.random(), max_retry_wait) logging.warning(f"Retrying API call ({retry_count}/{self.max_retries}) after error, waiting {wait_time:.2f}s...") time.sleep(wait_time) else: logging.error(f"Max retries ({self.max_retries}) reached for API errors. Aborting.") return "请求失败,无法生成内容。", 0 else: logging.error(f"Non-retriable API error: {e}. Aborting.") return "请求失败,发生不可重试错误。", 0 except Exception as e: logging.exception(f"Unexpected error during API call setup/execution:") retry_count += 1 if retry_count <= self.max_retries: wait_time = min(2 ** retry_count + random.random(), max_retry_wait) logging.warning(f"Retrying API call ({retry_count}/{self.max_retries}) after unexpected error, waiting {wait_time:.2f}s...") time.sleep(wait_time) else: logging.error(f"Max retries ({self.max_retries}) reached after unexpected errors. Aborting.") return "请求失败,发生未知错误。", 0 logging.info("Text generation completed.") estimated_tokens = len(full_response.split()) * 1.3 return full_response, estimated_tokens def read_folder(self, file_folder): """读取指定文件夹下的所有文件内容""" if not os.path.exists(file_folder): logging.warning(f"Referenced folder does not exist: {file_folder}") return "" context = "" try: for file in os.listdir(file_folder): file_path = os.path.join(file_folder, file) if os.path.isfile(file_path): try: with open(file_path, "r", encoding="utf-8") as f: context += f"文件名: {file}\n" context += f.read() context += "\n\n" except Exception as read_err: logging.error(f"Failed to read file {file_path}: {read_err}") except Exception as list_err: logging.error(f"Failed to list directory {file_folder}: {list_err}") return context def work(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty): """完整的工作流程:生成文本并返回结果""" logging.info(f"Starting 'work' process. File folder: {file_folder}") if file_folder: logging.info(f"Reading context from folder: {file_folder}") context = self.read_folder(file_folder) if context: user_prompt = f"{user_prompt.strip()}\n\n--- 参考资料 ---\n{context.strip()}" else: logging.warning(f"Folder {file_folder} provided but no content read.") time_start = time.time() result, tokens = self.generate_text(system_prompt, user_prompt, temperature, top_p, presence_penalty) time_end = time.time() time_cost = time_end - time_start logging.info(f"'work' completed in {time_cost:.2f}s. Estimated tokens: {tokens}") return result, tokens, time_cost def close(self): try: logging.info("Closing AI Agent (client resources will be garbage collected).") self.client = None except Exception as e: logging.error(f"Error during AI Agent close: {e}") # --- Streaming Methods --- def generate_text_stream(self, system_prompt, user_prompt, temperature, top_p, presence_penalty): """生成文本内容,并以生成器方式 yield 文本块""" logging.info("Streaming Generation Started...") logging.debug(f"Streaming System Prompt (first 100 chars): {system_prompt[:100]}...") logging.debug(f"Streaming User Prompt (first 100 chars): {user_prompt[:100]}...") logging.info(f"Streaming Params: temp={temperature}, top_p={top_p}, presence_penalty={presence_penalty}") retry_count = 0 max_retry_wait = 10 while retry_count <= self.max_retries: try: logging.info(f"Attempting API stream call (try {retry_count + 1}/{self.max_retries + 1})") response = self.client.chat.completions.create( model=self.model_name, messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}], temperature=temperature, top_p=top_p, presence_penalty=presence_penalty, stream=True, max_tokens=8192, timeout=self.timeout, extra_body={"repetition_penalty": 1.05}, ) try: logging.info("Stream connected, receiving content...") yielded_something = False for chunk in response: if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None: content = chunk.choices[0].delta.content yield content yielded_something = True if yielded_something: logging.info("Stream finished successfully.") else: logging.warning("Stream finished, but no content was yielded.") return except APIConnectionError as stream_err: logging.warning(f"Stream connection error occurred: {stream_err}") retry_count += 1 if retry_count <= self.max_retries: wait_time = min(2 ** retry_count + random.random(), max_retry_wait) logging.warning(f"Retrying connection ({retry_count}/{self.max_retries}), waiting {wait_time:.2f}s...") time.sleep(wait_time) continue else: logging.error("Max retries reached after stream connection error.") yield f"[STREAM_ERROR: Max retries reached after connection error: {stream_err}]" return except Exception as stream_err: logging.exception("Error occurred during stream processing:") yield f"[STREAM_ERROR: {stream_err}]" return except (APITimeoutError, APIConnectionError, RateLimitError, APIStatusError) as e: logging.warning(f"API Error occurred: {e}") should_retry = False if isinstance(e, (APITimeoutError, APIConnectionError, RateLimitError)): should_retry = True elif isinstance(e, APIStatusError) and e.status_code >= 500: should_retry = True if should_retry: retry_count += 1 if retry_count <= self.max_retries: wait_time = min(2 ** retry_count + random.random(), max_retry_wait) logging.warning(f"Retrying API call ({retry_count}/{self.max_retries}) after error, waiting {wait_time:.2f}s...") time.sleep(wait_time) continue else: logging.error(f"Max retries ({self.max_retries}) reached for API errors. Aborting stream.") yield "[API_ERROR: Max retries reached]" return else: logging.error(f"Non-retriable API error: {e}. Aborting stream.") yield f"[API_ERROR: Non-retriable status {e.status_code if isinstance(e, APIStatusError) else 'Unknown'}]" return except Exception as e: logging.exception("Non-retriable error occurred during API call setup:") yield f"[FATAL_ERROR: {e}]" return logging.error("Stream generation failed after exhausting all retries.") yield "[ERROR: Failed after all retries]" def work_stream(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty): """工作流程的流式版本:返回文本生成器""" logging.info(f"Starting 'work_stream' process. File folder: {file_folder}") if file_folder: logging.info(f"Reading context from folder: {file_folder}") context = self.read_folder(file_folder) if context: user_prompt = f"{user_prompt.strip()}\n\n--- 参考资料 ---\n{context.strip()}" else: logging.warning(f"Folder {file_folder} provided but no content read.") logging.info("Calling generate_text_stream...") return self.generate_text_stream(system_prompt, user_prompt, temperature, top_p, presence_penalty) # --- End Added Streaming Methods ---