194 lines
6.1 KiB
Python
194 lines
6.1 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
选题 ID 映射管理器
|
|
处理选题中的对象 ID 映射和模糊匹配
|
|
"""
|
|
|
|
import logging
|
|
import re
|
|
from typing import Dict, Any, Optional, List
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TopicIDMappingManager:
|
|
"""
|
|
选题 ID 映射管理器
|
|
|
|
职责:
|
|
- 管理名称到 ID 的映射
|
|
- 支持模糊匹配
|
|
- 从选题文本中提取相关 ID
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""初始化映射管理器"""
|
|
self.name_to_id: Dict[str, Dict[str, Any]] = {}
|
|
self.mapping_data = {
|
|
'styles': {},
|
|
'audiences': {},
|
|
'scenic_spots': {},
|
|
'products': {}
|
|
}
|
|
|
|
def add_objects_mapping(
|
|
self,
|
|
style_objects: Optional[List[Dict[str, Any]]] = None,
|
|
audience_objects: Optional[List[Dict[str, Any]]] = None,
|
|
scenic_spot_objects: Optional[List[Dict[str, Any]]] = None,
|
|
product_objects: Optional[List[Dict[str, Any]]] = None
|
|
):
|
|
"""
|
|
添加对象映射
|
|
|
|
Args:
|
|
style_objects: 风格对象列表
|
|
audience_objects: 受众对象列表
|
|
scenic_spot_objects: 景区对象列表
|
|
product_objects: 产品对象列表
|
|
"""
|
|
mapping_count = 0
|
|
|
|
# 处理风格
|
|
if style_objects:
|
|
for obj in style_objects:
|
|
name = obj.get('name', '')
|
|
obj_id = str(obj.get('id', ''))
|
|
if name and obj_id:
|
|
self.mapping_data['styles'][name] = obj_id
|
|
self._add_name_variants('styles', name, obj_id)
|
|
mapping_count += 1
|
|
|
|
# 处理受众
|
|
if audience_objects:
|
|
for obj in audience_objects:
|
|
name = obj.get('name', '')
|
|
obj_id = str(obj.get('id', ''))
|
|
if name and obj_id:
|
|
self.mapping_data['audiences'][name] = obj_id
|
|
self._add_name_variants('audiences', name, obj_id)
|
|
mapping_count += 1
|
|
|
|
# 处理景区
|
|
if scenic_spot_objects:
|
|
for obj in scenic_spot_objects:
|
|
name = obj.get('name', '')
|
|
obj_id = str(obj.get('id', ''))
|
|
if name and obj_id:
|
|
self.mapping_data['scenic_spots'][name] = obj_id
|
|
self._add_name_variants('scenic_spots', name, obj_id)
|
|
mapping_count += 1
|
|
|
|
# 处理产品
|
|
if product_objects:
|
|
for obj in product_objects:
|
|
name = obj.get('name', '')
|
|
obj_id = str(obj.get('id', ''))
|
|
if name and obj_id:
|
|
self.mapping_data['products'][name] = obj_id
|
|
self._add_name_variants('products', name, obj_id)
|
|
mapping_count += 1
|
|
|
|
logger.info(f"ID 映射关系建立完成: {mapping_count} 个对象")
|
|
|
|
def _add_name_variants(self, category: str, name: str, obj_id: str):
|
|
"""为名称添加各种变体以支持模糊匹配"""
|
|
variants = [name]
|
|
|
|
# 添加小写变体
|
|
variants.append(name.lower())
|
|
|
|
# 移除空格
|
|
variants.append(name.replace(' ', ''))
|
|
|
|
# 移除常见后缀
|
|
for suffix in ['风格', '类型', '人群', '景区', '景点', '产品']:
|
|
if name.endswith(suffix):
|
|
variants.append(name[:-len(suffix)])
|
|
|
|
# 添加所有变体
|
|
for variant in set(variants):
|
|
if variant:
|
|
self.name_to_id[variant.lower()] = {
|
|
'id': obj_id,
|
|
'category': category,
|
|
'original_name': name
|
|
}
|
|
|
|
def find_ids_in_topic(self, topic: Dict[str, Any]) -> Dict[str, List[str]]:
|
|
"""
|
|
在选题中查找相关的对象 ID
|
|
|
|
Args:
|
|
topic: 选题字典
|
|
|
|
Returns:
|
|
按类别分组的 ID 列表
|
|
"""
|
|
found_ids = {
|
|
'styles': [],
|
|
'audiences': [],
|
|
'scenic_spots': [],
|
|
'products': []
|
|
}
|
|
|
|
# 提取选题文本
|
|
topic_text = self._extract_topic_text(topic)
|
|
topic_text_lower = topic_text.lower()
|
|
|
|
# 在文本中查找匹配
|
|
for name_variant, mapping_info in self.name_to_id.items():
|
|
if name_variant in topic_text_lower:
|
|
category = mapping_info['category']
|
|
obj_id = mapping_info['id']
|
|
|
|
if obj_id not in found_ids[category]:
|
|
found_ids[category].append(obj_id)
|
|
logger.debug(f"在选题中找到 {category}: {mapping_info['original_name']} -> {obj_id}")
|
|
|
|
return found_ids
|
|
|
|
def _extract_topic_text(self, topic: Dict[str, Any]) -> str:
|
|
"""提取选题中的所有文本内容"""
|
|
text_parts = []
|
|
|
|
# 常见的文本字段
|
|
text_fields = ['title', 'topic', 'description', 'content', 'keywords', 'tags']
|
|
|
|
for field in text_fields:
|
|
value = topic.get(field)
|
|
if isinstance(value, str):
|
|
text_parts.append(value)
|
|
elif isinstance(value, list):
|
|
text_parts.extend([str(v) for v in value if v])
|
|
|
|
return ' '.join(text_parts)
|
|
|
|
def get_id_by_name(self, name: str, category: Optional[str] = None) -> Optional[str]:
|
|
"""
|
|
根据名称获取 ID
|
|
|
|
Args:
|
|
name: 名称
|
|
category: 类别(可选)
|
|
|
|
Returns:
|
|
ID 或 None
|
|
"""
|
|
name_lower = name.lower()
|
|
|
|
if name_lower in self.name_to_id:
|
|
mapping = self.name_to_id[name_lower]
|
|
if category is None or mapping['category'] == category:
|
|
return mapping['id']
|
|
|
|
return None
|
|
|
|
def clear(self):
|
|
"""清空所有映射"""
|
|
self.name_to_id.clear()
|
|
for category in self.mapping_data:
|
|
self.mapping_data[category].clear()
|