258 lines
7.4 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
提示词 API 路由
使用 PromptRegistry 管理所有 prompt
"""
import logging
from typing import Dict, Any, Optional, List
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from domain.prompt import PromptRegistry
logger = logging.getLogger(__name__)
# ========== 依赖注入 ==========
def get_prompt_registry():
"""获取 PromptRegistry 实例"""
return PromptRegistry('prompts')
# ========== 路由 ==========
router = APIRouter(
tags=["prompt"],
responses={404: {"description": "Not found"}},
)
# ========== 请求/响应模型 ==========
class PromptListResponse(BaseModel):
"""Prompt 列表响应"""
prompts: List[str]
count: int
class PromptVersionsResponse(BaseModel):
"""Prompt 版本列表响应"""
name: str
versions: List[str]
count: int
class PromptInfoResponse(BaseModel):
"""Prompt 信息响应"""
name: str
version: str
description: str
variables: Dict[str, Any]
model_params: Dict[str, float]
class PromptRenderRequest(BaseModel):
"""Prompt 渲染请求"""
name: str = Field(..., description="Prompt 名称")
version: str = Field(default="latest", description="版本号")
context: Dict[str, Any] = Field(..., description="变量上下文")
class Config:
json_schema_extra = {
"example": {
"name": "content_generate",
"version": "latest",
"context": {
"scenic_spot": {"name": "天津冒险湾"},
"product": {"name": "家庭套票"},
"topic": {"title": "寒假遛娃好去处"}
}
}
}
class PromptRenderResponse(BaseModel):
"""Prompt 渲染响应"""
name: str
version: str
system_prompt: str
user_prompt: str
model_params: Dict[str, float]
class StyleAudienceRequest(BaseModel):
"""风格/人群请求"""
style: str = Field(..., description="风格名称 (gonglue/tuijian)")
audience: str = Field(..., description="人群名称 (qinzi/zhoubianyou/gaoshe)")
class Config:
json_schema_extra = {
"example": {
"style": "gonglue",
"audience": "qinzi"
}
}
class StyleAudienceResponse(BaseModel):
"""风格/人群响应"""
style_name: str
style_content: str
audience_name: str
audience_content: str
# ========== API 端点 ==========
@router.get("/list", response_model=PromptListResponse)
async def list_prompts(
registry: PromptRegistry = Depends(get_prompt_registry)
):
"""列出所有可用的 Prompt"""
try:
prompts = registry.list_prompts()
return PromptListResponse(
prompts=prompts,
count=len(prompts)
)
except Exception as e:
logger.error(f"列出 Prompt 失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/{name}/versions", response_model=PromptVersionsResponse)
async def list_prompt_versions(
name: str,
registry: PromptRegistry = Depends(get_prompt_registry)
):
"""列出指定 Prompt 的所有版本"""
try:
versions = registry.list_versions(name)
if not versions:
raise HTTPException(status_code=404, detail=f"Prompt '{name}' 不存在")
return PromptVersionsResponse(
name=name,
versions=versions,
count=len(versions)
)
except HTTPException:
raise
except Exception as e:
logger.error(f"列出版本失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/{name}", response_model=PromptInfoResponse)
async def get_prompt_info(
name: str,
version: str = "latest",
registry: PromptRegistry = Depends(get_prompt_registry)
):
"""获取 Prompt 详细信息"""
try:
config = registry.get(name, version)
return PromptInfoResponse(
name=config.name,
version=config.version,
description=config.description,
variables=config.variables,
model_params=config.get_model_params()
)
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
logger.error(f"获取 Prompt 信息失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/render", response_model=PromptRenderResponse)
async def render_prompt(
request: PromptRenderRequest,
registry: PromptRegistry = Depends(get_prompt_registry)
):
"""渲染 Prompt"""
try:
system_prompt, user_prompt = registry.render(
name=request.name,
context=request.context,
version=request.version
)
config = registry.get(request.name, request.version)
return PromptRenderResponse(
name=config.name,
version=config.version,
system_prompt=system_prompt,
user_prompt=user_prompt,
model_params=config.get_model_params()
)
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"渲染 Prompt 失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/{name}/preview")
async def preview_prompt(
name: str,
version: str = "latest",
registry: PromptRegistry = Depends(get_prompt_registry)
):
"""预览 Prompt 原始模板"""
try:
config = registry.get(name, version)
return {
"name": config.name,
"version": config.version,
"description": config.description,
"system_template": config.system,
"user_template": config.user,
"variables": config.variables,
"model_params": config.get_model_params()
}
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
logger.error(f"预览 Prompt 失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/style-audience", response_model=StyleAudienceResponse)
async def get_style_audience(
request: StyleAudienceRequest,
registry: PromptRegistry = Depends(get_prompt_registry)
):
"""获取风格和人群提示词内容"""
try:
# 获取风格
style_config = registry.get("style", request.style)
style_content = style_config.meta.get("content", "")
style_name = style_config.meta.get("style_name", request.style)
# 获取人群
audience_config = registry.get("audience", request.audience)
audience_content = audience_config.meta.get("content", "")
audience_name = audience_config.meta.get("audience_name", request.audience)
return StyleAudienceResponse(
style_name=style_name,
style_content=style_content,
audience_name=audience_name,
audience_content=audience_content
)
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
logger.error(f"获取风格/人群失败: {e}")
raise HTTPException(status_code=500, detail=str(e))