152 lines
4.0 KiB
Python
152 lines
4.0 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
|
|||
|
|
"""
|
|||
|
|
LLM 客户端
|
|||
|
|
封装 LLM 调用,提供统一接口
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import logging
|
|||
|
|
from typing import Dict, Any, Optional, Callable, AsyncIterator
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class LLMClient:
|
|||
|
|
"""
|
|||
|
|
LLM 客户端
|
|||
|
|
|
|||
|
|
封装底层 AI Agent,提供简化的调用接口
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, ai_agent=None):
|
|||
|
|
"""
|
|||
|
|
初始化 LLM 客户端
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
ai_agent: 底层 AI Agent 实例
|
|||
|
|
"""
|
|||
|
|
self._ai_agent = ai_agent
|
|||
|
|
|
|||
|
|
def set_agent(self, ai_agent):
|
|||
|
|
"""设置 AI Agent"""
|
|||
|
|
self._ai_agent = ai_agent
|
|||
|
|
|
|||
|
|
async def generate(
|
|||
|
|
self,
|
|||
|
|
prompt: str,
|
|||
|
|
system_prompt: Optional[str] = None,
|
|||
|
|
temperature: Optional[float] = None,
|
|||
|
|
max_tokens: Optional[int] = None,
|
|||
|
|
**kwargs
|
|||
|
|
) -> str:
|
|||
|
|
"""
|
|||
|
|
生成文本
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
prompt: 用户提示词
|
|||
|
|
system_prompt: 系统提示词
|
|||
|
|
temperature: 温度参数
|
|||
|
|
max_tokens: 最大 token 数
|
|||
|
|
**kwargs: 其他参数
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
生成的文本
|
|||
|
|
"""
|
|||
|
|
if not self._ai_agent:
|
|||
|
|
raise RuntimeError("AI Agent 未初始化")
|
|||
|
|
|
|||
|
|
# 构建参数
|
|||
|
|
params = {}
|
|||
|
|
if temperature is not None:
|
|||
|
|
params['temperature'] = temperature
|
|||
|
|
if max_tokens is not None:
|
|||
|
|
params['max_tokens'] = max_tokens
|
|||
|
|
params.update(kwargs)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
result, input_tokens, output_tokens, time_cost = await self._ai_agent.generate_text(
|
|||
|
|
system_prompt=system_prompt or "",
|
|||
|
|
user_prompt=prompt,
|
|||
|
|
use_stream=True,
|
|||
|
|
stage="aigc_engine",
|
|||
|
|
**params
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
logger.debug(f"LLM 调用完成: input={input_tokens}, output={output_tokens}, time={time_cost:.2f}s")
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"LLM 调用失败: {e}")
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
async def generate_stream(
|
|||
|
|
self,
|
|||
|
|
prompt: str,
|
|||
|
|
system_prompt: Optional[str] = None,
|
|||
|
|
callback: Optional[Callable[[str], None]] = None,
|
|||
|
|
**kwargs
|
|||
|
|
) -> str:
|
|||
|
|
"""
|
|||
|
|
流式生成文本
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
prompt: 用户提示词
|
|||
|
|
system_prompt: 系统提示词
|
|||
|
|
callback: 流式回调函数
|
|||
|
|
**kwargs: 其他参数
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
完整的生成文本
|
|||
|
|
"""
|
|||
|
|
if not self._ai_agent:
|
|||
|
|
raise RuntimeError("AI Agent 未初始化")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
result, _, _, _ = await self._ai_agent.generate_text(
|
|||
|
|
system_prompt=system_prompt or "",
|
|||
|
|
user_prompt=prompt,
|
|||
|
|
use_stream=True,
|
|||
|
|
stage="aigc_engine_stream",
|
|||
|
|
**kwargs
|
|||
|
|
)
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"LLM 流式调用失败: {e}")
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
async def generate_with_retry(
|
|||
|
|
self,
|
|||
|
|
prompt: str,
|
|||
|
|
system_prompt: Optional[str] = None,
|
|||
|
|
max_retries: int = 3,
|
|||
|
|
**kwargs
|
|||
|
|
) -> str:
|
|||
|
|
"""
|
|||
|
|
带重试的生成
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
prompt: 用户提示词
|
|||
|
|
system_prompt: 系统提示词
|
|||
|
|
max_retries: 最大重试次数
|
|||
|
|
**kwargs: 其他参数
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
生成的文本
|
|||
|
|
"""
|
|||
|
|
last_error = None
|
|||
|
|
|
|||
|
|
for attempt in range(max_retries):
|
|||
|
|
try:
|
|||
|
|
return await self.generate(prompt, system_prompt, **kwargs)
|
|||
|
|
except Exception as e:
|
|||
|
|
last_error = e
|
|||
|
|
logger.warning(f"LLM 调用失败 (尝试 {attempt + 1}/{max_retries}): {e}")
|
|||
|
|
|
|||
|
|
if attempt < max_retries - 1:
|
|||
|
|
import asyncio
|
|||
|
|
await asyncio.sleep(1 * (attempt + 1)) # 递增等待
|
|||
|
|
|
|||
|
|
raise last_error
|