增加了流式输出的ai_agent work模式
This commit is contained in:
parent
b6899f87f3
commit
59312ed23a
Binary file not shown.
125
core/ai_agent.py
125
core/ai_agent.py
@ -1,7 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
from openai import OpenAI
|
from openai import OpenAI, APITimeoutError, APIConnectionError, RateLimitError, APIStatusError
|
||||||
import time
|
import time
|
||||||
import random
|
import random
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
class AI_Agent():
|
class AI_Agent():
|
||||||
@ -153,4 +154,124 @@ class AI_Agent():
|
|||||||
def close(self):
|
def close(self):
|
||||||
self.client.close()
|
self.client.close()
|
||||||
## del self.client
|
## del self.client
|
||||||
del self
|
del self
|
||||||
|
|
||||||
|
# --- Added Streaming Methods ---
|
||||||
|
def generate_text_stream(self, system_prompt, user_prompt, temperature, top_p, presence_penalty):
|
||||||
|
"""生成文本内容,并以生成器方式 yield 文本块"""
|
||||||
|
print("Streaming Generation Started...")
|
||||||
|
print("System Prompt:", system_prompt[:100] + "..." if len(system_prompt) > 100 else system_prompt) # Print truncated prompts for logs
|
||||||
|
print("User Prompt:", user_prompt[:100] + "..." if len(user_prompt) > 100 else user_prompt)
|
||||||
|
print(f"Params: temp={temperature}, top_p={top_p}, presence_penalty={presence_penalty}")
|
||||||
|
|
||||||
|
retry_count = 0
|
||||||
|
max_retry_wait = 10 # Max backoff wait time
|
||||||
|
|
||||||
|
while retry_count <= self.max_retries:
|
||||||
|
try:
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=self.model_name,
|
||||||
|
messages=[{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": user_prompt}],
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
presence_penalty=presence_penalty,
|
||||||
|
stream=True,
|
||||||
|
max_tokens=8192, # Or make configurable?
|
||||||
|
timeout=self.timeout, # Use configured timeout for initial connect/response
|
||||||
|
extra_body={"repetition_penalty": 1.05}, # Keep if needed
|
||||||
|
)
|
||||||
|
|
||||||
|
# Inner try-except specifically for handling errors during the stream
|
||||||
|
try:
|
||||||
|
print("Stream connected, receiving content...")
|
||||||
|
for chunk in response:
|
||||||
|
if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None:
|
||||||
|
content = chunk.choices[0].delta.content
|
||||||
|
yield content # Yield the content chunk
|
||||||
|
# Check for finish reason if needed, but loop termination handles it
|
||||||
|
# if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].finish_reason == "stop":
|
||||||
|
# print("\nStream finished (stop reason).")
|
||||||
|
# break # Exit inner loop
|
||||||
|
|
||||||
|
print("\nStream finished successfully.")
|
||||||
|
return # Generator successfully exhausted
|
||||||
|
|
||||||
|
except APIConnectionError as stream_err: # Catch connection errors during stream
|
||||||
|
print(f"\nStream connection error occurred: {stream_err}")
|
||||||
|
# Decide if retryable based on type or context
|
||||||
|
retry_count += 1
|
||||||
|
if retry_count <= self.max_retries:
|
||||||
|
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
|
||||||
|
print(f"Retrying connection ({retry_count}/{self.max_retries}), waiting {wait_time:.2f}s...")
|
||||||
|
time.sleep(wait_time)
|
||||||
|
continue # Continue outer loop to retry the whole API call
|
||||||
|
else:
|
||||||
|
print("Max retries reached after stream connection error.")
|
||||||
|
yield f"[STREAM_ERROR: Max retries reached after connection error: {stream_err}]" # Yield error info
|
||||||
|
return # Stop generator
|
||||||
|
|
||||||
|
except Exception as stream_err: # Catch other errors during stream processing
|
||||||
|
print(f"\nError occurred during stream processing: {stream_err}")
|
||||||
|
traceback.print_exc()
|
||||||
|
yield f"[STREAM_ERROR: {stream_err}]" # Yield error info
|
||||||
|
return # Stop generator
|
||||||
|
|
||||||
|
except (APITimeoutError, APIConnectionError, RateLimitError) as e: # Catch specific retriable API errors
|
||||||
|
print(f"\nRetriable API error occurred: {e}")
|
||||||
|
retry_count += 1
|
||||||
|
if retry_count <= self.max_retries:
|
||||||
|
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
|
||||||
|
print(f"Retrying API call ({retry_count}/{self.max_retries}), waiting {wait_time:.2f}s...")
|
||||||
|
time.sleep(wait_time)
|
||||||
|
continue # Continue outer loop
|
||||||
|
else:
|
||||||
|
print("Max retries reached for API errors.")
|
||||||
|
yield "[API_ERROR: Max retries reached]"
|
||||||
|
return # Stop generator
|
||||||
|
|
||||||
|
except APIStatusError as e: # Handle 5xx server errors specifically if possible
|
||||||
|
print(f"\nAPI Status Error: {e.status_code} - {e.response}")
|
||||||
|
if e.status_code >= 500: # Typically retry on 5xx
|
||||||
|
retry_count += 1
|
||||||
|
if retry_count <= self.max_retries:
|
||||||
|
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
|
||||||
|
print(f"Retrying API call ({retry_count}/{self.max_retries}) after server error, waiting {wait_time:.2f}s...")
|
||||||
|
time.sleep(wait_time)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
print("Max retries reached after server error.")
|
||||||
|
yield f"[API_ERROR: Max retries reached after server error {e.status_code}]"
|
||||||
|
return
|
||||||
|
else: # Don't retry on non-5xx status errors (like 4xx)
|
||||||
|
print("Non-retriable API status error.")
|
||||||
|
yield f"[API_ERROR: Non-retriable status {e.status_code}]"
|
||||||
|
return
|
||||||
|
|
||||||
|
except Exception as e: # Catch other non-retriable errors during setup/call
|
||||||
|
print(f"\nNon-retriable error occurred: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
yield f"[FATAL_ERROR: {e}]"
|
||||||
|
return # Stop generator
|
||||||
|
|
||||||
|
# This part is reached only if all retries failed without returning/yielding error
|
||||||
|
print("\nStream generation failed after exhausting all retries.")
|
||||||
|
yield "[ERROR: Failed after all retries]"
|
||||||
|
|
||||||
|
|
||||||
|
def work_stream(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty):
|
||||||
|
"""工作流程的流式版本:返回文本生成器"""
|
||||||
|
# 如果提供了参考文件夹,则读取其内容
|
||||||
|
if file_folder:
|
||||||
|
print(f"Reading context from folder: {file_folder}")
|
||||||
|
context = self.read_folder(file_folder)
|
||||||
|
if context:
|
||||||
|
# Append context carefully
|
||||||
|
user_prompt = f"{user_prompt.strip()}\n\n--- 参考资料 ---\n{context.strip()}"
|
||||||
|
else:
|
||||||
|
print(f"Warning: Folder {file_folder} provided but no content read.")
|
||||||
|
|
||||||
|
# 直接返回 generate_text_stream 的生成器
|
||||||
|
print("Calling generate_text_stream...")
|
||||||
|
return self.generate_text_stream(system_prompt, user_prompt, temperature, top_p, presence_penalty)
|
||||||
|
# --- End Added Streaming Methods ---
|
||||||
Loading…
x
Reference in New Issue
Block a user