149 lines
6.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from openai import OpenAI
import time
import random
class AI_Agent():
"""AI代理类负责与AI模型交互生成文本内容"""
def __init__(self, base_url, model_name, api, timeout=10, max_retries=3):
self.url_list = {
"ali": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"kimi": "https://api.moonshot.cn/v1",
"doubao": "https://ark.cn-beijing.volces.com/api/v3/",
"deepseek": "https://api.deepseek.com",
"vllm": "http://localhost:8000/v1",
}
self.base_url = self.url_list[base_url] if base_url in self.url_list else base_url
self.api = api
self.model_name = model_name
self.timeout = timeout # 设置超时时间(秒)
self.max_retries = max_retries # 最大重试次数
self.client = OpenAI(
api_key=self.api,
base_url=self.base_url,
timeout=self.timeout # 设置OpenAI客户端超时
)
def generate_text(self, system_prompt, user_prompt, temperature, top_p, presence_penalty):
"""生成文本内容并返回完整响应和token估计值"""
print("系统提示词:")
print(system_prompt)
print("\n用户提示词:")
print(user_prompt)
print(f"\nAPI Key: {self.api}")
print(f"Base URL: {self.base_url}")
print(f"Model: {self.model_name}")
retry_count = 0
max_retry_wait = 10 # 最大重试等待时间(秒)
while retry_count <= self.max_retries:
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=[{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}],
temperature=temperature,
top_p=top_p,
# presence_penalty=presence_penalty,
stream=True,
max_tokens=8192,
timeout=self.timeout, # 设置请求超时
extra_body={
"repetition_penalty": 1.05,
},
)
# 收集完整的输出内容
full_response = ""
try:
for chunk in response:
if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None:
content = chunk.choices[0].delta.content
full_response += content
print(content, end="", flush=True)
if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].finish_reason == "stop":
break
# 成功完成,跳出重试循环
break
except:
# 处理流式响应中的超时
print(f"\n接收响应时超时")
if len(full_response) > 0:
print(f"已接收部分响应({len(full_response)}字符)")
# 如果已接收足够内容,可以考虑使用已有内容
if len(full_response) > 100: # 假设至少需要100个字符才有意义
print("使用已接收的部分内容继续处理")
break
# 否则准备重试
retry_count += 1
if retry_count <= self.max_retries:
wait_time = min(2 ** retry_count + random.random(), max_retry_wait) # 指数退避
print(f"\n等待 {wait_time:.2f} 秒后重试({retry_count}/{self.max_retries})...")
time.sleep(wait_time)
continue
except Exception as e:
print(f"\n请求发生错误: {e}")
retry_count += 1
if retry_count <= self.max_retries:
wait_time = min(2 ** retry_count + random.random(), max_retry_wait) # 指数退避
print(f"\n等待 {wait_time:.2f} 秒后重试({retry_count}/{self.max_retries})...")
time.sleep(wait_time)
else:
print(f"已达到最大重试次数({self.max_retries}),放弃请求")
return "请求失败,无法生成内容。", 0
print("\n完成生成,正在处理结果...")
# 由于使用流式输出无法获取真实的token计数因此返回估计值
estimated_tokens = len(full_response.split()) * 1.3 # 简单估算token数量
return full_response, estimated_tokens
def read_folder(self, file_folder):
"""读取指定文件夹下的所有文件内容"""
if not os.path.exists(file_folder):
return ""
context = ""
for file in os.listdir(file_folder):
file_path = os.path.join(file_folder, file)
if os.path.isfile(file_path):
with open(file_path, "r", encoding="utf-8") as f:
context += f"文件名: {file}\n"
for line in f.readlines():
context += line
context += "\n"
return context
def work(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty):
"""完整的工作流程:生成文本并返回结果"""
# 生成时间戳
date_time = time.strftime("%Y-%m-%d_%H-%M-%S")
result_file = f"/root/autodl-tmp/xhsTweetGene/result/{date_time}.md"
# 如果提供了参考文件夹,则读取其内容
if file_folder:
context = self.read_folder(file_folder)
if context:
user_prompt = f"{user_prompt}\n参考资料:\n{context}"
# 计时
time_start = time.time()
# 生成文本
result, tokens = self.generate_text(system_prompt, user_prompt, temperature, top_p, presence_penalty)
# 计算耗时
time_end = time.time()
time_cost = time_end - time_start
return result, system_prompt, user_prompt, file_folder, result_file, tokens, time_cost