#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 文字内容API路由 """ import logging from typing import List, Dict, Any, Optional, Union from fastapi import APIRouter, Depends, HTTPException from api.models.tweet import ( TopicRequest, TopicResponse, ContentRequest, ContentResponse, JudgeRequest, JudgeResponse, PipelineRequest, PipelineResponse ) from api.services.tweet import TweetService from api.services.database_service import DatabaseService from api.dependencies import get_tweet_service, get_database_service # 创建一个新的模型用于接收预构建提示词的请求 from pydantic import BaseModel, Field class ContentWithPromptRequest(BaseModel): """带有预构建提示词的内容生成请求模型""" topic: Dict[str, Any] = Field(..., description="选题信息") system_prompt: str = Field(..., description="系统提示词") user_prompt: str = Field(..., description="用户提示词") class Config: schema_extra = { "example": { "topic": { "index": "1", "date": "2023-07-15", "object": "北京故宫", "product": "故宫门票", "style": "旅游攻略", "targetAudience": "年轻人", "logic": "暑期旅游热门景点推荐" }, "system_prompt": "你是一位专业的旅游内容创作者...", "user_prompt": "请为以下景点创作一篇旅游文章..." } } logger = logging.getLogger(__name__) router = APIRouter( tags=["tweet"], responses={404: {"description": "Not found"}}, ) def _resolve_ids_to_objects(db_service: DatabaseService, styleIds: Optional[List[Union[int, str]]] = None, audienceIds: Optional[List[Union[int, str]]] = None, scenicSpotIds: Optional[List[Union[int, str]]] = None, productIds: Optional[List[Union[int, str]]] = None) -> tuple: """ 将ID列表解析为完整的对象记录列表,并返回ID到名称的映射。 支持字符串和整数形式的ID,防止大整数精度丢失。 """ styles, audiences, scenic_spots, products = [], [], [], [] id_name_mappings = { 'style_mapping': {}, 'audience_mapping': {}, 'scenic_spot_mapping': {}, 'product_mapping': {} } def safe_convert_ids(ids: Optional[List[Union[int, str]]]) -> List[int]: """安全转换ID列表,确保大整数不丢失精度""" if not ids: return [] converted_ids = [] for id_val in ids: try: # 使用int()转换,它能正确处理大整数字符串 converted_id = int(id_val) converted_ids.append(converted_id) except (ValueError, TypeError) as e: logger.warning(f"无法转换ID {id_val} 为整数: {e}") return converted_ids if not db_service or not db_service.is_available(): logger.warning("数据库服务不可用,无法解析ID") return styles, audiences, scenic_spots, products, id_name_mappings if styleIds: converted_style_ids = safe_convert_ids(styleIds) styles = db_service.get_styles_by_ids(converted_style_ids) id_name_mappings['style_mapping'] = {record['styleName']: record['id'] for record in styles} if audienceIds: converted_audience_ids = safe_convert_ids(audienceIds) audiences = db_service.get_audiences_by_ids(converted_audience_ids) id_name_mappings['audience_mapping'] = {record['audienceName']: record['id'] for record in audiences} if scenicSpotIds: converted_scenic_ids = safe_convert_ids(scenicSpotIds) scenic_spots = db_service.get_scenic_spots_by_ids(converted_scenic_ids) id_name_mappings['scenic_spot_mapping'] = {record['name']: record['id'] for record in scenic_spots} if productIds: converted_product_ids = safe_convert_ids(productIds) products = db_service.get_products_by_ids(converted_product_ids) id_name_mappings['product_mapping'] = {record['productName']: record['id'] for record in products} return styles, audiences, scenic_spots, products, id_name_mappings def _resolve_ids_to_names(db_service: DatabaseService, styleIds: Optional[List[Union[int, str]]] = None, audienceIds: Optional[List[Union[int, str]]] = None, scenicSpotIds: Optional[List[Union[int, str]]] = None, productIds: Optional[List[Union[int, str]]] = None) -> tuple: """ 将ID列表转换为名称列表(保持向后兼容) Args: db_service: 数据库服务 styleIds: 风格ID列表 audienceIds: 受众ID列表 scenicSpotIds: 景区ID列表 productIds: 产品ID列表 Returns: (styles, audiences, scenic_spots, products) 名称列表元组 """ styles, audiences, scenic_spots, products, _ = _resolve_ids_to_objects( db_service, styleIds, audienceIds, scenicSpotIds, productIds ) return styles, audiences, scenic_spots, products def _add_ids_to_topics(topics: List[Dict[str, Any]], id_name_mappings: Dict[str, Dict[str, int]]) -> List[Dict[str, Any]]: """ 为每个topic添加对应的ID字段 Args: topics: 生成的选题列表 id_name_mappings: 名称到ID的映射字典 Returns: 包含ID字段的选题列表 """ def find_best_match(target_name: str, mapping: Dict[str, int]) -> Optional[int]: """ 寻找最佳匹配的ID,支持模糊匹配 """ if not target_name or not mapping: return None # 1. 精确匹配 if target_name in mapping: return mapping[target_name] # 2. 模糊匹配 - 去除空格后匹配 target_clean = target_name.replace(" ", "").strip() for name, id_val in mapping.items(): if name.replace(" ", "").strip() == target_clean: logger.info(f"模糊匹配成功: '{target_name}' -> '{name}' (ID: {id_val})") return id_val # 3. 包含匹配 - 检查是否互相包含 for name, id_val in mapping.items(): if target_clean in name.replace(" ", "") or name.replace(" ", "") in target_clean: logger.info(f"包含匹配成功: '{target_name}' -> '{name}' (ID: {id_val})") return id_val # 4. 未找到匹配 logger.warning(f"未找到匹配的ID: '{target_name}', 可用选项: {list(mapping.keys())}") return None enriched_topics = [] for topic in topics: # 复制原topic enriched_topic = topic.copy() # 初始化ID字段 enriched_topic['styleIds'] = [] enriched_topic['audienceIds'] = [] enriched_topic['scenicSpotIds'] = [] enriched_topic['productIds'] = [] # 记录匹配结果 match_results = { 'style_matched': False, 'audience_matched': False, 'scenic_spot_matched': False, 'product_matched': False } # 根据topic中的name查找对应的ID if 'style' in topic and topic['style']: style_id = find_best_match(topic['style'], id_name_mappings['style_mapping']) if style_id: enriched_topic['styleIds'] = [style_id] match_results['style_matched'] = True if 'targetAudience' in topic and topic['targetAudience']: audience_id = find_best_match(topic['targetAudience'], id_name_mappings['audience_mapping']) if audience_id: enriched_topic['audienceIds'] = [audience_id] match_results['audience_matched'] = True if 'object' in topic and topic['object']: spot_id = find_best_match(topic['object'], id_name_mappings['scenic_spot_mapping']) if spot_id: enriched_topic['scenicSpotIds'] = [spot_id] match_results['scenic_spot_matched'] = True if 'product' in topic and topic['product']: product_id = find_best_match(topic['product'], id_name_mappings['product_mapping']) if product_id: enriched_topic['productIds'] = [product_id] match_results['product_matched'] = True # 记录匹配情况 total_fields = sum(1 for key in ['style', 'targetAudience', 'object', 'product'] if key in topic and topic[key]) matched_fields = sum(match_results.values()) if total_fields > 0: match_rate = matched_fields / total_fields * 100 logger.info(f"选题 {topic.get('index', 'N/A')} ID匹配率: {match_rate:.1f}% ({matched_fields}/{total_fields})") # 如果匹配率低于50%,记录警告 if match_rate < 50: logger.warning(f"选题 {topic.get('index', 'N/A')} ID匹配率较低: {match_rate:.1f}%") # 添加匹配元数据 # enriched_topic['_id_match_metadata'] = { # 'match_results': match_results, # 'match_rate': matched_fields / max(total_fields, 1) * 100 if total_fields > 0 else 0 # } enriched_topics.append(enriched_topic) return enriched_topics @router.post("/topics", response_model=TopicResponse, summary="生成选题") async def generate_topics( request: TopicRequest, tweet_service: TweetService = Depends(get_tweet_service), db_service: DatabaseService = Depends(get_database_service) ): """ 生成选题 - **dates**: 日期字符串,可能为单个日期、多个日期用逗号分隔或范围 - **numTopics**: 要生成的选题数量 - **styleIds**: 风格ID列表 - **audienceIds**: 受众ID列表 - **scenicSpotIds**: 景区ID列表 - **productIds**: 产品ID列表 """ try: # 将ID解析为完整的对象记录 styles, audiences, scenic_spots, products, id_name_mappings = _resolve_ids_to_objects( db_service, request.styleIds, request.audienceIds, request.scenicSpotIds, request.productIds ) # 从对象中提取名称列表,用于向后兼容 style_names = [s['styleName'] for s in styles] audience_names = [a['audienceName'] for a in audiences] scenic_spot_names = [s['name'] for s in scenic_spots] product_names = [p['productName'] for p in products] request_id, topics = await tweet_service.generate_topics( dates=request.dates, numTopics=request.numTopics, styles=style_names, audiences=audience_names, scenic_spots=scenic_spot_names, products=product_names, # 传递完整的对象以供下游使用 style_objects=styles, audience_objects=audiences, scenic_spot_objects=scenic_spots, product_objects=products ) enriched_topics = _add_ids_to_topics(topics, id_name_mappings) return TopicResponse( requestId=request_id, topics=enriched_topics ) except Exception as e: logger.error(f"生成选题失败: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"生成选题失败: {str(e)}") @router.post("/content", response_model=ContentResponse, summary="生成内容") async def generate_content( request: ContentRequest, tweet_service: TweetService = Depends(get_tweet_service), db_service: DatabaseService = Depends(get_database_service) ): """ 生成内容 - **topic**: 选题信息 - **styleIds**: 风格ID列表 - **audienceIds**: 受众ID列表 - **scenicSpotIds**: 景区ID列表 - **productIds**: 产品ID列表 - **autoJudge**: 是否自动进行内容审核 """ try: # 将ID解析为完整的对象记录 styles, audiences, scenic_spots, products, _ = _resolve_ids_to_objects( db_service, request.styleIds, request.audienceIds, request.scenicSpotIds, request.productIds ) request_id, topic_index, content = await tweet_service.generate_content( topic=request.topic, style_objects=styles, audience_objects=audiences, scenic_spot_objects=scenic_spots, product_objects=products, autoJudge=request.autoJudge ) judge_success = content.pop('judgeSuccess', None) if isinstance(content, dict) else None return ContentResponse( requestId=request_id, topicIndex=topic_index, content=content, judgeSuccess=judge_success ) except Exception as e: logger.error(f"生成内容失败: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"生成内容失败: {str(e)}") @router.post("/content-with-prompt", response_model=ContentResponse, summary="使用预构建提示词生成内容") async def generate_content_with_prompt( request: ContentWithPromptRequest, tweet_service: TweetService = Depends(get_tweet_service) ): """ 使用预构建的提示词为选题生成内容 - **topic**: 选题信息 - **system_prompt**: 系统提示词 - **user_prompt**: 用户提示词 """ try: request_id, topic_index, content = await tweet_service.generate_content_with_prompt( topic=request.topic, system_prompt=request.system_prompt, user_prompt=request.user_prompt ) return ContentResponse( requestId=request_id, topicIndex=topic_index, content=content, judgeSuccess=None ) except Exception as e: logger.error(f"使用预构建提示词生成内容失败: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"使用预构建提示词生成内容失败: {str(e)}") @router.post("/judge", response_model=JudgeResponse, summary="审核内容") async def judge_content( request: JudgeRequest, tweet_service: TweetService = Depends(get_tweet_service), db_service: DatabaseService = Depends(get_database_service) ): """ 审核内容 - **topic**: 选题信息 - **content**: 要审核的内容 - **styleIds**: 风格ID列表 - **audienceIds**: 受众ID列表 - **scenicSpotIds**: 景区ID列表 - **productIds**: 产品ID列表 """ try: # 将ID解析为完整的对象记录 styles, audiences, scenic_spots, products, _ = _resolve_ids_to_objects( db_service, request.styleIds, request.audienceIds, request.scenicSpotIds, request.productIds ) request_id, topic_index, judged_content, judge_success = await tweet_service.judge_content( topic=request.topic, content=request.content, style_objects=styles, audience_objects=audiences, scenic_spot_objects=scenic_spots, product_objects=products ) return JudgeResponse( requestId=request_id, topicIndex=topic_index, content=judged_content, judgeSuccess=judge_success ) except Exception as e: logger.error(f"审核内容失败: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"审核内容失败: {str(e)}") @router.post("/pipeline", response_model=PipelineResponse, summary="运行完整流水线") async def run_pipeline( request: PipelineRequest, tweet_service: TweetService = Depends(get_tweet_service), db_service: DatabaseService = Depends(get_database_service) ): """ 运行完整流水线:生成选题 → 生成内容 → 审核内容 - **dates**: 日期范围 - **numTopics**: 要生成的选题数量 - **styleIds**: 风格ID列表 - **audienceIds**: 受众ID列表 - **scenicSpotIds**: 景区ID列表 - **productIds**: 产品ID列表 - **skip_judge**: 是否跳过内容审核步骤 - **autoJudge**: 是否在内容生成时进行内嵌审核 """ try: # 将ID转换为名称 styles, audiences, scenic_spots, products = _resolve_ids_to_names( db_service, request.styleIds, request.audienceIds, request.scenicSpotIds, request.productIds ) request_id, topics, contents, judged_contents = await tweet_service.run_pipeline( dates=request.dates, numTopics=request.numTopics, styles=styles, audiences=audiences, scenic_spots=scenic_spots, products=products, skipJudge=request.skipJudge, autoJudge=request.autoJudge ) return PipelineResponse( requestId=request_id, topics=topics, contents=contents, judgedContents=judged_contents ) except Exception as e: logger.error(f"运行流水线失败: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"运行流水线失败: {str(e)}")