194 lines
6.1 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
选题 ID 映射管理器
处理选题中的对象 ID 映射和模糊匹配
"""
import logging
import re
from typing import Dict, Any, Optional, List
logger = logging.getLogger(__name__)
class TopicIDMappingManager:
"""
选题 ID 映射管理器
职责:
- 管理名称到 ID 的映射
- 支持模糊匹配
- 从选题文本中提取相关 ID
"""
def __init__(self):
"""初始化映射管理器"""
self.name_to_id: Dict[str, Dict[str, Any]] = {}
self.mapping_data = {
'styles': {},
'audiences': {},
'scenic_spots': {},
'products': {}
}
def add_objects_mapping(
self,
style_objects: Optional[List[Dict[str, Any]]] = None,
audience_objects: Optional[List[Dict[str, Any]]] = None,
scenic_spot_objects: Optional[List[Dict[str, Any]]] = None,
product_objects: Optional[List[Dict[str, Any]]] = None
):
"""
添加对象映射
Args:
style_objects: 风格对象列表
audience_objects: 受众对象列表
scenic_spot_objects: 景区对象列表
product_objects: 产品对象列表
"""
mapping_count = 0
# 处理风格
if style_objects:
for obj in style_objects:
name = obj.get('name', '')
obj_id = str(obj.get('id', ''))
if name and obj_id:
self.mapping_data['styles'][name] = obj_id
self._add_name_variants('styles', name, obj_id)
mapping_count += 1
# 处理受众
if audience_objects:
for obj in audience_objects:
name = obj.get('name', '')
obj_id = str(obj.get('id', ''))
if name and obj_id:
self.mapping_data['audiences'][name] = obj_id
self._add_name_variants('audiences', name, obj_id)
mapping_count += 1
# 处理景区
if scenic_spot_objects:
for obj in scenic_spot_objects:
name = obj.get('name', '')
obj_id = str(obj.get('id', ''))
if name and obj_id:
self.mapping_data['scenic_spots'][name] = obj_id
self._add_name_variants('scenic_spots', name, obj_id)
mapping_count += 1
# 处理产品
if product_objects:
for obj in product_objects:
name = obj.get('name', '')
obj_id = str(obj.get('id', ''))
if name and obj_id:
self.mapping_data['products'][name] = obj_id
self._add_name_variants('products', name, obj_id)
mapping_count += 1
logger.info(f"ID 映射关系建立完成: {mapping_count} 个对象")
def _add_name_variants(self, category: str, name: str, obj_id: str):
"""为名称添加各种变体以支持模糊匹配"""
variants = [name]
# 添加小写变体
variants.append(name.lower())
# 移除空格
variants.append(name.replace(' ', ''))
# 移除常见后缀
for suffix in ['风格', '类型', '人群', '景区', '景点', '产品']:
if name.endswith(suffix):
variants.append(name[:-len(suffix)])
# 添加所有变体
for variant in set(variants):
if variant:
self.name_to_id[variant.lower()] = {
'id': obj_id,
'category': category,
'original_name': name
}
def find_ids_in_topic(self, topic: Dict[str, Any]) -> Dict[str, List[str]]:
"""
在选题中查找相关的对象 ID
Args:
topic: 选题字典
Returns:
按类别分组的 ID 列表
"""
found_ids = {
'styles': [],
'audiences': [],
'scenic_spots': [],
'products': []
}
# 提取选题文本
topic_text = self._extract_topic_text(topic)
topic_text_lower = topic_text.lower()
# 在文本中查找匹配
for name_variant, mapping_info in self.name_to_id.items():
if name_variant in topic_text_lower:
category = mapping_info['category']
obj_id = mapping_info['id']
if obj_id not in found_ids[category]:
found_ids[category].append(obj_id)
logger.debug(f"在选题中找到 {category}: {mapping_info['original_name']} -> {obj_id}")
return found_ids
def _extract_topic_text(self, topic: Dict[str, Any]) -> str:
"""提取选题中的所有文本内容"""
text_parts = []
# 常见的文本字段
text_fields = ['title', 'topic', 'description', 'content', 'keywords', 'tags']
for field in text_fields:
value = topic.get(field)
if isinstance(value, str):
text_parts.append(value)
elif isinstance(value, list):
text_parts.extend([str(v) for v in value if v])
return ' '.join(text_parts)
def get_id_by_name(self, name: str, category: Optional[str] = None) -> Optional[str]:
"""
根据名称获取 ID
Args:
name: 名称
category: 类别(可选)
Returns:
ID 或 None
"""
name_lower = name.lower()
if name_lower in self.name_to_id:
mapping = self.name_to_id[name_lower]
if category is None or mapping['category'] == category:
return mapping['id']
return None
def clear(self):
"""清空所有映射"""
self.name_to_id.clear()
for category in self.mapping_data:
self.mapping_data[category].clear()