增加了流式输出的ai_agent work模式

This commit is contained in:
jinye_huang 2025-04-22 16:29:51 +08:00
parent b6899f87f3
commit 59312ed23a
2 changed files with 123 additions and 2 deletions

View File

@ -1,7 +1,8 @@
import os
from openai import OpenAI
from openai import OpenAI, APITimeoutError, APIConnectionError, RateLimitError, APIStatusError
import time
import random
import traceback
class AI_Agent():
@ -153,4 +154,124 @@ class AI_Agent():
def close(self):
self.client.close()
## del self.client
del self
del self
# --- Added Streaming Methods ---
def generate_text_stream(self, system_prompt, user_prompt, temperature, top_p, presence_penalty):
"""生成文本内容,并以生成器方式 yield 文本块"""
print("Streaming Generation Started...")
print("System Prompt:", system_prompt[:100] + "..." if len(system_prompt) > 100 else system_prompt) # Print truncated prompts for logs
print("User Prompt:", user_prompt[:100] + "..." if len(user_prompt) > 100 else user_prompt)
print(f"Params: temp={temperature}, top_p={top_p}, presence_penalty={presence_penalty}")
retry_count = 0
max_retry_wait = 10 # Max backoff wait time
while retry_count <= self.max_retries:
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=[{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}],
temperature=temperature,
top_p=top_p,
presence_penalty=presence_penalty,
stream=True,
max_tokens=8192, # Or make configurable?
timeout=self.timeout, # Use configured timeout for initial connect/response
extra_body={"repetition_penalty": 1.05}, # Keep if needed
)
# Inner try-except specifically for handling errors during the stream
try:
print("Stream connected, receiving content...")
for chunk in response:
if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None:
content = chunk.choices[0].delta.content
yield content # Yield the content chunk
# Check for finish reason if needed, but loop termination handles it
# if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].finish_reason == "stop":
# print("\nStream finished (stop reason).")
# break # Exit inner loop
print("\nStream finished successfully.")
return # Generator successfully exhausted
except APIConnectionError as stream_err: # Catch connection errors during stream
print(f"\nStream connection error occurred: {stream_err}")
# Decide if retryable based on type or context
retry_count += 1
if retry_count <= self.max_retries:
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
print(f"Retrying connection ({retry_count}/{self.max_retries}), waiting {wait_time:.2f}s...")
time.sleep(wait_time)
continue # Continue outer loop to retry the whole API call
else:
print("Max retries reached after stream connection error.")
yield f"[STREAM_ERROR: Max retries reached after connection error: {stream_err}]" # Yield error info
return # Stop generator
except Exception as stream_err: # Catch other errors during stream processing
print(f"\nError occurred during stream processing: {stream_err}")
traceback.print_exc()
yield f"[STREAM_ERROR: {stream_err}]" # Yield error info
return # Stop generator
except (APITimeoutError, APIConnectionError, RateLimitError) as e: # Catch specific retriable API errors
print(f"\nRetriable API error occurred: {e}")
retry_count += 1
if retry_count <= self.max_retries:
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
print(f"Retrying API call ({retry_count}/{self.max_retries}), waiting {wait_time:.2f}s...")
time.sleep(wait_time)
continue # Continue outer loop
else:
print("Max retries reached for API errors.")
yield "[API_ERROR: Max retries reached]"
return # Stop generator
except APIStatusError as e: # Handle 5xx server errors specifically if possible
print(f"\nAPI Status Error: {e.status_code} - {e.response}")
if e.status_code >= 500: # Typically retry on 5xx
retry_count += 1
if retry_count <= self.max_retries:
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
print(f"Retrying API call ({retry_count}/{self.max_retries}) after server error, waiting {wait_time:.2f}s...")
time.sleep(wait_time)
continue
else:
print("Max retries reached after server error.")
yield f"[API_ERROR: Max retries reached after server error {e.status_code}]"
return
else: # Don't retry on non-5xx status errors (like 4xx)
print("Non-retriable API status error.")
yield f"[API_ERROR: Non-retriable status {e.status_code}]"
return
except Exception as e: # Catch other non-retriable errors during setup/call
print(f"\nNon-retriable error occurred: {e}")
traceback.print_exc()
yield f"[FATAL_ERROR: {e}]"
return # Stop generator
# This part is reached only if all retries failed without returning/yielding error
print("\nStream generation failed after exhausting all retries.")
yield "[ERROR: Failed after all retries]"
def work_stream(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty):
"""工作流程的流式版本:返回文本生成器"""
# 如果提供了参考文件夹,则读取其内容
if file_folder:
print(f"Reading context from folder: {file_folder}")
context = self.read_folder(file_folder)
if context:
# Append context carefully
user_prompt = f"{user_prompt.strip()}\n\n--- 参考资料 ---\n{context.strip()}"
else:
print(f"Warning: Folder {file_folder} provided but no content read.")
# 直接返回 generate_text_stream 的生成器
print("Calling generate_text_stream...")
return self.generate_text_stream(system_prompt, user_prompt, temperature, top_p, presence_penalty)
# --- End Added Streaming Methods ---