277 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from openai import OpenAI, APITimeoutError, APIConnectionError, RateLimitError, APIStatusError
import time
import random
import traceback
class AI_Agent():
"""AI代理类负责与AI模型交互生成文本内容"""
def __init__(self, base_url, model_name, api, timeout=30, max_retries=3):
self.url_list = {
"ali": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"kimi": "https://api.moonshot.cn/v1",
"doubao": "https://ark.cn-beijing.volces.com/api/v3/",
"deepseek": "https://api.deepseek.com",
"vllm": "http://localhost:8000/v1",
}
self.base_url = self.url_list.get(base_url, base_url)
self.api = api
self.model_name = model_name
self.timeout = timeout
self.max_retries = max_retries
print(f"Initializing AI Agent with timeout={self.timeout}s, max_retries={self.max_retries}")
self.client = OpenAI(
api_key=self.api,
base_url=self.base_url,
timeout=self.timeout
)
def generate_text(self, system_prompt, user_prompt, temperature, top_p, presence_penalty):
"""生成文本内容并返回完整响应和token估计值"""
print("系统提示词:")
print(system_prompt)
print("\n用户提示词:")
print(user_prompt)
print(f"\nAPI Key: {self.api}")
print(f"Base URL: {self.base_url}")
print(f"Model: {self.model_name}")
time.sleep(random.random())
retry_count = 0
max_retry_wait = 10 # 最大重试等待时间(秒)
while retry_count <= self.max_retries:
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=[{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}],
temperature=temperature,
top_p=top_p,
presence_penalty=presence_penalty,
stream=True,
max_tokens=8192,
timeout=self.timeout, # 设置请求超时
extra_body={
"repetition_penalty": 1.05,
},
)
# 收集完整的输出内容
full_response = ""
try:
for chunk in response:
if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None:
content = chunk.choices[0].delta.content
full_response += content
print(content, end="", flush=True)
if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].finish_reason == "stop":
break
# 成功完成,跳出重试循环
break
except:
# 处理流式响应中的超时
print(f"\n接收响应时超时")
if len(full_response) > 0:
print(f"已接收部分响应({len(full_response)}字符)")
# 如果已接收足够内容,可以考虑使用已有内容
if len(full_response) > 100: # 假设至少需要100个字符才有意义
print("使用已接收的部分内容继续处理")
break
# 否则准备重试
retry_count += 1
if retry_count <= self.max_retries:
wait_time = min(2 ** retry_count + random.random(), max_retry_wait) # 指数退避
print(f"\n等待 {wait_time:.2f} 秒后重试({retry_count}/{self.max_retries})...")
time.sleep(wait_time)
continue
except Exception as e:
print(f"\n请求发生错误: {e}")
retry_count += 1
if retry_count <= self.max_retries:
wait_time = min(2 ** retry_count + random.random(), max_retry_wait) # 指数退避
print(f"\n等待 {wait_time:.2f} 秒后重试({retry_count}/{self.max_retries})...")
time.sleep(wait_time)
else:
print(f"已达到最大重试次数({self.max_retries}),放弃请求")
return "请求失败,无法生成内容。", 0
print("\n完成生成,正在处理结果...")
# 由于使用流式输出无法获取真实的token计数因此返回估计值
estimated_tokens = len(full_response.split()) * 1.3 # 简单估算token数量
return full_response, estimated_tokens
def read_folder(self, file_folder):
"""读取指定文件夹下的所有文件内容"""
if not os.path.exists(file_folder):
return ""
context = ""
for file in os.listdir(file_folder):
file_path = os.path.join(file_folder, file)
if os.path.isfile(file_path):
with open(file_path, "r", encoding="utf-8") as f:
context += f"文件名: {file}\n"
for line in f.readlines():
context += line
context += "\n"
return context
def work(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty):
"""完整的工作流程:生成文本并返回结果"""
# 生成时间戳
date_time = time.strftime("%Y-%m-%d_%H-%M-%S")
result_file = f"/root/autodl-tmp/xhsTweetGene/result/{date_time}.md"
# 如果提供了参考文件夹,则读取其内容
if file_folder:
context = self.read_folder(file_folder)
if context:
user_prompt = f"{user_prompt}\n参考资料:\n{context}"
# 计时
time_start = time.time()
# 生成文本
result, tokens = self.generate_text(system_prompt, user_prompt, temperature, top_p, presence_penalty)
# 计算耗时
time_end = time.time()
time_cost = time_end - time_start
return result, system_prompt, user_prompt, file_folder, result_file, tokens, time_cost
def close(self):
self.client.close()
## del self.client
del self
# --- Added Streaming Methods ---
def generate_text_stream(self, system_prompt, user_prompt, temperature, top_p, presence_penalty):
"""生成文本内容,并以生成器方式 yield 文本块"""
print("Streaming Generation Started...")
print("System Prompt:", system_prompt[:100] + "..." if len(system_prompt) > 100 else system_prompt) # Print truncated prompts for logs
print("User Prompt:", user_prompt[:100] + "..." if len(user_prompt) > 100 else user_prompt)
print(f"Params: temp={temperature}, top_p={top_p}, presence_penalty={presence_penalty}")
retry_count = 0
max_retry_wait = 10 # Max backoff wait time
while retry_count <= self.max_retries:
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=[{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}],
temperature=temperature,
top_p=top_p,
presence_penalty=presence_penalty,
stream=True,
max_tokens=8192, # Or make configurable?
timeout=self.timeout, # Use configured timeout for initial connect/response
extra_body={"repetition_penalty": 1.05}, # Keep if needed
)
# Inner try-except specifically for handling errors during the stream
try:
print("Stream connected, receiving content...")
for chunk in response:
if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None:
content = chunk.choices[0].delta.content
yield content # Yield the content chunk
# Check for finish reason if needed, but loop termination handles it
# if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].finish_reason == "stop":
# print("\nStream finished (stop reason).")
# break # Exit inner loop
print("\nStream finished successfully.")
return # Generator successfully exhausted
except APIConnectionError as stream_err: # Catch connection errors during stream
print(f"\nStream connection error occurred: {stream_err}")
# Decide if retryable based on type or context
retry_count += 1
if retry_count <= self.max_retries:
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
print(f"Retrying connection ({retry_count}/{self.max_retries}), waiting {wait_time:.2f}s...")
time.sleep(wait_time)
continue # Continue outer loop to retry the whole API call
else:
print("Max retries reached after stream connection error.")
yield f"[STREAM_ERROR: Max retries reached after connection error: {stream_err}]" # Yield error info
return # Stop generator
except Exception as stream_err: # Catch other errors during stream processing
print(f"\nError occurred during stream processing: {stream_err}")
traceback.print_exc()
yield f"[STREAM_ERROR: {stream_err}]" # Yield error info
return # Stop generator
except (APITimeoutError, APIConnectionError, RateLimitError) as e: # Catch specific retriable API errors
print(f"\nRetriable API error occurred: {e}")
retry_count += 1
if retry_count <= self.max_retries:
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
print(f"Retrying API call ({retry_count}/{self.max_retries}), waiting {wait_time:.2f}s...")
time.sleep(wait_time)
continue # Continue outer loop
else:
print("Max retries reached for API errors.")
yield "[API_ERROR: Max retries reached]"
return # Stop generator
except APIStatusError as e: # Handle 5xx server errors specifically if possible
print(f"\nAPI Status Error: {e.status_code} - {e.response}")
if e.status_code >= 500: # Typically retry on 5xx
retry_count += 1
if retry_count <= self.max_retries:
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
print(f"Retrying API call ({retry_count}/{self.max_retries}) after server error, waiting {wait_time:.2f}s...")
time.sleep(wait_time)
continue
else:
print("Max retries reached after server error.")
yield f"[API_ERROR: Max retries reached after server error {e.status_code}]"
return
else: # Don't retry on non-5xx status errors (like 4xx)
print("Non-retriable API status error.")
yield f"[API_ERROR: Non-retriable status {e.status_code}]"
return
except Exception as e: # Catch other non-retriable errors during setup/call
print(f"\nNon-retriable error occurred: {e}")
traceback.print_exc()
yield f"[FATAL_ERROR: {e}]"
return # Stop generator
# This part is reached only if all retries failed without returning/yielding error
print("\nStream generation failed after exhausting all retries.")
yield "[ERROR: Failed after all retries]"
def work_stream(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty):
"""工作流程的流式版本:返回文本生成器"""
# 如果提供了参考文件夹,则读取其内容
if file_folder:
print(f"Reading context from folder: {file_folder}")
context = self.read_folder(file_folder)
if context:
# Append context carefully
user_prompt = f"{user_prompt.strip()}\n\n--- 参考资料 ---\n{context.strip()}"
else:
print(f"Warning: Folder {file_folder} provided but no content read.")
# 直接返回 generate_text_stream 的生成器
print("Calling generate_text_stream...")
return self.generate_text_stream(system_prompt, user_prompt, temperature, top_p, presence_penalty)
# --- End Added Streaming Methods ---