diff --git a/core/__pycache__/ai_agent.cpython-312.pyc b/core/__pycache__/ai_agent.cpython-312.pyc index 858b742..8b29a68 100644 Binary files a/core/__pycache__/ai_agent.cpython-312.pyc and b/core/__pycache__/ai_agent.cpython-312.pyc differ diff --git a/core/ai_agent.py b/core/ai_agent.py index 2b16b24..d34beaf 100644 --- a/core/ai_agent.py +++ b/core/ai_agent.py @@ -1,7 +1,8 @@ import os -from openai import OpenAI +from openai import OpenAI, APITimeoutError, APIConnectionError, RateLimitError, APIStatusError import time import random +import traceback class AI_Agent(): @@ -153,4 +154,124 @@ class AI_Agent(): def close(self): self.client.close() ## del self.client - del self \ No newline at end of file + del self + + # --- Added Streaming Methods --- + def generate_text_stream(self, system_prompt, user_prompt, temperature, top_p, presence_penalty): + """生成文本内容,并以生成器方式 yield 文本块""" + print("Streaming Generation Started...") + print("System Prompt:", system_prompt[:100] + "..." if len(system_prompt) > 100 else system_prompt) # Print truncated prompts for logs + print("User Prompt:", user_prompt[:100] + "..." if len(user_prompt) > 100 else user_prompt) + print(f"Params: temp={temperature}, top_p={top_p}, presence_penalty={presence_penalty}") + + retry_count = 0 + max_retry_wait = 10 # Max backoff wait time + + while retry_count <= self.max_retries: + try: + response = self.client.chat.completions.create( + model=self.model_name, + messages=[{"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}], + temperature=temperature, + top_p=top_p, + presence_penalty=presence_penalty, + stream=True, + max_tokens=8192, # Or make configurable? + timeout=self.timeout, # Use configured timeout for initial connect/response + extra_body={"repetition_penalty": 1.05}, # Keep if needed + ) + + # Inner try-except specifically for handling errors during the stream + try: + print("Stream connected, receiving content...") + for chunk in response: + if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None: + content = chunk.choices[0].delta.content + yield content # Yield the content chunk + # Check for finish reason if needed, but loop termination handles it + # if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].finish_reason == "stop": + # print("\nStream finished (stop reason).") + # break # Exit inner loop + + print("\nStream finished successfully.") + return # Generator successfully exhausted + + except APIConnectionError as stream_err: # Catch connection errors during stream + print(f"\nStream connection error occurred: {stream_err}") + # Decide if retryable based on type or context + retry_count += 1 + if retry_count <= self.max_retries: + wait_time = min(2 ** retry_count + random.random(), max_retry_wait) + print(f"Retrying connection ({retry_count}/{self.max_retries}), waiting {wait_time:.2f}s...") + time.sleep(wait_time) + continue # Continue outer loop to retry the whole API call + else: + print("Max retries reached after stream connection error.") + yield f"[STREAM_ERROR: Max retries reached after connection error: {stream_err}]" # Yield error info + return # Stop generator + + except Exception as stream_err: # Catch other errors during stream processing + print(f"\nError occurred during stream processing: {stream_err}") + traceback.print_exc() + yield f"[STREAM_ERROR: {stream_err}]" # Yield error info + return # Stop generator + + except (APITimeoutError, APIConnectionError, RateLimitError) as e: # Catch specific retriable API errors + print(f"\nRetriable API error occurred: {e}") + retry_count += 1 + if retry_count <= self.max_retries: + wait_time = min(2 ** retry_count + random.random(), max_retry_wait) + print(f"Retrying API call ({retry_count}/{self.max_retries}), waiting {wait_time:.2f}s...") + time.sleep(wait_time) + continue # Continue outer loop + else: + print("Max retries reached for API errors.") + yield "[API_ERROR: Max retries reached]" + return # Stop generator + + except APIStatusError as e: # Handle 5xx server errors specifically if possible + print(f"\nAPI Status Error: {e.status_code} - {e.response}") + if e.status_code >= 500: # Typically retry on 5xx + retry_count += 1 + if retry_count <= self.max_retries: + wait_time = min(2 ** retry_count + random.random(), max_retry_wait) + print(f"Retrying API call ({retry_count}/{self.max_retries}) after server error, waiting {wait_time:.2f}s...") + time.sleep(wait_time) + continue + else: + print("Max retries reached after server error.") + yield f"[API_ERROR: Max retries reached after server error {e.status_code}]" + return + else: # Don't retry on non-5xx status errors (like 4xx) + print("Non-retriable API status error.") + yield f"[API_ERROR: Non-retriable status {e.status_code}]" + return + + except Exception as e: # Catch other non-retriable errors during setup/call + print(f"\nNon-retriable error occurred: {e}") + traceback.print_exc() + yield f"[FATAL_ERROR: {e}]" + return # Stop generator + + # This part is reached only if all retries failed without returning/yielding error + print("\nStream generation failed after exhausting all retries.") + yield "[ERROR: Failed after all retries]" + + + def work_stream(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty): + """工作流程的流式版本:返回文本生成器""" + # 如果提供了参考文件夹,则读取其内容 + if file_folder: + print(f"Reading context from folder: {file_folder}") + context = self.read_folder(file_folder) + if context: + # Append context carefully + user_prompt = f"{user_prompt.strip()}\n\n--- 参考资料 ---\n{context.strip()}" + else: + print(f"Warning: Folder {file_folder} provided but no content read.") + + # 直接返回 generate_text_stream 的生成器 + print("Calling generate_text_stream...") + return self.generate_text_stream(system_prompt, user_prompt, temperature, top_p, presence_penalty) + # --- End Added Streaming Methods --- \ No newline at end of file