107 lines
4.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from openai import OpenAI
import time
class AI_Agent():
"""AI代理类负责与AI模型交互生成文本内容"""
def __init__(self, base_url, model_name, api):
self.url_list = {
"ali": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"kimi": "https://api.moonshot.cn/v1",
"doubao": "https://ark.cn-beijing.volces.com/api/v3/",
"deepseek": "https://api.deepseek.com",
"vllm": "http://localhost:8000/v1",
}
self.base_url = self.url_list[base_url] if base_url in self.url_list else base_url
self.api = api
self.model_name = model_name
self.client = OpenAI(
api_key=self.api,
base_url=self.base_url,
# timeout=10
)
def generate_text(self, system_prompt, user_prompt, temperature, top_p, presence_penalty):
"""生成文本内容并返回完整响应和token估计值"""
print("系统提示词:")
print(system_prompt)
print("\n用户提示词:")
print(user_prompt)
print(f"\nAPI Key: {self.api}")
print(f"Base URL: {self.base_url}")
print(f"Model: {self.model_name}")
response = self.client.chat.completions.create(
model=self.model_name,
messages=[{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}],
temperature=temperature,
top_p=top_p,
# presence_penalty=presence_penalty,
stream=True,
max_tokens=8192,
extra_body={
"repetition_penalty": 1.05,
},
)
# 收集完整的输出内容
full_response = ""
for chunk in response:
if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None:
content = chunk.choices[0].delta.content
full_response += content
print(content, end="", flush=True)
if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].finish_reason == "stop":
break
print("\n完成生成,正在处理结果...")
# 由于使用流式输出无法获取真实的token计数因此返回估计值
estimated_tokens = len(full_response.split()) * 1.3 # 简单估算token数量
return full_response, estimated_tokens
def read_folder(self, file_folder):
"""读取指定文件夹下的所有文件内容"""
if not os.path.exists(file_folder):
return ""
context = ""
for file in os.listdir(file_folder):
file_path = os.path.join(file_folder, file)
if os.path.isfile(file_path):
with open(file_path, "r", encoding="utf-8") as f:
context += f"文件名: {file}\n"
for line in f.readlines():
context += line
context += "\n"
return context
def work(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty):
"""完整的工作流程:生成文本并返回结果"""
# 生成时间戳
date_time = time.strftime("%Y-%m-%d_%H-%M-%S")
result_file = f"/root/autodl-tmp/xhsTweetGene/result/{date_time}.md"
# 如果提供了参考文件夹,则读取其内容
if file_folder:
context = self.read_folder(file_folder)
if context:
user_prompt = f"{user_prompt}\n参考资料:\n{context}"
# 计时
time_start = time.time()
# 生成文本
result, tokens = self.generate_text(system_prompt, user_prompt, temperature, top_p, presence_penalty)
# 计算耗时
time_end = time.time()
time_cost = time_end - time_start
return result, system_prompt, user_prompt, file_folder, result_file, tokens, time_cost