增加了流式输出的ai_agent work模式
This commit is contained in:
parent
b6899f87f3
commit
59312ed23a
Binary file not shown.
125
core/ai_agent.py
125
core/ai_agent.py
@ -1,7 +1,8 @@
|
||||
import os
|
||||
from openai import OpenAI
|
||||
from openai import OpenAI, APITimeoutError, APIConnectionError, RateLimitError, APIStatusError
|
||||
import time
|
||||
import random
|
||||
import traceback
|
||||
|
||||
|
||||
class AI_Agent():
|
||||
@ -153,4 +154,124 @@ class AI_Agent():
|
||||
def close(self):
|
||||
self.client.close()
|
||||
## del self.client
|
||||
del self
|
||||
del self
|
||||
|
||||
# --- Added Streaming Methods ---
|
||||
def generate_text_stream(self, system_prompt, user_prompt, temperature, top_p, presence_penalty):
|
||||
"""生成文本内容,并以生成器方式 yield 文本块"""
|
||||
print("Streaming Generation Started...")
|
||||
print("System Prompt:", system_prompt[:100] + "..." if len(system_prompt) > 100 else system_prompt) # Print truncated prompts for logs
|
||||
print("User Prompt:", user_prompt[:100] + "..." if len(user_prompt) > 100 else user_prompt)
|
||||
print(f"Params: temp={temperature}, top_p={top_p}, presence_penalty={presence_penalty}")
|
||||
|
||||
retry_count = 0
|
||||
max_retry_wait = 10 # Max backoff wait time
|
||||
|
||||
while retry_count <= self.max_retries:
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=[{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}],
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
presence_penalty=presence_penalty,
|
||||
stream=True,
|
||||
max_tokens=8192, # Or make configurable?
|
||||
timeout=self.timeout, # Use configured timeout for initial connect/response
|
||||
extra_body={"repetition_penalty": 1.05}, # Keep if needed
|
||||
)
|
||||
|
||||
# Inner try-except specifically for handling errors during the stream
|
||||
try:
|
||||
print("Stream connected, receiving content...")
|
||||
for chunk in response:
|
||||
if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None:
|
||||
content = chunk.choices[0].delta.content
|
||||
yield content # Yield the content chunk
|
||||
# Check for finish reason if needed, but loop termination handles it
|
||||
# if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].finish_reason == "stop":
|
||||
# print("\nStream finished (stop reason).")
|
||||
# break # Exit inner loop
|
||||
|
||||
print("\nStream finished successfully.")
|
||||
return # Generator successfully exhausted
|
||||
|
||||
except APIConnectionError as stream_err: # Catch connection errors during stream
|
||||
print(f"\nStream connection error occurred: {stream_err}")
|
||||
# Decide if retryable based on type or context
|
||||
retry_count += 1
|
||||
if retry_count <= self.max_retries:
|
||||
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
|
||||
print(f"Retrying connection ({retry_count}/{self.max_retries}), waiting {wait_time:.2f}s...")
|
||||
time.sleep(wait_time)
|
||||
continue # Continue outer loop to retry the whole API call
|
||||
else:
|
||||
print("Max retries reached after stream connection error.")
|
||||
yield f"[STREAM_ERROR: Max retries reached after connection error: {stream_err}]" # Yield error info
|
||||
return # Stop generator
|
||||
|
||||
except Exception as stream_err: # Catch other errors during stream processing
|
||||
print(f"\nError occurred during stream processing: {stream_err}")
|
||||
traceback.print_exc()
|
||||
yield f"[STREAM_ERROR: {stream_err}]" # Yield error info
|
||||
return # Stop generator
|
||||
|
||||
except (APITimeoutError, APIConnectionError, RateLimitError) as e: # Catch specific retriable API errors
|
||||
print(f"\nRetriable API error occurred: {e}")
|
||||
retry_count += 1
|
||||
if retry_count <= self.max_retries:
|
||||
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
|
||||
print(f"Retrying API call ({retry_count}/{self.max_retries}), waiting {wait_time:.2f}s...")
|
||||
time.sleep(wait_time)
|
||||
continue # Continue outer loop
|
||||
else:
|
||||
print("Max retries reached for API errors.")
|
||||
yield "[API_ERROR: Max retries reached]"
|
||||
return # Stop generator
|
||||
|
||||
except APIStatusError as e: # Handle 5xx server errors specifically if possible
|
||||
print(f"\nAPI Status Error: {e.status_code} - {e.response}")
|
||||
if e.status_code >= 500: # Typically retry on 5xx
|
||||
retry_count += 1
|
||||
if retry_count <= self.max_retries:
|
||||
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
|
||||
print(f"Retrying API call ({retry_count}/{self.max_retries}) after server error, waiting {wait_time:.2f}s...")
|
||||
time.sleep(wait_time)
|
||||
continue
|
||||
else:
|
||||
print("Max retries reached after server error.")
|
||||
yield f"[API_ERROR: Max retries reached after server error {e.status_code}]"
|
||||
return
|
||||
else: # Don't retry on non-5xx status errors (like 4xx)
|
||||
print("Non-retriable API status error.")
|
||||
yield f"[API_ERROR: Non-retriable status {e.status_code}]"
|
||||
return
|
||||
|
||||
except Exception as e: # Catch other non-retriable errors during setup/call
|
||||
print(f"\nNon-retriable error occurred: {e}")
|
||||
traceback.print_exc()
|
||||
yield f"[FATAL_ERROR: {e}]"
|
||||
return # Stop generator
|
||||
|
||||
# This part is reached only if all retries failed without returning/yielding error
|
||||
print("\nStream generation failed after exhausting all retries.")
|
||||
yield "[ERROR: Failed after all retries]"
|
||||
|
||||
|
||||
def work_stream(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty):
|
||||
"""工作流程的流式版本:返回文本生成器"""
|
||||
# 如果提供了参考文件夹,则读取其内容
|
||||
if file_folder:
|
||||
print(f"Reading context from folder: {file_folder}")
|
||||
context = self.read_folder(file_folder)
|
||||
if context:
|
||||
# Append context carefully
|
||||
user_prompt = f"{user_prompt.strip()}\n\n--- 参考资料 ---\n{context.strip()}"
|
||||
else:
|
||||
print(f"Warning: Folder {file_folder} provided but no content read.")
|
||||
|
||||
# 直接返回 generate_text_stream 的生成器
|
||||
print("Calling generate_text_stream...")
|
||||
return self.generate_text_stream(system_prompt, user_prompt, temperature, top_p, presence_penalty)
|
||||
# --- End Added Streaming Methods ---
|
||||
Loading…
x
Reference in New Issue
Block a user