187 lines
6.6 KiB
Python
187 lines
6.6 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
|
|||
|
|
"""
|
|||
|
|
AI 代理模块
|
|||
|
|
负责与大模型进行交互,生成文本内容
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
import time
|
|||
|
|
import logging
|
|||
|
|
import traceback
|
|||
|
|
from openai import OpenAI, APITimeoutError, APIConnectionError, RateLimitError, APIStatusError
|
|||
|
|
import tiktoken
|
|||
|
|
from typing import Optional
|
|||
|
|
|
|||
|
|
from core.config import AIModelConfig
|
|||
|
|
from core.exception import AIModelError, RetryableError, NonRetryableError
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class AIAgent:
|
|||
|
|
"""
|
|||
|
|
AI代理类,负责与AI模型交互生成文本内容
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, config: AIModelConfig):
|
|||
|
|
"""
|
|||
|
|
初始化 AI 代理
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
config: AI模型配置
|
|||
|
|
"""
|
|||
|
|
self.config = config
|
|||
|
|
self.client = OpenAI(
|
|||
|
|
api_key=self.config.api_key,
|
|||
|
|
base_url=self.config.api_url,
|
|||
|
|
timeout=self.config.timeout
|
|||
|
|
)
|
|||
|
|
try:
|
|||
|
|
self.tokenizer = tiktoken.encoding_for_model(self.config.model)
|
|||
|
|
except KeyError:
|
|||
|
|
logger.warning(f"模型 '{self.config.model}' 没有找到对应的tokenizer,将使用 'cl100k_base'")
|
|||
|
|
self.tokenizer = tiktoken.get_encoding("cl100k_base")
|
|||
|
|
|
|||
|
|
def generate_text(self, system_prompt: str, user_prompt: str, stream: bool = False):
|
|||
|
|
"""
|
|||
|
|
生成文本 (支持流式和非流式)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
system_prompt: 系统提示
|
|||
|
|
user_prompt: 用户提示
|
|||
|
|
stream: 是否流式返回
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
生成的文本字符串 或 流式响应的生成器
|
|||
|
|
"""
|
|||
|
|
messages = [
|
|||
|
|
{"role": "system", "content": system_prompt},
|
|||
|
|
{"role": "user", "content": user_prompt}
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
last_exception = None
|
|||
|
|
backoff_time = 1.0 # Start with 1 second
|
|||
|
|
|
|||
|
|
for attempt in range(self.config.max_retries):
|
|||
|
|
try:
|
|||
|
|
response = self.client.chat.completions.create(
|
|||
|
|
model=self.config.model,
|
|||
|
|
messages=messages,
|
|||
|
|
temperature=self.config.temperature,
|
|||
|
|
top_p=self.config.top_p,
|
|||
|
|
presence_penalty=self.config.presence_penalty,
|
|||
|
|
stream=stream
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if stream:
|
|||
|
|
return self._process_stream(response)
|
|||
|
|
else:
|
|||
|
|
return response.choices[0].message.content.strip()
|
|||
|
|
|
|||
|
|
except (APITimeoutError, APIConnectionError) as e:
|
|||
|
|
last_exception = RetryableError(f"AI模型连接或超时错误: {e}")
|
|||
|
|
logger.warning(f"尝试 {attempt + 1}/{self.config.max_retries} 失败: {last_exception}. "
|
|||
|
|
f"将在 {backoff_time:.1f} 秒后重试...")
|
|||
|
|
time.sleep(backoff_time)
|
|||
|
|
backoff_time *= 2 # Exponential backoff
|
|||
|
|
except (RateLimitError, APIStatusError) as e:
|
|||
|
|
last_exception = NonRetryableError(f"AI模型API错误 (不可重试): {e}")
|
|||
|
|
logger.error(f"发生不可重试的API错误: {last_exception}")
|
|||
|
|
break # Do not retry on these errors
|
|||
|
|
except Exception as e:
|
|||
|
|
last_exception = AIModelError(f"调用AI模型时发生未知错误: {e}")
|
|||
|
|
logger.error(f"发生未知错误: {last_exception}\n{traceback.format_exc()}")
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
raise AIModelError(f"AI模型调用在 {self.config.max_retries} 次重试后失败") from last_exception
|
|||
|
|
|
|||
|
|
def _process_stream(self, response):
|
|||
|
|
"""处理流式响应"""
|
|||
|
|
full_response = []
|
|||
|
|
for chunk in response:
|
|||
|
|
content = chunk.choices[0].delta.content
|
|||
|
|
if content:
|
|||
|
|
full_response.append(content)
|
|||
|
|
yield content
|
|||
|
|
|
|||
|
|
logger.info(f"流式响应接收完成,总长度: {len(''.join(full_response))}")
|
|||
|
|
|
|||
|
|
def count_tokens(self, text: str) -> int:
|
|||
|
|
"""
|
|||
|
|
计算文本的token数量
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
text: 输入文本
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
token数量
|
|||
|
|
"""
|
|||
|
|
return len(self.tokenizer.encode(text))
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def read_folder_content(folder_path: str) -> str:
|
|||
|
|
"""
|
|||
|
|
读取指定文件夹下的所有文件内容并合并
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
folder_path: 文件夹路径
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
合并后的文件内容
|
|||
|
|
"""
|
|||
|
|
if not os.path.exists(folder_path):
|
|||
|
|
logger.warning(f"引用的文件夹不存在: {folder_path}")
|
|||
|
|
return ""
|
|||
|
|
|
|||
|
|
context = ""
|
|||
|
|
try:
|
|||
|
|
for file_name in sorted(os.listdir(folder_path)):
|
|||
|
|
file_path = os.path.join(folder_path, file_name)
|
|||
|
|
if os.path.isfile(file_path):
|
|||
|
|
try:
|
|||
|
|
with open(file_path, "r", encoding="utf-8") as f:
|
|||
|
|
context += f"--- 文件名: {file_name} ---\n"
|
|||
|
|
context += f.read().strip()
|
|||
|
|
context += "\n\n"
|
|||
|
|
except Exception as read_err:
|
|||
|
|
logger.error(f"读取文件失败 {file_path}: {read_err}")
|
|||
|
|
except Exception as list_err:
|
|||
|
|
logger.error(f"列出目录失败 {folder_path}: {list_err}")
|
|||
|
|
return context.strip()
|
|||
|
|
|
|||
|
|
def work(self, system_prompt: str, user_prompt: str, context_folder: Optional[str] = None):
|
|||
|
|
"""
|
|||
|
|
完整的工作流程:准备提示、生成文本并返回结果
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
system_prompt: 系统提示
|
|||
|
|
user_prompt: 用户提示
|
|||
|
|
context_folder: 包含上下文文件的文件夹路径 (可选)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
一个元组 (result, input_tokens, output_tokens, time_cost)
|
|||
|
|
"""
|
|||
|
|
if context_folder:
|
|||
|
|
logger.info(f"从文件夹读取上下文: {context_folder}")
|
|||
|
|
context = self.read_folder_content(context_folder)
|
|||
|
|
if context:
|
|||
|
|
user_prompt = f"{user_prompt.strip()}\n\n--- 参考资料 ---\n{context}"
|
|||
|
|
|
|||
|
|
input_tokens = self.count_tokens(system_prompt + user_prompt)
|
|||
|
|
logger.info(f"开始生成任务... 输入token数: {input_tokens}")
|
|||
|
|
|
|||
|
|
start_time = time.time()
|
|||
|
|
|
|||
|
|
# 使用非流式生成获取完整结果
|
|||
|
|
result = self.generate_text(system_prompt, user_prompt, stream=False)
|
|||
|
|
|
|||
|
|
time_cost = time.time() - start_time
|
|||
|
|
|
|||
|
|
output_tokens = self.count_tokens(result)
|
|||
|
|
|
|||
|
|
logger.info(f"任务完成,耗时 {time_cost:.2f} 秒. 输出token数: {output_tokens}")
|
|||
|
|
|
|||
|
|
return result, input_tokens, output_tokens, time_cost
|