206 lines
6.7 KiB
Python
206 lines
6.7 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
文字内容API路由
|
|
"""
|
|
|
|
import logging
|
|
from typing import List, Dict, Any, Optional
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
|
|
from api.models.tweet import (
|
|
TopicRequest, TopicResponse,
|
|
ContentRequest, ContentResponse,
|
|
JudgeRequest, JudgeResponse,
|
|
PipelineRequest, PipelineResponse
|
|
)
|
|
from api.services.tweet import TweetService
|
|
from api.dependencies import get_tweet_service
|
|
|
|
# 创建一个新的模型用于接收预构建提示词的请求
|
|
from pydantic import BaseModel, Field
|
|
|
|
class ContentWithPromptRequest(BaseModel):
|
|
"""带有预构建提示词的内容生成请求模型"""
|
|
topic: Dict[str, Any] = Field(..., description="选题信息")
|
|
system_prompt: str = Field(..., description="系统提示词")
|
|
user_prompt: str = Field(..., description="用户提示词")
|
|
|
|
class Config:
|
|
schema_extra = {
|
|
"example": {
|
|
"topic": {
|
|
"index": "1",
|
|
"date": "2023-07-15",
|
|
"object": "北京故宫",
|
|
"product": "故宫门票",
|
|
"style": "旅游攻略",
|
|
"target_audience": "年轻人",
|
|
"logic": "暑期旅游热门景点推荐"
|
|
},
|
|
"system_prompt": "你是一位专业的旅游内容创作者...",
|
|
"user_prompt": "请为以下景点创作一篇旅游文章..."
|
|
}
|
|
}
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(
|
|
tags=["tweet"],
|
|
responses={404: {"description": "Not found"}},
|
|
)
|
|
|
|
|
|
@router.post("/topics", response_model=TopicResponse, summary="生成选题")
|
|
async def generate_topics(
|
|
request: TopicRequest,
|
|
tweet_service: TweetService = Depends(get_tweet_service)
|
|
):
|
|
"""
|
|
生成选题
|
|
|
|
- **dates**: 日期字符串,可能为单个日期、多个日期用逗号分隔或范围
|
|
- **num_topics**: 要生成的选题数量
|
|
- **styles**: 风格列表
|
|
- **audiences**: 受众列表
|
|
- **scenic_spots**: 景区列表
|
|
- **products**: 产品列表
|
|
"""
|
|
try:
|
|
request_id, topics = await tweet_service.generate_topics(
|
|
dates=request.dates,
|
|
num_topics=request.num_topics,
|
|
styles=request.styles,
|
|
audiences=request.audiences,
|
|
scenic_spots=request.scenic_spots,
|
|
products=request.products
|
|
)
|
|
|
|
return TopicResponse(
|
|
request_id=request_id,
|
|
topics=topics
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"生成选题失败: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=f"生成选题失败: {str(e)}")
|
|
|
|
|
|
@router.post("/content", response_model=ContentResponse, summary="生成内容")
|
|
async def generate_content(
|
|
request: ContentRequest,
|
|
tweet_service: TweetService = Depends(get_tweet_service)
|
|
):
|
|
"""
|
|
为选题生成内容
|
|
|
|
- **topic**: 选题信息
|
|
"""
|
|
try:
|
|
request_id, topic_index, content = await tweet_service.generate_content(
|
|
topic=request.topic
|
|
)
|
|
|
|
return ContentResponse(
|
|
request_id=request_id,
|
|
topic_index=topic_index,
|
|
content=content
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"生成内容失败: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=f"生成内容失败: {str(e)}")
|
|
|
|
|
|
@router.post("/content-with-prompt", response_model=ContentResponse, summary="使用预构建提示词生成内容")
|
|
async def generate_content_with_prompt(
|
|
request: ContentWithPromptRequest,
|
|
tweet_service: TweetService = Depends(get_tweet_service)
|
|
):
|
|
"""
|
|
使用预构建的提示词为选题生成内容
|
|
|
|
- **topic**: 选题信息
|
|
- **system_prompt**: 系统提示词
|
|
- **user_prompt**: 用户提示词
|
|
"""
|
|
try:
|
|
request_id, topic_index, content = await tweet_service.generate_content_with_prompt(
|
|
topic=request.topic,
|
|
system_prompt=request.system_prompt,
|
|
user_prompt=request.user_prompt
|
|
)
|
|
|
|
return ContentResponse(
|
|
request_id=request_id,
|
|
topic_index=topic_index,
|
|
content=content
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"使用预构建提示词生成内容失败: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=f"使用预构建提示词生成内容失败: {str(e)}")
|
|
|
|
|
|
@router.post("/judge", response_model=JudgeResponse, summary="审核内容")
|
|
async def judge_content(
|
|
request: JudgeRequest,
|
|
tweet_service: TweetService = Depends(get_tweet_service)
|
|
):
|
|
"""
|
|
审核内容
|
|
|
|
- **topic**: 选题信息
|
|
- **content**: 要审核的内容
|
|
"""
|
|
try:
|
|
request_id, topic_index, judged_content, judge_success = await tweet_service.judge_content(
|
|
topic=request.topic,
|
|
content=request.content
|
|
)
|
|
|
|
return JudgeResponse(
|
|
request_id=request_id,
|
|
topic_index=topic_index,
|
|
judged_content=judged_content,
|
|
judge_success=judge_success
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"审核内容失败: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=f"审核内容失败: {str(e)}")
|
|
|
|
|
|
@router.post("/pipeline", response_model=PipelineResponse, summary="运行完整流水线")
|
|
async def run_pipeline(
|
|
request: PipelineRequest,
|
|
tweet_service: TweetService = Depends(get_tweet_service)
|
|
):
|
|
"""
|
|
运行完整流水线,包括生成选题、生成内容和审核内容
|
|
|
|
- **dates**: 日期字符串,可能为单个日期、多个日期用逗号分隔或范围
|
|
- **num_topics**: 要生成的选题数量
|
|
- **styles**: 风格列表
|
|
- **audiences**: 受众列表
|
|
- **scenic_spots**: 景区列表
|
|
- **products**: 产品列表
|
|
- **skip_judge**: 是否跳过内容审核步骤
|
|
"""
|
|
try:
|
|
request_id, topics, contents, judged_contents = await tweet_service.run_pipeline(
|
|
dates=request.dates,
|
|
num_topics=request.num_topics,
|
|
styles=request.styles,
|
|
audiences=request.audiences,
|
|
scenic_spots=request.scenic_spots,
|
|
products=request.products,
|
|
skip_judge=request.skip_judge
|
|
)
|
|
|
|
return PipelineResponse(
|
|
request_id=request_id,
|
|
topics=topics,
|
|
contents=contents,
|
|
judged_contents=judged_contents
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"运行流水线失败: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=f"运行流水线失败: {str(e)}") |