187 lines
6.6 KiB
Python
Raw Normal View History

2025-07-08 17:45:40 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
AI 代理模块
负责与大模型进行交互生成文本内容
"""
import os
import time
import logging
import traceback
from openai import OpenAI, APITimeoutError, APIConnectionError, RateLimitError, APIStatusError
import tiktoken
from typing import Optional
from core.config import AIModelConfig
from core.exception import AIModelError, RetryableError, NonRetryableError
logger = logging.getLogger(__name__)
class AIAgent:
"""
AI代理类负责与AI模型交互生成文本内容
"""
def __init__(self, config: AIModelConfig):
"""
初始化 AI 代理
Args:
config: AI模型配置
"""
self.config = config
self.client = OpenAI(
api_key=self.config.api_key,
base_url=self.config.api_url,
timeout=self.config.timeout
)
try:
self.tokenizer = tiktoken.encoding_for_model(self.config.model)
except KeyError:
logger.warning(f"模型 '{self.config.model}' 没有找到对应的tokenizer将使用 'cl100k_base'")
self.tokenizer = tiktoken.get_encoding("cl100k_base")
def generate_text(self, system_prompt: str, user_prompt: str, stream: bool = False):
"""
生成文本 (支持流式和非流式)
Args:
system_prompt: 系统提示
user_prompt: 用户提示
stream: 是否流式返回
Returns:
生成的文本字符串 流式响应的生成器
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
last_exception = None
backoff_time = 1.0 # Start with 1 second
for attempt in range(self.config.max_retries):
try:
response = self.client.chat.completions.create(
model=self.config.model,
messages=messages,
temperature=self.config.temperature,
top_p=self.config.top_p,
presence_penalty=self.config.presence_penalty,
stream=stream
)
if stream:
return self._process_stream(response)
else:
return response.choices[0].message.content.strip()
except (APITimeoutError, APIConnectionError) as e:
last_exception = RetryableError(f"AI模型连接或超时错误: {e}")
logger.warning(f"尝试 {attempt + 1}/{self.config.max_retries} 失败: {last_exception}. "
f"将在 {backoff_time:.1f} 秒后重试...")
time.sleep(backoff_time)
backoff_time *= 2 # Exponential backoff
except (RateLimitError, APIStatusError) as e:
last_exception = NonRetryableError(f"AI模型API错误 (不可重试): {e}")
logger.error(f"发生不可重试的API错误: {last_exception}")
break # Do not retry on these errors
except Exception as e:
last_exception = AIModelError(f"调用AI模型时发生未知错误: {e}")
logger.error(f"发生未知错误: {last_exception}\n{traceback.format_exc()}")
break
raise AIModelError(f"AI模型调用在 {self.config.max_retries} 次重试后失败") from last_exception
def _process_stream(self, response):
"""处理流式响应"""
full_response = []
for chunk in response:
content = chunk.choices[0].delta.content
if content:
full_response.append(content)
yield content
logger.info(f"流式响应接收完成,总长度: {len(''.join(full_response))}")
def count_tokens(self, text: str) -> int:
"""
计算文本的token数量
Args:
text: 输入文本
Returns:
token数量
"""
return len(self.tokenizer.encode(text))
@staticmethod
def read_folder_content(folder_path: str) -> str:
"""
读取指定文件夹下的所有文件内容并合并
Args:
folder_path: 文件夹路径
Returns:
合并后的文件内容
"""
if not os.path.exists(folder_path):
logger.warning(f"引用的文件夹不存在: {folder_path}")
return ""
context = ""
try:
for file_name in sorted(os.listdir(folder_path)):
file_path = os.path.join(folder_path, file_name)
if os.path.isfile(file_path):
try:
with open(file_path, "r", encoding="utf-8") as f:
context += f"--- 文件名: {file_name} ---\n"
context += f.read().strip()
context += "\n\n"
except Exception as read_err:
logger.error(f"读取文件失败 {file_path}: {read_err}")
except Exception as list_err:
logger.error(f"列出目录失败 {folder_path}: {list_err}")
return context.strip()
def work(self, system_prompt: str, user_prompt: str, context_folder: Optional[str] = None):
"""
完整的工作流程准备提示生成文本并返回结果
Args:
system_prompt: 系统提示
user_prompt: 用户提示
context_folder: 包含上下文文件的文件夹路径 (可选)
Returns:
一个元组 (result, input_tokens, output_tokens, time_cost)
"""
if context_folder:
logger.info(f"从文件夹读取上下文: {context_folder}")
context = self.read_folder_content(context_folder)
if context:
user_prompt = f"{user_prompt.strip()}\n\n--- 参考资料 ---\n{context}"
input_tokens = self.count_tokens(system_prompt + user_prompt)
logger.info(f"开始生成任务... 输入token数: {input_tokens}")
start_time = time.time()
# 使用非流式生成获取完整结果
result = self.generate_text(system_prompt, user_prompt, stream=False)
time_cost = time.time() - start_time
output_tokens = self.count_tokens(result)
logger.info(f"任务完成,耗时 {time_cost:.2f} 秒. 输出token数: {output_tokens}")
return result, input_tokens, output_tokens, time_cost