167 lines
4.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
LLM 客户端
封装 LLM 调用,提供统一接口
"""
import logging
from typing import Dict, Any, Optional, Callable, AsyncIterator
logger = logging.getLogger(__name__)
class LLMClient:
"""
LLM 客户端
封装底层 AI Agent提供简化的调用接口
"""
def __init__(self, ai_agent=None):
"""
初始化 LLM 客户端
Args:
ai_agent: 底层 AI Agent 实例
"""
self._ai_agent = ai_agent
def set_agent(self, ai_agent):
"""设置 AI Agent"""
self._ai_agent = ai_agent
async def generate(
self,
prompt: str = None,
system_prompt: Optional[str] = None,
user_prompt: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
top_p: Optional[float] = None,
presence_penalty: Optional[float] = None,
**kwargs
) -> tuple:
"""
生成文本
Args:
prompt: 用户提示词 (兼容旧接口)
system_prompt: 系统提示词
user_prompt: 用户提示词 (新接口)
temperature: 温度参数
max_tokens: 最大 token 数
top_p: top_p 参数
presence_penalty: 存在惩罚
**kwargs: 其他参数
Returns:
tuple: (result, input_tokens, output_tokens, time_cost)
"""
# 兼容两种调用方式
actual_user_prompt = user_prompt or prompt
if not actual_user_prompt:
raise ValueError("必须提供 prompt 或 user_prompt")
if not self._ai_agent:
raise RuntimeError("AI Agent 未初始化")
# 构建参数
params = {}
if temperature is not None:
params['temperature'] = temperature
if max_tokens is not None:
params['max_tokens'] = max_tokens
if top_p is not None:
params['top_p'] = top_p
if presence_penalty is not None:
params['presence_penalty'] = presence_penalty
params.update(kwargs)
try:
result, input_tokens, output_tokens, time_cost = await self._ai_agent.generate_text(
system_prompt=system_prompt or "",
user_prompt=actual_user_prompt,
use_stream=True,
stage="aigc_engine",
**params
)
logger.debug(f"LLM 调用完成: input={input_tokens}, output={output_tokens}, time={time_cost:.2f}s")
return result, input_tokens, output_tokens, time_cost
except Exception as e:
logger.error(f"LLM 调用失败: {e}")
raise
async def generate_stream(
self,
prompt: str,
system_prompt: Optional[str] = None,
callback: Optional[Callable[[str], None]] = None,
**kwargs
) -> str:
"""
流式生成文本
Args:
prompt: 用户提示词
system_prompt: 系统提示词
callback: 流式回调函数
**kwargs: 其他参数
Returns:
完整的生成文本
"""
if not self._ai_agent:
raise RuntimeError("AI Agent 未初始化")
try:
result, _, _, _ = await self._ai_agent.generate_text(
system_prompt=system_prompt or "",
user_prompt=prompt,
use_stream=True,
stage="aigc_engine_stream",
**kwargs
)
return result
except Exception as e:
logger.error(f"LLM 流式调用失败: {e}")
raise
async def generate_with_retry(
self,
prompt: str,
system_prompt: Optional[str] = None,
max_retries: int = 3,
**kwargs
) -> str:
"""
带重试的生成
Args:
prompt: 用户提示词
system_prompt: 系统提示词
max_retries: 最大重试次数
**kwargs: 其他参数
Returns:
生成的文本
"""
last_error = None
for attempt in range(max_retries):
try:
return await self.generate(prompt, system_prompt, **kwargs)
except Exception as e:
last_error = e
logger.warning(f"LLM 调用失败 (尝试 {attempt + 1}/{max_retries}): {e}")
if attempt < max_retries - 1:
import asyncio
await asyncio.sleep(1 * (attempt + 1)) # 递增等待
raise last_error