207 lines
8.0 KiB
Python
207 lines
8.0 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
提示词API路由
|
|
"""
|
|
|
|
import logging
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from typing import List, Dict, Any
|
|
|
|
from core.config import ConfigManager
|
|
from api.services.prompt_service import PromptService
|
|
from api.services.prompt_builder import PromptBuilderService
|
|
from api.models.prompt import (
|
|
StyleRequest, StyleResponse, StyleListResponse,
|
|
AudienceRequest, AudienceResponse, AudienceListResponse,
|
|
ScenicSpotRequest, ScenicSpotResponse, ScenicSpotListResponse,
|
|
PromptBuilderRequest, PromptBuilderResponse
|
|
)
|
|
|
|
# 从依赖注入模块导入依赖
|
|
from api.dependencies import get_config
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# 创建路由
|
|
router = APIRouter()
|
|
|
|
# 依赖注入函数
|
|
def get_prompt_service(
|
|
config_manager: ConfigManager = Depends(get_config)
|
|
) -> PromptService:
|
|
"""获取提示词服务"""
|
|
return PromptService(config_manager)
|
|
|
|
def get_prompt_builder(
|
|
config_manager: ConfigManager = Depends(get_config),
|
|
prompt_service: PromptService = Depends(get_prompt_service)
|
|
) -> PromptBuilderService:
|
|
"""获取提示词构建服务"""
|
|
return PromptBuilderService(config_manager, prompt_service)
|
|
|
|
|
|
@router.get("/styles", response_model=StyleListResponse)
|
|
async def get_all_styles(
|
|
prompt_service: PromptService = Depends(get_prompt_service)
|
|
):
|
|
"""获取所有内容风格"""
|
|
try:
|
|
styles_dict = prompt_service.get_all_styles()
|
|
# 将字典列表转换为StyleResponse对象列表
|
|
styles = [StyleResponse(name=style["name"], description=style["description"]) for style in styles_dict]
|
|
return StyleListResponse(styles=styles)
|
|
except Exception as e:
|
|
logger.error(f"获取所有风格失败: {e}")
|
|
raise HTTPException(status_code=500, detail=f"获取风格列表失败: {str(e)}")
|
|
|
|
|
|
@router.get("/styles/{style_name}", response_model=StyleResponse)
|
|
async def get_style(
|
|
style_name: str,
|
|
prompt_service: PromptService = Depends(get_prompt_service)
|
|
):
|
|
"""获取指定内容风格"""
|
|
try:
|
|
content = prompt_service.get_style_content(style_name)
|
|
return StyleResponse(name=style_name, description=content)
|
|
except Exception as e:
|
|
logger.error(f"获取风格 '{style_name}' 失败: {e}")
|
|
raise HTTPException(status_code=404, detail=f"未找到风格: {style_name}")
|
|
|
|
|
|
@router.post("/styles", response_model=StyleResponse)
|
|
async def create_or_update_style(
|
|
style: StyleRequest,
|
|
prompt_service: PromptService = Depends(get_prompt_service)
|
|
):
|
|
"""创建或更新内容风格"""
|
|
try:
|
|
if not style.description:
|
|
# 如果没有提供描述,则获取现有的
|
|
content = prompt_service.get_style_content(style.name)
|
|
return StyleResponse(name=style.name, description=content)
|
|
|
|
success = prompt_service.save_style(style.name, style.description)
|
|
if not success:
|
|
raise HTTPException(status_code=500, detail=f"保存风格 '{style.name}' 失败")
|
|
|
|
return StyleResponse(name=style.name, description=style.description)
|
|
except Exception as e:
|
|
logger.error(f"保存风格 '{style.name}' 失败: {e}")
|
|
raise HTTPException(status_code=500, detail=f"操作失败: {str(e)}")
|
|
|
|
|
|
@router.get("/audiences", response_model=AudienceListResponse)
|
|
async def get_all_audiences(
|
|
prompt_service: PromptService = Depends(get_prompt_service)
|
|
):
|
|
"""获取所有目标受众"""
|
|
try:
|
|
audiences_dict = prompt_service.get_all_audiences()
|
|
# 将字典列表转换为AudienceResponse对象列表
|
|
audiences = [AudienceResponse(name=audience["name"], description=audience["description"]) for audience in audiences_dict]
|
|
return AudienceListResponse(audiences=audiences)
|
|
except Exception as e:
|
|
logger.error(f"获取所有受众失败: {e}")
|
|
raise HTTPException(status_code=500, detail=f"获取受众列表失败: {str(e)}")
|
|
|
|
|
|
@router.get("/audiences/{audience_name}", response_model=AudienceResponse)
|
|
async def get_audience(
|
|
audience_name: str,
|
|
prompt_service: PromptService = Depends(get_prompt_service)
|
|
):
|
|
"""获取指定目标受众"""
|
|
try:
|
|
content = prompt_service.get_audience_content(audience_name)
|
|
return AudienceResponse(name=audience_name, description=content)
|
|
except Exception as e:
|
|
logger.error(f"获取受众 '{audience_name}' 失败: {e}")
|
|
raise HTTPException(status_code=404, detail=f"未找到受众: {audience_name}")
|
|
|
|
|
|
@router.post("/audiences", response_model=AudienceResponse)
|
|
async def create_or_update_audience(
|
|
audience: AudienceRequest,
|
|
prompt_service: PromptService = Depends(get_prompt_service)
|
|
):
|
|
"""创建或更新目标受众"""
|
|
try:
|
|
if not audience.description:
|
|
# 如果没有提供描述,则获取现有的
|
|
content = prompt_service.get_audience_content(audience.name)
|
|
return AudienceResponse(name=audience.name, description=content)
|
|
|
|
success = prompt_service.save_audience(audience.name, audience.description)
|
|
if not success:
|
|
raise HTTPException(status_code=500, detail=f"保存受众 '{audience.name}' 失败")
|
|
|
|
return AudienceResponse(name=audience.name, description=audience.description)
|
|
except Exception as e:
|
|
logger.error(f"保存受众 '{audience.name}' 失败: {e}")
|
|
raise HTTPException(status_code=500, detail=f"操作失败: {str(e)}")
|
|
|
|
|
|
@router.get("/scenic-spots", response_model=ScenicSpotListResponse)
|
|
async def get_all_scenic_spots(
|
|
prompt_service: PromptService = Depends(get_prompt_service)
|
|
):
|
|
"""获取所有景区"""
|
|
try:
|
|
spots_dict = prompt_service.get_all_scenic_spots()
|
|
# 将字典列表转换为ScenicSpotResponse对象列表
|
|
spots = [ScenicSpotResponse(name=spot["name"], description=spot["description"]) for spot in spots_dict]
|
|
return ScenicSpotListResponse(spots=spots)
|
|
except Exception as e:
|
|
logger.error(f"获取所有景区失败: {e}")
|
|
raise HTTPException(status_code=500, detail=f"获取景区列表失败: {str(e)}")
|
|
|
|
|
|
@router.get("/scenic-spots/{spot_name}", response_model=ScenicSpotResponse)
|
|
async def get_scenic_spot(
|
|
spot_name: str,
|
|
prompt_service: PromptService = Depends(get_prompt_service)
|
|
):
|
|
"""获取指定景区信息"""
|
|
try:
|
|
content = prompt_service.get_scenic_spot_info(spot_name)
|
|
return ScenicSpotResponse(name=spot_name, description=content)
|
|
except Exception as e:
|
|
logger.error(f"获取景区 '{spot_name}' 失败: {e}")
|
|
raise HTTPException(status_code=404, detail=f"未找到景区: {spot_name}")
|
|
|
|
|
|
@router.post("/build-prompt", response_model=PromptBuilderResponse)
|
|
async def build_prompt(
|
|
request: PromptBuilderRequest,
|
|
prompt_builder: PromptBuilderService = Depends(get_prompt_builder)
|
|
):
|
|
"""构建完整提示词"""
|
|
try:
|
|
# 根据请求中的step确定构建哪种类型的提示词
|
|
step = request.step or "content"
|
|
|
|
if step == "topic":
|
|
# 构建选题提示词
|
|
# 从topic中提取必要的参数
|
|
num_topics = request.topic.get("num_topics", 5)
|
|
month = request.topic.get("month", "7")
|
|
system_prompt, user_prompt = prompt_builder.build_topic_prompt(num_topics, month)
|
|
elif step == "judge":
|
|
# 构建审核提示词
|
|
# 需要提供生成的内容
|
|
content = request.topic.get("content", {})
|
|
system_prompt, user_prompt = prompt_builder.build_judge_prompt(request.topic, content)
|
|
else:
|
|
# 默认构建内容生成提示词
|
|
system_prompt, user_prompt = prompt_builder.build_content_prompt(request.topic, step)
|
|
|
|
return PromptBuilderResponse(
|
|
system_prompt=system_prompt,
|
|
user_prompt=user_prompt
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"构建提示词失败: {e}")
|
|
raise HTTPException(status_code=500, detail=f"构建提示词失败: {str(e)}") |