import os from openai import OpenAI import time import random class AI_Agent(): """AI代理类,负责与AI模型交互生成文本内容""" def __init__(self, base_url, model_name, api, timeout=10, max_retries=3): self.url_list = { "ali": "https://dashscope.aliyuncs.com/compatible-mode/v1", "kimi": "https://api.moonshot.cn/v1", "doubao": "https://ark.cn-beijing.volces.com/api/v3/", "deepseek": "https://api.deepseek.com", "vllm": "http://localhost:8000/v1", } self.base_url = self.url_list[base_url] if base_url in self.url_list else base_url self.api = api self.model_name = model_name self.timeout = timeout # 设置超时时间(秒) self.max_retries = max_retries # 最大重试次数 self.client = OpenAI( api_key=self.api, base_url=self.base_url, timeout=self.timeout # 设置OpenAI客户端超时 ) def generate_text(self, system_prompt, user_prompt, temperature, top_p, presence_penalty): """生成文本内容,并返回完整响应和token估计值""" print("系统提示词:") print(system_prompt) print("\n用户提示词:") print(user_prompt) print(f"\nAPI Key: {self.api}") print(f"Base URL: {self.base_url}") print(f"Model: {self.model_name}") time.sleep(random.random()) retry_count = 0 max_retry_wait = 10 # 最大重试等待时间(秒) while retry_count <= self.max_retries: try: response = self.client.chat.completions.create( model=self.model_name, messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}], temperature=temperature, top_p=top_p, presence_penalty=presence_penalty, stream=True, max_tokens=8192, timeout=self.timeout, # 设置请求超时 extra_body={ "repetition_penalty": 1.05, }, ) # 收集完整的输出内容 full_response = "" try: for chunk in response: if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None: content = chunk.choices[0].delta.content full_response += content print(content, end="", flush=True) if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].finish_reason == "stop": break # 成功完成,跳出重试循环 break except: # 处理流式响应中的超时 print(f"\n接收响应时超时") if len(full_response) > 0: print(f"已接收部分响应({len(full_response)}字符)") # 如果已接收足够内容,可以考虑使用已有内容 if len(full_response) > 100: # 假设至少需要100个字符才有意义 print("使用已接收的部分内容继续处理") break # 否则准备重试 retry_count += 1 if retry_count <= self.max_retries: wait_time = min(2 ** retry_count + random.random(), max_retry_wait) # 指数退避 print(f"\n等待 {wait_time:.2f} 秒后重试({retry_count}/{self.max_retries})...") time.sleep(wait_time) continue except Exception as e: print(f"\n请求发生错误: {e}") retry_count += 1 if retry_count <= self.max_retries: wait_time = min(2 ** retry_count + random.random(), max_retry_wait) # 指数退避 print(f"\n等待 {wait_time:.2f} 秒后重试({retry_count}/{self.max_retries})...") time.sleep(wait_time) else: print(f"已达到最大重试次数({self.max_retries}),放弃请求") return "请求失败,无法生成内容。", 0 print("\n完成生成,正在处理结果...") # 由于使用流式输出,无法获取真实的token计数,因此返回估计值 estimated_tokens = len(full_response.split()) * 1.3 # 简单估算token数量 return full_response, estimated_tokens def read_folder(self, file_folder): """读取指定文件夹下的所有文件内容""" if not os.path.exists(file_folder): return "" context = "" for file in os.listdir(file_folder): file_path = os.path.join(file_folder, file) if os.path.isfile(file_path): with open(file_path, "r", encoding="utf-8") as f: context += f"文件名: {file}\n" for line in f.readlines(): context += line context += "\n" return context def work(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty): """完整的工作流程:生成文本并返回结果""" # 生成时间戳 date_time = time.strftime("%Y-%m-%d_%H-%M-%S") result_file = f"/root/autodl-tmp/xhsTweetGene/result/{date_time}.md" # 如果提供了参考文件夹,则读取其内容 if file_folder: context = self.read_folder(file_folder) if context: user_prompt = f"{user_prompt}\n参考资料:\n{context}" # 计时 time_start = time.time() # 生成文本 result, tokens = self.generate_text(system_prompt, user_prompt, temperature, top_p, presence_penalty) # 计算耗时 time_end = time.time() time_cost = time_end - time_start return result, system_prompt, user_prompt, file_folder, result_file, tokens, time_cost def close(self): self.client.close() ## del self.client del self