152 lines
4.0 KiB
Python
Raw Normal View History

2025-12-08 14:58:35 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
LLM 客户端
封装 LLM 调用提供统一接口
"""
import logging
from typing import Dict, Any, Optional, Callable, AsyncIterator
logger = logging.getLogger(__name__)
class LLMClient:
"""
LLM 客户端
封装底层 AI Agent提供简化的调用接口
"""
def __init__(self, ai_agent=None):
"""
初始化 LLM 客户端
Args:
ai_agent: 底层 AI Agent 实例
"""
self._ai_agent = ai_agent
def set_agent(self, ai_agent):
"""设置 AI Agent"""
self._ai_agent = ai_agent
async def generate(
self,
prompt: str,
system_prompt: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
**kwargs
) -> str:
"""
生成文本
Args:
prompt: 用户提示词
system_prompt: 系统提示词
temperature: 温度参数
max_tokens: 最大 token
**kwargs: 其他参数
Returns:
生成的文本
"""
if not self._ai_agent:
raise RuntimeError("AI Agent 未初始化")
# 构建参数
params = {}
if temperature is not None:
params['temperature'] = temperature
if max_tokens is not None:
params['max_tokens'] = max_tokens
params.update(kwargs)
try:
result, input_tokens, output_tokens, time_cost = await self._ai_agent.generate_text(
system_prompt=system_prompt or "",
user_prompt=prompt,
use_stream=True,
stage="aigc_engine",
**params
)
logger.debug(f"LLM 调用完成: input={input_tokens}, output={output_tokens}, time={time_cost:.2f}s")
return result
except Exception as e:
logger.error(f"LLM 调用失败: {e}")
raise
async def generate_stream(
self,
prompt: str,
system_prompt: Optional[str] = None,
callback: Optional[Callable[[str], None]] = None,
**kwargs
) -> str:
"""
流式生成文本
Args:
prompt: 用户提示词
system_prompt: 系统提示词
callback: 流式回调函数
**kwargs: 其他参数
Returns:
完整的生成文本
"""
if not self._ai_agent:
raise RuntimeError("AI Agent 未初始化")
try:
result, _, _, _ = await self._ai_agent.generate_text(
system_prompt=system_prompt or "",
user_prompt=prompt,
use_stream=True,
stage="aigc_engine_stream",
**kwargs
)
return result
except Exception as e:
logger.error(f"LLM 流式调用失败: {e}")
raise
async def generate_with_retry(
self,
prompt: str,
system_prompt: Optional[str] = None,
max_retries: int = 3,
**kwargs
) -> str:
"""
带重试的生成
Args:
prompt: 用户提示词
system_prompt: 系统提示词
max_retries: 最大重试次数
**kwargs: 其他参数
Returns:
生成的文本
"""
last_error = None
for attempt in range(max_retries):
try:
return await self.generate(prompt, system_prompt, **kwargs)
except Exception as e:
last_error = e
logger.warning(f"LLM 调用失败 (尝试 {attempt + 1}/{max_retries}): {e}")
if attempt < max_retries - 1:
import asyncio
await asyncio.sleep(1 * (attempt + 1)) # 递增等待
raise last_error