323 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
文字内容API路由
"""
import logging
from typing import List, Dict, Any, Optional
from fastapi import APIRouter, Depends, HTTPException
from api.models.tweet import (
TopicRequest, TopicResponse,
ContentRequest, ContentResponse,
JudgeRequest, JudgeResponse,
PipelineRequest, PipelineResponse
)
from api.services.tweet import TweetService
from api.services.database_service import DatabaseService
from api.dependencies import get_tweet_service, get_database_service
# 创建一个新的模型用于接收预构建提示词的请求
from pydantic import BaseModel, Field
class ContentWithPromptRequest(BaseModel):
"""带有预构建提示词的内容生成请求模型"""
topic: Dict[str, Any] = Field(..., description="选题信息")
system_prompt: str = Field(..., description="系统提示词")
user_prompt: str = Field(..., description="用户提示词")
class Config:
schema_extra = {
"example": {
"topic": {
"index": "1",
"date": "2023-07-15",
"object": "北京故宫",
"product": "故宫门票",
"style": "旅游攻略",
"target_audience": "年轻人",
"logic": "暑期旅游热门景点推荐"
},
"system_prompt": "你是一位专业的旅游内容创作者...",
"user_prompt": "请为以下景点创作一篇旅游文章..."
}
}
logger = logging.getLogger(__name__)
router = APIRouter(
tags=["tweet"],
responses={404: {"description": "Not found"}},
)
def _resolve_ids_to_names(db_service: DatabaseService,
styleIds: Optional[List[int]] = None,
audienceIds: Optional[List[int]] = None,
scenicSpotIds: Optional[List[int]] = None,
productIds: Optional[List[int]] = None) -> tuple:
"""
将ID列表转换为名称列表
Args:
db_service: 数据库服务
styleIds: 风格ID列表
audienceIds: 受众ID列表
scenicSpotIds: 景区ID列表
productIds: 产品ID列表
Returns:
(styles, audiences, scenic_spots, products) 名称列表元组
"""
styles = []
audiences = []
scenic_spots = []
products = []
# 如果数据库服务不可用,返回空列表
if not db_service or not db_service.is_available():
logger.warning("数据库服务不可用无法解析ID")
return styles, audiences, scenic_spots, products
# 解析风格ID
if styleIds:
style_records = db_service.get_styles_by_ids(styleIds)
styles = [record['styleName'] for record in style_records]
# 解析受众ID
if audienceIds:
audience_records = db_service.get_audiences_by_ids(audienceIds)
audiences = [record['audienceName'] for record in audience_records]
# 解析景区ID
if scenicSpotIds:
spot_records = db_service.get_scenic_spots_by_ids(scenicSpotIds)
scenic_spots = [record['name'] for record in spot_records]
# 解析产品ID
if productIds:
product_records = db_service.get_products_by_ids(productIds)
products = [record['name'] for record in product_records]
return styles, audiences, scenic_spots, products
@router.post("/topics", response_model=TopicResponse, summary="生成选题")
async def generate_topics(
request: TopicRequest,
tweet_service: TweetService = Depends(get_tweet_service),
db_service: DatabaseService = Depends(get_database_service)
):
"""
生成选题
- **dates**: 日期字符串,可能为单个日期、多个日期用逗号分隔或范围
- **numTopics**: 要生成的选题数量
- **styleIds**: 风格ID列表
- **audienceIds**: 受众ID列表
- **scenicSpotIds**: 景区ID列表
- **productIds**: 产品ID列表
"""
try:
# 将ID转换为名称
styles, audiences, scenic_spots, products = _resolve_ids_to_names(
db_service,
request.styleIds,
request.audienceIds,
request.scenicSpotIds,
request.productIds
)
request_id, topics = await tweet_service.generate_topics(
dates=request.dates,
numTopics=request.numTopics,
styles=styles,
audiences=audiences,
scenic_spots=scenic_spots,
products=products
)
return TopicResponse(
requestId=request_id,
topics=topics
)
except Exception as e:
logger.error(f"生成选题失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"生成选题失败: {str(e)}")
@router.post("/content", response_model=ContentResponse, summary="生成内容")
async def generate_content(
request: ContentRequest,
tweet_service: TweetService = Depends(get_tweet_service),
db_service: DatabaseService = Depends(get_database_service)
):
"""
生成内容
- **topic**: 选题信息
- **styleIds**: 风格ID列表
- **audienceIds**: 受众ID列表
- **scenicSpotIds**: 景区ID列表
- **productIds**: 产品ID列表
- **auto_judge**: 是否自动进行内容审核
"""
try:
# 将ID转换为名称
styles, audiences, scenic_spots, products = _resolve_ids_to_names(
db_service,
request.styleIds,
request.audienceIds,
request.scenicSpotIds,
request.productIds
)
request_id, topic_index, content = await tweet_service.generate_content(
topic=request.topic,
styles=styles,
audiences=audiences,
scenic_spots=scenic_spots,
products=products,
auto_judge=request.autoJudge
)
# 提取judge_success字段
judge_success = content.pop('judge_success', None) if isinstance(content, dict) else None
return ContentResponse(
requestId=request_id,
topicIndex=topic_index,
content=content,
judgeSuccess=judge_success
)
except Exception as e:
logger.error(f"生成内容失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"生成内容失败: {str(e)}")
@router.post("/content-with-prompt", response_model=ContentResponse, summary="使用预构建提示词生成内容")
async def generate_content_with_prompt(
request: ContentWithPromptRequest,
tweet_service: TweetService = Depends(get_tweet_service)
):
"""
使用预构建的提示词为选题生成内容
- **topic**: 选题信息
- **system_prompt**: 系统提示词
- **user_prompt**: 用户提示词
"""
try:
request_id, topic_index, content = await tweet_service.generate_content_with_prompt(
topic=request.topic,
system_prompt=request.system_prompt,
user_prompt=request.user_prompt
)
return ContentResponse(
requestId=request_id,
topicIndex=topic_index,
content=content,
judgeSuccess=None
)
except Exception as e:
logger.error(f"使用预构建提示词生成内容失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"使用预构建提示词生成内容失败: {str(e)}")
@router.post("/judge", response_model=JudgeResponse, summary="审核内容")
async def judge_content(
request: JudgeRequest,
tweet_service: TweetService = Depends(get_tweet_service),
db_service: DatabaseService = Depends(get_database_service)
):
"""
审核内容
- **topic**: 选题信息
- **content**: 要审核的内容
- **styleIds**: 风格ID列表
- **audienceIds**: 受众ID列表
- **scenicSpotIds**: 景区ID列表
- **productIds**: 产品ID列表
"""
try:
# 将ID转换为名称
styles, audiences, scenic_spots, products = _resolve_ids_to_names(
db_service,
request.styleIds,
request.audienceIds,
request.scenicSpotIds,
request.productIds
)
request_id, topic_index, judged_content, judge_success = await tweet_service.judge_content(
topic=request.topic,
content=request.content,
styles=styles,
audiences=audiences,
scenic_spots=scenic_spots,
products=products
)
return JudgeResponse(
requestId=request_id,
topicIndex=topic_index,
content=judged_content,
judgeSuccess=judge_success
)
except Exception as e:
logger.error(f"审核内容失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"审核内容失败: {str(e)}")
@router.post("/pipeline", response_model=PipelineResponse, summary="运行完整流水线")
async def run_pipeline(
request: PipelineRequest,
tweet_service: TweetService = Depends(get_tweet_service),
db_service: DatabaseService = Depends(get_database_service)
):
"""
运行完整流水线:生成选题 → 生成内容 → 审核内容
- **dates**: 日期范围
- **numTopics**: 要生成的选题数量
- **styleIds**: 风格ID列表
- **audienceIds**: 受众ID列表
- **scenicSpotIds**: 景区ID列表
- **productIds**: 产品ID列表
- **skip_judge**: 是否跳过内容审核步骤
- **auto_judge**: 是否在内容生成时进行内嵌审核
"""
try:
# 将ID转换为名称
styles, audiences, scenic_spots, products = _resolve_ids_to_names(
db_service,
request.styleIds,
request.audienceIds,
request.scenicSpotIds,
request.productIds
)
request_id, topics, contents, judged_contents = await tweet_service.run_pipeline(
dates=request.dates,
numTopics=request.numTopics,
styles=styles,
audiences=audiences,
scenic_spots=scenic_spots,
products=products,
skip_judge=request.skipJudge,
auto_judge=request.autoJudge
)
return PipelineResponse(
requestId=request_id,
topics=topics,
contents=contents,
judgedContents=judged_contents
)
except Exception as e:
logger.error(f"运行流水线失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"运行流水线失败: {str(e)}")