187 lines
7.4 KiB
Python
Raw Normal View History

2025-07-08 17:45:40 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
AI 代理模块
负责与大模型进行交互生成文本内容
"""
import os
import time
import logging
import traceback
from openai import AsyncOpenAI, APITimeoutError, APIConnectionError, RateLimitError, APIStatusError
2025-07-08 17:45:40 +08:00
import tiktoken
from typing import Optional, Tuple
2025-07-08 17:45:40 +08:00
from core.config import AIModelConfig
from core.exception import AIModelError, RetryableError, NonRetryableError
logger = logging.getLogger(__name__)
class AIAgent:
"""
AI代理类负责与AI模型交互生成文本内容
"""
def __init__(self, config: AIModelConfig):
"""
初始化 AI 代理
Args:
config: AI模型配置
"""
self.config = config
self.client = AsyncOpenAI(
2025-07-08 17:45:40 +08:00
api_key=self.config.api_key,
base_url=self.config.api_url,
timeout=self.config.timeout
)
# try:
# self.tokenizer = tiktoken.encoding_for_model(self.config.model)
# except KeyError:
# logger.warning(f"模型 '{self.config.model}' 没有找到对应的tokenizer将使用 'cl100k_base'")
# self.tokenizer = tiktoken.get_encoding("cl100k_base")
async def generate_text(
self, system_prompt: str, user_prompt: str, use_stream: bool = False,
temperature: Optional[float] = None, top_p: Optional[float] = None,
presence_penalty: Optional[float] = None, stage: str = ""
) -> Tuple[str, int, int, float]:
2025-07-08 17:45:40 +08:00
"""
生成文本 (支持流式和非流式)
Args:
system_prompt: 系统提示
user_prompt: 用户提示
use_stream: 是否流式返回
temperature: 温度参数控制随机性
top_p: Top-p采样参数
presence_penalty: 存在惩罚参数
stage: 当前所处阶段用于日志记录
2025-07-08 17:45:40 +08:00
Returns:
一个元组 (generated_text, input_tokens, output_tokens, time_cost)
2025-07-08 17:45:40 +08:00
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
# 使用传入的参数或默认配置
temp = temperature if temperature is not None else self.config.temperature
tp = top_p if top_p is not None else self.config.top_p
pp = presence_penalty if presence_penalty is not None else self.config.presence_penalty
# 记录使用的模型参数
stage_info = f"[{stage}]" if stage else ""
logger.info(f"{stage_info} 使用模型参数: temperature={temp:.2f}, top_p={tp:.2f}, presence_penalty={pp:.2f}")
input_tokens = self.count_tokens(system_prompt + user_prompt)
logger.info(f"{stage_info} 开始生成任务... 输入token数: {input_tokens}")
2025-07-08 17:45:40 +08:00
last_exception = None
backoff_time = 1.0 # Start with 1 second
start_time = time.time()
2025-07-08 17:45:40 +08:00
for attempt in range(self.config.max_retries):
try:
response = await self.client.chat.completions.create(
2025-07-08 17:45:40 +08:00
model=self.config.model,
messages=messages,
temperature=temp,
top_p=tp,
presence_penalty=pp,
stream=use_stream
2025-07-08 17:45:40 +08:00
)
time_cost = time.time() - start_time
if use_stream:
# 流式处理暂时不返回token计数和时间需要更复杂的实现
# 这里返回一个空的生成器,但实际逻辑在 _process_stream 中
# 为了统一返回类型,我们可能需要重新设计这里
# 简化处理:流式模式下,我们返回拼接后的完整文本
full_text = "".join([chunk for chunk in self._process_stream(response)])
output_tokens = self.count_tokens(full_text)
logger.info(f"{stage_info} 任务完成,耗时 {time_cost:.2f} 秒. 输出token数: {output_tokens}")
return full_text, input_tokens, output_tokens, time_cost
2025-07-08 17:45:40 +08:00
else:
output_text = response.choices[0].message.content.strip()
output_tokens = self.count_tokens(output_text)
logger.info(f"{stage_info} 任务完成,耗时 {time_cost:.2f} 秒. 输出token数: {output_tokens}")
return output_text, input_tokens, output_tokens, time_cost
2025-07-08 17:45:40 +08:00
except (APITimeoutError, APIConnectionError) as e:
last_exception = RetryableError(f"AI模型连接或超时错误: {e}")
logger.warning(f"{stage_info} 尝试 {attempt + 1}/{self.config.max_retries} 失败: {last_exception}. "
2025-07-08 17:45:40 +08:00
f"将在 {backoff_time:.1f} 秒后重试...")
time.sleep(backoff_time)
backoff_time *= 2 # Exponential backoff
except (RateLimitError, APIStatusError) as e:
last_exception = NonRetryableError(f"AI模型API错误 (不可重试): {e}")
logger.error(f"{stage_info} 发生不可重试的API错误: {last_exception}")
2025-07-08 17:45:40 +08:00
break # Do not retry on these errors
except Exception as e:
last_exception = AIModelError(f"调用AI模型时发生未知错误: {e}")
logger.error(f"{stage_info} 发生未知错误: {last_exception}\n{traceback.format_exc()}")
2025-07-08 17:45:40 +08:00
break
raise AIModelError(f"AI模型调用在 {self.config.max_retries} 次重试后失败") from last_exception
def _process_stream(self, response):
"""处理流式响应"""
full_response = []
for chunk in response:
content = chunk.choices[0].delta.content
if content:
full_response.append(content)
yield content
logger.info(f"流式响应接收完成,总长度: {len(''.join(full_response))}")
def count_tokens(self, text: str) -> int:
"""
计算文本的token数量
Args:
text: 输入文本
Returns:
token数量
"""
return len(text) // 1.5
# return len(self.tokenizer.encode(text))
2025-07-08 17:45:40 +08:00
@staticmethod
def read_folder_content(folder_path: str) -> str:
"""
读取指定文件夹下的所有文件内容并合并
Args:
folder_path: 文件夹路径
Returns:
合并后的文件内容
"""
if not os.path.exists(folder_path):
logger.warning(f"引用的文件夹不存在: {folder_path}")
return ""
context = ""
try:
for file_name in sorted(os.listdir(folder_path)):
file_path = os.path.join(folder_path, file_name)
if os.path.isfile(file_path):
try:
with open(file_path, "r", encoding="utf-8") as f:
context += f"--- 文件名: {file_name} ---\n"
context += f.read().strip()
context += "\n\n"
except Exception as read_err:
logger.error(f"读取文件失败 {file_path}: {read_err}")
except Exception as list_err:
logger.error(f"列出目录失败 {folder_path}: {list_err}")
return context.strip()