import os from openai import OpenAI, APITimeoutError, APIConnectionError, RateLimitError, APIStatusError import time import random import traceback class AI_Agent(): """AI代理类,负责与AI模型交互生成文本内容""" def __init__(self, base_url, model_name, api, timeout=30, max_retries=3): self.url_list = { "ali": "https://dashscope.aliyuncs.com/compatible-mode/v1", "kimi": "https://api.moonshot.cn/v1", "doubao": "https://ark.cn-beijing.volces.com/api/v3/", "deepseek": "https://api.deepseek.com", "vllm": "http://localhost:8000/v1", } self.base_url = self.url_list.get(base_url, base_url) self.api = api self.model_name = model_name self.timeout = timeout self.max_retries = max_retries print(f"Initializing AI Agent with timeout={self.timeout}s, max_retries={self.max_retries}") self.client = OpenAI( api_key=self.api, base_url=self.base_url, timeout=self.timeout ) def generate_text(self, system_prompt, user_prompt, temperature, top_p, presence_penalty): """生成文本内容,并返回完整响应和token估计值""" print("系统提示词:") print(system_prompt) print("\n用户提示词:") print(user_prompt) print(f"\nAPI Key: {self.api}") print(f"Base URL: {self.base_url}") print(f"Model: {self.model_name}") time.sleep(random.random()) retry_count = 0 max_retry_wait = 10 # 最大重试等待时间(秒) while retry_count <= self.max_retries: try: response = self.client.chat.completions.create( model=self.model_name, messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}], temperature=temperature, top_p=top_p, presence_penalty=presence_penalty, stream=True, max_tokens=8192, timeout=self.timeout, # 设置请求超时 extra_body={ "repetition_penalty": 1.05, }, ) # 收集完整的输出内容 full_response = "" try: for chunk in response: if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None: content = chunk.choices[0].delta.content full_response += content print(content, end="", flush=True) if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].finish_reason == "stop": break # 成功完成,跳出重试循环 break except: # 处理流式响应中的超时 print(f"\n接收响应时超时") if len(full_response) > 0: print(f"已接收部分响应({len(full_response)}字符)") # 如果已接收足够内容,可以考虑使用已有内容 if len(full_response) > 100: # 假设至少需要100个字符才有意义 print("使用已接收的部分内容继续处理") break # 否则准备重试 retry_count += 1 if retry_count <= self.max_retries: wait_time = min(2 ** retry_count + random.random(), max_retry_wait) # 指数退避 print(f"\n等待 {wait_time:.2f} 秒后重试({retry_count}/{self.max_retries})...") time.sleep(wait_time) continue except Exception as e: print(f"\n请求发生错误: {e}") retry_count += 1 if retry_count <= self.max_retries: wait_time = min(2 ** retry_count + random.random(), max_retry_wait) # 指数退避 print(f"\n等待 {wait_time:.2f} 秒后重试({retry_count}/{self.max_retries})...") time.sleep(wait_time) else: print(f"已达到最大重试次数({self.max_retries}),放弃请求") return "请求失败,无法生成内容。", 0 print("\n完成生成,正在处理结果...") # 由于使用流式输出,无法获取真实的token计数,因此返回估计值 estimated_tokens = len(full_response.split()) * 1.3 # 简单估算token数量 return full_response, estimated_tokens def read_folder(self, file_folder): """读取指定文件夹下的所有文件内容""" if not os.path.exists(file_folder): return "" context = "" for file in os.listdir(file_folder): file_path = os.path.join(file_folder, file) if os.path.isfile(file_path): with open(file_path, "r", encoding="utf-8") as f: context += f"文件名: {file}\n" for line in f.readlines(): context += line context += "\n" return context def work(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty): """完整的工作流程:生成文本并返回结果""" # 生成时间戳 date_time = time.strftime("%Y-%m-%d_%H-%M-%S") result_file = f"/root/autodl-tmp/xhsTweetGene/result/{date_time}.md" # 如果提供了参考文件夹,则读取其内容 if file_folder: context = self.read_folder(file_folder) if context: user_prompt = f"{user_prompt}\n参考资料:\n{context}" # 计时 time_start = time.time() # 生成文本 result, tokens = self.generate_text(system_prompt, user_prompt, temperature, top_p, presence_penalty) # 计算耗时 time_end = time.time() time_cost = time_end - time_start return result, system_prompt, user_prompt, file_folder, result_file, tokens, time_cost def close(self): self.client.close() ## del self.client del self # --- Added Streaming Methods --- def generate_text_stream(self, system_prompt, user_prompt, temperature, top_p, presence_penalty): """生成文本内容,并以生成器方式 yield 文本块""" print("Streaming Generation Started...") print("System Prompt:", system_prompt[:100] + "..." if len(system_prompt) > 100 else system_prompt) # Print truncated prompts for logs print("User Prompt:", user_prompt[:100] + "..." if len(user_prompt) > 100 else user_prompt) print(f"Params: temp={temperature}, top_p={top_p}, presence_penalty={presence_penalty}") retry_count = 0 max_retry_wait = 10 # Max backoff wait time while retry_count <= self.max_retries: try: response = self.client.chat.completions.create( model=self.model_name, messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}], temperature=temperature, top_p=top_p, presence_penalty=presence_penalty, stream=True, max_tokens=8192, # Or make configurable? timeout=self.timeout, # Use configured timeout for initial connect/response extra_body={"repetition_penalty": 1.05}, # Keep if needed ) # Inner try-except specifically for handling errors during the stream try: print("Stream connected, receiving content...") for chunk in response: if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None: content = chunk.choices[0].delta.content yield content # Yield the content chunk # Check for finish reason if needed, but loop termination handles it # if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].finish_reason == "stop": # print("\nStream finished (stop reason).") # break # Exit inner loop print("\nStream finished successfully.") return # Generator successfully exhausted except APIConnectionError as stream_err: # Catch connection errors during stream print(f"\nStream connection error occurred: {stream_err}") # Decide if retryable based on type or context retry_count += 1 if retry_count <= self.max_retries: wait_time = min(2 ** retry_count + random.random(), max_retry_wait) print(f"Retrying connection ({retry_count}/{self.max_retries}), waiting {wait_time:.2f}s...") time.sleep(wait_time) continue # Continue outer loop to retry the whole API call else: print("Max retries reached after stream connection error.") yield f"[STREAM_ERROR: Max retries reached after connection error: {stream_err}]" # Yield error info return # Stop generator except Exception as stream_err: # Catch other errors during stream processing print(f"\nError occurred during stream processing: {stream_err}") traceback.print_exc() yield f"[STREAM_ERROR: {stream_err}]" # Yield error info return # Stop generator except (APITimeoutError, APIConnectionError, RateLimitError) as e: # Catch specific retriable API errors print(f"\nRetriable API error occurred: {e}") retry_count += 1 if retry_count <= self.max_retries: wait_time = min(2 ** retry_count + random.random(), max_retry_wait) print(f"Retrying API call ({retry_count}/{self.max_retries}), waiting {wait_time:.2f}s...") time.sleep(wait_time) continue # Continue outer loop else: print("Max retries reached for API errors.") yield "[API_ERROR: Max retries reached]" return # Stop generator except APIStatusError as e: # Handle 5xx server errors specifically if possible print(f"\nAPI Status Error: {e.status_code} - {e.response}") if e.status_code >= 500: # Typically retry on 5xx retry_count += 1 if retry_count <= self.max_retries: wait_time = min(2 ** retry_count + random.random(), max_retry_wait) print(f"Retrying API call ({retry_count}/{self.max_retries}) after server error, waiting {wait_time:.2f}s...") time.sleep(wait_time) continue else: print("Max retries reached after server error.") yield f"[API_ERROR: Max retries reached after server error {e.status_code}]" return else: # Don't retry on non-5xx status errors (like 4xx) print("Non-retriable API status error.") yield f"[API_ERROR: Non-retriable status {e.status_code}]" return except Exception as e: # Catch other non-retriable errors during setup/call print(f"\nNon-retriable error occurred: {e}") traceback.print_exc() yield f"[FATAL_ERROR: {e}]" return # Stop generator # This part is reached only if all retries failed without returning/yielding error print("\nStream generation failed after exhausting all retries.") yield "[ERROR: Failed after all retries]" def work_stream(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty): """工作流程的流式版本:返回文本生成器""" # 如果提供了参考文件夹,则读取其内容 if file_folder: print(f"Reading context from folder: {file_folder}") context = self.read_folder(file_folder) if context: # Append context carefully user_prompt = f"{user_prompt.strip()}\n\n--- 参考资料 ---\n{context.strip()}" else: print(f"Warning: Folder {file_folder} provided but no content read.") # 直接返回 generate_text_stream 的生成器 print("Calling generate_text_stream...") return self.generate_text_stream(system_prompt, user_prompt, temperature, top_p, presence_penalty) # --- End Added Streaming Methods ---