394 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
文字内容API路由
"""
import logging
from typing import List, Dict, Any, Optional
from fastapi import APIRouter, Depends, HTTPException
from api.models.tweet import (
TopicRequest, TopicResponse,
ContentRequest, ContentResponse,
JudgeRequest, JudgeResponse,
PipelineRequest, PipelineResponse
)
from api.services.tweet import TweetService
from api.services.database_service import DatabaseService
from api.dependencies import get_tweet_service, get_database_service
# 创建一个新的模型用于接收预构建提示词的请求
from pydantic import BaseModel, Field
class ContentWithPromptRequest(BaseModel):
"""带有预构建提示词的内容生成请求模型"""
topic: Dict[str, Any] = Field(..., description="选题信息")
system_prompt: str = Field(..., description="系统提示词")
user_prompt: str = Field(..., description="用户提示词")
class Config:
schema_extra = {
"example": {
"topic": {
"index": "1",
"date": "2023-07-15",
"object": "北京故宫",
"product": "故宫门票",
"style": "旅游攻略",
"targetAudience": "年轻人",
"logic": "暑期旅游热门景点推荐"
},
"system_prompt": "你是一位专业的旅游内容创作者...",
"user_prompt": "请为以下景点创作一篇旅游文章..."
}
}
logger = logging.getLogger(__name__)
router = APIRouter(
tags=["tweet"],
responses={404: {"description": "Not found"}},
)
def _resolve_ids_to_objects(db_service: DatabaseService,
styleIds: Optional[List[int]] = None,
audienceIds: Optional[List[int]] = None,
scenicSpotIds: Optional[List[int]] = None,
productIds: Optional[List[int]] = None) -> tuple:
"""
将ID列表解析为完整的对象记录列表并返回ID到名称的映射。
"""
styles, audiences, scenic_spots, products = [], [], [], []
id_name_mappings = {
'style_mapping': {}, 'audience_mapping': {},
'scenic_spot_mapping': {}, 'product_mapping': {}
}
if not db_service or not db_service.is_available():
logger.warning("数据库服务不可用无法解析ID")
return styles, audiences, scenic_spots, products, id_name_mappings
if styleIds:
styles = db_service.get_styles_by_ids(styleIds)
id_name_mappings['style_mapping'] = {record['styleName']: record['id'] for record in styles}
if audienceIds:
audiences = db_service.get_audiences_by_ids(audienceIds)
id_name_mappings['audience_mapping'] = {record['audienceName']: record['id'] for record in audiences}
if scenicSpotIds:
scenic_spots = db_service.get_scenic_spots_by_ids(scenicSpotIds)
id_name_mappings['scenic_spot_mapping'] = {record['name']: record['id'] for record in scenic_spots}
if productIds:
products = db_service.get_products_by_ids(productIds)
id_name_mappings['product_mapping'] = {record['productName']: record['id'] for record in products}
return styles, audiences, scenic_spots, products, id_name_mappings
def _resolve_ids_to_names(db_service: DatabaseService,
styleIds: Optional[List[int]] = None,
audienceIds: Optional[List[int]] = None,
scenicSpotIds: Optional[List[int]] = None,
productIds: Optional[List[int]] = None) -> tuple:
"""
将ID列表转换为名称列表保持向后兼容
Args:
db_service: 数据库服务
styleIds: 风格ID列表
audienceIds: 受众ID列表
scenicSpotIds: 景区ID列表
productIds: 产品ID列表
Returns:
(styles, audiences, scenic_spots, products) 名称列表元组
"""
styles, audiences, scenic_spots, products, _ = _resolve_ids_to_objects(
db_service, styleIds, audienceIds, scenicSpotIds, productIds
)
return styles, audiences, scenic_spots, products
def _add_ids_to_topics(topics: List[Dict[str, Any]], id_name_mappings: Dict[str, Dict[str, int]]) -> List[Dict[str, Any]]:
"""
为每个topic添加对应的ID字段
Args:
topics: 生成的选题列表
id_name_mappings: 名称到ID的映射字典
Returns:
包含ID字段的选题列表
"""
enriched_topics = []
for topic in topics:
# 复制原topic
enriched_topic = topic.copy()
# 添加ID字段
enriched_topic['styleIds'] = []
enriched_topic['audienceIds'] = []
enriched_topic['scenicSpotIds'] = []
enriched_topic['productIds'] = []
# 根据topic中的name查找对应的ID
if 'style' in topic and topic['style']:
style_name = topic['style']
if style_name in id_name_mappings['style_mapping']:
enriched_topic['styleIds'] = [id_name_mappings['style_mapping'][style_name]]
if 'targetAudience' in topic and topic['targetAudience']:
audience_name = topic['targetAudience']
if audience_name in id_name_mappings['audience_mapping']:
enriched_topic['audienceIds'] = [id_name_mappings['audience_mapping'][audience_name]]
if 'object' in topic and topic['object']:
spot_name = topic['object']
if spot_name in id_name_mappings['scenic_spot_mapping']:
enriched_topic['scenicSpotIds'] = [id_name_mappings['scenic_spot_mapping'][spot_name]]
if 'product' in topic and topic['product']:
product_name = topic['product']
if product_name in id_name_mappings['product_mapping']:
enriched_topic['productIds'] = [id_name_mappings['product_mapping'][product_name]]
enriched_topics.append(enriched_topic)
return enriched_topics
@router.post("/topics", response_model=TopicResponse, summary="生成选题")
async def generate_topics(
request: TopicRequest,
tweet_service: TweetService = Depends(get_tweet_service),
db_service: DatabaseService = Depends(get_database_service)
):
"""
生成选题
- **dates**: 日期字符串,可能为单个日期、多个日期用逗号分隔或范围
- **numTopics**: 要生成的选题数量
- **styleIds**: 风格ID列表
- **audienceIds**: 受众ID列表
- **scenicSpotIds**: 景区ID列表
- **productIds**: 产品ID列表
"""
try:
# 将ID解析为完整的对象记录
styles, audiences, scenic_spots, products, id_name_mappings = _resolve_ids_to_objects(
db_service,
request.styleIds,
request.audienceIds,
request.scenicSpotIds,
request.productIds
)
# 从对象中提取名称列表,用于向后兼容
style_names = [s['styleName'] for s in styles]
audience_names = [a['audienceName'] for a in audiences]
scenic_spot_names = [s['name'] for s in scenic_spots]
product_names = [p['productName'] for p in products]
request_id, topics = await tweet_service.generate_topics(
dates=request.dates,
numTopics=request.numTopics,
styles=style_names,
audiences=audience_names,
scenic_spots=scenic_spot_names,
products=product_names,
# 传递完整的对象以供下游使用
style_objects=styles,
audience_objects=audiences,
scenic_spot_objects=scenic_spots,
product_objects=products
)
enriched_topics = _add_ids_to_topics(topics, id_name_mappings)
return TopicResponse(
requestId=request_id,
topics=enriched_topics
)
except Exception as e:
logger.error(f"生成选题失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"生成选题失败: {str(e)}")
@router.post("/content", response_model=ContentResponse, summary="生成内容")
async def generate_content(
request: ContentRequest,
tweet_service: TweetService = Depends(get_tweet_service),
db_service: DatabaseService = Depends(get_database_service)
):
"""
生成内容
- **topic**: 选题信息
- **styleIds**: 风格ID列表
- **audienceIds**: 受众ID列表
- **scenicSpotIds**: 景区ID列表
- **productIds**: 产品ID列表
- **autoJudge**: 是否自动进行内容审核
"""
try:
# 将ID解析为完整的对象记录
styles, audiences, scenic_spots, products, _ = _resolve_ids_to_objects(
db_service,
request.styleIds,
request.audienceIds,
request.scenicSpotIds,
request.productIds
)
request_id, topic_index, content = await tweet_service.generate_content(
topic=request.topic,
style_objects=styles,
audience_objects=audiences,
scenic_spot_objects=scenic_spots,
product_objects=products,
autoJudge=request.autoJudge
)
judge_success = content.pop('judgeSuccess', None) if isinstance(content, dict) else None
return ContentResponse(
requestId=request_id,
topicIndex=topic_index,
content=content,
judgeSuccess=judge_success
)
except Exception as e:
logger.error(f"生成内容失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"生成内容失败: {str(e)}")
@router.post("/content-with-prompt", response_model=ContentResponse, summary="使用预构建提示词生成内容")
async def generate_content_with_prompt(
request: ContentWithPromptRequest,
tweet_service: TweetService = Depends(get_tweet_service)
):
"""
使用预构建的提示词为选题生成内容
- **topic**: 选题信息
- **system_prompt**: 系统提示词
- **user_prompt**: 用户提示词
"""
try:
request_id, topic_index, content = await tweet_service.generate_content_with_prompt(
topic=request.topic,
system_prompt=request.system_prompt,
user_prompt=request.user_prompt
)
return ContentResponse(
requestId=request_id,
topicIndex=topic_index,
content=content,
judgeSuccess=None
)
except Exception as e:
logger.error(f"使用预构建提示词生成内容失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"使用预构建提示词生成内容失败: {str(e)}")
@router.post("/judge", response_model=JudgeResponse, summary="审核内容")
async def judge_content(
request: JudgeRequest,
tweet_service: TweetService = Depends(get_tweet_service),
db_service: DatabaseService = Depends(get_database_service)
):
"""
审核内容
- **topic**: 选题信息
- **content**: 要审核的内容
- **styleIds**: 风格ID列表
- **audienceIds**: 受众ID列表
- **scenicSpotIds**: 景区ID列表
- **productIds**: 产品ID列表
"""
try:
# 将ID解析为完整的对象记录
styles, audiences, scenic_spots, products, _ = _resolve_ids_to_objects(
db_service,
request.styleIds,
request.audienceIds,
request.scenicSpotIds,
request.productIds
)
request_id, topic_index, judged_content, judge_success = await tweet_service.judge_content(
topic=request.topic,
content=request.content,
style_objects=styles,
audience_objects=audiences,
scenic_spot_objects=scenic_spots,
product_objects=products
)
return JudgeResponse(
requestId=request_id,
topicIndex=topic_index,
content=judged_content,
judgeSuccess=judge_success
)
except Exception as e:
logger.error(f"审核内容失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"审核内容失败: {str(e)}")
@router.post("/pipeline", response_model=PipelineResponse, summary="运行完整流水线")
async def run_pipeline(
request: PipelineRequest,
tweet_service: TweetService = Depends(get_tweet_service),
db_service: DatabaseService = Depends(get_database_service)
):
"""
运行完整流水线:生成选题 → 生成内容 → 审核内容
- **dates**: 日期范围
- **numTopics**: 要生成的选题数量
- **styleIds**: 风格ID列表
- **audienceIds**: 受众ID列表
- **scenicSpotIds**: 景区ID列表
- **productIds**: 产品ID列表
- **skip_judge**: 是否跳过内容审核步骤
- **autoJudge**: 是否在内容生成时进行内嵌审核
"""
try:
# 将ID转换为名称
styles, audiences, scenic_spots, products = _resolve_ids_to_names(
db_service,
request.styleIds,
request.audienceIds,
request.scenicSpotIds,
request.productIds
)
request_id, topics, contents, judged_contents = await tweet_service.run_pipeline(
dates=request.dates,
numTopics=request.numTopics,
styles=styles,
audiences=audiences,
scenic_spots=scenic_spots,
products=products,
skipJudge=request.skipJudge,
autoJudge=request.autoJudge
)
return PipelineResponse(
requestId=request_id,
topics=topics,
contents=contents,
judgedContents=judged_contents
)
except Exception as e:
logger.error(f"运行流水线失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"运行流水线失败: {str(e)}")