2025-04-17 11:05:46 +08:00
|
|
|
|
import os
|
2025-04-22 16:29:51 +08:00
|
|
|
|
from openai import OpenAI, APITimeoutError, APIConnectionError, RateLimitError, APIStatusError
|
2025-04-17 11:05:46 +08:00
|
|
|
|
import time
|
2025-04-18 11:08:54 +08:00
|
|
|
|
import random
|
2025-04-22 16:29:51 +08:00
|
|
|
|
import traceback
|
2025-04-17 11:05:46 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AI_Agent():
|
|
|
|
|
|
"""AI代理类,负责与AI模型交互生成文本内容"""
|
|
|
|
|
|
|
2025-04-22 16:21:09 +08:00
|
|
|
|
def __init__(self, base_url, model_name, api, timeout=30, max_retries=3):
|
2025-04-17 11:05:46 +08:00
|
|
|
|
self.url_list = {
|
|
|
|
|
|
"ali": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
|
|
|
|
|
"kimi": "https://api.moonshot.cn/v1",
|
|
|
|
|
|
"doubao": "https://ark.cn-beijing.volces.com/api/v3/",
|
|
|
|
|
|
"deepseek": "https://api.deepseek.com",
|
|
|
|
|
|
"vllm": "http://localhost:8000/v1",
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-04-22 16:21:09 +08:00
|
|
|
|
self.base_url = self.url_list.get(base_url, base_url)
|
2025-04-17 11:05:46 +08:00
|
|
|
|
self.api = api
|
|
|
|
|
|
self.model_name = model_name
|
2025-04-22 16:21:09 +08:00
|
|
|
|
self.timeout = timeout
|
|
|
|
|
|
self.max_retries = max_retries
|
|
|
|
|
|
|
|
|
|
|
|
print(f"Initializing AI Agent with timeout={self.timeout}s, max_retries={self.max_retries}")
|
2025-04-17 11:05:46 +08:00
|
|
|
|
|
|
|
|
|
|
self.client = OpenAI(
|
|
|
|
|
|
api_key=self.api,
|
|
|
|
|
|
base_url=self.base_url,
|
2025-04-22 16:21:09 +08:00
|
|
|
|
timeout=self.timeout
|
2025-04-17 11:05:46 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def generate_text(self, system_prompt, user_prompt, temperature, top_p, presence_penalty):
|
|
|
|
|
|
"""生成文本内容,并返回完整响应和token估计值"""
|
|
|
|
|
|
print("系统提示词:")
|
|
|
|
|
|
print(system_prompt)
|
|
|
|
|
|
print("\n用户提示词:")
|
|
|
|
|
|
print(user_prompt)
|
|
|
|
|
|
print(f"\nAPI Key: {self.api}")
|
|
|
|
|
|
print(f"Base URL: {self.base_url}")
|
|
|
|
|
|
print(f"Model: {self.model_name}")
|
2025-04-18 15:52:31 +08:00
|
|
|
|
time.sleep(random.random())
|
2025-04-18 11:08:54 +08:00
|
|
|
|
retry_count = 0
|
|
|
|
|
|
max_retry_wait = 10 # 最大重试等待时间(秒)
|
2025-04-17 11:05:46 +08:00
|
|
|
|
|
2025-04-18 11:08:54 +08:00
|
|
|
|
while retry_count <= self.max_retries:
|
|
|
|
|
|
try:
|
|
|
|
|
|
response = self.client.chat.completions.create(
|
|
|
|
|
|
model=self.model_name,
|
|
|
|
|
|
messages=[{"role": "system", "content": system_prompt},
|
|
|
|
|
|
{"role": "user", "content": user_prompt}],
|
|
|
|
|
|
temperature=temperature,
|
|
|
|
|
|
top_p=top_p,
|
2025-04-22 15:25:12 +08:00
|
|
|
|
presence_penalty=presence_penalty,
|
2025-04-18 11:08:54 +08:00
|
|
|
|
stream=True,
|
|
|
|
|
|
max_tokens=8192,
|
|
|
|
|
|
timeout=self.timeout, # 设置请求超时
|
|
|
|
|
|
extra_body={
|
|
|
|
|
|
"repetition_penalty": 1.05,
|
|
|
|
|
|
},
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 收集完整的输出内容
|
|
|
|
|
|
full_response = ""
|
|
|
|
|
|
try:
|
|
|
|
|
|
for chunk in response:
|
|
|
|
|
|
if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None:
|
|
|
|
|
|
content = chunk.choices[0].delta.content
|
|
|
|
|
|
full_response += content
|
|
|
|
|
|
print(content, end="", flush=True)
|
|
|
|
|
|
if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].finish_reason == "stop":
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
# 成功完成,跳出重试循环
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
except:
|
|
|
|
|
|
# 处理流式响应中的超时
|
|
|
|
|
|
print(f"\n接收响应时超时")
|
|
|
|
|
|
if len(full_response) > 0:
|
|
|
|
|
|
print(f"已接收部分响应({len(full_response)}字符)")
|
|
|
|
|
|
# 如果已接收足够内容,可以考虑使用已有内容
|
|
|
|
|
|
if len(full_response) > 100: # 假设至少需要100个字符才有意义
|
|
|
|
|
|
print("使用已接收的部分内容继续处理")
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
# 否则准备重试
|
|
|
|
|
|
retry_count += 1
|
|
|
|
|
|
if retry_count <= self.max_retries:
|
|
|
|
|
|
wait_time = min(2 ** retry_count + random.random(), max_retry_wait) # 指数退避
|
|
|
|
|
|
print(f"\n等待 {wait_time:.2f} 秒后重试({retry_count}/{self.max_retries})...")
|
|
|
|
|
|
time.sleep(wait_time)
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"\n请求发生错误: {e}")
|
|
|
|
|
|
retry_count += 1
|
|
|
|
|
|
if retry_count <= self.max_retries:
|
|
|
|
|
|
wait_time = min(2 ** retry_count + random.random(), max_retry_wait) # 指数退避
|
|
|
|
|
|
print(f"\n等待 {wait_time:.2f} 秒后重试({retry_count}/{self.max_retries})...")
|
|
|
|
|
|
time.sleep(wait_time)
|
|
|
|
|
|
else:
|
|
|
|
|
|
print(f"已达到最大重试次数({self.max_retries}),放弃请求")
|
|
|
|
|
|
return "请求失败,无法生成内容。", 0
|
2025-04-17 11:05:46 +08:00
|
|
|
|
|
|
|
|
|
|
print("\n完成生成,正在处理结果...")
|
|
|
|
|
|
|
|
|
|
|
|
# 由于使用流式输出,无法获取真实的token计数,因此返回估计值
|
|
|
|
|
|
estimated_tokens = len(full_response.split()) * 1.3 # 简单估算token数量
|
|
|
|
|
|
|
|
|
|
|
|
return full_response, estimated_tokens
|
|
|
|
|
|
|
|
|
|
|
|
def read_folder(self, file_folder):
|
|
|
|
|
|
"""读取指定文件夹下的所有文件内容"""
|
|
|
|
|
|
if not os.path.exists(file_folder):
|
|
|
|
|
|
return ""
|
|
|
|
|
|
|
|
|
|
|
|
context = ""
|
|
|
|
|
|
for file in os.listdir(file_folder):
|
|
|
|
|
|
file_path = os.path.join(file_folder, file)
|
|
|
|
|
|
if os.path.isfile(file_path):
|
|
|
|
|
|
with open(file_path, "r", encoding="utf-8") as f:
|
|
|
|
|
|
context += f"文件名: {file}\n"
|
|
|
|
|
|
for line in f.readlines():
|
|
|
|
|
|
context += line
|
|
|
|
|
|
context += "\n"
|
|
|
|
|
|
return context
|
|
|
|
|
|
|
|
|
|
|
|
def work(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty):
|
|
|
|
|
|
"""完整的工作流程:生成文本并返回结果"""
|
|
|
|
|
|
# 生成时间戳
|
|
|
|
|
|
date_time = time.strftime("%Y-%m-%d_%H-%M-%S")
|
|
|
|
|
|
result_file = f"/root/autodl-tmp/xhsTweetGene/result/{date_time}.md"
|
|
|
|
|
|
|
|
|
|
|
|
# 如果提供了参考文件夹,则读取其内容
|
|
|
|
|
|
if file_folder:
|
|
|
|
|
|
context = self.read_folder(file_folder)
|
|
|
|
|
|
if context:
|
|
|
|
|
|
user_prompt = f"{user_prompt}\n参考资料:\n{context}"
|
|
|
|
|
|
|
|
|
|
|
|
# 计时
|
|
|
|
|
|
time_start = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
# 生成文本
|
|
|
|
|
|
result, tokens = self.generate_text(system_prompt, user_prompt, temperature, top_p, presence_penalty)
|
|
|
|
|
|
|
|
|
|
|
|
# 计算耗时
|
|
|
|
|
|
time_end = time.time()
|
|
|
|
|
|
time_cost = time_end - time_start
|
|
|
|
|
|
|
2025-04-18 15:52:31 +08:00
|
|
|
|
return result, system_prompt, user_prompt, file_folder, result_file, tokens, time_cost
|
|
|
|
|
|
|
|
|
|
|
|
def close(self):
|
|
|
|
|
|
self.client.close()
|
|
|
|
|
|
## del self.client
|
2025-04-22 16:29:51 +08:00
|
|
|
|
del self
|
|
|
|
|
|
|
|
|
|
|
|
# --- Added Streaming Methods ---
|
|
|
|
|
|
def generate_text_stream(self, system_prompt, user_prompt, temperature, top_p, presence_penalty):
|
|
|
|
|
|
"""生成文本内容,并以生成器方式 yield 文本块"""
|
|
|
|
|
|
print("Streaming Generation Started...")
|
|
|
|
|
|
print("System Prompt:", system_prompt[:100] + "..." if len(system_prompt) > 100 else system_prompt) # Print truncated prompts for logs
|
|
|
|
|
|
print("User Prompt:", user_prompt[:100] + "..." if len(user_prompt) > 100 else user_prompt)
|
|
|
|
|
|
print(f"Params: temp={temperature}, top_p={top_p}, presence_penalty={presence_penalty}")
|
|
|
|
|
|
|
|
|
|
|
|
retry_count = 0
|
|
|
|
|
|
max_retry_wait = 10 # Max backoff wait time
|
|
|
|
|
|
|
|
|
|
|
|
while retry_count <= self.max_retries:
|
|
|
|
|
|
try:
|
|
|
|
|
|
response = self.client.chat.completions.create(
|
|
|
|
|
|
model=self.model_name,
|
|
|
|
|
|
messages=[{"role": "system", "content": system_prompt},
|
|
|
|
|
|
{"role": "user", "content": user_prompt}],
|
|
|
|
|
|
temperature=temperature,
|
|
|
|
|
|
top_p=top_p,
|
|
|
|
|
|
presence_penalty=presence_penalty,
|
|
|
|
|
|
stream=True,
|
|
|
|
|
|
max_tokens=8192, # Or make configurable?
|
|
|
|
|
|
timeout=self.timeout, # Use configured timeout for initial connect/response
|
|
|
|
|
|
extra_body={"repetition_penalty": 1.05}, # Keep if needed
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Inner try-except specifically for handling errors during the stream
|
|
|
|
|
|
try:
|
|
|
|
|
|
print("Stream connected, receiving content...")
|
|
|
|
|
|
for chunk in response:
|
|
|
|
|
|
if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None:
|
|
|
|
|
|
content = chunk.choices[0].delta.content
|
|
|
|
|
|
yield content # Yield the content chunk
|
|
|
|
|
|
# Check for finish reason if needed, but loop termination handles it
|
|
|
|
|
|
# if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].finish_reason == "stop":
|
|
|
|
|
|
# print("\nStream finished (stop reason).")
|
|
|
|
|
|
# break # Exit inner loop
|
|
|
|
|
|
|
|
|
|
|
|
print("\nStream finished successfully.")
|
|
|
|
|
|
return # Generator successfully exhausted
|
|
|
|
|
|
|
|
|
|
|
|
except APIConnectionError as stream_err: # Catch connection errors during stream
|
|
|
|
|
|
print(f"\nStream connection error occurred: {stream_err}")
|
|
|
|
|
|
# Decide if retryable based on type or context
|
|
|
|
|
|
retry_count += 1
|
|
|
|
|
|
if retry_count <= self.max_retries:
|
|
|
|
|
|
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
|
|
|
|
|
|
print(f"Retrying connection ({retry_count}/{self.max_retries}), waiting {wait_time:.2f}s...")
|
|
|
|
|
|
time.sleep(wait_time)
|
|
|
|
|
|
continue # Continue outer loop to retry the whole API call
|
|
|
|
|
|
else:
|
|
|
|
|
|
print("Max retries reached after stream connection error.")
|
|
|
|
|
|
yield f"[STREAM_ERROR: Max retries reached after connection error: {stream_err}]" # Yield error info
|
|
|
|
|
|
return # Stop generator
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as stream_err: # Catch other errors during stream processing
|
|
|
|
|
|
print(f"\nError occurred during stream processing: {stream_err}")
|
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
|
yield f"[STREAM_ERROR: {stream_err}]" # Yield error info
|
|
|
|
|
|
return # Stop generator
|
|
|
|
|
|
|
|
|
|
|
|
except (APITimeoutError, APIConnectionError, RateLimitError) as e: # Catch specific retriable API errors
|
|
|
|
|
|
print(f"\nRetriable API error occurred: {e}")
|
|
|
|
|
|
retry_count += 1
|
|
|
|
|
|
if retry_count <= self.max_retries:
|
|
|
|
|
|
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
|
|
|
|
|
|
print(f"Retrying API call ({retry_count}/{self.max_retries}), waiting {wait_time:.2f}s...")
|
|
|
|
|
|
time.sleep(wait_time)
|
|
|
|
|
|
continue # Continue outer loop
|
|
|
|
|
|
else:
|
|
|
|
|
|
print("Max retries reached for API errors.")
|
|
|
|
|
|
yield "[API_ERROR: Max retries reached]"
|
|
|
|
|
|
return # Stop generator
|
|
|
|
|
|
|
|
|
|
|
|
except APIStatusError as e: # Handle 5xx server errors specifically if possible
|
|
|
|
|
|
print(f"\nAPI Status Error: {e.status_code} - {e.response}")
|
|
|
|
|
|
if e.status_code >= 500: # Typically retry on 5xx
|
|
|
|
|
|
retry_count += 1
|
|
|
|
|
|
if retry_count <= self.max_retries:
|
|
|
|
|
|
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
|
|
|
|
|
|
print(f"Retrying API call ({retry_count}/{self.max_retries}) after server error, waiting {wait_time:.2f}s...")
|
|
|
|
|
|
time.sleep(wait_time)
|
|
|
|
|
|
continue
|
|
|
|
|
|
else:
|
|
|
|
|
|
print("Max retries reached after server error.")
|
|
|
|
|
|
yield f"[API_ERROR: Max retries reached after server error {e.status_code}]"
|
|
|
|
|
|
return
|
|
|
|
|
|
else: # Don't retry on non-5xx status errors (like 4xx)
|
|
|
|
|
|
print("Non-retriable API status error.")
|
|
|
|
|
|
yield f"[API_ERROR: Non-retriable status {e.status_code}]"
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: # Catch other non-retriable errors during setup/call
|
|
|
|
|
|
print(f"\nNon-retriable error occurred: {e}")
|
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
|
yield f"[FATAL_ERROR: {e}]"
|
|
|
|
|
|
return # Stop generator
|
|
|
|
|
|
|
|
|
|
|
|
# This part is reached only if all retries failed without returning/yielding error
|
|
|
|
|
|
print("\nStream generation failed after exhausting all retries.")
|
|
|
|
|
|
yield "[ERROR: Failed after all retries]"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def work_stream(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty):
|
|
|
|
|
|
"""工作流程的流式版本:返回文本生成器"""
|
|
|
|
|
|
# 如果提供了参考文件夹,则读取其内容
|
|
|
|
|
|
if file_folder:
|
|
|
|
|
|
print(f"Reading context from folder: {file_folder}")
|
|
|
|
|
|
context = self.read_folder(file_folder)
|
|
|
|
|
|
if context:
|
|
|
|
|
|
# Append context carefully
|
|
|
|
|
|
user_prompt = f"{user_prompt.strip()}\n\n--- 参考资料 ---\n{context.strip()}"
|
|
|
|
|
|
else:
|
|
|
|
|
|
print(f"Warning: Folder {file_folder} provided but no content read.")
|
|
|
|
|
|
|
|
|
|
|
|
# 直接返回 generate_text_stream 的生成器
|
|
|
|
|
|
print("Calling generate_text_stream...")
|
|
|
|
|
|
return self.generate_text_stream(system_prompt, user_prompt, temperature, top_p, presence_penalty)
|
|
|
|
|
|
# --- End Added Streaming Methods ---
|