hot_video_analyse/code/pre_data_1.py

316 lines
12 KiB
Python
Raw Normal View History

import os
def read_json_file(json_path):
"""读取JSON文件内容"""
try:
import json
with open(json_path, 'r', encoding='utf-8') as file:
data = json.load(file)
print(f"成功读取JSON文件: {json_path}")
return data
except FileNotFoundError:
print(f"错误: 找不到文件 {json_path}")
return None
except json.JSONDecodeError as e:
print(f"JSON解析错误: {e}")
return None
except Exception as e:
print(f"读取JSON文件时出错: {e}")
return None
def calculate_text_similarity(text1, text2):
"""
计算两个文本的相似度使用Jaccard相似度
Args:
text1: 第一个文本
text2: 第二个文本
Returns:
float: 相似度值 (0-1之间)
"""
# 检查空文本
if not text1 or not text2:
return 0.0
# 清理文本,移除空白字符
text1 = text1.strip()
text2 = text2.strip()
if not text1 or not text2:
return 0.0
# 如果两个文本完全相同
if text1 == text2:
return 1.0
# 将文本转换为字符集合
chars1 = set(text1)
chars2 = set(text2)
# 计算Jaccard相似度
intersection = len(chars1.intersection(chars2))
union = len(chars1.union(chars2))
similarity = intersection / union if union > 0 else 0.0
return similarity
def calculate_iou(box1, box2):
"""
计算两个边界框的IoU (Intersection over Union)
Args:
box1: 第一个边界框 [x1, y1, x2, y2] [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
box2: 第二个边界框 [x1, y1, x2, y2] [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
Returns:
float: IoU值 (0-1之间)
"""
# 处理不同的输入格式
if len(box1) == 4 and isinstance(box1[0], (int, float)):
# 格式: [x1, y1, x2, y2]
x1_1, y1_1, x2_1, y2_1 = box1
elif len(box1) == 4 and isinstance(box1[0], list):
# 格式: [[x1,y1], [x2,y2], [x3,y3], [x4,y4]] - 取最小和最大坐标
x_coords = [point[0] for point in box1]
y_coords = [point[1] for point in box1]
x1_1, x2_1 = min(x_coords), max(x_coords)
y1_1, y2_1 = min(y_coords), max(y_coords)
else:
raise ValueError("box1格式错误应为[x1,y1,x2,y2]或[[x1,y1],[x2,y2],[x3,y3],[x4,y4]]")
if len(box2) == 4 and isinstance(box2[0], (int, float)):
# 格式: [x1, y1, x2, y2]
x1_2, y1_2, x2_2, y2_2 = box2
elif len(box2) == 4 and isinstance(box2[0], list):
# 格式: [[x1,y1], [x2,y2], [x3,y3], [x4,y4]] - 取最小和最大坐标
x_coords = [point[0] for point in box2]
y_coords = [point[1] for point in box2]
x1_2, x2_2 = min(x_coords), max(x_coords)
y1_2, y2_2 = min(y_coords), max(y_coords)
else:
raise ValueError("box2格式错误应为[x1,y1,x2,y2]或[[x1,y1],[x2,y2],[x3,y3],[x4,y4]]")
# 计算交集区域
x_left = max(x1_1, x1_2)
y_top = max(y1_1, y1_2)
x_right = min(x2_1, x2_2)
y_bottom = min(y2_1, y2_2)
# 检查是否有交集
if x_right < x_left or y_bottom < y_top:
return 0.0
# 计算交集面积
intersection_area = (x_right - x_left) * (y_bottom - y_top)
# 计算并集面积
box1_area = (x2_1 - x1_1) * (y2_1 - y1_1)
box2_area = (x2_2 - x1_2) * (y2_2 - y1_2)
union_area = box1_area + box2_area - intersection_area
# 计算IoU
iou = intersection_area / union_area if union_area > 0 else 0.0
return iou
def format_ocr_json(ocr_data):
"""格式化OCR字幕转文字JSON数据"""
if not ocr_data:
return "", []
formatted_text = "【OCR字幕识别内容】\n"
# 如果是字幕提取器的格式
if isinstance(ocr_data, dict):
# 基本信息
if 'ocr_engine' in ocr_data:
formatted_text += f"OCR引擎: {ocr_data['ocr_engine']}\n"
if 'video_path' in ocr_data:
formatted_text += f"视频文件: {ocr_data['video_path']}\n"
if 'duration' in ocr_data:
formatted_text += f"视频时长: {ocr_data['duration']:.2f}\n"
if 'fps' in ocr_data:
formatted_text += f"视频帧率: {ocr_data['fps']:.2f}FPS\n"
if 'frame_width' in ocr_data and 'frame_height' in ocr_data:
formatted_text += f"视频分辨率: {ocr_data['frame_width']}x{ocr_data['frame_height']}\n"
# 字幕区域信息
if 'subtitle_position' in ocr_data:
formatted_text += f"字幕区域: {ocr_data['subtitle_position']}\n"
if 'subtitle_region' in ocr_data:
region = ocr_data['subtitle_region']
formatted_text += f"字幕区域坐标: {region}\n"
# 处理参数
if 'sample_interval' in ocr_data:
formatted_text += f"采样间隔: {ocr_data['sample_interval']}\n"
if 'confidence_threshold' in ocr_data:
formatted_text += f"置信度阈值: {ocr_data['confidence_threshold']}\n"
# 完整字幕文本
if 'continuous_text' in ocr_data:
formatted_text += f"\n📄 完整字幕文本:\n"
formatted_text += f"{ocr_data['continuous_text']}\n"
# 详细字幕时间轴 - 按三层嵌套数组结构组织
if 'subtitles' in ocr_data and len(ocr_data['subtitles']) > 0:
subtitles = ocr_data['subtitles']
# 按时间戳分组存储
timestamp_groups = {}
for subtitle in subtitles:
timestamp = subtitle.get('timestamp', 0)
text = subtitle.get('text', '')
confidence = subtitle.get('confidence', 0)
engine = subtitle.get('engine', 'Unknown')
bbox = subtitle.get('bbox', [])
if timestamp not in timestamp_groups:
timestamp_groups[timestamp] = []
# 第三层:内容和位置
subtitle_content = {
'text': text,
'bbox': bbox,
"timestamp": timestamp
}
timestamp_groups[timestamp].append(subtitle_content)
# 转换为三层嵌套数组结构
subtitle_array = []
sorted_timestamps = sorted(timestamp_groups.keys())
for timestamp in sorted_timestamps:
# 第一层:时间戳
timestamp_entry = {
'timestamp': timestamp,
'contents': timestamp_groups[timestamp] # 第二层:同一时间戳内的各个内容
}
subtitle_array.append(timestamp_entry)
# 显示三层嵌套数组结构
formatted_text += f"\n⏰ 详细字幕时间轴 (三层嵌套数组结构):\n"
# 只显示前10个时间戳避免过长
display_count = min(10, len(subtitle_array))
for i, timestamp_entry in enumerate(subtitle_array[:display_count], 1):
timestamp = timestamp_entry['timestamp']
contents = timestamp_entry['contents']
formatted_text += f" {i}. {timestamp:.2f}s:\n"
# 显示该时间戳下的所有字幕(第二层)
for j, content in enumerate(contents, 1):
text = content['text']
bbox = content['bbox']
formatted_text += f" {j}. [{timestamp:.2f}s|{confidence:.3f}]: {text}\n"
# 如果有位置信息显示bbox第三层
if bbox:
formatted_text += f" 位置: {bbox}\n"
formatted_text += "\n"
if len(subtitle_array) > display_count:
formatted_text += f" ... (还有{len(subtitle_array) - display_count}个时间戳)\n"
# 返回三层嵌套数组结构
return formatted_text, subtitle_array
return formatted_text, []
def merge_and_filter_subtitles(subtitle_array, iou_threshold=0.7, text_similarity_threshold=0.7):
"""
合并并过滤字幕内容去除重复和空内容返回格式化字符串和处理后的数组
"""
# 深拷贝,避免原地修改
import copy
subtitle_array = copy.deepcopy(subtitle_array)
formatted_text = []
for i in range(len(subtitle_array)):
for j in range(len(subtitle_array[i]["contents"])):
# 修复确保i+k不会超出数组范围
for k in range(1, len(subtitle_array) - i): # 从1开始避免自己和自己比较
if i + k >= len(subtitle_array): # 安全检查
break
for l in range(len(subtitle_array[i+k]["contents"])):
text = subtitle_array[i]["contents"][j]["text"]
bbox = subtitle_array[i]["contents"][j]["bbox"]
text_1 = subtitle_array[i+k]["contents"][l]["text"]
bbox_1 = subtitle_array[i+k]["contents"][l]["bbox"]
iou = calculate_iou(bbox, bbox_1)
text_similarity = calculate_text_similarity(text, text_1)
if iou > iou_threshold and text_similarity > text_similarity_threshold:
# 记录需要删除的索引
subtitle_array[i+k]["contents"][l]["text"] = ''
subtitle_array[i]["contents"][j]["timestamp"] += 1
# 删除text为空字符串的contents
for i in range(len(subtitle_array)):
subtitle_array[i]["contents"] = [content for content in subtitle_array[i]["contents"] if content["text"] != '']
# 删除contents为空的时间戳条目
subtitle_array = [entry for entry in subtitle_array if len(entry["contents"]) > 0]
#formatted_text.append("处理后的字幕数组:")
for i, timestamp_entry in enumerate(subtitle_array[:], 1):
formatted_text.append(f"\n开始时间 {timestamp_entry['timestamp']:.2f}s:")
#formatted_text.append(f" 包含 {len(timestamp_entry['contents'])} 个字幕内容")
for j, content in enumerate(timestamp_entry['contents'], 1):
formatted_text.append(f" {j}. 文本: '{content['text']}'")
if content['bbox']:
formatted_text.append(f" 位置: {content['bbox']}")
if 'timestamp' in content and content['timestamp']:
formatted_text.append(f" 结束时间: {content['timestamp']:.2f}s")
#formatted_text.append("\n完整数组结构:")
#formatted_text.append(str(subtitle_array))
return '\n'.join(formatted_text), subtitle_array
ocr_json_path = "/root/autodl-tmp/new_cnocr/哈尔滨_subtitles.json"
ocr_data = read_json_file(ocr_json_path)
pre_data , subtitle_array= format_ocr_json(ocr_data)
iou_threshold = 0.8
text_similarity_threshold = 0.8
a , b = merge_and_filter_subtitles(subtitle_array, iou_threshold, text_similarity_threshold)
#print("\n完整数组结构:")
print(a)
print(b)
# 保存输出结果到txt文件
output_dir = os.path.dirname(ocr_json_path)
output_filename = os.path.splitext(os.path.basename(ocr_json_path))[0] + "_processed.txt"
output_path = os.path.join(output_dir, output_filename)
try:
with open(output_path, 'w', encoding='utf-8') as f:
f.write(a)
print(f"\n处理结果已保存到: {output_path}")
except Exception as e:
print(f"保存文件时出错: {e}")
#验证 "/root/autodl-tmp/douyin_ocr/兰州_subtitles.json" 里面的重复的两个内容确实是bbox不重叠
# a = [[303, 243], [442, 243], [442, 303], [303, 303]]
# b = [[339, 231], [495, 241], [490, 304], [335, 294]]
# c = [[482, 273], [660, 276], [660, 303], [481, 300]]
# d = [[536, 268], [732, 273], [731, 300], [535, 295]]
# iou = calculate_iou(a,b) # 0.47
# d = calculate_iou(c,d) # 0.40