323 lines
11 KiB
Python
Raw Normal View History

2025-07-10 17:51:37 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
文字内容API路由
"""
import logging
from typing import List, Dict, Any, Optional
2025-07-10 17:51:37 +08:00
from fastapi import APIRouter, Depends, HTTPException
from api.models.tweet import (
TopicRequest, TopicResponse,
ContentRequest, ContentResponse,
JudgeRequest, JudgeResponse,
2025-07-10 17:51:37 +08:00
PipelineRequest, PipelineResponse
)
from api.services.tweet import TweetService
from api.services.database_service import DatabaseService
from api.dependencies import get_tweet_service, get_database_service
2025-07-10 17:51:37 +08:00
# 创建一个新的模型用于接收预构建提示词的请求
from pydantic import BaseModel, Field
2025-07-10 17:51:37 +08:00
class ContentWithPromptRequest(BaseModel):
"""带有预构建提示词的内容生成请求模型"""
topic: Dict[str, Any] = Field(..., description="选题信息")
system_prompt: str = Field(..., description="系统提示词")
user_prompt: str = Field(..., description="用户提示词")
class Config:
schema_extra = {
"example": {
"topic": {
"index": "1",
"date": "2023-07-15",
"object": "北京故宫",
"product": "故宫门票",
"style": "旅游攻略",
"target_audience": "年轻人",
"logic": "暑期旅游热门景点推荐"
},
"system_prompt": "你是一位专业的旅游内容创作者...",
"user_prompt": "请为以下景点创作一篇旅游文章..."
}
}
2025-07-10 17:51:37 +08:00
logger = logging.getLogger(__name__)
2025-07-10 17:51:37 +08:00
router = APIRouter(
tags=["tweet"],
responses={404: {"description": "Not found"}},
)
2025-07-10 17:51:37 +08:00
def _resolve_ids_to_names(db_service: DatabaseService,
2025-07-15 10:59:36 +08:00
styleIds: Optional[List[int]] = None,
audienceIds: Optional[List[int]] = None,
scenicSpotIds: Optional[List[int]] = None,
productIds: Optional[List[int]] = None) -> tuple:
"""
将ID列表转换为名称列表
Args:
db_service: 数据库服务
2025-07-15 10:59:36 +08:00
styleIds: 风格ID列表
audienceIds: 受众ID列表
scenicSpotIds: 景区ID列表
productIds: 产品ID列表
Returns:
(styles, audiences, scenic_spots, products) 名称列表元组
"""
styles = []
audiences = []
scenic_spots = []
products = []
# 如果数据库服务不可用,返回空列表
if not db_service or not db_service.is_available():
logger.warning("数据库服务不可用无法解析ID")
return styles, audiences, scenic_spots, products
# 解析风格ID
2025-07-15 10:59:36 +08:00
if styleIds:
style_records = db_service.get_styles_by_ids(styleIds)
styles = [record['styleName'] for record in style_records]
# 解析受众ID
2025-07-15 10:59:36 +08:00
if audienceIds:
audience_records = db_service.get_audiences_by_ids(audienceIds)
audiences = [record['audienceName'] for record in audience_records]
# 解析景区ID
2025-07-15 10:59:36 +08:00
if scenicSpotIds:
spot_records = db_service.get_scenic_spots_by_ids(scenicSpotIds)
scenic_spots = [record['name'] for record in spot_records]
# 解析产品ID
2025-07-15 10:59:36 +08:00
if productIds:
product_records = db_service.get_products_by_ids(productIds)
products = [record['name'] for record in product_records]
return styles, audiences, scenic_spots, products
2025-07-10 17:51:37 +08:00
@router.post("/topics", response_model=TopicResponse, summary="生成选题")
async def generate_topics(
request: TopicRequest,
tweet_service: TweetService = Depends(get_tweet_service),
db_service: DatabaseService = Depends(get_database_service)
2025-07-10 17:51:37 +08:00
):
"""
生成选题
2025-07-11 17:39:51 +08:00
- **dates**: 日期字符串可能为单个日期多个日期用逗号分隔或范围
2025-07-15 10:59:36 +08:00
- **numTopics**: 要生成的选题数量
- **styleIds**: 风格ID列表
- **audienceIds**: 受众ID列表
- **scenicSpotIds**: 景区ID列表
- **productIds**: 产品ID列表
2025-07-10 17:51:37 +08:00
"""
try:
# 将ID转换为名称
styles, audiences, scenic_spots, products = _resolve_ids_to_names(
db_service,
2025-07-15 10:59:36 +08:00
request.styleIds,
request.audienceIds,
request.scenicSpotIds,
request.productIds
)
2025-07-10 17:51:37 +08:00
request_id, topics = await tweet_service.generate_topics(
2025-07-11 17:39:51 +08:00
dates=request.dates,
2025-07-15 10:59:36 +08:00
numTopics=request.numTopics,
styles=styles,
audiences=audiences,
scenic_spots=scenic_spots,
products=products
2025-07-10 17:51:37 +08:00
)
return TopicResponse(
2025-07-15 10:59:36 +08:00
requestId=request_id,
2025-07-10 17:51:37 +08:00
topics=topics
)
except Exception as e:
logger.error(f"生成选题失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"生成选题失败: {str(e)}")
@router.post("/content", response_model=ContentResponse, summary="生成内容")
async def generate_content(
request: ContentRequest,
tweet_service: TweetService = Depends(get_tweet_service),
db_service: DatabaseService = Depends(get_database_service)
2025-07-10 17:51:37 +08:00
):
"""
生成内容
2025-07-10 17:51:37 +08:00
- **topic**: 选题信息
2025-07-15 10:59:36 +08:00
- **styleIds**: 风格ID列表
- **audienceIds**: 受众ID列表
- **scenicSpotIds**: 景区ID列表
- **productIds**: 产品ID列表
2025-07-15 11:03:24 +08:00
- **autoJudge**: 是否自动进行内容审核
2025-07-10 17:51:37 +08:00
"""
try:
# 将ID转换为名称
styles, audiences, scenic_spots, products = _resolve_ids_to_names(
db_service,
2025-07-15 10:59:36 +08:00
request.styleIds,
request.audienceIds,
request.scenicSpotIds,
request.productIds
)
2025-07-10 17:51:37 +08:00
request_id, topic_index, content = await tweet_service.generate_content(
topic=request.topic,
styles=styles,
audiences=audiences,
scenic_spots=scenic_spots,
products=products,
2025-07-15 11:03:24 +08:00
autoJudge=request.autoJudge
2025-07-10 17:51:37 +08:00
)
2025-07-14 13:41:43 +08:00
# 提取judge_success字段
judge_success = content.pop('judge_success', None) if isinstance(content, dict) else None
2025-07-10 17:51:37 +08:00
return ContentResponse(
2025-07-15 10:59:36 +08:00
requestId=request_id,
topicIndex=topic_index,
2025-07-14 13:41:43 +08:00
content=content,
2025-07-15 10:59:36 +08:00
judgeSuccess=judge_success
2025-07-10 17:51:37 +08:00
)
except Exception as e:
logger.error(f"生成内容失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"生成内容失败: {str(e)}")
@router.post("/content-with-prompt", response_model=ContentResponse, summary="使用预构建提示词生成内容")
async def generate_content_with_prompt(
request: ContentWithPromptRequest,
tweet_service: TweetService = Depends(get_tweet_service)
):
"""
使用预构建的提示词为选题生成内容
- **topic**: 选题信息
- **system_prompt**: 系统提示词
- **user_prompt**: 用户提示词
"""
try:
request_id, topic_index, content = await tweet_service.generate_content_with_prompt(
topic=request.topic,
system_prompt=request.system_prompt,
user_prompt=request.user_prompt
)
return ContentResponse(
2025-07-15 10:59:36 +08:00
requestId=request_id,
topicIndex=topic_index,
2025-07-14 13:41:43 +08:00
content=content,
2025-07-15 10:59:36 +08:00
judgeSuccess=None
)
except Exception as e:
logger.error(f"使用预构建提示词生成内容失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"使用预构建提示词生成内容失败: {str(e)}")
2025-07-10 17:51:37 +08:00
@router.post("/judge", response_model=JudgeResponse, summary="审核内容")
async def judge_content(
request: JudgeRequest,
tweet_service: TweetService = Depends(get_tweet_service),
db_service: DatabaseService = Depends(get_database_service)
2025-07-10 17:51:37 +08:00
):
"""
审核内容
- **topic**: 选题信息
- **content**: 要审核的内容
2025-07-15 10:59:36 +08:00
- **styleIds**: 风格ID列表
- **audienceIds**: 受众ID列表
- **scenicSpotIds**: 景区ID列表
- **productIds**: 产品ID列表
2025-07-10 17:51:37 +08:00
"""
try:
# 将ID转换为名称
styles, audiences, scenic_spots, products = _resolve_ids_to_names(
db_service,
2025-07-15 10:59:36 +08:00
request.styleIds,
request.audienceIds,
request.scenicSpotIds,
request.productIds
)
2025-07-10 17:51:37 +08:00
request_id, topic_index, judged_content, judge_success = await tweet_service.judge_content(
topic=request.topic,
content=request.content,
styles=styles,
audiences=audiences,
scenic_spots=scenic_spots,
products=products
2025-07-10 17:51:37 +08:00
)
return JudgeResponse(
2025-07-15 10:59:36 +08:00
requestId=request_id,
topicIndex=topic_index,
2025-07-14 13:41:43 +08:00
content=judged_content,
2025-07-15 10:59:36 +08:00
judgeSuccess=judge_success
2025-07-10 17:51:37 +08:00
)
except Exception as e:
logger.error(f"审核内容失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"审核内容失败: {str(e)}")
@router.post("/pipeline", response_model=PipelineResponse, summary="运行完整流水线")
async def run_pipeline(
request: PipelineRequest,
tweet_service: TweetService = Depends(get_tweet_service),
db_service: DatabaseService = Depends(get_database_service)
2025-07-10 17:51:37 +08:00
):
"""
运行完整流水线生成选题 生成内容 审核内容
2025-07-10 17:51:37 +08:00
- **dates**: 日期范围
2025-07-15 10:59:36 +08:00
- **numTopics**: 要生成的选题数量
- **styleIds**: 风格ID列表
- **audienceIds**: 受众ID列表
- **scenicSpotIds**: 景区ID列表
- **productIds**: 产品ID列表
2025-07-10 17:51:37 +08:00
- **skip_judge**: 是否跳过内容审核步骤
2025-07-15 11:03:24 +08:00
- **autoJudge**: 是否在内容生成时进行内嵌审核
2025-07-10 17:51:37 +08:00
"""
try:
# 将ID转换为名称
styles, audiences, scenic_spots, products = _resolve_ids_to_names(
db_service,
2025-07-15 10:59:36 +08:00
request.styleIds,
request.audienceIds,
request.scenicSpotIds,
request.productIds
)
2025-07-10 17:51:37 +08:00
request_id, topics, contents, judged_contents = await tweet_service.run_pipeline(
2025-07-11 17:39:51 +08:00
dates=request.dates,
2025-07-15 10:59:36 +08:00
numTopics=request.numTopics,
styles=styles,
audiences=audiences,
scenic_spots=scenic_spots,
products=products,
2025-07-15 10:59:36 +08:00
skip_judge=request.skipJudge,
2025-07-15 11:03:24 +08:00
autoJudge=request.autoJudge
2025-07-10 17:51:37 +08:00
)
return PipelineResponse(
2025-07-15 10:59:36 +08:00
requestId=request_id,
2025-07-10 17:51:37 +08:00
topics=topics,
contents=contents,
2025-07-15 10:59:36 +08:00
judgedContents=judged_contents
2025-07-10 17:51:37 +08:00
)
except Exception as e:
logger.error(f"运行流水线失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"运行流水线失败: {str(e)}")