157 lines
4.6 KiB
Python
157 lines
4.6 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
|
|||
|
|
"""
|
|||
|
|
提示词构建器
|
|||
|
|
统一管理提示词模板的加载和构建
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import logging
|
|||
|
|
from typing import Dict, Any, Optional
|
|||
|
|
from pathlib import Path
|
|||
|
|
from string import Template
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class PromptBuilder:
|
|||
|
|
"""
|
|||
|
|
提示词构建器
|
|||
|
|
|
|||
|
|
支持:
|
|||
|
|
- 从文件加载模板
|
|||
|
|
- 变量替换
|
|||
|
|
- 缓存机制
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, prompt_base_path: str = "resource/prompt"):
|
|||
|
|
"""
|
|||
|
|
初始化提示词构建器
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
prompt_base_path: 提示词模板基础路径(相对于项目根目录)
|
|||
|
|
"""
|
|||
|
|
self._base_path = Path(prompt_base_path)
|
|||
|
|
self._project_root: Optional[Path] = None
|
|||
|
|
self._cache: Dict[str, str] = {}
|
|||
|
|
|
|||
|
|
def set_project_root(self, project_root: str):
|
|||
|
|
"""设置项目根目录"""
|
|||
|
|
self._project_root = Path(project_root)
|
|||
|
|
|
|||
|
|
def _get_full_path(self, relative_path: str) -> Path:
|
|||
|
|
"""获取完整路径"""
|
|||
|
|
if self._project_root:
|
|||
|
|
return self._project_root / self._base_path / relative_path
|
|||
|
|
return self._base_path / relative_path
|
|||
|
|
|
|||
|
|
def load_template(self, template_name: str, template_type: str = "user") -> str:
|
|||
|
|
"""
|
|||
|
|
加载模板文件
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
template_name: 模板名称(对应目录名)
|
|||
|
|
template_type: 模板类型 (system, user)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
模板内容
|
|||
|
|
"""
|
|||
|
|
cache_key = f"{template_name}/{template_type}"
|
|||
|
|
|
|||
|
|
if cache_key in self._cache:
|
|||
|
|
return self._cache[cache_key]
|
|||
|
|
|
|||
|
|
# 尝试多种文件名
|
|||
|
|
possible_files = [
|
|||
|
|
f"{template_type}.txt",
|
|||
|
|
f"{template_type}_prompt.txt",
|
|||
|
|
f"{template_type}.md",
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
for filename in possible_files:
|
|||
|
|
file_path = self._get_full_path(f"{template_name}/{filename}")
|
|||
|
|
if file_path.exists():
|
|||
|
|
try:
|
|||
|
|
content = file_path.read_text(encoding='utf-8')
|
|||
|
|
self._cache[cache_key] = content
|
|||
|
|
return content
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"读取模板文件失败: {file_path}, {e}")
|
|||
|
|
|
|||
|
|
logger.warning(f"模板文件不存在: {template_name}/{template_type}")
|
|||
|
|
return ""
|
|||
|
|
|
|||
|
|
def build(
|
|||
|
|
self,
|
|||
|
|
template_name: str,
|
|||
|
|
template_type: str = "user",
|
|||
|
|
**variables
|
|||
|
|
) -> str:
|
|||
|
|
"""
|
|||
|
|
构建提示词
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
template_name: 模板名称
|
|||
|
|
template_type: 模板类型 (system, user)
|
|||
|
|
**variables: 模板变量
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
构建后的提示词
|
|||
|
|
"""
|
|||
|
|
template_content = self.load_template(template_name, template_type)
|
|||
|
|
|
|||
|
|
if not template_content:
|
|||
|
|
return ""
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 使用 Python Template 进行变量替换
|
|||
|
|
# 支持 $variable 和 ${variable} 格式
|
|||
|
|
template = Template(template_content)
|
|||
|
|
return template.safe_substitute(**variables)
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"构建提示词失败: {e}")
|
|||
|
|
return template_content
|
|||
|
|
|
|||
|
|
def build_full(
|
|||
|
|
self,
|
|||
|
|
template_name: str,
|
|||
|
|
**variables
|
|||
|
|
) -> tuple[str, str]:
|
|||
|
|
"""
|
|||
|
|
构建完整提示词(系统 + 用户)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
template_name: 模板名称
|
|||
|
|
**variables: 模板变量
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
(系统提示词, 用户提示词)
|
|||
|
|
"""
|
|||
|
|
system_prompt = self.build(template_name, "system", **variables)
|
|||
|
|
user_prompt = self.build(template_name, "user", **variables)
|
|||
|
|
return system_prompt, user_prompt
|
|||
|
|
|
|||
|
|
def get_system_prompt(self, template_name: str, **variables) -> str:
|
|||
|
|
"""获取系统提示词"""
|
|||
|
|
return self.build(template_name, "system", **variables)
|
|||
|
|
|
|||
|
|
def get_user_prompt(self, template_name: str, **variables) -> str:
|
|||
|
|
"""获取用户提示词"""
|
|||
|
|
return self.build(template_name, "user", **variables)
|
|||
|
|
|
|||
|
|
def clear_cache(self):
|
|||
|
|
"""清空缓存"""
|
|||
|
|
self._cache.clear()
|
|||
|
|
|
|||
|
|
def list_templates(self) -> list[str]:
|
|||
|
|
"""列出所有可用模板"""
|
|||
|
|
templates = []
|
|||
|
|
base_path = self._get_full_path("")
|
|||
|
|
|
|||
|
|
if base_path.exists():
|
|||
|
|
for item in base_path.iterdir():
|
|||
|
|
if item.is_dir() and not item.name.startswith('_'):
|
|||
|
|
templates.append(item.name)
|
|||
|
|
|
|||
|
|
return templates
|