2025-12-08 14:58:35 +08:00
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
Prompt 注册表
|
|
|
|
|
提供版本化的 Prompt 管理、加载和渲染功能
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import logging
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import Dict, Any, Optional, List, Tuple
|
|
|
|
|
from dataclasses import dataclass, field
|
|
|
|
|
import yaml
|
|
|
|
|
from jinja2 import Template, Environment, BaseLoader
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class PromptConfig:
|
|
|
|
|
"""Prompt 配置"""
|
|
|
|
|
name: str
|
|
|
|
|
version: str
|
|
|
|
|
description: str
|
|
|
|
|
system: str
|
|
|
|
|
user: str
|
|
|
|
|
variables: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
model: Dict[str, float] = field(default_factory=dict)
|
2025-12-09 21:16:44 +08:00
|
|
|
meta: Dict[str, Any] = field(default_factory=dict) # 额外元数据
|
|
|
|
|
content: str = "" # 风格/人群等配置内容
|
2025-12-08 14:58:35 +08:00
|
|
|
|
|
|
|
|
def get_model_params(self) -> Dict[str, float]:
|
|
|
|
|
"""获取模型参数"""
|
|
|
|
|
return {
|
|
|
|
|
'temperature': self.model.get('temperature', 0.7),
|
|
|
|
|
'top_p': self.model.get('top_p', 0.9),
|
|
|
|
|
'presence_penalty': self.model.get('presence_penalty', 0.0),
|
|
|
|
|
}
|
2025-12-09 21:16:44 +08:00
|
|
|
|
|
|
|
|
def get_content(self) -> str:
|
|
|
|
|
"""获取配置内容 (用于风格/人群等)"""
|
|
|
|
|
return self.content or self.system
|
2025-12-08 14:58:35 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class PromptRegistry:
|
|
|
|
|
"""
|
|
|
|
|
Prompt 注册表
|
|
|
|
|
|
|
|
|
|
功能:
|
|
|
|
|
1. 加载和缓存 prompt 配置
|
|
|
|
|
2. 版本管理 (latest, v1.0.0, v1.1.0)
|
|
|
|
|
3. 变量验证
|
|
|
|
|
4. 模板渲染 (Jinja2)
|
|
|
|
|
|
|
|
|
|
使用示例:
|
|
|
|
|
```python
|
|
|
|
|
registry = PromptRegistry("prompts")
|
|
|
|
|
|
|
|
|
|
# 渲染 prompt
|
|
|
|
|
system, user = registry.render(
|
|
|
|
|
name="content_generate",
|
|
|
|
|
context={"style_content": "攻略风", ...},
|
|
|
|
|
version="latest"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 获取模型参数
|
|
|
|
|
config = registry.get("content_generate")
|
|
|
|
|
params = config.get_model_params()
|
|
|
|
|
```
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, prompts_dir: str = "prompts"):
|
|
|
|
|
"""
|
|
|
|
|
初始化 Prompt 注册表
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
prompts_dir: prompt 配置文件目录
|
|
|
|
|
"""
|
|
|
|
|
self.prompts_dir = Path(prompts_dir)
|
|
|
|
|
self._cache: Dict[str, PromptConfig] = {}
|
|
|
|
|
self._jinja_env = Environment(loader=BaseLoader())
|
|
|
|
|
|
|
|
|
|
logger.info(f"PromptRegistry 初始化,目录: {self.prompts_dir}")
|
|
|
|
|
|
|
|
|
|
def get(self, name: str, version: str = "latest") -> PromptConfig:
|
|
|
|
|
"""
|
|
|
|
|
获取 prompt 配置
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
name: prompt 名称 (如 "content_generate")
|
|
|
|
|
version: 版本号 (如 "v1.0.0" 或 "latest")
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
PromptConfig 对象
|
|
|
|
|
"""
|
|
|
|
|
cache_key = f"{name}:{version}"
|
|
|
|
|
|
|
|
|
|
if cache_key not in self._cache:
|
|
|
|
|
self._cache[cache_key] = self._load(name, version)
|
|
|
|
|
logger.info(f"加载 prompt: {name}:{version}")
|
|
|
|
|
|
|
|
|
|
return self._cache[cache_key]
|
|
|
|
|
|
|
|
|
|
def render(self, name: str, context: Dict[str, Any],
|
|
|
|
|
version: str = "latest") -> Tuple[str, str]:
|
|
|
|
|
"""
|
|
|
|
|
渲染 prompt
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
name: prompt 名称
|
|
|
|
|
context: 变量上下文
|
|
|
|
|
version: 版本号
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
(system_prompt, user_prompt) 元组
|
|
|
|
|
"""
|
|
|
|
|
config = self.get(name, version)
|
|
|
|
|
|
|
|
|
|
# 验证必填变量
|
|
|
|
|
self._validate_variables(config, context)
|
|
|
|
|
|
|
|
|
|
# 填充默认值
|
|
|
|
|
full_context = self._fill_defaults(config, context)
|
|
|
|
|
|
|
|
|
|
# 渲染模板
|
|
|
|
|
try:
|
|
|
|
|
system = self._render_template(config.system, full_context)
|
|
|
|
|
user = self._render_template(config.user, full_context)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"渲染 prompt 失败: {name}:{version}, 错误: {e}")
|
|
|
|
|
raise ValueError(f"渲染 prompt 失败: {e}")
|
|
|
|
|
|
|
|
|
|
return system, user
|
|
|
|
|
|
|
|
|
|
def _render_template(self, template_str: str, context: Dict[str, Any]) -> str:
|
|
|
|
|
"""渲染 Jinja2 模板"""
|
|
|
|
|
template = self._jinja_env.from_string(template_str)
|
|
|
|
|
return template.render(**context)
|
|
|
|
|
|
|
|
|
|
def _load(self, name: str, version: str) -> PromptConfig:
|
|
|
|
|
"""加载 prompt 配置文件"""
|
|
|
|
|
prompt_path = self._resolve_path(name, version)
|
|
|
|
|
|
|
|
|
|
if not prompt_path.exists():
|
|
|
|
|
raise FileNotFoundError(f"Prompt 不存在: {name}:{version} (路径: {prompt_path})")
|
|
|
|
|
|
|
|
|
|
with open(prompt_path, 'r', encoding='utf-8') as f:
|
|
|
|
|
data = yaml.safe_load(f)
|
|
|
|
|
|
|
|
|
|
meta = data.get('meta', {})
|
|
|
|
|
|
|
|
|
|
return PromptConfig(
|
|
|
|
|
name=meta.get('name', name),
|
|
|
|
|
version=meta.get('version', version),
|
|
|
|
|
description=meta.get('description', ''),
|
|
|
|
|
system=data.get('system', ''),
|
|
|
|
|
user=data.get('user', ''),
|
|
|
|
|
variables=data.get('variables', {}),
|
2025-12-09 21:16:44 +08:00
|
|
|
model=data.get('model', {}),
|
|
|
|
|
meta=meta,
|
|
|
|
|
content=data.get('content', '')
|
2025-12-08 14:58:35 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _resolve_path(self, name: str, version: str) -> Path:
|
|
|
|
|
"""解析 prompt 文件路径"""
|
|
|
|
|
prompt_dir = self.prompts_dir / name
|
|
|
|
|
|
|
|
|
|
if version == "latest":
|
|
|
|
|
# 尝试 latest.yaml
|
|
|
|
|
latest_path = prompt_dir / "latest.yaml"
|
|
|
|
|
if latest_path.exists():
|
|
|
|
|
# 如果是软链接,解析真实路径
|
|
|
|
|
if latest_path.is_symlink():
|
|
|
|
|
return latest_path.resolve()
|
|
|
|
|
return latest_path
|
|
|
|
|
|
|
|
|
|
# 找最新版本
|
|
|
|
|
versions = self.list_versions(name)
|
|
|
|
|
if versions:
|
|
|
|
|
return prompt_dir / f"{versions[0]}.yaml"
|
|
|
|
|
|
|
|
|
|
raise FileNotFoundError(f"找不到 prompt: {name}")
|
|
|
|
|
else:
|
|
|
|
|
return prompt_dir / f"{version}.yaml"
|
|
|
|
|
|
|
|
|
|
def _validate_variables(self, config: PromptConfig, context: Dict[str, Any]):
|
|
|
|
|
"""验证必填变量"""
|
|
|
|
|
missing = []
|
|
|
|
|
for var_name, var_config in config.variables.items():
|
|
|
|
|
if isinstance(var_config, dict) and var_config.get('required', False):
|
|
|
|
|
if var_name not in context or context[var_name] is None:
|
|
|
|
|
missing.append(var_name)
|
|
|
|
|
|
|
|
|
|
if missing:
|
|
|
|
|
raise ValueError(f"缺少必填变量: {', '.join(missing)}")
|
|
|
|
|
|
|
|
|
|
def _fill_defaults(self, config: PromptConfig, context: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
|
"""填充默认值"""
|
|
|
|
|
full_context = dict(context)
|
|
|
|
|
|
|
|
|
|
for var_name, var_config in config.variables.items():
|
|
|
|
|
if isinstance(var_config, dict):
|
|
|
|
|
if var_name not in full_context and 'default' in var_config:
|
|
|
|
|
full_context[var_name] = var_config['default']
|
|
|
|
|
|
|
|
|
|
return full_context
|
|
|
|
|
|
|
|
|
|
def list_versions(self, name: str) -> List[str]:
|
|
|
|
|
"""列出所有版本 (降序)"""
|
|
|
|
|
prompt_dir = self.prompts_dir / name
|
|
|
|
|
if not prompt_dir.exists():
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
versions = []
|
|
|
|
|
for f in prompt_dir.glob("v*.yaml"):
|
|
|
|
|
versions.append(f.stem)
|
|
|
|
|
|
|
|
|
|
# 按版本号降序排序
|
|
|
|
|
return sorted(versions, reverse=True, key=lambda v: [int(x) for x in v[1:].split('.')])
|
|
|
|
|
|
|
|
|
|
def list_prompts(self) -> List[str]:
|
2025-12-09 21:16:44 +08:00
|
|
|
"""列出所有 prompt 名称 (支持嵌套目录)"""
|
2025-12-08 14:58:35 +08:00
|
|
|
if not self.prompts_dir.exists():
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
prompts = []
|
|
|
|
|
|
2025-12-09 21:16:44 +08:00
|
|
|
def scan_dir(base_dir: Path, prefix: str = ""):
|
|
|
|
|
for d in base_dir.iterdir():
|
|
|
|
|
if d.is_dir() and not d.name.startswith('.'):
|
|
|
|
|
name = f"{prefix}{d.name}" if prefix else d.name
|
|
|
|
|
# 检查是否有版本文件
|
|
|
|
|
has_versions = any(d.glob("v*.yaml"))
|
|
|
|
|
if has_versions:
|
|
|
|
|
prompts.append(name)
|
|
|
|
|
else:
|
|
|
|
|
# 递归扫描子目录
|
|
|
|
|
scan_dir(d, f"{name}/")
|
|
|
|
|
|
|
|
|
|
scan_dir(self.prompts_dir)
|
2025-12-08 14:58:35 +08:00
|
|
|
return sorted(prompts)
|
|
|
|
|
|
|
|
|
|
def reload(self, name: str = None):
|
|
|
|
|
"""重新加载缓存"""
|
|
|
|
|
if name:
|
|
|
|
|
# 清除指定 prompt 的缓存
|
|
|
|
|
keys_to_remove = [k for k in self._cache if k.startswith(f"{name}:")]
|
|
|
|
|
for k in keys_to_remove:
|
|
|
|
|
del self._cache[k]
|
|
|
|
|
logger.info(f"已清除 {name} 的缓存")
|
|
|
|
|
else:
|
|
|
|
|
# 清除所有缓存
|
|
|
|
|
self._cache.clear()
|
|
|
|
|
logger.info("已清除所有 prompt 缓存")
|