152 lines
4.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
LLM 客户端
封装 LLM 调用,提供统一接口
"""
import logging
from typing import Dict, Any, Optional, Callable, AsyncIterator
logger = logging.getLogger(__name__)
class LLMClient:
"""
LLM 客户端
封装底层 AI Agent提供简化的调用接口
"""
def __init__(self, ai_agent=None):
"""
初始化 LLM 客户端
Args:
ai_agent: 底层 AI Agent 实例
"""
self._ai_agent = ai_agent
def set_agent(self, ai_agent):
"""设置 AI Agent"""
self._ai_agent = ai_agent
async def generate(
self,
prompt: str,
system_prompt: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
**kwargs
) -> str:
"""
生成文本
Args:
prompt: 用户提示词
system_prompt: 系统提示词
temperature: 温度参数
max_tokens: 最大 token 数
**kwargs: 其他参数
Returns:
生成的文本
"""
if not self._ai_agent:
raise RuntimeError("AI Agent 未初始化")
# 构建参数
params = {}
if temperature is not None:
params['temperature'] = temperature
if max_tokens is not None:
params['max_tokens'] = max_tokens
params.update(kwargs)
try:
result, input_tokens, output_tokens, time_cost = await self._ai_agent.generate_text(
system_prompt=system_prompt or "",
user_prompt=prompt,
use_stream=True,
stage="aigc_engine",
**params
)
logger.debug(f"LLM 调用完成: input={input_tokens}, output={output_tokens}, time={time_cost:.2f}s")
return result
except Exception as e:
logger.error(f"LLM 调用失败: {e}")
raise
async def generate_stream(
self,
prompt: str,
system_prompt: Optional[str] = None,
callback: Optional[Callable[[str], None]] = None,
**kwargs
) -> str:
"""
流式生成文本
Args:
prompt: 用户提示词
system_prompt: 系统提示词
callback: 流式回调函数
**kwargs: 其他参数
Returns:
完整的生成文本
"""
if not self._ai_agent:
raise RuntimeError("AI Agent 未初始化")
try:
result, _, _, _ = await self._ai_agent.generate_text(
system_prompt=system_prompt or "",
user_prompt=prompt,
use_stream=True,
stage="aigc_engine_stream",
**kwargs
)
return result
except Exception as e:
logger.error(f"LLM 流式调用失败: {e}")
raise
async def generate_with_retry(
self,
prompt: str,
system_prompt: Optional[str] = None,
max_retries: int = 3,
**kwargs
) -> str:
"""
带重试的生成
Args:
prompt: 用户提示词
system_prompt: 系统提示词
max_retries: 最大重试次数
**kwargs: 其他参数
Returns:
生成的文本
"""
last_error = None
for attempt in range(max_retries):
try:
return await self.generate(prompt, system_prompt, **kwargs)
except Exception as e:
last_error = e
logger.warning(f"LLM 调用失败 (尝试 {attempt + 1}/{max_retries}): {e}")
if attempt < max_retries - 1:
import asyncio
await asyncio.sleep(1 * (attempt + 1)) # 递增等待
raise last_error