diff --git a/api/models/__pycache__/tweet.cpython-312.pyc b/api/models/__pycache__/tweet.cpython-312.pyc index ec8ebc4..45e950e 100644 Binary files a/api/models/__pycache__/tweet.cpython-312.pyc and b/api/models/__pycache__/tweet.cpython-312.pyc differ diff --git a/api/models/tweet.py b/api/models/tweet.py index 92c02f5..fbda200 100644 --- a/api/models/tweet.py +++ b/api/models/tweet.py @@ -51,7 +51,11 @@ class TopicResponse(BaseModel): "styleLogic": "重点解析如何避开高温时段并高效游玩各园区", "targetAudience": "年轻人", "targetAudienceLogic": "解决家长担心孩子中暑的问题,提供科学游玩方案", - "logic": "暑期旅游热门景点推荐" + "logic": "暑期旅游热门景点推荐", + "styleIds": [1], + "audienceIds": [2], + "scenicSpotIds": [1], + "productIds": [1] } ] } @@ -106,7 +110,7 @@ class ContentResponse(BaseModel): "content": { "title": "天津冒险湾亲子游攻略", "content": "详细的游玩攻略内容...", - "tags": ["亲子游", "水上乐园", "天津"] + "tags": ["#亲子游", "#水上乐园", "#天津"] }, "judgeSuccess": True } diff --git a/api/routers/__pycache__/tweet.cpython-312.pyc b/api/routers/__pycache__/tweet.cpython-312.pyc index 1c8158b..e3c7635 100644 Binary files a/api/routers/__pycache__/tweet.cpython-312.pyc and b/api/routers/__pycache__/tweet.cpython-312.pyc differ diff --git a/api/routers/tweet.py b/api/routers/tweet.py index 919c4f3..40d6fee 100644 --- a/api/routers/tweet.py +++ b/api/routers/tweet.py @@ -53,13 +53,83 @@ router = APIRouter( ) +def _resolve_ids_to_names_with_mapping(db_service: DatabaseService, + styleIds: Optional[List[int]] = None, + audienceIds: Optional[List[int]] = None, + scenicSpotIds: Optional[List[int]] = None, + productIds: Optional[List[int]] = None) -> tuple: + """ + 将ID列表转换为名称列表,并返回ID到名称的映射关系 + + Args: + db_service: 数据库服务 + styleIds: 风格ID列表 + audienceIds: 受众ID列表 + scenicSpotIds: 景区ID列表 + productIds: 产品ID列表 + + Returns: + (styles, audiences, scenic_spots, products, id_name_mappings) 名称列表和映射关系元组 + """ + styles = [] + audiences = [] + scenic_spots = [] + products = [] + + # 建立ID到名称的映射字典 + id_name_mappings = { + 'style_mapping': {}, # {name: id} + 'audience_mapping': {}, # {name: id} + 'scenic_spot_mapping': {}, # {name: id} + 'product_mapping': {} # {name: id} + } + + # 如果数据库服务不可用,返回空列表 + if not db_service or not db_service.is_available(): + logger.warning("数据库服务不可用,无法解析ID") + return styles, audiences, scenic_spots, products, id_name_mappings + + # 解析风格ID + if styleIds: + style_records = db_service.get_styles_by_ids(styleIds) + for record in style_records: + style_name = record['styleName'] + styles.append(style_name) + id_name_mappings['style_mapping'][style_name] = record['id'] + + # 解析受众ID + if audienceIds: + audience_records = db_service.get_audiences_by_ids(audienceIds) + for record in audience_records: + audience_name = record['audienceName'] + audiences.append(audience_name) + id_name_mappings['audience_mapping'][audience_name] = record['id'] + + # 解析景区ID + if scenicSpotIds: + spot_records = db_service.get_scenic_spots_by_ids(scenicSpotIds) + for record in spot_records: + spot_name = record['name'] + scenic_spots.append(spot_name) + id_name_mappings['scenic_spot_mapping'][spot_name] = record['id'] + + # 解析产品ID + if productIds: + product_records = db_service.get_products_by_ids(productIds) + for record in product_records: + product_name = record['name'] + products.append(product_name) + id_name_mappings['product_mapping'][product_name] = record['id'] + + return styles, audiences, scenic_spots, products, id_name_mappings + def _resolve_ids_to_names(db_service: DatabaseService, styleIds: Optional[List[int]] = None, audienceIds: Optional[List[int]] = None, scenicSpotIds: Optional[List[int]] = None, productIds: Optional[List[int]] = None) -> tuple: """ - 将ID列表转换为名称列表 + 将ID列表转换为名称列表(保持向后兼容) Args: db_service: 数据库服务 @@ -71,39 +141,61 @@ def _resolve_ids_to_names(db_service: DatabaseService, Returns: (styles, audiences, scenic_spots, products) 名称列表元组 """ - styles = [] - audiences = [] - scenic_spots = [] - products = [] - - # 如果数据库服务不可用,返回空列表 - if not db_service or not db_service.is_available(): - logger.warning("数据库服务不可用,无法解析ID") - return styles, audiences, scenic_spots, products - - # 解析风格ID - if styleIds: - style_records = db_service.get_styles_by_ids(styleIds) - styles = [record['styleName'] for record in style_records] - - # 解析受众ID - if audienceIds: - audience_records = db_service.get_audiences_by_ids(audienceIds) - audiences = [record['audienceName'] for record in audience_records] - - # 解析景区ID - if scenicSpotIds: - spot_records = db_service.get_scenic_spots_by_ids(scenicSpotIds) - scenic_spots = [record['name'] for record in spot_records] - - # 解析产品ID - if productIds: - product_records = db_service.get_products_by_ids(productIds) - products = [record['name'] for record in product_records] - + styles, audiences, scenic_spots, products, _ = _resolve_ids_to_names_with_mapping( + db_service, styleIds, audienceIds, scenicSpotIds, productIds + ) return styles, audiences, scenic_spots, products +def _add_ids_to_topics(topics: List[Dict[str, Any]], id_name_mappings: Dict[str, Dict[str, int]]) -> List[Dict[str, Any]]: + """ + 为每个topic添加对应的ID字段 + + Args: + topics: 生成的选题列表 + id_name_mappings: 名称到ID的映射字典 + + Returns: + 包含ID字段的选题列表 + """ + enriched_topics = [] + + for topic in topics: + # 复制原topic + enriched_topic = topic.copy() + + # 添加ID字段 + enriched_topic['styleIds'] = [] + enriched_topic['audienceIds'] = [] + enriched_topic['scenicSpotIds'] = [] + enriched_topic['productIds'] = [] + + # 根据topic中的name查找对应的ID + if 'style' in topic and topic['style']: + style_name = topic['style'] + if style_name in id_name_mappings['style_mapping']: + enriched_topic['styleIds'] = [id_name_mappings['style_mapping'][style_name]] + + if 'targetAudience' in topic and topic['targetAudience']: + audience_name = topic['targetAudience'] + if audience_name in id_name_mappings['audience_mapping']: + enriched_topic['audienceIds'] = [id_name_mappings['audience_mapping'][audience_name]] + + if 'object' in topic and topic['object']: + spot_name = topic['object'] + if spot_name in id_name_mappings['scenic_spot_mapping']: + enriched_topic['scenicSpotIds'] = [id_name_mappings['scenic_spot_mapping'][spot_name]] + + if 'product' in topic and topic['product']: + product_name = topic['product'] + if product_name in id_name_mappings['product_mapping']: + enriched_topic['productIds'] = [id_name_mappings['product_mapping'][product_name]] + + enriched_topics.append(enriched_topic) + + return enriched_topics + + @router.post("/topics", response_model=TopicResponse, summary="生成选题") async def generate_topics( request: TopicRequest, @@ -121,8 +213,8 @@ async def generate_topics( - **productIds**: 产品ID列表 """ try: - # 将ID转换为名称 - styles, audiences, scenic_spots, products = _resolve_ids_to_names( + # 将ID转换为名称,并获取映射关系 + styles, audiences, scenic_spots, products, id_name_mappings = _resolve_ids_to_names_with_mapping( db_service, request.styleIds, request.audienceIds, @@ -139,9 +231,12 @@ async def generate_topics( products=products ) + # 为topics添加ID字段 + enriched_topics = _add_ids_to_topics(topics, id_name_mappings) + return TopicResponse( requestId=request_id, - topics=topics + topics=enriched_topics ) except Exception as e: logger.error(f"生成选题失败: {e}", exc_info=True) diff --git a/api/services/__pycache__/database_service.cpython-312.pyc b/api/services/__pycache__/database_service.cpython-312.pyc index 207d00b..7f6e3a4 100644 Binary files a/api/services/__pycache__/database_service.cpython-312.pyc and b/api/services/__pycache__/database_service.cpython-312.pyc differ diff --git a/api/services/database_service.py b/api/services/database_service.py index 16039d2..c13301d 100644 --- a/api/services/database_service.py +++ b/api/services/database_service.py @@ -505,4 +505,146 @@ class DatabaseService: Returns: 数据库是否可用 """ - return self.db_pool is not None \ No newline at end of file + return self.db_pool is not None + + # 名称到ID的反向查询方法 + + def get_style_id_by_name(self, style_name: str) -> Optional[int]: + """ + 根据风格名称获取风格ID + + Args: + style_name: 风格名称 + + Returns: + 风格ID,如果未找到则返回None + """ + if not self.db_pool: + logger.error("数据库连接池未初始化") + return None + + try: + conn = self.db_pool.get_connection() + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT id FROM style WHERE styleName = %s AND isDelete = 0", + (style_name,) + ) + result = cursor.fetchone() + cursor.close() + conn.close() + + if result: + return result['id'] + else: + logger.warning(f"未找到风格: {style_name}") + return None + + except Exception as e: + logger.error(f"查询风格ID失败: {e}") + return None + + def get_audience_id_by_name(self, audience_name: str) -> Optional[int]: + """ + 根据受众名称获取受众ID + + Args: + audience_name: 受众名称 + + Returns: + 受众ID,如果未找到则返回None + """ + if not self.db_pool: + logger.error("数据库连接池未初始化") + return None + + try: + conn = self.db_pool.get_connection() + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT id FROM targetAudience WHERE audienceName = %s AND isDelete = 0", + (audience_name,) + ) + result = cursor.fetchone() + cursor.close() + conn.close() + + if result: + return result['id'] + else: + logger.warning(f"未找到受众: {audience_name}") + return None + + except Exception as e: + logger.error(f"查询受众ID失败: {e}") + return None + + def get_scenic_spot_id_by_name(self, spot_name: str) -> Optional[int]: + """ + 根据景区名称获取景区ID + + Args: + spot_name: 景区名称 + + Returns: + 景区ID,如果未找到则返回None + """ + if not self.db_pool: + logger.error("数据库连接池未初始化") + return None + + try: + conn = self.db_pool.get_connection() + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT id FROM scenicSpot WHERE name = %s AND isDelete = 0", + (spot_name,) + ) + result = cursor.fetchone() + cursor.close() + conn.close() + + if result: + return result['id'] + else: + logger.warning(f"未找到景区: {spot_name}") + return None + + except Exception as e: + logger.error(f"查询景区ID失败: {e}") + return None + + def get_product_id_by_name(self, product_name: str) -> Optional[int]: + """ + 根据产品名称获取产品ID + + Args: + product_name: 产品名称 + + Returns: + 产品ID,如果未找到则返回None + """ + if not self.db_pool: + logger.error("数据库连接池未初始化") + return None + + try: + conn = self.db_pool.get_connection() + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT id FROM product WHERE productName = %s AND isDelete = 0", + (product_name,) + ) + result = cursor.fetchone() + cursor.close() + conn.close() + + if result: + return result['id'] + else: + logger.warning(f"未找到产品: {product_name}") + return None + + except Exception as e: + logger.error(f"查询产品ID失败: {e}") + return None \ No newline at end of file