186 lines
7.2 KiB
Python
186 lines
7.2 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
|
|||
|
|
"""
|
|||
|
|
AI 代理模块
|
|||
|
|
负责与大模型进行交互,生成文本内容
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
import time
|
|||
|
|
import logging
|
|||
|
|
import traceback
|
|||
|
|
from openai import AsyncOpenAI, APITimeoutError, APIConnectionError, RateLimitError, APIStatusError
|
|||
|
|
import tiktoken
|
|||
|
|
from typing import Optional, Tuple
|
|||
|
|
|
|||
|
|
from core.config import AIModelConfig
|
|||
|
|
from core.exception import AIModelError, RetryableError, NonRetryableError
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class AIAgent:
|
|||
|
|
"""
|
|||
|
|
AI代理类,负责与AI模型交互生成文本内容
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, config: AIModelConfig):
|
|||
|
|
"""
|
|||
|
|
初始化 AI 代理
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
config: AI模型配置
|
|||
|
|
"""
|
|||
|
|
self.config = config
|
|||
|
|
self.client = AsyncOpenAI(
|
|||
|
|
api_key=self.config.api_key,
|
|||
|
|
base_url=self.config.api_url,
|
|||
|
|
timeout=self.config.timeout
|
|||
|
|
)
|
|||
|
|
# try:
|
|||
|
|
# self.tokenizer = tiktoken.encoding_for_model(self.config.model)
|
|||
|
|
# except KeyError:
|
|||
|
|
# logger.warning(f"模型 '{self.config.model}' 没有找到对应的tokenizer,将使用 'cl100k_base'")
|
|||
|
|
# self.tokenizer = tiktoken.get_encoding("cl100k_base")
|
|||
|
|
|
|||
|
|
async def generate_text(
|
|||
|
|
self, system_prompt: str, user_prompt: str, use_stream: bool = False,
|
|||
|
|
temperature: Optional[float] = None, top_p: Optional[float] = None,
|
|||
|
|
presence_penalty: Optional[float] = None, stage: str = ""
|
|||
|
|
) -> Tuple[str, int, int, float]:
|
|||
|
|
"""
|
|||
|
|
生成文本 (支持流式和非流式)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
system_prompt: 系统提示
|
|||
|
|
user_prompt: 用户提示
|
|||
|
|
use_stream: 是否流式返回
|
|||
|
|
temperature: 温度参数,控制随机性
|
|||
|
|
top_p: Top-p采样参数
|
|||
|
|
presence_penalty: 存在惩罚参数
|
|||
|
|
stage: 当前所处阶段,用于日志记录
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
一个元组 (generated_text, input_tokens, output_tokens, time_cost)
|
|||
|
|
"""
|
|||
|
|
messages = [
|
|||
|
|
{"role": "system", "content": system_prompt},
|
|||
|
|
{"role": "user", "content": user_prompt}
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
# 使用传入的参数或默认配置
|
|||
|
|
temp = temperature if temperature is not None else self.config.temperature
|
|||
|
|
tp = top_p if top_p is not None else self.config.top_p
|
|||
|
|
pp = presence_penalty if presence_penalty is not None else self.config.presence_penalty
|
|||
|
|
|
|||
|
|
# 记录使用的模型参数
|
|||
|
|
stage_info = f"[{stage}]" if stage else ""
|
|||
|
|
logger.info(f"{stage_info} 使用模型参数: temperature={temp:.2f}, top_p={tp:.2f}, presence_penalty={pp:.2f}")
|
|||
|
|
|
|||
|
|
input_tokens = self.count_tokens(system_prompt + user_prompt)
|
|||
|
|
logger.info(f"{stage_info} 开始生成任务... 输入token数: {input_tokens}")
|
|||
|
|
|
|||
|
|
last_exception = None
|
|||
|
|
backoff_time = 1.0 # Start with 1 second
|
|||
|
|
start_time = time.time()
|
|||
|
|
|
|||
|
|
for attempt in range(self.config.max_retries):
|
|||
|
|
try:
|
|||
|
|
response = await self.client.chat.completions.create(
|
|||
|
|
model=self.config.model,
|
|||
|
|
messages=messages,
|
|||
|
|
temperature=temp,
|
|||
|
|
top_p=tp,
|
|||
|
|
presence_penalty=pp,
|
|||
|
|
stream=use_stream
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
time_cost = time.time() - start_time
|
|||
|
|
|
|||
|
|
if use_stream:
|
|||
|
|
# 流式处理需要异步迭代
|
|||
|
|
full_text = await self._process_stream(response)
|
|||
|
|
output_tokens = self.count_tokens(full_text)
|
|||
|
|
logger.info(f"{stage_info} 任务完成,耗时 {time_cost:.2f} 秒. 输出token数: {output_tokens}")
|
|||
|
|
return full_text, input_tokens, output_tokens, time_cost
|
|||
|
|
else:
|
|||
|
|
output_text = response.choices[0].message.content.strip()
|
|||
|
|
output_tokens = self.count_tokens(output_text)
|
|||
|
|
logger.info(f"{stage_info} 任务完成,耗时 {time_cost:.2f} 秒. 输出token数: {output_tokens}")
|
|||
|
|
return output_text, input_tokens, output_tokens, time_cost
|
|||
|
|
|
|||
|
|
except (APITimeoutError, APIConnectionError) as e:
|
|||
|
|
last_exception = RetryableError(f"AI模型连接或超时错误: {e}")
|
|||
|
|
logger.warning(f"{stage_info} 尝试 {attempt + 1}/{self.config.max_retries} 失败: {last_exception}. "
|
|||
|
|
f"将在 {backoff_time:.1f} 秒后重试...")
|
|||
|
|
time.sleep(backoff_time)
|
|||
|
|
backoff_time *= 2 # Exponential backoff
|
|||
|
|
except (RateLimitError, APIStatusError) as e:
|
|||
|
|
last_exception = NonRetryableError(f"AI模型API错误 (不可重试): {e}")
|
|||
|
|
logger.error(f"{stage_info} 发生不可重试的API错误: {last_exception}")
|
|||
|
|
break # Do not retry on these errors
|
|||
|
|
except Exception as e:
|
|||
|
|
last_exception = AIModelError(f"调用AI模型时发生未知错误: {e}")
|
|||
|
|
logger.error(f"{stage_info} 发生未知错误: {last_exception}\n{traceback.format_exc()}")
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
raise AIModelError(f"AI模型调用在 {self.config.max_retries} 次重试后失败") from last_exception
|
|||
|
|
|
|||
|
|
async def _process_stream(self, response):
|
|||
|
|
"""异步处理流式响应"""
|
|||
|
|
full_response = []
|
|||
|
|
async for chunk in response:
|
|||
|
|
content = chunk.choices[0].delta.content
|
|||
|
|
if content:
|
|||
|
|
full_response.append(content)
|
|||
|
|
# 如果需要在这里实现真正的流式处理,可以使用回调函数或其他方式
|
|||
|
|
|
|||
|
|
full_text = "".join(full_response)
|
|||
|
|
logger.info(f"流式响应接收完成,总长度: {len(full_text)}")
|
|||
|
|
return full_text
|
|||
|
|
|
|||
|
|
def count_tokens(self, text: str) -> int:
|
|||
|
|
"""
|
|||
|
|
计算文本的token数量
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
text: 输入文本
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
token数量
|
|||
|
|
"""
|
|||
|
|
return len(text) // 1.5
|
|||
|
|
# return len(self.tokenizer.encode(text))
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def read_folder_content(folder_path: str) -> str:
|
|||
|
|
"""
|
|||
|
|
读取指定文件夹下的所有文件内容并合并
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
folder_path: 文件夹路径
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
合并后的文件内容
|
|||
|
|
"""
|
|||
|
|
if not os.path.exists(folder_path):
|
|||
|
|
logger.warning(f"引用的文件夹不存在: {folder_path}")
|
|||
|
|
return ""
|
|||
|
|
|
|||
|
|
context = ""
|
|||
|
|
try:
|
|||
|
|
for file_name in sorted(os.listdir(folder_path)):
|
|||
|
|
file_path = os.path.join(folder_path, file_name)
|
|||
|
|
if os.path.isfile(file_path):
|
|||
|
|
try:
|
|||
|
|
with open(file_path, "r", encoding="utf-8") as f:
|
|||
|
|
context += f"--- 文件名: {file_name} ---\n"
|
|||
|
|
context += f.read().strip()
|
|||
|
|
context += "\n\n"
|
|||
|
|
except Exception as read_err:
|
|||
|
|
logger.error(f"读取文件失败 {file_path}: {read_err}")
|
|||
|
|
except Exception as list_err:
|
|||
|
|
logger.error(f"列出目录失败 {folder_path}: {list_err}")
|
|||
|
|
return context.strip()
|