258 lines
7.4 KiB
Python
Raw Normal View History

2025-07-11 13:50:08 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
提示词 API 路由
使用 PromptRegistry 管理所有 prompt
2025-07-11 13:50:08 +08:00
"""
import logging
from typing import Dict, Any, Optional, List
2025-07-11 13:50:08 +08:00
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
2025-07-11 13:50:08 +08:00
from domain.prompt import PromptRegistry
2025-07-11 13:50:08 +08:00
logger = logging.getLogger(__name__)
# ========== 依赖注入 ==========
def get_prompt_registry():
"""获取 PromptRegistry 实例"""
return PromptRegistry('prompts')
# ========== 路由 ==========
router = APIRouter(
tags=["prompt"],
responses={404: {"description": "Not found"}},
)
2025-07-11 13:50:08 +08:00
# ========== 请求/响应模型 ==========
class PromptListResponse(BaseModel):
"""Prompt 列表响应"""
prompts: List[str]
count: int
class PromptVersionsResponse(BaseModel):
"""Prompt 版本列表响应"""
name: str
versions: List[str]
count: int
class PromptInfoResponse(BaseModel):
"""Prompt 信息响应"""
name: str
version: str
description: str
variables: Dict[str, Any]
model_params: Dict[str, float]
class PromptRenderRequest(BaseModel):
"""Prompt 渲染请求"""
name: str = Field(..., description="Prompt 名称")
version: str = Field(default="latest", description="版本号")
context: Dict[str, Any] = Field(..., description="变量上下文")
class Config:
json_schema_extra = {
"example": {
"name": "content_generate",
"version": "latest",
"context": {
"scenic_spot": {"name": "天津冒险湾"},
"product": {"name": "家庭套票"},
"topic": {"title": "寒假遛娃好去处"}
}
}
}
class PromptRenderResponse(BaseModel):
"""Prompt 渲染响应"""
name: str
version: str
system_prompt: str
user_prompt: str
model_params: Dict[str, float]
class StyleAudienceRequest(BaseModel):
"""风格/人群请求"""
style: str = Field(..., description="风格名称 (gonglue/tuijian)")
audience: str = Field(..., description="人群名称 (qinzi/zhoubianyou/gaoshe)")
class Config:
json_schema_extra = {
"example": {
"style": "gonglue",
"audience": "qinzi"
}
}
class StyleAudienceResponse(BaseModel):
"""风格/人群响应"""
style_name: str
style_content: str
audience_name: str
audience_content: str
# ========== API 端点 ==========
@router.get("/list", response_model=PromptListResponse)
async def list_prompts(
registry: PromptRegistry = Depends(get_prompt_registry)
):
"""列出所有可用的 Prompt"""
try:
prompts = registry.list_prompts()
return PromptListResponse(
prompts=prompts,
count=len(prompts)
)
except Exception as e:
logger.error(f"列出 Prompt 失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/{name}/versions", response_model=PromptVersionsResponse)
async def list_prompt_versions(
name: str,
registry: PromptRegistry = Depends(get_prompt_registry)
2025-07-11 13:50:08 +08:00
):
"""列出指定 Prompt 的所有版本"""
2025-07-11 13:50:08 +08:00
try:
versions = registry.list_versions(name)
if not versions:
raise HTTPException(status_code=404, detail=f"Prompt '{name}' 不存在")
2025-07-11 13:50:08 +08:00
return PromptVersionsResponse(
name=name,
versions=versions,
count=len(versions)
)
except HTTPException:
raise
except Exception as e:
logger.error(f"列出版本失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/{name}", response_model=PromptInfoResponse)
async def get_prompt_info(
name: str,
version: str = "latest",
registry: PromptRegistry = Depends(get_prompt_registry)
):
"""获取 Prompt 详细信息"""
try:
config = registry.get(name, version)
return PromptInfoResponse(
name=config.name,
version=config.version,
description=config.description,
variables=config.variables,
model_params=config.get_model_params()
)
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
2025-07-11 13:50:08 +08:00
except Exception as e:
logger.error(f"获取 Prompt 信息失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
2025-07-11 13:50:08 +08:00
@router.post("/render", response_model=PromptRenderResponse)
async def render_prompt(
request: PromptRenderRequest,
registry: PromptRegistry = Depends(get_prompt_registry)
2025-07-11 13:50:08 +08:00
):
"""渲染 Prompt"""
2025-07-11 13:50:08 +08:00
try:
system_prompt, user_prompt = registry.render(
name=request.name,
context=request.context,
version=request.version
)
2025-07-11 13:50:08 +08:00
config = registry.get(request.name, request.version)
2025-07-11 13:50:08 +08:00
return PromptRenderResponse(
name=config.name,
version=config.version,
2025-07-11 13:50:08 +08:00
system_prompt=system_prompt,
user_prompt=user_prompt,
model_params=config.get_model_params()
2025-07-11 13:50:08 +08:00
)
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"渲染 Prompt 失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/{name}/preview")
async def preview_prompt(
name: str,
version: str = "latest",
registry: PromptRegistry = Depends(get_prompt_registry)
):
"""预览 Prompt 原始模板"""
try:
config = registry.get(name, version)
return {
"name": config.name,
"version": config.version,
"description": config.description,
"system_template": config.system,
"user_template": config.user,
"variables": config.variables,
"model_params": config.get_model_params()
}
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
2025-07-11 13:50:08 +08:00
except Exception as e:
logger.error(f"预览 Prompt 失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/style-audience", response_model=StyleAudienceResponse)
async def get_style_audience(
request: StyleAudienceRequest,
registry: PromptRegistry = Depends(get_prompt_registry)
):
"""获取风格和人群提示词内容"""
try:
# 获取风格
style_config = registry.get("style", request.style)
style_content = style_config.meta.get("content", "")
style_name = style_config.meta.get("style_name", request.style)
# 获取人群
audience_config = registry.get("audience", request.audience)
audience_content = audience_config.meta.get("content", "")
audience_name = audience_config.meta.get("audience_name", request.audience)
return StyleAudienceResponse(
style_name=style_name,
style_content=style_content,
audience_name=audience_name,
audience_content=audience_content
)
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
logger.error(f"获取风格/人群失败: {e}")
raise HTTPException(status_code=500, detail=str(e))