258 lines
7.4 KiB
Python
258 lines
7.4 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
提示词 API 路由
|
|
|
|
使用 PromptRegistry 管理所有 prompt
|
|
"""
|
|
|
|
import logging
|
|
from typing import Dict, Any, Optional, List
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from pydantic import BaseModel, Field
|
|
|
|
from domain.prompt import PromptRegistry
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ========== 依赖注入 ==========
|
|
|
|
def get_prompt_registry():
|
|
"""获取 PromptRegistry 实例"""
|
|
return PromptRegistry('prompts')
|
|
|
|
|
|
# ========== 路由 ==========
|
|
|
|
router = APIRouter(
|
|
tags=["prompt"],
|
|
responses={404: {"description": "Not found"}},
|
|
)
|
|
|
|
|
|
# ========== 请求/响应模型 ==========
|
|
|
|
class PromptListResponse(BaseModel):
|
|
"""Prompt 列表响应"""
|
|
prompts: List[str]
|
|
count: int
|
|
|
|
|
|
class PromptVersionsResponse(BaseModel):
|
|
"""Prompt 版本列表响应"""
|
|
name: str
|
|
versions: List[str]
|
|
count: int
|
|
|
|
|
|
class PromptInfoResponse(BaseModel):
|
|
"""Prompt 信息响应"""
|
|
name: str
|
|
version: str
|
|
description: str
|
|
variables: Dict[str, Any]
|
|
model_params: Dict[str, float]
|
|
|
|
|
|
class PromptRenderRequest(BaseModel):
|
|
"""Prompt 渲染请求"""
|
|
name: str = Field(..., description="Prompt 名称")
|
|
version: str = Field(default="latest", description="版本号")
|
|
context: Dict[str, Any] = Field(..., description="变量上下文")
|
|
|
|
class Config:
|
|
json_schema_extra = {
|
|
"example": {
|
|
"name": "content_generate",
|
|
"version": "latest",
|
|
"context": {
|
|
"scenic_spot": {"name": "天津冒险湾"},
|
|
"product": {"name": "家庭套票"},
|
|
"topic": {"title": "寒假遛娃好去处"}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
class PromptRenderResponse(BaseModel):
|
|
"""Prompt 渲染响应"""
|
|
name: str
|
|
version: str
|
|
system_prompt: str
|
|
user_prompt: str
|
|
model_params: Dict[str, float]
|
|
|
|
|
|
class StyleAudienceRequest(BaseModel):
|
|
"""风格/人群请求"""
|
|
style: str = Field(..., description="风格名称 (gonglue/tuijian)")
|
|
audience: str = Field(..., description="人群名称 (qinzi/zhoubianyou/gaoshe)")
|
|
|
|
class Config:
|
|
json_schema_extra = {
|
|
"example": {
|
|
"style": "gonglue",
|
|
"audience": "qinzi"
|
|
}
|
|
}
|
|
|
|
|
|
class StyleAudienceResponse(BaseModel):
|
|
"""风格/人群响应"""
|
|
style_name: str
|
|
style_content: str
|
|
audience_name: str
|
|
audience_content: str
|
|
|
|
|
|
# ========== API 端点 ==========
|
|
|
|
@router.get("/list", response_model=PromptListResponse)
|
|
async def list_prompts(
|
|
registry: PromptRegistry = Depends(get_prompt_registry)
|
|
):
|
|
"""列出所有可用的 Prompt"""
|
|
try:
|
|
prompts = registry.list_prompts()
|
|
return PromptListResponse(
|
|
prompts=prompts,
|
|
count=len(prompts)
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"列出 Prompt 失败: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.get("/{name}/versions", response_model=PromptVersionsResponse)
|
|
async def list_prompt_versions(
|
|
name: str,
|
|
registry: PromptRegistry = Depends(get_prompt_registry)
|
|
):
|
|
"""列出指定 Prompt 的所有版本"""
|
|
try:
|
|
versions = registry.list_versions(name)
|
|
if not versions:
|
|
raise HTTPException(status_code=404, detail=f"Prompt '{name}' 不存在")
|
|
|
|
return PromptVersionsResponse(
|
|
name=name,
|
|
versions=versions,
|
|
count=len(versions)
|
|
)
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"列出版本失败: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.get("/{name}", response_model=PromptInfoResponse)
|
|
async def get_prompt_info(
|
|
name: str,
|
|
version: str = "latest",
|
|
registry: PromptRegistry = Depends(get_prompt_registry)
|
|
):
|
|
"""获取 Prompt 详细信息"""
|
|
try:
|
|
config = registry.get(name, version)
|
|
return PromptInfoResponse(
|
|
name=config.name,
|
|
version=config.version,
|
|
description=config.description,
|
|
variables=config.variables,
|
|
model_params=config.get_model_params()
|
|
)
|
|
except FileNotFoundError as e:
|
|
raise HTTPException(status_code=404, detail=str(e))
|
|
except Exception as e:
|
|
logger.error(f"获取 Prompt 信息失败: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.post("/render", response_model=PromptRenderResponse)
|
|
async def render_prompt(
|
|
request: PromptRenderRequest,
|
|
registry: PromptRegistry = Depends(get_prompt_registry)
|
|
):
|
|
"""渲染 Prompt"""
|
|
try:
|
|
system_prompt, user_prompt = registry.render(
|
|
name=request.name,
|
|
context=request.context,
|
|
version=request.version
|
|
)
|
|
|
|
config = registry.get(request.name, request.version)
|
|
|
|
return PromptRenderResponse(
|
|
name=config.name,
|
|
version=config.version,
|
|
system_prompt=system_prompt,
|
|
user_prompt=user_prompt,
|
|
model_params=config.get_model_params()
|
|
)
|
|
except FileNotFoundError as e:
|
|
raise HTTPException(status_code=404, detail=str(e))
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
except Exception as e:
|
|
logger.error(f"渲染 Prompt 失败: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.get("/{name}/preview")
|
|
async def preview_prompt(
|
|
name: str,
|
|
version: str = "latest",
|
|
registry: PromptRegistry = Depends(get_prompt_registry)
|
|
):
|
|
"""预览 Prompt 原始模板"""
|
|
try:
|
|
config = registry.get(name, version)
|
|
return {
|
|
"name": config.name,
|
|
"version": config.version,
|
|
"description": config.description,
|
|
"system_template": config.system,
|
|
"user_template": config.user,
|
|
"variables": config.variables,
|
|
"model_params": config.get_model_params()
|
|
}
|
|
except FileNotFoundError as e:
|
|
raise HTTPException(status_code=404, detail=str(e))
|
|
except Exception as e:
|
|
logger.error(f"预览 Prompt 失败: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.post("/style-audience", response_model=StyleAudienceResponse)
|
|
async def get_style_audience(
|
|
request: StyleAudienceRequest,
|
|
registry: PromptRegistry = Depends(get_prompt_registry)
|
|
):
|
|
"""获取风格和人群提示词内容"""
|
|
try:
|
|
# 获取风格
|
|
style_config = registry.get("style", request.style)
|
|
style_content = style_config.meta.get("content", "")
|
|
style_name = style_config.meta.get("style_name", request.style)
|
|
|
|
# 获取人群
|
|
audience_config = registry.get("audience", request.audience)
|
|
audience_content = audience_config.meta.get("content", "")
|
|
audience_name = audience_config.meta.get("audience_name", request.audience)
|
|
|
|
return StyleAudienceResponse(
|
|
style_name=style_name,
|
|
style_content=style_content,
|
|
audience_name=audience_name,
|
|
audience_content=audience_content
|
|
)
|
|
except FileNotFoundError as e:
|
|
raise HTTPException(status_code=404, detail=str(e))
|
|
except Exception as e:
|
|
logger.error(f"获取风格/人群失败: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|