255 lines
8.3 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Prompt 注册表
提供版本化的 Prompt 管理、加载和渲染功能
"""
import logging
from pathlib import Path
from typing import Dict, Any, Optional, List, Tuple
from dataclasses import dataclass, field
import yaml
from jinja2 import Template, Environment, BaseLoader
logger = logging.getLogger(__name__)
@dataclass
class PromptConfig:
"""Prompt 配置"""
name: str
version: str
description: str
system: str
user: str
variables: Dict[str, Any] = field(default_factory=dict)
model: Dict[str, float] = field(default_factory=dict)
meta: Dict[str, Any] = field(default_factory=dict) # 额外元数据
content: str = "" # 风格/人群等配置内容
def get_model_params(self) -> Dict[str, float]:
"""获取模型参数"""
return {
'temperature': self.model.get('temperature', 0.7),
'top_p': self.model.get('top_p', 0.9),
'presence_penalty': self.model.get('presence_penalty', 0.0),
}
def get_content(self) -> str:
"""获取配置内容 (用于风格/人群等)"""
return self.content or self.system
class PromptRegistry:
"""
Prompt 注册表
功能:
1. 加载和缓存 prompt 配置
2. 版本管理 (latest, v1.0.0, v1.1.0)
3. 变量验证
4. 模板渲染 (Jinja2)
使用示例:
```python
registry = PromptRegistry("prompts")
# 渲染 prompt
system, user = registry.render(
name="content_generate",
context={"style_content": "攻略风", ...},
version="latest"
)
# 获取模型参数
config = registry.get("content_generate")
params = config.get_model_params()
```
"""
def __init__(self, prompts_dir: str = "prompts"):
"""
初始化 Prompt 注册表
Args:
prompts_dir: prompt 配置文件目录
"""
self.prompts_dir = Path(prompts_dir)
self._cache: Dict[str, PromptConfig] = {}
self._jinja_env = Environment(loader=BaseLoader())
logger.info(f"PromptRegistry 初始化,目录: {self.prompts_dir}")
def get(self, name: str, version: str = "latest") -> PromptConfig:
"""
获取 prompt 配置
Args:
name: prompt 名称 (如 "content_generate")
version: 版本号 (如 "v1.0.0""latest")
Returns:
PromptConfig 对象
"""
cache_key = f"{name}:{version}"
if cache_key not in self._cache:
self._cache[cache_key] = self._load(name, version)
logger.info(f"加载 prompt: {name}:{version}")
return self._cache[cache_key]
def render(self, name: str, context: Dict[str, Any],
version: str = "latest") -> Tuple[str, str]:
"""
渲染 prompt
Args:
name: prompt 名称
context: 变量上下文
version: 版本号
Returns:
(system_prompt, user_prompt) 元组
"""
config = self.get(name, version)
# 验证必填变量
self._validate_variables(config, context)
# 填充默认值
full_context = self._fill_defaults(config, context)
# 渲染模板
try:
system = self._render_template(config.system, full_context)
user = self._render_template(config.user, full_context)
except Exception as e:
logger.error(f"渲染 prompt 失败: {name}:{version}, 错误: {e}")
raise ValueError(f"渲染 prompt 失败: {e}")
return system, user
def _render_template(self, template_str: str, context: Dict[str, Any]) -> str:
"""渲染 Jinja2 模板"""
template = self._jinja_env.from_string(template_str)
return template.render(**context)
def _load(self, name: str, version: str) -> PromptConfig:
"""加载 prompt 配置文件"""
prompt_path = self._resolve_path(name, version)
if not prompt_path.exists():
raise FileNotFoundError(f"Prompt 不存在: {name}:{version} (路径: {prompt_path})")
with open(prompt_path, 'r', encoding='utf-8') as f:
data = yaml.safe_load(f)
meta = data.get('meta', {})
return PromptConfig(
name=meta.get('name', name),
version=meta.get('version', version),
description=meta.get('description', ''),
system=data.get('system', ''),
user=data.get('user', ''),
variables=data.get('variables', {}),
model=data.get('model', {}),
meta=meta,
content=data.get('content', '')
)
def _resolve_path(self, name: str, version: str) -> Path:
"""解析 prompt 文件路径"""
prompt_dir = self.prompts_dir / name
if version == "latest":
# 尝试 latest.yaml
latest_path = prompt_dir / "latest.yaml"
if latest_path.exists():
# 如果是软链接,解析真实路径
if latest_path.is_symlink():
return latest_path.resolve()
return latest_path
# 找最新版本
versions = self.list_versions(name)
if versions:
return prompt_dir / f"{versions[0]}.yaml"
raise FileNotFoundError(f"找不到 prompt: {name}")
else:
return prompt_dir / f"{version}.yaml"
def _validate_variables(self, config: PromptConfig, context: Dict[str, Any]):
"""验证必填变量"""
missing = []
for var_name, var_config in config.variables.items():
if isinstance(var_config, dict) and var_config.get('required', False):
if var_name not in context or context[var_name] is None:
missing.append(var_name)
if missing:
raise ValueError(f"缺少必填变量: {', '.join(missing)}")
def _fill_defaults(self, config: PromptConfig, context: Dict[str, Any]) -> Dict[str, Any]:
"""填充默认值"""
full_context = dict(context)
for var_name, var_config in config.variables.items():
if isinstance(var_config, dict):
if var_name not in full_context and 'default' in var_config:
full_context[var_name] = var_config['default']
return full_context
def list_versions(self, name: str) -> List[str]:
"""列出所有版本 (降序)"""
prompt_dir = self.prompts_dir / name
if not prompt_dir.exists():
return []
versions = []
for f in prompt_dir.glob("v*.yaml"):
versions.append(f.stem)
# 按版本号降序排序
return sorted(versions, reverse=True, key=lambda v: [int(x) for x in v[1:].split('.')])
def list_prompts(self) -> List[str]:
"""列出所有 prompt 名称 (支持嵌套目录)"""
if not self.prompts_dir.exists():
return []
prompts = []
def scan_dir(base_dir: Path, prefix: str = ""):
for d in base_dir.iterdir():
if d.is_dir() and not d.name.startswith('.'):
name = f"{prefix}{d.name}" if prefix else d.name
# 检查是否有版本文件
has_versions = any(d.glob("v*.yaml"))
if has_versions:
prompts.append(name)
else:
# 递归扫描子目录
scan_dir(d, f"{name}/")
scan_dir(self.prompts_dir)
return sorted(prompts)
def reload(self, name: str = None):
"""重新加载缓存"""
if name:
# 清除指定 prompt 的缓存
keys_to_remove = [k for k in self._cache if k.startswith(f"{name}:")]
for k in keys_to_remove:
del self._cache[k]
logger.info(f"已清除 {name} 的缓存")
else:
# 清除所有缓存
self._cache.clear()
logger.info("已清除所有 prompt 缓存")