207 lines
8.0 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
提示词API路由
"""
import logging
from fastapi import APIRouter, Depends, HTTPException
from typing import List, Dict, Any
from core.config import ConfigManager
from api.services.prompt_service import PromptService
from api.services.prompt_builder import PromptBuilderService
from api.models.prompt import (
StyleRequest, StyleResponse, StyleListResponse,
AudienceRequest, AudienceResponse, AudienceListResponse,
ScenicSpotRequest, ScenicSpotResponse, ScenicSpotListResponse,
PromptBuilderRequest, PromptBuilderResponse
)
# 从依赖注入模块导入依赖
from api.dependencies import get_config
logger = logging.getLogger(__name__)
# 创建路由
router = APIRouter()
# 依赖注入函数
def get_prompt_service(
config_manager: ConfigManager = Depends(get_config)
) -> PromptService:
"""获取提示词服务"""
return PromptService(config_manager)
def get_prompt_builder(
config_manager: ConfigManager = Depends(get_config),
prompt_service: PromptService = Depends(get_prompt_service)
) -> PromptBuilderService:
"""获取提示词构建服务"""
return PromptBuilderService(config_manager, prompt_service)
@router.get("/styles", response_model=StyleListResponse)
async def get_all_styles(
prompt_service: PromptService = Depends(get_prompt_service)
):
"""获取所有内容风格"""
try:
styles_dict = prompt_service.get_all_styles()
# 将字典列表转换为StyleResponse对象列表
styles = [StyleResponse(name=style["name"], description=style["description"]) for style in styles_dict]
return StyleListResponse(styles=styles)
except Exception as e:
logger.error(f"获取所有风格失败: {e}")
raise HTTPException(status_code=500, detail=f"获取风格列表失败: {str(e)}")
@router.get("/styles/{style_name}", response_model=StyleResponse)
async def get_style(
style_name: str,
prompt_service: PromptService = Depends(get_prompt_service)
):
"""获取指定内容风格"""
try:
content = prompt_service.get_style_content(style_name)
return StyleResponse(name=style_name, description=content)
except Exception as e:
logger.error(f"获取风格 '{style_name}' 失败: {e}")
raise HTTPException(status_code=404, detail=f"未找到风格: {style_name}")
@router.post("/styles", response_model=StyleResponse)
async def create_or_update_style(
style: StyleRequest,
prompt_service: PromptService = Depends(get_prompt_service)
):
"""创建或更新内容风格"""
try:
if not style.description:
# 如果没有提供描述,则获取现有的
content = prompt_service.get_style_content(style.name)
return StyleResponse(name=style.name, description=content)
success = prompt_service.save_style(style.name, style.description)
if not success:
raise HTTPException(status_code=500, detail=f"保存风格 '{style.name}' 失败")
return StyleResponse(name=style.name, description=style.description)
except Exception as e:
logger.error(f"保存风格 '{style.name}' 失败: {e}")
raise HTTPException(status_code=500, detail=f"操作失败: {str(e)}")
@router.get("/audiences", response_model=AudienceListResponse)
async def get_all_audiences(
prompt_service: PromptService = Depends(get_prompt_service)
):
"""获取所有目标受众"""
try:
audiences_dict = prompt_service.get_all_audiences()
# 将字典列表转换为AudienceResponse对象列表
audiences = [AudienceResponse(name=audience["name"], description=audience["description"]) for audience in audiences_dict]
return AudienceListResponse(audiences=audiences)
except Exception as e:
logger.error(f"获取所有受众失败: {e}")
raise HTTPException(status_code=500, detail=f"获取受众列表失败: {str(e)}")
@router.get("/audiences/{audience_name}", response_model=AudienceResponse)
async def get_audience(
audience_name: str,
prompt_service: PromptService = Depends(get_prompt_service)
):
"""获取指定目标受众"""
try:
content = prompt_service.get_audience_content(audience_name)
return AudienceResponse(name=audience_name, description=content)
except Exception as e:
logger.error(f"获取受众 '{audience_name}' 失败: {e}")
raise HTTPException(status_code=404, detail=f"未找到受众: {audience_name}")
@router.post("/audiences", response_model=AudienceResponse)
async def create_or_update_audience(
audience: AudienceRequest,
prompt_service: PromptService = Depends(get_prompt_service)
):
"""创建或更新目标受众"""
try:
if not audience.description:
# 如果没有提供描述,则获取现有的
content = prompt_service.get_audience_content(audience.name)
return AudienceResponse(name=audience.name, description=content)
success = prompt_service.save_audience(audience.name, audience.description)
if not success:
raise HTTPException(status_code=500, detail=f"保存受众 '{audience.name}' 失败")
return AudienceResponse(name=audience.name, description=audience.description)
except Exception as e:
logger.error(f"保存受众 '{audience.name}' 失败: {e}")
raise HTTPException(status_code=500, detail=f"操作失败: {str(e)}")
@router.get("/scenic-spots", response_model=ScenicSpotListResponse)
async def get_all_scenic_spots(
prompt_service: PromptService = Depends(get_prompt_service)
):
"""获取所有景区"""
try:
spots_dict = prompt_service.get_all_scenic_spots()
# 将字典列表转换为ScenicSpotResponse对象列表
spots = [ScenicSpotResponse(name=spot["name"], description=spot["description"]) for spot in spots_dict]
return ScenicSpotListResponse(spots=spots)
except Exception as e:
logger.error(f"获取所有景区失败: {e}")
raise HTTPException(status_code=500, detail=f"获取景区列表失败: {str(e)}")
@router.get("/scenic-spots/{spot_name}", response_model=ScenicSpotResponse)
async def get_scenic_spot(
spot_name: str,
prompt_service: PromptService = Depends(get_prompt_service)
):
"""获取指定景区信息"""
try:
content = prompt_service.get_scenic_spot_info(spot_name)
return ScenicSpotResponse(name=spot_name, description=content)
except Exception as e:
logger.error(f"获取景区 '{spot_name}' 失败: {e}")
raise HTTPException(status_code=404, detail=f"未找到景区: {spot_name}")
@router.post("/build-prompt", response_model=PromptBuilderResponse)
async def build_prompt(
request: PromptBuilderRequest,
prompt_builder: PromptBuilderService = Depends(get_prompt_builder)
):
"""构建完整提示词"""
try:
# 根据请求中的step确定构建哪种类型的提示词
step = request.step or "content"
if step == "topic":
# 构建选题提示词
# 从topic中提取必要的参数
num_topics = request.topic.get("num_topics", 5)
month = request.topic.get("month", "7")
system_prompt, user_prompt = prompt_builder.build_topic_prompt(num_topics, month)
elif step == "judge":
# 构建审核提示词
# 需要提供生成的内容
content = request.topic.get("content", {})
system_prompt, user_prompt = prompt_builder.build_judge_prompt(request.topic, content)
else:
# 默认构建内容生成提示词
system_prompt, user_prompt = prompt_builder.build_content_prompt(request.topic, step)
return PromptBuilderResponse(
system_prompt=system_prompt,
user_prompt=user_prompt
)
except Exception as e:
logger.error(f"构建提示词失败: {e}")
raise HTTPException(status_code=500, detail=f"构建提示词失败: {str(e)}")