186 lines
7.2 KiB
Python
186 lines
7.2 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
AI 代理模块
|
||
负责与大模型进行交互,生成文本内容
|
||
"""
|
||
|
||
import os
|
||
import time
|
||
import logging
|
||
import traceback
|
||
from openai import AsyncOpenAI, APITimeoutError, APIConnectionError, RateLimitError, APIStatusError
|
||
import tiktoken
|
||
from typing import Optional, Tuple
|
||
|
||
from core.config import AIModelConfig
|
||
from core.exception import AIModelError, RetryableError, NonRetryableError
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class AIAgent:
|
||
"""
|
||
AI代理类,负责与AI模型交互生成文本内容
|
||
"""
|
||
|
||
def __init__(self, config: AIModelConfig):
|
||
"""
|
||
初始化 AI 代理
|
||
|
||
Args:
|
||
config: AI模型配置
|
||
"""
|
||
self.config = config
|
||
self.client = AsyncOpenAI(
|
||
api_key=self.config.api_key,
|
||
base_url=self.config.api_url,
|
||
timeout=self.config.timeout
|
||
)
|
||
# try:
|
||
# self.tokenizer = tiktoken.encoding_for_model(self.config.model)
|
||
# except KeyError:
|
||
# logger.warning(f"模型 '{self.config.model}' 没有找到对应的tokenizer,将使用 'cl100k_base'")
|
||
# self.tokenizer = tiktoken.get_encoding("cl100k_base")
|
||
|
||
async def generate_text(
|
||
self, system_prompt: str, user_prompt: str, use_stream: bool = False,
|
||
temperature: Optional[float] = None, top_p: Optional[float] = None,
|
||
presence_penalty: Optional[float] = None, stage: str = ""
|
||
) -> Tuple[str, int, int, float]:
|
||
"""
|
||
生成文本 (支持流式和非流式)
|
||
|
||
Args:
|
||
system_prompt: 系统提示
|
||
user_prompt: 用户提示
|
||
use_stream: 是否流式返回
|
||
temperature: 温度参数,控制随机性
|
||
top_p: Top-p采样参数
|
||
presence_penalty: 存在惩罚参数
|
||
stage: 当前所处阶段,用于日志记录
|
||
|
||
Returns:
|
||
一个元组 (generated_text, input_tokens, output_tokens, time_cost)
|
||
"""
|
||
messages = [
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_prompt}
|
||
]
|
||
|
||
# 使用传入的参数或默认配置
|
||
temp = temperature if temperature is not None else self.config.temperature
|
||
tp = top_p if top_p is not None else self.config.top_p
|
||
pp = presence_penalty if presence_penalty is not None else self.config.presence_penalty
|
||
|
||
# 记录使用的模型参数
|
||
stage_info = f"[{stage}]" if stage else ""
|
||
logger.info(f"{stage_info} 使用模型参数: temperature={temp:.2f}, top_p={tp:.2f}, presence_penalty={pp:.2f}")
|
||
|
||
input_tokens = self.count_tokens(system_prompt + user_prompt)
|
||
logger.info(f"{stage_info} 开始生成任务... 输入token数: {input_tokens}")
|
||
|
||
last_exception = None
|
||
backoff_time = 1.0 # Start with 1 second
|
||
start_time = time.time()
|
||
|
||
for attempt in range(self.config.max_retries):
|
||
try:
|
||
response = await self.client.chat.completions.create(
|
||
model=self.config.model,
|
||
messages=messages,
|
||
temperature=temp,
|
||
top_p=tp,
|
||
presence_penalty=pp,
|
||
stream=use_stream
|
||
)
|
||
|
||
time_cost = time.time() - start_time
|
||
|
||
if use_stream:
|
||
# 流式处理需要异步迭代
|
||
full_text = await self._process_stream(response)
|
||
output_tokens = self.count_tokens(full_text)
|
||
logger.info(f"{stage_info} 任务完成,耗时 {time_cost:.2f} 秒. 输出token数: {output_tokens}")
|
||
return full_text, input_tokens, output_tokens, time_cost
|
||
else:
|
||
output_text = response.choices[0].message.content.strip()
|
||
output_tokens = self.count_tokens(output_text)
|
||
logger.info(f"{stage_info} 任务完成,耗时 {time_cost:.2f} 秒. 输出token数: {output_tokens}")
|
||
return output_text, input_tokens, output_tokens, time_cost
|
||
|
||
except (APITimeoutError, APIConnectionError) as e:
|
||
last_exception = RetryableError(f"AI模型连接或超时错误: {e}")
|
||
logger.warning(f"{stage_info} 尝试 {attempt + 1}/{self.config.max_retries} 失败: {last_exception}. "
|
||
f"将在 {backoff_time:.1f} 秒后重试...")
|
||
time.sleep(backoff_time)
|
||
backoff_time *= 2 # Exponential backoff
|
||
except (RateLimitError, APIStatusError) as e:
|
||
last_exception = NonRetryableError(f"AI模型API错误 (不可重试): {e}")
|
||
logger.error(f"{stage_info} 发生不可重试的API错误: {last_exception}")
|
||
break # Do not retry on these errors
|
||
except Exception as e:
|
||
last_exception = AIModelError(f"调用AI模型时发生未知错误: {e}")
|
||
logger.error(f"{stage_info} 发生未知错误: {last_exception}\n{traceback.format_exc()}")
|
||
break
|
||
|
||
raise AIModelError(f"AI模型调用在 {self.config.max_retries} 次重试后失败") from last_exception
|
||
|
||
async def _process_stream(self, response):
|
||
"""异步处理流式响应"""
|
||
full_response = []
|
||
async for chunk in response:
|
||
content = chunk.choices[0].delta.content
|
||
if content:
|
||
full_response.append(content)
|
||
# 如果需要在这里实现真正的流式处理,可以使用回调函数或其他方式
|
||
|
||
full_text = "".join(full_response)
|
||
logger.info(f"流式响应接收完成,总长度: {len(full_text)}")
|
||
return full_text
|
||
|
||
def count_tokens(self, text: str) -> int:
|
||
"""
|
||
计算文本的token数量
|
||
|
||
Args:
|
||
text: 输入文本
|
||
|
||
Returns:
|
||
token数量
|
||
"""
|
||
return len(text) // 1.5
|
||
# return len(self.tokenizer.encode(text))
|
||
|
||
@staticmethod
|
||
def read_folder_content(folder_path: str) -> str:
|
||
"""
|
||
读取指定文件夹下的所有文件内容并合并
|
||
|
||
Args:
|
||
folder_path: 文件夹路径
|
||
|
||
Returns:
|
||
合并后的文件内容
|
||
"""
|
||
if not os.path.exists(folder_path):
|
||
logger.warning(f"引用的文件夹不存在: {folder_path}")
|
||
return ""
|
||
|
||
context = ""
|
||
try:
|
||
for file_name in sorted(os.listdir(folder_path)):
|
||
file_path = os.path.join(folder_path, file_name)
|
||
if os.path.isfile(file_path):
|
||
try:
|
||
with open(file_path, "r", encoding="utf-8") as f:
|
||
context += f"--- 文件名: {file_name} ---\n"
|
||
context += f.read().strip()
|
||
context += "\n\n"
|
||
except Exception as read_err:
|
||
logger.error(f"读取文件失败 {file_path}: {read_err}")
|
||
except Exception as list_err:
|
||
logger.error(f"列出目录失败 {folder_path}: {list_err}")
|
||
return context.strip()
|