更改了topic的返回信息
This commit is contained in:
parent
f6cff4e5c0
commit
015570b685
Binary file not shown.
@ -51,7 +51,11 @@ class TopicResponse(BaseModel):
|
||||
"styleLogic": "重点解析如何避开高温时段并高效游玩各园区",
|
||||
"targetAudience": "年轻人",
|
||||
"targetAudienceLogic": "解决家长担心孩子中暑的问题,提供科学游玩方案",
|
||||
"logic": "暑期旅游热门景点推荐"
|
||||
"logic": "暑期旅游热门景点推荐",
|
||||
"styleIds": [1],
|
||||
"audienceIds": [2],
|
||||
"scenicSpotIds": [1],
|
||||
"productIds": [1]
|
||||
}
|
||||
]
|
||||
}
|
||||
@ -106,7 +110,7 @@ class ContentResponse(BaseModel):
|
||||
"content": {
|
||||
"title": "天津冒险湾亲子游攻略",
|
||||
"content": "详细的游玩攻略内容...",
|
||||
"tags": ["亲子游", "水上乐园", "天津"]
|
||||
"tags": ["#亲子游", "#水上乐园", "#天津"]
|
||||
},
|
||||
"judgeSuccess": True
|
||||
}
|
||||
|
||||
Binary file not shown.
@ -53,13 +53,83 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
def _resolve_ids_to_names_with_mapping(db_service: DatabaseService,
|
||||
styleIds: Optional[List[int]] = None,
|
||||
audienceIds: Optional[List[int]] = None,
|
||||
scenicSpotIds: Optional[List[int]] = None,
|
||||
productIds: Optional[List[int]] = None) -> tuple:
|
||||
"""
|
||||
将ID列表转换为名称列表,并返回ID到名称的映射关系
|
||||
|
||||
Args:
|
||||
db_service: 数据库服务
|
||||
styleIds: 风格ID列表
|
||||
audienceIds: 受众ID列表
|
||||
scenicSpotIds: 景区ID列表
|
||||
productIds: 产品ID列表
|
||||
|
||||
Returns:
|
||||
(styles, audiences, scenic_spots, products, id_name_mappings) 名称列表和映射关系元组
|
||||
"""
|
||||
styles = []
|
||||
audiences = []
|
||||
scenic_spots = []
|
||||
products = []
|
||||
|
||||
# 建立ID到名称的映射字典
|
||||
id_name_mappings = {
|
||||
'style_mapping': {}, # {name: id}
|
||||
'audience_mapping': {}, # {name: id}
|
||||
'scenic_spot_mapping': {}, # {name: id}
|
||||
'product_mapping': {} # {name: id}
|
||||
}
|
||||
|
||||
# 如果数据库服务不可用,返回空列表
|
||||
if not db_service or not db_service.is_available():
|
||||
logger.warning("数据库服务不可用,无法解析ID")
|
||||
return styles, audiences, scenic_spots, products, id_name_mappings
|
||||
|
||||
# 解析风格ID
|
||||
if styleIds:
|
||||
style_records = db_service.get_styles_by_ids(styleIds)
|
||||
for record in style_records:
|
||||
style_name = record['styleName']
|
||||
styles.append(style_name)
|
||||
id_name_mappings['style_mapping'][style_name] = record['id']
|
||||
|
||||
# 解析受众ID
|
||||
if audienceIds:
|
||||
audience_records = db_service.get_audiences_by_ids(audienceIds)
|
||||
for record in audience_records:
|
||||
audience_name = record['audienceName']
|
||||
audiences.append(audience_name)
|
||||
id_name_mappings['audience_mapping'][audience_name] = record['id']
|
||||
|
||||
# 解析景区ID
|
||||
if scenicSpotIds:
|
||||
spot_records = db_service.get_scenic_spots_by_ids(scenicSpotIds)
|
||||
for record in spot_records:
|
||||
spot_name = record['name']
|
||||
scenic_spots.append(spot_name)
|
||||
id_name_mappings['scenic_spot_mapping'][spot_name] = record['id']
|
||||
|
||||
# 解析产品ID
|
||||
if productIds:
|
||||
product_records = db_service.get_products_by_ids(productIds)
|
||||
for record in product_records:
|
||||
product_name = record['name']
|
||||
products.append(product_name)
|
||||
id_name_mappings['product_mapping'][product_name] = record['id']
|
||||
|
||||
return styles, audiences, scenic_spots, products, id_name_mappings
|
||||
|
||||
def _resolve_ids_to_names(db_service: DatabaseService,
|
||||
styleIds: Optional[List[int]] = None,
|
||||
audienceIds: Optional[List[int]] = None,
|
||||
scenicSpotIds: Optional[List[int]] = None,
|
||||
productIds: Optional[List[int]] = None) -> tuple:
|
||||
"""
|
||||
将ID列表转换为名称列表
|
||||
将ID列表转换为名称列表(保持向后兼容)
|
||||
|
||||
Args:
|
||||
db_service: 数据库服务
|
||||
@ -71,39 +141,61 @@ def _resolve_ids_to_names(db_service: DatabaseService,
|
||||
Returns:
|
||||
(styles, audiences, scenic_spots, products) 名称列表元组
|
||||
"""
|
||||
styles = []
|
||||
audiences = []
|
||||
scenic_spots = []
|
||||
products = []
|
||||
|
||||
# 如果数据库服务不可用,返回空列表
|
||||
if not db_service or not db_service.is_available():
|
||||
logger.warning("数据库服务不可用,无法解析ID")
|
||||
return styles, audiences, scenic_spots, products
|
||||
|
||||
# 解析风格ID
|
||||
if styleIds:
|
||||
style_records = db_service.get_styles_by_ids(styleIds)
|
||||
styles = [record['styleName'] for record in style_records]
|
||||
|
||||
# 解析受众ID
|
||||
if audienceIds:
|
||||
audience_records = db_service.get_audiences_by_ids(audienceIds)
|
||||
audiences = [record['audienceName'] for record in audience_records]
|
||||
|
||||
# 解析景区ID
|
||||
if scenicSpotIds:
|
||||
spot_records = db_service.get_scenic_spots_by_ids(scenicSpotIds)
|
||||
scenic_spots = [record['name'] for record in spot_records]
|
||||
|
||||
# 解析产品ID
|
||||
if productIds:
|
||||
product_records = db_service.get_products_by_ids(productIds)
|
||||
products = [record['name'] for record in product_records]
|
||||
|
||||
styles, audiences, scenic_spots, products, _ = _resolve_ids_to_names_with_mapping(
|
||||
db_service, styleIds, audienceIds, scenicSpotIds, productIds
|
||||
)
|
||||
return styles, audiences, scenic_spots, products
|
||||
|
||||
|
||||
def _add_ids_to_topics(topics: List[Dict[str, Any]], id_name_mappings: Dict[str, Dict[str, int]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
为每个topic添加对应的ID字段
|
||||
|
||||
Args:
|
||||
topics: 生成的选题列表
|
||||
id_name_mappings: 名称到ID的映射字典
|
||||
|
||||
Returns:
|
||||
包含ID字段的选题列表
|
||||
"""
|
||||
enriched_topics = []
|
||||
|
||||
for topic in topics:
|
||||
# 复制原topic
|
||||
enriched_topic = topic.copy()
|
||||
|
||||
# 添加ID字段
|
||||
enriched_topic['styleIds'] = []
|
||||
enriched_topic['audienceIds'] = []
|
||||
enriched_topic['scenicSpotIds'] = []
|
||||
enriched_topic['productIds'] = []
|
||||
|
||||
# 根据topic中的name查找对应的ID
|
||||
if 'style' in topic and topic['style']:
|
||||
style_name = topic['style']
|
||||
if style_name in id_name_mappings['style_mapping']:
|
||||
enriched_topic['styleIds'] = [id_name_mappings['style_mapping'][style_name]]
|
||||
|
||||
if 'targetAudience' in topic and topic['targetAudience']:
|
||||
audience_name = topic['targetAudience']
|
||||
if audience_name in id_name_mappings['audience_mapping']:
|
||||
enriched_topic['audienceIds'] = [id_name_mappings['audience_mapping'][audience_name]]
|
||||
|
||||
if 'object' in topic and topic['object']:
|
||||
spot_name = topic['object']
|
||||
if spot_name in id_name_mappings['scenic_spot_mapping']:
|
||||
enriched_topic['scenicSpotIds'] = [id_name_mappings['scenic_spot_mapping'][spot_name]]
|
||||
|
||||
if 'product' in topic and topic['product']:
|
||||
product_name = topic['product']
|
||||
if product_name in id_name_mappings['product_mapping']:
|
||||
enriched_topic['productIds'] = [id_name_mappings['product_mapping'][product_name]]
|
||||
|
||||
enriched_topics.append(enriched_topic)
|
||||
|
||||
return enriched_topics
|
||||
|
||||
|
||||
@router.post("/topics", response_model=TopicResponse, summary="生成选题")
|
||||
async def generate_topics(
|
||||
request: TopicRequest,
|
||||
@ -121,8 +213,8 @@ async def generate_topics(
|
||||
- **productIds**: 产品ID列表
|
||||
"""
|
||||
try:
|
||||
# 将ID转换为名称
|
||||
styles, audiences, scenic_spots, products = _resolve_ids_to_names(
|
||||
# 将ID转换为名称,并获取映射关系
|
||||
styles, audiences, scenic_spots, products, id_name_mappings = _resolve_ids_to_names_with_mapping(
|
||||
db_service,
|
||||
request.styleIds,
|
||||
request.audienceIds,
|
||||
@ -139,9 +231,12 @@ async def generate_topics(
|
||||
products=products
|
||||
)
|
||||
|
||||
# 为topics添加ID字段
|
||||
enriched_topics = _add_ids_to_topics(topics, id_name_mappings)
|
||||
|
||||
return TopicResponse(
|
||||
requestId=request_id,
|
||||
topics=topics
|
||||
topics=enriched_topics
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"生成选题失败: {e}", exc_info=True)
|
||||
|
||||
Binary file not shown.
@ -505,4 +505,146 @@ class DatabaseService:
|
||||
Returns:
|
||||
数据库是否可用
|
||||
"""
|
||||
return self.db_pool is not None
|
||||
return self.db_pool is not None
|
||||
|
||||
# 名称到ID的反向查询方法
|
||||
|
||||
def get_style_id_by_name(self, style_name: str) -> Optional[int]:
|
||||
"""
|
||||
根据风格名称获取风格ID
|
||||
|
||||
Args:
|
||||
style_name: 风格名称
|
||||
|
||||
Returns:
|
||||
风格ID,如果未找到则返回None
|
||||
"""
|
||||
if not self.db_pool:
|
||||
logger.error("数据库连接池未初始化")
|
||||
return None
|
||||
|
||||
try:
|
||||
conn = self.db_pool.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
cursor.execute(
|
||||
"SELECT id FROM style WHERE styleName = %s AND isDelete = 0",
|
||||
(style_name,)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
cursor.close()
|
||||
conn.close()
|
||||
|
||||
if result:
|
||||
return result['id']
|
||||
else:
|
||||
logger.warning(f"未找到风格: {style_name}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询风格ID失败: {e}")
|
||||
return None
|
||||
|
||||
def get_audience_id_by_name(self, audience_name: str) -> Optional[int]:
|
||||
"""
|
||||
根据受众名称获取受众ID
|
||||
|
||||
Args:
|
||||
audience_name: 受众名称
|
||||
|
||||
Returns:
|
||||
受众ID,如果未找到则返回None
|
||||
"""
|
||||
if not self.db_pool:
|
||||
logger.error("数据库连接池未初始化")
|
||||
return None
|
||||
|
||||
try:
|
||||
conn = self.db_pool.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
cursor.execute(
|
||||
"SELECT id FROM targetAudience WHERE audienceName = %s AND isDelete = 0",
|
||||
(audience_name,)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
cursor.close()
|
||||
conn.close()
|
||||
|
||||
if result:
|
||||
return result['id']
|
||||
else:
|
||||
logger.warning(f"未找到受众: {audience_name}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询受众ID失败: {e}")
|
||||
return None
|
||||
|
||||
def get_scenic_spot_id_by_name(self, spot_name: str) -> Optional[int]:
|
||||
"""
|
||||
根据景区名称获取景区ID
|
||||
|
||||
Args:
|
||||
spot_name: 景区名称
|
||||
|
||||
Returns:
|
||||
景区ID,如果未找到则返回None
|
||||
"""
|
||||
if not self.db_pool:
|
||||
logger.error("数据库连接池未初始化")
|
||||
return None
|
||||
|
||||
try:
|
||||
conn = self.db_pool.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
cursor.execute(
|
||||
"SELECT id FROM scenicSpot WHERE name = %s AND isDelete = 0",
|
||||
(spot_name,)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
cursor.close()
|
||||
conn.close()
|
||||
|
||||
if result:
|
||||
return result['id']
|
||||
else:
|
||||
logger.warning(f"未找到景区: {spot_name}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询景区ID失败: {e}")
|
||||
return None
|
||||
|
||||
def get_product_id_by_name(self, product_name: str) -> Optional[int]:
|
||||
"""
|
||||
根据产品名称获取产品ID
|
||||
|
||||
Args:
|
||||
product_name: 产品名称
|
||||
|
||||
Returns:
|
||||
产品ID,如果未找到则返回None
|
||||
"""
|
||||
if not self.db_pool:
|
||||
logger.error("数据库连接池未初始化")
|
||||
return None
|
||||
|
||||
try:
|
||||
conn = self.db_pool.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
cursor.execute(
|
||||
"SELECT id FROM product WHERE productName = %s AND isDelete = 0",
|
||||
(product_name,)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
cursor.close()
|
||||
conn.close()
|
||||
|
||||
if result:
|
||||
return result['id']
|
||||
else:
|
||||
logger.warning(f"未找到产品: {product_name}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询产品ID失败: {e}")
|
||||
return None
|
||||
Loading…
x
Reference in New Issue
Block a user