107 lines
4.0 KiB
Python
Raw Normal View History

import os
from openai import OpenAI
import time
class AI_Agent():
"""AI代理类负责与AI模型交互生成文本内容"""
def __init__(self, base_url, model_name, api):
self.url_list = {
"ali": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"kimi": "https://api.moonshot.cn/v1",
"doubao": "https://ark.cn-beijing.volces.com/api/v3/",
"deepseek": "https://api.deepseek.com",
"vllm": "http://localhost:8000/v1",
}
self.base_url = self.url_list[base_url] if base_url in self.url_list else base_url
self.api = api
self.model_name = model_name
self.client = OpenAI(
api_key=self.api,
base_url=self.base_url,
# timeout=10
)
def generate_text(self, system_prompt, user_prompt, temperature, top_p, presence_penalty):
"""生成文本内容并返回完整响应和token估计值"""
print("系统提示词:")
print(system_prompt)
print("\n用户提示词:")
print(user_prompt)
print(f"\nAPI Key: {self.api}")
print(f"Base URL: {self.base_url}")
print(f"Model: {self.model_name}")
response = self.client.chat.completions.create(
model=self.model_name,
messages=[{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}],
temperature=temperature,
top_p=top_p,
# presence_penalty=presence_penalty,
stream=True,
max_tokens=8192,
extra_body={
"repetition_penalty": 1.05,
},
)
# 收集完整的输出内容
full_response = ""
for chunk in response:
if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None:
content = chunk.choices[0].delta.content
full_response += content
print(content, end="", flush=True)
if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].finish_reason == "stop":
break
print("\n完成生成,正在处理结果...")
# 由于使用流式输出无法获取真实的token计数因此返回估计值
estimated_tokens = len(full_response.split()) * 1.3 # 简单估算token数量
return full_response, estimated_tokens
def read_folder(self, file_folder):
"""读取指定文件夹下的所有文件内容"""
if not os.path.exists(file_folder):
return ""
context = ""
for file in os.listdir(file_folder):
file_path = os.path.join(file_folder, file)
if os.path.isfile(file_path):
with open(file_path, "r", encoding="utf-8") as f:
context += f"文件名: {file}\n"
for line in f.readlines():
context += line
context += "\n"
return context
def work(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty):
"""完整的工作流程:生成文本并返回结果"""
# 生成时间戳
date_time = time.strftime("%Y-%m-%d_%H-%M-%S")
result_file = f"/root/autodl-tmp/xhsTweetGene/result/{date_time}.md"
# 如果提供了参考文件夹,则读取其内容
if file_folder:
context = self.read_folder(file_folder)
if context:
user_prompt = f"{user_prompt}\n参考资料:\n{context}"
# 计时
time_start = time.time()
# 生成文本
result, tokens = self.generate_text(system_prompt, user_prompt, temperature, top_p, presence_penalty)
# 计算耗时
time_end = time.time()
time_cost = time_end - time_start
return result, system_prompt, user_prompt, file_folder, result_file, tokens, time_cost