167 lines
4.8 KiB
Python
167 lines
4.8 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
LLM 客户端
|
||
封装 LLM 调用,提供统一接口
|
||
"""
|
||
|
||
import logging
|
||
from typing import Dict, Any, Optional, Callable, AsyncIterator
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class LLMClient:
|
||
"""
|
||
LLM 客户端
|
||
|
||
封装底层 AI Agent,提供简化的调用接口
|
||
"""
|
||
|
||
def __init__(self, ai_agent=None):
|
||
"""
|
||
初始化 LLM 客户端
|
||
|
||
Args:
|
||
ai_agent: 底层 AI Agent 实例
|
||
"""
|
||
self._ai_agent = ai_agent
|
||
|
||
def set_agent(self, ai_agent):
|
||
"""设置 AI Agent"""
|
||
self._ai_agent = ai_agent
|
||
|
||
async def generate(
|
||
self,
|
||
prompt: str = None,
|
||
system_prompt: Optional[str] = None,
|
||
user_prompt: Optional[str] = None,
|
||
temperature: Optional[float] = None,
|
||
max_tokens: Optional[int] = None,
|
||
top_p: Optional[float] = None,
|
||
presence_penalty: Optional[float] = None,
|
||
**kwargs
|
||
) -> tuple:
|
||
"""
|
||
生成文本
|
||
|
||
Args:
|
||
prompt: 用户提示词 (兼容旧接口)
|
||
system_prompt: 系统提示词
|
||
user_prompt: 用户提示词 (新接口)
|
||
temperature: 温度参数
|
||
max_tokens: 最大 token 数
|
||
top_p: top_p 参数
|
||
presence_penalty: 存在惩罚
|
||
**kwargs: 其他参数
|
||
|
||
Returns:
|
||
tuple: (result, input_tokens, output_tokens, time_cost)
|
||
"""
|
||
# 兼容两种调用方式
|
||
actual_user_prompt = user_prompt or prompt
|
||
if not actual_user_prompt:
|
||
raise ValueError("必须提供 prompt 或 user_prompt")
|
||
|
||
if not self._ai_agent:
|
||
raise RuntimeError("AI Agent 未初始化")
|
||
|
||
# 构建参数
|
||
params = {}
|
||
if temperature is not None:
|
||
params['temperature'] = temperature
|
||
if max_tokens is not None:
|
||
params['max_tokens'] = max_tokens
|
||
if top_p is not None:
|
||
params['top_p'] = top_p
|
||
if presence_penalty is not None:
|
||
params['presence_penalty'] = presence_penalty
|
||
params.update(kwargs)
|
||
|
||
try:
|
||
result, input_tokens, output_tokens, time_cost = await self._ai_agent.generate_text(
|
||
system_prompt=system_prompt or "",
|
||
user_prompt=actual_user_prompt,
|
||
use_stream=True,
|
||
stage="aigc_engine",
|
||
**params
|
||
)
|
||
|
||
logger.debug(f"LLM 调用完成: input={input_tokens}, output={output_tokens}, time={time_cost:.2f}s")
|
||
return result, input_tokens, output_tokens, time_cost
|
||
|
||
except Exception as e:
|
||
logger.error(f"LLM 调用失败: {e}")
|
||
raise
|
||
|
||
async def generate_stream(
|
||
self,
|
||
prompt: str,
|
||
system_prompt: Optional[str] = None,
|
||
callback: Optional[Callable[[str], None]] = None,
|
||
**kwargs
|
||
) -> str:
|
||
"""
|
||
流式生成文本
|
||
|
||
Args:
|
||
prompt: 用户提示词
|
||
system_prompt: 系统提示词
|
||
callback: 流式回调函数
|
||
**kwargs: 其他参数
|
||
|
||
Returns:
|
||
完整的生成文本
|
||
"""
|
||
if not self._ai_agent:
|
||
raise RuntimeError("AI Agent 未初始化")
|
||
|
||
try:
|
||
result, _, _, _ = await self._ai_agent.generate_text(
|
||
system_prompt=system_prompt or "",
|
||
user_prompt=prompt,
|
||
use_stream=True,
|
||
stage="aigc_engine_stream",
|
||
**kwargs
|
||
)
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"LLM 流式调用失败: {e}")
|
||
raise
|
||
|
||
async def generate_with_retry(
|
||
self,
|
||
prompt: str,
|
||
system_prompt: Optional[str] = None,
|
||
max_retries: int = 3,
|
||
**kwargs
|
||
) -> str:
|
||
"""
|
||
带重试的生成
|
||
|
||
Args:
|
||
prompt: 用户提示词
|
||
system_prompt: 系统提示词
|
||
max_retries: 最大重试次数
|
||
**kwargs: 其他参数
|
||
|
||
Returns:
|
||
生成的文本
|
||
"""
|
||
last_error = None
|
||
|
||
for attempt in range(max_retries):
|
||
try:
|
||
return await self.generate(prompt, system_prompt, **kwargs)
|
||
except Exception as e:
|
||
last_error = e
|
||
logger.warning(f"LLM 调用失败 (尝试 {attempt + 1}/{max_retries}): {e}")
|
||
|
||
if attempt < max_retries - 1:
|
||
import asyncio
|
||
await asyncio.sleep(1 * (attempt + 1)) # 递增等待
|
||
|
||
raise last_error
|