316 lines
12 KiB
Python
316 lines
12 KiB
Python
import os
|
||
|
||
def read_json_file(json_path):
|
||
"""读取JSON文件内容"""
|
||
try:
|
||
import json
|
||
with open(json_path, 'r', encoding='utf-8') as file:
|
||
data = json.load(file)
|
||
print(f"成功读取JSON文件: {json_path}")
|
||
return data
|
||
except FileNotFoundError:
|
||
print(f"错误: 找不到文件 {json_path}")
|
||
return None
|
||
except json.JSONDecodeError as e:
|
||
print(f"JSON解析错误: {e}")
|
||
return None
|
||
except Exception as e:
|
||
print(f"读取JSON文件时出错: {e}")
|
||
return None
|
||
|
||
def calculate_text_similarity(text1, text2):
|
||
"""
|
||
计算两个文本的相似度(使用Jaccard相似度)
|
||
|
||
Args:
|
||
text1: 第一个文本
|
||
text2: 第二个文本
|
||
|
||
Returns:
|
||
float: 相似度值 (0-1之间)
|
||
"""
|
||
# 检查空文本
|
||
if not text1 or not text2:
|
||
return 0.0
|
||
|
||
# 清理文本,移除空白字符
|
||
text1 = text1.strip()
|
||
text2 = text2.strip()
|
||
|
||
if not text1 or not text2:
|
||
return 0.0
|
||
|
||
# 如果两个文本完全相同
|
||
if text1 == text2:
|
||
return 1.0
|
||
|
||
# 将文本转换为字符集合
|
||
chars1 = set(text1)
|
||
chars2 = set(text2)
|
||
|
||
# 计算Jaccard相似度
|
||
intersection = len(chars1.intersection(chars2))
|
||
union = len(chars1.union(chars2))
|
||
|
||
similarity = intersection / union if union > 0 else 0.0
|
||
return similarity
|
||
|
||
def calculate_iou(box1, box2):
|
||
"""
|
||
计算两个边界框的IoU (Intersection over Union)
|
||
|
||
Args:
|
||
box1: 第一个边界框 [x1, y1, x2, y2] 或 [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
|
||
box2: 第二个边界框 [x1, y1, x2, y2] 或 [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
|
||
|
||
Returns:
|
||
float: IoU值 (0-1之间)
|
||
"""
|
||
# 处理不同的输入格式
|
||
if len(box1) == 4 and isinstance(box1[0], (int, float)):
|
||
# 格式: [x1, y1, x2, y2]
|
||
x1_1, y1_1, x2_1, y2_1 = box1
|
||
elif len(box1) == 4 and isinstance(box1[0], list):
|
||
# 格式: [[x1,y1], [x2,y2], [x3,y3], [x4,y4]] - 取最小和最大坐标
|
||
x_coords = [point[0] for point in box1]
|
||
y_coords = [point[1] for point in box1]
|
||
x1_1, x2_1 = min(x_coords), max(x_coords)
|
||
y1_1, y2_1 = min(y_coords), max(y_coords)
|
||
else:
|
||
raise ValueError("box1格式错误,应为[x1,y1,x2,y2]或[[x1,y1],[x2,y2],[x3,y3],[x4,y4]]")
|
||
|
||
if len(box2) == 4 and isinstance(box2[0], (int, float)):
|
||
# 格式: [x1, y1, x2, y2]
|
||
x1_2, y1_2, x2_2, y2_2 = box2
|
||
elif len(box2) == 4 and isinstance(box2[0], list):
|
||
# 格式: [[x1,y1], [x2,y2], [x3,y3], [x4,y4]] - 取最小和最大坐标
|
||
x_coords = [point[0] for point in box2]
|
||
y_coords = [point[1] for point in box2]
|
||
x1_2, x2_2 = min(x_coords), max(x_coords)
|
||
y1_2, y2_2 = min(y_coords), max(y_coords)
|
||
else:
|
||
raise ValueError("box2格式错误,应为[x1,y1,x2,y2]或[[x1,y1],[x2,y2],[x3,y3],[x4,y4]]")
|
||
|
||
# 计算交集区域
|
||
x_left = max(x1_1, x1_2)
|
||
y_top = max(y1_1, y1_2)
|
||
x_right = min(x2_1, x2_2)
|
||
y_bottom = min(y2_1, y2_2)
|
||
|
||
# 检查是否有交集
|
||
if x_right < x_left or y_bottom < y_top:
|
||
return 0.0
|
||
|
||
# 计算交集面积
|
||
intersection_area = (x_right - x_left) * (y_bottom - y_top)
|
||
|
||
# 计算并集面积
|
||
box1_area = (x2_1 - x1_1) * (y2_1 - y1_1)
|
||
box2_area = (x2_2 - x1_2) * (y2_2 - y1_2)
|
||
union_area = box1_area + box2_area - intersection_area
|
||
|
||
# 计算IoU
|
||
iou = intersection_area / union_area if union_area > 0 else 0.0
|
||
|
||
return iou
|
||
|
||
def format_ocr_json(ocr_data):
|
||
"""格式化OCR字幕转文字JSON数据"""
|
||
if not ocr_data:
|
||
return "", []
|
||
|
||
formatted_text = "【OCR字幕识别内容】\n"
|
||
|
||
# 如果是字幕提取器的格式
|
||
if isinstance(ocr_data, dict):
|
||
# 基本信息
|
||
if 'ocr_engine' in ocr_data:
|
||
formatted_text += f"OCR引擎: {ocr_data['ocr_engine']}\n"
|
||
|
||
if 'video_path' in ocr_data:
|
||
formatted_text += f"视频文件: {ocr_data['video_path']}\n"
|
||
|
||
if 'duration' in ocr_data:
|
||
formatted_text += f"视频时长: {ocr_data['duration']:.2f}秒\n"
|
||
|
||
if 'fps' in ocr_data:
|
||
formatted_text += f"视频帧率: {ocr_data['fps']:.2f}FPS\n"
|
||
|
||
if 'frame_width' in ocr_data and 'frame_height' in ocr_data:
|
||
formatted_text += f"视频分辨率: {ocr_data['frame_width']}x{ocr_data['frame_height']}\n"
|
||
|
||
# 字幕区域信息
|
||
if 'subtitle_position' in ocr_data:
|
||
formatted_text += f"字幕区域: {ocr_data['subtitle_position']}\n"
|
||
|
||
if 'subtitle_region' in ocr_data:
|
||
region = ocr_data['subtitle_region']
|
||
formatted_text += f"字幕区域坐标: {region}\n"
|
||
|
||
# 处理参数
|
||
if 'sample_interval' in ocr_data:
|
||
formatted_text += f"采样间隔: {ocr_data['sample_interval']}帧\n"
|
||
|
||
if 'confidence_threshold' in ocr_data:
|
||
formatted_text += f"置信度阈值: {ocr_data['confidence_threshold']}\n"
|
||
|
||
# 完整字幕文本
|
||
if 'continuous_text' in ocr_data:
|
||
formatted_text += f"\n📄 完整字幕文本:\n"
|
||
formatted_text += f"{ocr_data['continuous_text']}\n"
|
||
|
||
# 详细字幕时间轴 - 按三层嵌套数组结构组织
|
||
if 'subtitles' in ocr_data and len(ocr_data['subtitles']) > 0:
|
||
subtitles = ocr_data['subtitles']
|
||
|
||
# 按时间戳分组存储
|
||
timestamp_groups = {}
|
||
for subtitle in subtitles:
|
||
timestamp = subtitle.get('timestamp', 0)
|
||
text = subtitle.get('text', '')
|
||
confidence = subtitle.get('confidence', 0)
|
||
engine = subtitle.get('engine', 'Unknown')
|
||
bbox = subtitle.get('bbox', [])
|
||
|
||
if timestamp not in timestamp_groups:
|
||
timestamp_groups[timestamp] = []
|
||
|
||
# 第三层:内容和位置
|
||
subtitle_content = {
|
||
'text': text,
|
||
'bbox': bbox,
|
||
"timestamp": timestamp
|
||
}
|
||
|
||
timestamp_groups[timestamp].append(subtitle_content)
|
||
|
||
# 转换为三层嵌套数组结构
|
||
subtitle_array = []
|
||
sorted_timestamps = sorted(timestamp_groups.keys())
|
||
|
||
for timestamp in sorted_timestamps:
|
||
# 第一层:时间戳
|
||
timestamp_entry = {
|
||
'timestamp': timestamp,
|
||
'contents': timestamp_groups[timestamp] # 第二层:同一时间戳内的各个内容
|
||
}
|
||
subtitle_array.append(timestamp_entry)
|
||
|
||
# 显示三层嵌套数组结构
|
||
formatted_text += f"\n⏰ 详细字幕时间轴 (三层嵌套数组结构):\n"
|
||
|
||
# 只显示前10个时间戳,避免过长
|
||
display_count = min(10, len(subtitle_array))
|
||
for i, timestamp_entry in enumerate(subtitle_array[:display_count], 1):
|
||
timestamp = timestamp_entry['timestamp']
|
||
contents = timestamp_entry['contents']
|
||
|
||
formatted_text += f" {i}. {timestamp:.2f}s:\n"
|
||
|
||
# 显示该时间戳下的所有字幕(第二层)
|
||
for j, content in enumerate(contents, 1):
|
||
text = content['text']
|
||
bbox = content['bbox']
|
||
|
||
formatted_text += f" {j}. [{timestamp:.2f}s|{confidence:.3f}]: {text}\n"
|
||
|
||
# 如果有位置信息,显示bbox(第三层)
|
||
if bbox:
|
||
formatted_text += f" 位置: {bbox}\n"
|
||
|
||
formatted_text += "\n"
|
||
|
||
if len(subtitle_array) > display_count:
|
||
formatted_text += f" ... (还有{len(subtitle_array) - display_count}个时间戳)\n"
|
||
|
||
# 返回三层嵌套数组结构
|
||
return formatted_text, subtitle_array
|
||
|
||
return formatted_text, []
|
||
|
||
def merge_and_filter_subtitles(subtitle_array, iou_threshold=0.7, text_similarity_threshold=0.7):
|
||
"""
|
||
合并并过滤字幕内容,去除重复和空内容,返回格式化字符串和处理后的数组
|
||
"""
|
||
# 深拷贝,避免原地修改
|
||
import copy
|
||
subtitle_array = copy.deepcopy(subtitle_array)
|
||
formatted_text = []
|
||
|
||
for i in range(len(subtitle_array)):
|
||
for j in range(len(subtitle_array[i]["contents"])):
|
||
# 修复:确保i+k不会超出数组范围
|
||
for k in range(1, len(subtitle_array) - i): # 从1开始,避免自己和自己比较
|
||
if i + k >= len(subtitle_array): # 安全检查
|
||
break
|
||
for l in range(len(subtitle_array[i+k]["contents"])):
|
||
text = subtitle_array[i]["contents"][j]["text"]
|
||
bbox = subtitle_array[i]["contents"][j]["bbox"]
|
||
text_1 = subtitle_array[i+k]["contents"][l]["text"]
|
||
bbox_1 = subtitle_array[i+k]["contents"][l]["bbox"]
|
||
|
||
iou = calculate_iou(bbox, bbox_1)
|
||
text_similarity = calculate_text_similarity(text, text_1)
|
||
|
||
if iou > iou_threshold and text_similarity > text_similarity_threshold:
|
||
# 记录需要删除的索引
|
||
subtitle_array[i+k]["contents"][l]["text"] = ''
|
||
subtitle_array[i]["contents"][j]["timestamp"] += 1
|
||
|
||
# 删除text为空字符串的contents
|
||
for i in range(len(subtitle_array)):
|
||
subtitle_array[i]["contents"] = [content for content in subtitle_array[i]["contents"] if content["text"] != '']
|
||
|
||
# 删除contents为空的时间戳条目
|
||
subtitle_array = [entry for entry in subtitle_array if len(entry["contents"]) > 0]
|
||
|
||
#formatted_text.append("处理后的字幕数组:")
|
||
for i, timestamp_entry in enumerate(subtitle_array[:], 1):
|
||
formatted_text.append(f"\n开始时间 {timestamp_entry['timestamp']:.2f}s:")
|
||
#formatted_text.append(f" 包含 {len(timestamp_entry['contents'])} 个字幕内容")
|
||
for j, content in enumerate(timestamp_entry['contents'], 1):
|
||
formatted_text.append(f" {j}. 文本: '{content['text']}'")
|
||
if content['bbox']:
|
||
formatted_text.append(f" 位置: {content['bbox']}")
|
||
if 'timestamp' in content and content['timestamp']:
|
||
formatted_text.append(f" 结束时间: {content['timestamp']:.2f}s")
|
||
|
||
#formatted_text.append("\n完整数组结构:")
|
||
#formatted_text.append(str(subtitle_array))
|
||
|
||
return '\n'.join(formatted_text), subtitle_array
|
||
|
||
|
||
ocr_json_path = "/root/autodl-tmp/video_cnocr/中国国旅_subtitles.json"
|
||
|
||
ocr_data = read_json_file(ocr_json_path)
|
||
pre_data , subtitle_array= format_ocr_json(ocr_data)
|
||
|
||
iou_threshold = 0.7
|
||
text_similarity_threshold = 0.7
|
||
a , b = merge_and_filter_subtitles(subtitle_array, iou_threshold, text_similarity_threshold)
|
||
#print("\n完整数组结构:")
|
||
print(a)
|
||
print(b)
|
||
|
||
# 保存输出结果到txt文件
|
||
|
||
output_dir = os.path.dirname(ocr_json_path)
|
||
output_filename = os.path.splitext(os.path.basename(ocr_json_path))[0] + "_processed.txt"
|
||
output_path = os.path.join(output_dir, output_filename)
|
||
|
||
try:
|
||
with open(output_path, 'w', encoding='utf-8') as f:
|
||
f.write(a)
|
||
print(f"\n处理结果已保存到: {output_path}")
|
||
except Exception as e:
|
||
print(f"保存文件时出错: {e}")
|
||
|
||
#验证 "/root/autodl-tmp/douyin_ocr/兰州_subtitles.json" 里面的重复的两个内容,确实是bbox不重叠
|
||
# a = [[303, 243], [442, 243], [442, 303], [303, 303]]
|
||
# b = [[339, 231], [495, 241], [490, 304], [335, 294]]
|
||
# c = [[482, 273], [660, 276], [660, 303], [481, 300]]
|
||
# d = [[536, 268], [732, 273], [731, 300], [535, 295]]
|
||
|
||
# iou = calculate_iou(a,b) # 0.47
|
||
# d = calculate_iou(c,d) # 0.40 |