#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ AI 代理模块 负责与大模型进行交互,生成文本内容 """ import os import time import logging import traceback from openai import AsyncOpenAI, APITimeoutError, APIConnectionError, RateLimitError, APIStatusError import tiktoken from typing import Optional, Tuple from core.config import AIModelConfig from core.exception import AIModelError, RetryableError, NonRetryableError logger = logging.getLogger(__name__) class AIAgent: """ AI代理类,负责与AI模型交互生成文本内容 """ def __init__(self, config: AIModelConfig): """ 初始化 AI 代理 Args: config: AI模型配置 """ self.config = config self.client = AsyncOpenAI( api_key=self.config.api_key, base_url=self.config.api_url, timeout=self.config.timeout ) # try: # self.tokenizer = tiktoken.encoding_for_model(self.config.model) # except KeyError: # logger.warning(f"模型 '{self.config.model}' 没有找到对应的tokenizer,将使用 'cl100k_base'") # self.tokenizer = tiktoken.get_encoding("cl100k_base") async def generate_text( self, system_prompt: str, user_prompt: str, use_stream: bool = False, temperature: Optional[float] = None, top_p: Optional[float] = None, presence_penalty: Optional[float] = None, stage: str = "" ) -> Tuple[str, int, int, float]: """ 生成文本 (支持流式和非流式) Args: system_prompt: 系统提示 user_prompt: 用户提示 use_stream: 是否流式返回 temperature: 温度参数,控制随机性 top_p: Top-p采样参数 presence_penalty: 存在惩罚参数 stage: 当前所处阶段,用于日志记录 Returns: 一个元组 (generated_text, input_tokens, output_tokens, time_cost) """ messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} ] # 使用传入的参数或默认配置 temp = temperature if temperature is not None else self.config.temperature tp = top_p if top_p is not None else self.config.top_p pp = presence_penalty if presence_penalty is not None else self.config.presence_penalty # 记录使用的模型参数 stage_info = f"[{stage}]" if stage else "" logger.info(f"{stage_info} 使用模型参数: temperature={temp:.2f}, top_p={tp:.2f}, presence_penalty={pp:.2f}") input_tokens = self.count_tokens(system_prompt + user_prompt) logger.info(f"{stage_info} 开始生成任务... 输入token数: {input_tokens}") last_exception = None backoff_time = 1.0 # Start with 1 second start_time = time.time() for attempt in range(self.config.max_retries): try: response = await self.client.chat.completions.create( model=self.config.model, messages=messages, temperature=temp, top_p=tp, presence_penalty=pp, stream=use_stream ) time_cost = time.time() - start_time if use_stream: # 流式处理暂时不返回token计数和时间,需要更复杂的实现 # 这里返回一个空的生成器,但实际逻辑在 _process_stream 中 # 为了统一返回类型,我们可能需要重新设计这里 # 简化处理:流式模式下,我们返回拼接后的完整文本 full_text = "".join([chunk for chunk in self._process_stream(response)]) output_tokens = self.count_tokens(full_text) logger.info(f"{stage_info} 任务完成,耗时 {time_cost:.2f} 秒. 输出token数: {output_tokens}") return full_text, input_tokens, output_tokens, time_cost else: output_text = response.choices[0].message.content.strip() output_tokens = self.count_tokens(output_text) logger.info(f"{stage_info} 任务完成,耗时 {time_cost:.2f} 秒. 输出token数: {output_tokens}") return output_text, input_tokens, output_tokens, time_cost except (APITimeoutError, APIConnectionError) as e: last_exception = RetryableError(f"AI模型连接或超时错误: {e}") logger.warning(f"{stage_info} 尝试 {attempt + 1}/{self.config.max_retries} 失败: {last_exception}. " f"将在 {backoff_time:.1f} 秒后重试...") time.sleep(backoff_time) backoff_time *= 2 # Exponential backoff except (RateLimitError, APIStatusError) as e: last_exception = NonRetryableError(f"AI模型API错误 (不可重试): {e}") logger.error(f"{stage_info} 发生不可重试的API错误: {last_exception}") break # Do not retry on these errors except Exception as e: last_exception = AIModelError(f"调用AI模型时发生未知错误: {e}") logger.error(f"{stage_info} 发生未知错误: {last_exception}\n{traceback.format_exc()}") break raise AIModelError(f"AI模型调用在 {self.config.max_retries} 次重试后失败") from last_exception def _process_stream(self, response): """处理流式响应""" full_response = [] for chunk in response: content = chunk.choices[0].delta.content if content: full_response.append(content) yield content logger.info(f"流式响应接收完成,总长度: {len(''.join(full_response))}") def count_tokens(self, text: str) -> int: """ 计算文本的token数量 Args: text: 输入文本 Returns: token数量 """ return len(text) // 1.5 # return len(self.tokenizer.encode(text)) @staticmethod def read_folder_content(folder_path: str) -> str: """ 读取指定文件夹下的所有文件内容并合并 Args: folder_path: 文件夹路径 Returns: 合并后的文件内容 """ if not os.path.exists(folder_path): logger.warning(f"引用的文件夹不存在: {folder_path}") return "" context = "" try: for file_name in sorted(os.listdir(folder_path)): file_path = os.path.join(folder_path, file_name) if os.path.isfile(file_path): try: with open(file_path, "r", encoding="utf-8") as f: context += f"--- 文件名: {file_name} ---\n" context += f.read().strip() context += "\n\n" except Exception as read_err: logger.error(f"读取文件失败 {file_path}: {read_err}") except Exception as list_err: logger.error(f"列出目录失败 {folder_path}: {list_err}") return context.strip()