167 lines
4.8 KiB
Python
Raw Normal View History

2025-12-08 14:58:35 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
LLM 客户端
封装 LLM 调用提供统一接口
"""
import logging
from typing import Dict, Any, Optional, Callable, AsyncIterator
logger = logging.getLogger(__name__)
class LLMClient:
"""
LLM 客户端
封装底层 AI Agent提供简化的调用接口
"""
def __init__(self, ai_agent=None):
"""
初始化 LLM 客户端
Args:
ai_agent: 底层 AI Agent 实例
"""
self._ai_agent = ai_agent
def set_agent(self, ai_agent):
"""设置 AI Agent"""
self._ai_agent = ai_agent
async def generate(
self,
prompt: str = None,
2025-12-08 14:58:35 +08:00
system_prompt: Optional[str] = None,
user_prompt: Optional[str] = None,
2025-12-08 14:58:35 +08:00
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
top_p: Optional[float] = None,
presence_penalty: Optional[float] = None,
2025-12-08 14:58:35 +08:00
**kwargs
) -> tuple:
2025-12-08 14:58:35 +08:00
"""
生成文本
Args:
prompt: 用户提示词 (兼容旧接口)
2025-12-08 14:58:35 +08:00
system_prompt: 系统提示词
user_prompt: 用户提示词 (新接口)
2025-12-08 14:58:35 +08:00
temperature: 温度参数
max_tokens: 最大 token
top_p: top_p 参数
presence_penalty: 存在惩罚
2025-12-08 14:58:35 +08:00
**kwargs: 其他参数
Returns:
tuple: (result, input_tokens, output_tokens, time_cost)
2025-12-08 14:58:35 +08:00
"""
# 兼容两种调用方式
actual_user_prompt = user_prompt or prompt
if not actual_user_prompt:
raise ValueError("必须提供 prompt 或 user_prompt")
2025-12-08 14:58:35 +08:00
if not self._ai_agent:
raise RuntimeError("AI Agent 未初始化")
# 构建参数
params = {}
if temperature is not None:
params['temperature'] = temperature
if max_tokens is not None:
params['max_tokens'] = max_tokens
if top_p is not None:
params['top_p'] = top_p
if presence_penalty is not None:
params['presence_penalty'] = presence_penalty
2025-12-08 14:58:35 +08:00
params.update(kwargs)
try:
result, input_tokens, output_tokens, time_cost = await self._ai_agent.generate_text(
system_prompt=system_prompt or "",
user_prompt=actual_user_prompt,
2025-12-08 14:58:35 +08:00
use_stream=True,
stage="aigc_engine",
**params
)
logger.debug(f"LLM 调用完成: input={input_tokens}, output={output_tokens}, time={time_cost:.2f}s")
return result, input_tokens, output_tokens, time_cost
2025-12-08 14:58:35 +08:00
except Exception as e:
logger.error(f"LLM 调用失败: {e}")
raise
async def generate_stream(
self,
prompt: str,
system_prompt: Optional[str] = None,
callback: Optional[Callable[[str], None]] = None,
**kwargs
) -> str:
"""
流式生成文本
Args:
prompt: 用户提示词
system_prompt: 系统提示词
callback: 流式回调函数
**kwargs: 其他参数
Returns:
完整的生成文本
"""
if not self._ai_agent:
raise RuntimeError("AI Agent 未初始化")
try:
result, _, _, _ = await self._ai_agent.generate_text(
system_prompt=system_prompt or "",
user_prompt=prompt,
use_stream=True,
stage="aigc_engine_stream",
**kwargs
)
return result
except Exception as e:
logger.error(f"LLM 流式调用失败: {e}")
raise
async def generate_with_retry(
self,
prompt: str,
system_prompt: Optional[str] = None,
max_retries: int = 3,
**kwargs
) -> str:
"""
带重试的生成
Args:
prompt: 用户提示词
system_prompt: 系统提示词
max_retries: 最大重试次数
**kwargs: 其他参数
Returns:
生成的文本
"""
last_error = None
for attempt in range(max_retries):
try:
return await self.generate(prompt, system_prompt, **kwargs)
except Exception as e:
last_error = e
logger.warning(f"LLM 调用失败 (尝试 {attempt + 1}/{max_retries}): {e}")
if attempt < max_retries - 1:
import asyncio
await asyncio.sleep(1 * (attempt + 1)) # 递增等待
raise last_error