471 lines
18 KiB
Python
471 lines
18 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
文字内容API路由
|
||
"""
|
||
|
||
import logging
|
||
from typing import List, Dict, Any, Optional, Union
|
||
from fastapi import APIRouter, Depends, HTTPException
|
||
|
||
from api.models.tweet import (
|
||
TopicRequest, TopicResponse,
|
||
ContentRequest, ContentResponse,
|
||
JudgeRequest, JudgeResponse,
|
||
PipelineRequest, PipelineResponse
|
||
)
|
||
from api.services.tweet import TweetService
|
||
from api.services.database_service import DatabaseService
|
||
from api.dependencies import get_tweet_service, get_database_service
|
||
|
||
# 创建一个新的模型用于接收预构建提示词的请求
|
||
from pydantic import BaseModel, Field
|
||
|
||
class ContentWithPromptRequest(BaseModel):
|
||
"""带有预构建提示词的内容生成请求模型"""
|
||
topic: Dict[str, Any] = Field(..., description="选题信息")
|
||
system_prompt: str = Field(..., description="系统提示词")
|
||
user_prompt: str = Field(..., description="用户提示词")
|
||
|
||
class Config:
|
||
schema_extra = {
|
||
"example": {
|
||
"topic": {
|
||
"index": "1",
|
||
"date": "2023-07-15",
|
||
"object": "北京故宫",
|
||
"product": "故宫门票",
|
||
"style": "旅游攻略",
|
||
"targetAudience": "年轻人",
|
||
"logic": "暑期旅游热门景点推荐"
|
||
},
|
||
"system_prompt": "你是一位专业的旅游内容创作者...",
|
||
"user_prompt": "请为以下景点创作一篇旅游文章..."
|
||
}
|
||
}
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
router = APIRouter(
|
||
tags=["tweet"],
|
||
responses={404: {"description": "Not found"}},
|
||
)
|
||
|
||
|
||
def _resolve_ids_to_objects(db_service: DatabaseService,
|
||
styleIds: Optional[List[Union[int, str]]] = None,
|
||
audienceIds: Optional[List[Union[int, str]]] = None,
|
||
scenicSpotIds: Optional[List[Union[int, str]]] = None,
|
||
productIds: Optional[List[Union[int, str]]] = None) -> tuple:
|
||
"""
|
||
将ID列表解析为完整的对象记录列表,并返回ID到名称的映射。
|
||
支持字符串和整数形式的ID,防止大整数精度丢失。
|
||
"""
|
||
styles, audiences, scenic_spots, products = [], [], [], []
|
||
|
||
id_name_mappings = {
|
||
'style_mapping': {}, 'audience_mapping': {},
|
||
'scenic_spot_mapping': {}, 'product_mapping': {}
|
||
}
|
||
|
||
def safe_convert_ids(ids: Optional[List[Union[int, str]]]) -> List[int]:
|
||
"""安全转换ID列表,确保大整数不丢失精度"""
|
||
if not ids:
|
||
return []
|
||
converted_ids = []
|
||
for id_val in ids:
|
||
try:
|
||
# 使用int()转换,它能正确处理大整数字符串
|
||
converted_id = int(id_val)
|
||
converted_ids.append(converted_id)
|
||
except (ValueError, TypeError) as e:
|
||
logger.warning(f"无法转换ID {id_val} 为整数: {e}")
|
||
return converted_ids
|
||
|
||
if not db_service or not db_service.is_available():
|
||
logger.warning("数据库服务不可用,无法解析ID")
|
||
return styles, audiences, scenic_spots, products, id_name_mappings
|
||
|
||
if styleIds:
|
||
converted_style_ids = safe_convert_ids(styleIds)
|
||
styles = db_service.get_styles_by_ids(converted_style_ids)
|
||
id_name_mappings['style_mapping'] = {record['styleName']: record['id'] for record in styles}
|
||
|
||
if audienceIds:
|
||
converted_audience_ids = safe_convert_ids(audienceIds)
|
||
audiences = db_service.get_audiences_by_ids(converted_audience_ids)
|
||
id_name_mappings['audience_mapping'] = {record['audienceName']: record['id'] for record in audiences}
|
||
|
||
if scenicSpotIds:
|
||
converted_scenic_ids = safe_convert_ids(scenicSpotIds)
|
||
scenic_spots = db_service.get_scenic_spots_by_ids(converted_scenic_ids)
|
||
id_name_mappings['scenic_spot_mapping'] = {record['name']: record['id'] for record in scenic_spots}
|
||
|
||
if productIds:
|
||
converted_product_ids = safe_convert_ids(productIds)
|
||
products = db_service.get_products_by_ids(converted_product_ids)
|
||
id_name_mappings['product_mapping'] = {record['productName']: record['id'] for record in products}
|
||
|
||
return styles, audiences, scenic_spots, products, id_name_mappings
|
||
|
||
def _resolve_ids_to_names(db_service: DatabaseService,
|
||
styleIds: Optional[List[Union[int, str]]] = None,
|
||
audienceIds: Optional[List[Union[int, str]]] = None,
|
||
scenicSpotIds: Optional[List[Union[int, str]]] = None,
|
||
productIds: Optional[List[Union[int, str]]] = None) -> tuple:
|
||
"""
|
||
将ID列表转换为名称列表(保持向后兼容)
|
||
|
||
Args:
|
||
db_service: 数据库服务
|
||
styleIds: 风格ID列表
|
||
audienceIds: 受众ID列表
|
||
scenicSpotIds: 景区ID列表
|
||
productIds: 产品ID列表
|
||
|
||
Returns:
|
||
(styles, audiences, scenic_spots, products) 名称列表元组
|
||
"""
|
||
styles, audiences, scenic_spots, products, _ = _resolve_ids_to_objects(
|
||
db_service, styleIds, audienceIds, scenicSpotIds, productIds
|
||
)
|
||
return styles, audiences, scenic_spots, products
|
||
|
||
|
||
def _add_ids_to_topics(topics: List[Dict[str, Any]], id_name_mappings: Dict[str, Dict[str, int]]) -> List[Dict[str, Any]]:
|
||
"""
|
||
为每个topic添加对应的ID字段
|
||
|
||
Args:
|
||
topics: 生成的选题列表
|
||
id_name_mappings: 名称到ID的映射字典
|
||
|
||
Returns:
|
||
包含ID字段的选题列表
|
||
"""
|
||
def find_best_match(target_name: str, mapping: Dict[str, int]) -> Optional[int]:
|
||
"""
|
||
寻找最佳匹配的ID,支持模糊匹配
|
||
"""
|
||
if not target_name or not mapping:
|
||
return None
|
||
|
||
# 1. 精确匹配
|
||
if target_name in mapping:
|
||
return mapping[target_name]
|
||
|
||
# 2. 模糊匹配 - 去除空格后匹配
|
||
target_clean = target_name.replace(" ", "").strip()
|
||
for name, id_val in mapping.items():
|
||
if name.replace(" ", "").strip() == target_clean:
|
||
logger.info(f"模糊匹配成功: '{target_name}' -> '{name}' (ID: {id_val})")
|
||
return id_val
|
||
|
||
# 3. 包含匹配 - 检查是否互相包含
|
||
for name, id_val in mapping.items():
|
||
if target_clean in name.replace(" ", "") or name.replace(" ", "") in target_clean:
|
||
logger.info(f"包含匹配成功: '{target_name}' -> '{name}' (ID: {id_val})")
|
||
return id_val
|
||
|
||
# 4. 未找到匹配
|
||
logger.warning(f"未找到匹配的ID: '{target_name}', 可用选项: {list(mapping.keys())}")
|
||
return None
|
||
|
||
enriched_topics = []
|
||
|
||
for topic in topics:
|
||
# 复制原topic
|
||
enriched_topic = topic.copy()
|
||
|
||
# 初始化ID字段
|
||
enriched_topic['styleIds'] = []
|
||
enriched_topic['audienceIds'] = []
|
||
enriched_topic['scenicSpotIds'] = []
|
||
enriched_topic['productIds'] = []
|
||
|
||
# 记录匹配结果
|
||
match_results = {
|
||
'style_matched': False,
|
||
'audience_matched': False,
|
||
'scenic_spot_matched': False,
|
||
'product_matched': False
|
||
}
|
||
|
||
# 根据topic中的name查找对应的ID
|
||
if 'style' in topic and topic['style']:
|
||
style_id = find_best_match(topic['style'], id_name_mappings['style_mapping'])
|
||
if style_id:
|
||
enriched_topic['styleIds'] = [style_id]
|
||
match_results['style_matched'] = True
|
||
|
||
if 'targetAudience' in topic and topic['targetAudience']:
|
||
audience_id = find_best_match(topic['targetAudience'], id_name_mappings['audience_mapping'])
|
||
if audience_id:
|
||
enriched_topic['audienceIds'] = [audience_id]
|
||
match_results['audience_matched'] = True
|
||
|
||
if 'object' in topic and topic['object']:
|
||
spot_id = find_best_match(topic['object'], id_name_mappings['scenic_spot_mapping'])
|
||
if spot_id:
|
||
enriched_topic['scenicSpotIds'] = [spot_id]
|
||
match_results['scenic_spot_matched'] = True
|
||
|
||
if 'product' in topic and topic['product']:
|
||
product_id = find_best_match(topic['product'], id_name_mappings['product_mapping'])
|
||
if product_id:
|
||
enriched_topic['productIds'] = [product_id]
|
||
match_results['product_matched'] = True
|
||
|
||
# 记录匹配情况
|
||
total_fields = sum(1 for key in ['style', 'targetAudience', 'object', 'product'] if key in topic and topic[key])
|
||
matched_fields = sum(match_results.values())
|
||
|
||
if total_fields > 0:
|
||
match_rate = matched_fields / total_fields * 100
|
||
logger.info(f"选题 {topic.get('index', 'N/A')} ID匹配率: {match_rate:.1f}% ({matched_fields}/{total_fields})")
|
||
|
||
# 如果匹配率低于50%,记录警告
|
||
if match_rate < 50:
|
||
logger.warning(f"选题 {topic.get('index', 'N/A')} ID匹配率较低: {match_rate:.1f}%")
|
||
|
||
# 添加匹配元数据
|
||
# enriched_topic['_id_match_metadata'] = {
|
||
# 'match_results': match_results,
|
||
# 'match_rate': matched_fields / max(total_fields, 1) * 100 if total_fields > 0 else 0
|
||
# }
|
||
|
||
enriched_topics.append(enriched_topic)
|
||
|
||
return enriched_topics
|
||
|
||
|
||
@router.post("/topics", response_model=TopicResponse, summary="生成选题")
|
||
async def generate_topics(
|
||
request: TopicRequest,
|
||
tweet_service: TweetService = Depends(get_tweet_service),
|
||
db_service: DatabaseService = Depends(get_database_service)
|
||
):
|
||
"""
|
||
生成选题
|
||
|
||
- **dates**: 日期字符串,可能为单个日期、多个日期用逗号分隔或范围
|
||
- **numTopics**: 要生成的选题数量
|
||
- **styleIds**: 风格ID列表
|
||
- **audienceIds**: 受众ID列表
|
||
- **scenicSpotIds**: 景区ID列表
|
||
- **productIds**: 产品ID列表
|
||
"""
|
||
try:
|
||
# 将ID解析为完整的对象记录
|
||
styles, audiences, scenic_spots, products, id_name_mappings = _resolve_ids_to_objects(
|
||
db_service,
|
||
request.styleIds,
|
||
request.audienceIds,
|
||
request.scenicSpotIds,
|
||
request.productIds
|
||
)
|
||
|
||
# 从对象中提取名称列表,用于向后兼容
|
||
style_names = [s['styleName'] for s in styles]
|
||
audience_names = [a['audienceName'] for a in audiences]
|
||
scenic_spot_names = [s['name'] for s in scenic_spots]
|
||
product_names = [p['productName'] for p in products]
|
||
|
||
request_id, topics = await tweet_service.generate_topics(
|
||
dates=request.dates,
|
||
numTopics=request.numTopics,
|
||
styles=style_names,
|
||
audiences=audience_names,
|
||
scenic_spots=scenic_spot_names,
|
||
products=product_names,
|
||
# 传递完整的对象以供下游使用
|
||
style_objects=styles,
|
||
audience_objects=audiences,
|
||
scenic_spot_objects=scenic_spots,
|
||
product_objects=products
|
||
)
|
||
|
||
enriched_topics = _add_ids_to_topics(topics, id_name_mappings)
|
||
|
||
return TopicResponse(
|
||
requestId=request_id,
|
||
topics=enriched_topics
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"生成选题失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"生成选题失败: {str(e)}")
|
||
|
||
|
||
@router.post("/content", response_model=ContentResponse, summary="生成内容")
|
||
async def generate_content(
|
||
request: ContentRequest,
|
||
tweet_service: TweetService = Depends(get_tweet_service),
|
||
db_service: DatabaseService = Depends(get_database_service)
|
||
):
|
||
"""
|
||
生成内容
|
||
|
||
- **topic**: 选题信息
|
||
- **styleIds**: 风格ID列表
|
||
- **audienceIds**: 受众ID列表
|
||
- **scenicSpotIds**: 景区ID列表
|
||
- **productIds**: 产品ID列表
|
||
- **autoJudge**: 是否自动进行内容审核
|
||
"""
|
||
try:
|
||
# 将ID解析为完整的对象记录
|
||
styles, audiences, scenic_spots, products, _ = _resolve_ids_to_objects(
|
||
db_service,
|
||
request.styleIds,
|
||
request.audienceIds,
|
||
request.scenicSpotIds,
|
||
request.productIds
|
||
)
|
||
|
||
request_id, topic_index, content = await tweet_service.generate_content(
|
||
topic=request.topic,
|
||
style_objects=styles,
|
||
audience_objects=audiences,
|
||
scenic_spot_objects=scenic_spots,
|
||
product_objects=products,
|
||
autoJudge=request.autoJudge
|
||
)
|
||
|
||
judge_success = content.pop('judgeSuccess', None) if isinstance(content, dict) else None
|
||
|
||
return ContentResponse(
|
||
requestId=request_id,
|
||
topicIndex=topic_index,
|
||
content=content,
|
||
judgeSuccess=judge_success
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"生成内容失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"生成内容失败: {str(e)}")
|
||
|
||
|
||
@router.post("/content-with-prompt", response_model=ContentResponse, summary="使用预构建提示词生成内容")
|
||
async def generate_content_with_prompt(
|
||
request: ContentWithPromptRequest,
|
||
tweet_service: TweetService = Depends(get_tweet_service)
|
||
):
|
||
"""
|
||
使用预构建的提示词为选题生成内容
|
||
|
||
- **topic**: 选题信息
|
||
- **system_prompt**: 系统提示词
|
||
- **user_prompt**: 用户提示词
|
||
"""
|
||
try:
|
||
request_id, topic_index, content = await tweet_service.generate_content_with_prompt(
|
||
topic=request.topic,
|
||
system_prompt=request.system_prompt,
|
||
user_prompt=request.user_prompt
|
||
)
|
||
|
||
return ContentResponse(
|
||
requestId=request_id,
|
||
topicIndex=topic_index,
|
||
content=content,
|
||
judgeSuccess=None
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"使用预构建提示词生成内容失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"使用预构建提示词生成内容失败: {str(e)}")
|
||
|
||
|
||
@router.post("/judge", response_model=JudgeResponse, summary="审核内容")
|
||
async def judge_content(
|
||
request: JudgeRequest,
|
||
tweet_service: TweetService = Depends(get_tweet_service),
|
||
db_service: DatabaseService = Depends(get_database_service)
|
||
):
|
||
"""
|
||
审核内容
|
||
|
||
- **topic**: 选题信息
|
||
- **content**: 要审核的内容
|
||
- **styleIds**: 风格ID列表
|
||
- **audienceIds**: 受众ID列表
|
||
- **scenicSpotIds**: 景区ID列表
|
||
- **productIds**: 产品ID列表
|
||
"""
|
||
try:
|
||
# 将ID解析为完整的对象记录
|
||
styles, audiences, scenic_spots, products, _ = _resolve_ids_to_objects(
|
||
db_service,
|
||
request.styleIds,
|
||
request.audienceIds,
|
||
request.scenicSpotIds,
|
||
request.productIds
|
||
)
|
||
|
||
request_id, topic_index, judged_content, judge_success = await tweet_service.judge_content(
|
||
topic=request.topic,
|
||
content=request.content,
|
||
style_objects=styles,
|
||
audience_objects=audiences,
|
||
scenic_spot_objects=scenic_spots,
|
||
product_objects=products
|
||
)
|
||
|
||
return JudgeResponse(
|
||
requestId=request_id,
|
||
topicIndex=topic_index,
|
||
content=judged_content,
|
||
judgeSuccess=judge_success
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"审核内容失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"审核内容失败: {str(e)}")
|
||
|
||
|
||
@router.post("/pipeline", response_model=PipelineResponse, summary="运行完整流水线")
|
||
async def run_pipeline(
|
||
request: PipelineRequest,
|
||
tweet_service: TweetService = Depends(get_tweet_service),
|
||
db_service: DatabaseService = Depends(get_database_service)
|
||
):
|
||
"""
|
||
运行完整流水线:生成选题 → 生成内容 → 审核内容
|
||
|
||
- **dates**: 日期范围
|
||
- **numTopics**: 要生成的选题数量
|
||
- **styleIds**: 风格ID列表
|
||
- **audienceIds**: 受众ID列表
|
||
- **scenicSpotIds**: 景区ID列表
|
||
- **productIds**: 产品ID列表
|
||
- **skip_judge**: 是否跳过内容审核步骤
|
||
- **autoJudge**: 是否在内容生成时进行内嵌审核
|
||
"""
|
||
try:
|
||
# 将ID转换为名称
|
||
styles, audiences, scenic_spots, products = _resolve_ids_to_names(
|
||
db_service,
|
||
request.styleIds,
|
||
request.audienceIds,
|
||
request.scenicSpotIds,
|
||
request.productIds
|
||
)
|
||
|
||
request_id, topics, contents, judged_contents = await tweet_service.run_pipeline(
|
||
dates=request.dates,
|
||
numTopics=request.numTopics,
|
||
styles=styles,
|
||
audiences=audiences,
|
||
scenic_spots=scenic_spots,
|
||
products=products,
|
||
skipJudge=request.skipJudge,
|
||
autoJudge=request.autoJudge
|
||
)
|
||
|
||
return PipelineResponse(
|
||
requestId=request_id,
|
||
topics=topics,
|
||
contents=contents,
|
||
judgedContents=judged_contents
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"运行流水线失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"运行流水线失败: {str(e)}") |