video_template_gen/code/batch_predata.py

135 lines
3.9 KiB
Python

import os
import glob
from pathlib import Path
from pre_data_1 import read_json_file, format_ocr_json, merge_and_filter_subtitles
def find_ocr_json_files(base_dir):
"""
在指定目录中查找所有OCR JSON文件
Args:
base_dir: 基础目录路径
Returns:
list: 找到的OCR JSON文件路径列表
"""
ocr_files = []
base_path = Path(base_dir)
# 查找所有可能的OCR目录
for ocr_dir in base_path.rglob("ocr"):
if ocr_dir.is_dir():
# 在ocr目录中查找JSON文件
json_files = list(ocr_dir.glob("*.json"))
ocr_files.extend(json_files)
# 也查找直接包含"subtitles.json"的文件
subtitle_files = list(base_path.rglob("*subtitles.json"))
ocr_files.extend(subtitle_files)
# 去重
ocr_files = list(set(ocr_files))
return ocr_files
def process_ocr_file(ocr_json_path, iou_threshold=0.7, text_similarity_threshold=0.7):
"""
处理单个OCR JSON文件
Args:
ocr_json_path: OCR JSON文件路径
iou_threshold: IoU阈值
text_similarity_threshold: 文本相似度阈值
Returns:
bool: 处理是否成功
"""
try:
print(f"\n正在处理文件: {ocr_json_path}")
# 读取OCR数据
ocr_data = read_json_file(ocr_json_path)
if ocr_data is None:
print(f"跳过文件 {ocr_json_path} - 读取失败")
return False
# 格式化OCR数据
pre_data, subtitle_array = format_ocr_json(ocr_data)
if not subtitle_array:
print(f"跳过文件 {ocr_json_path} - 没有有效的字幕数据")
return False
# 合并并过滤字幕
processed_text, processed_array = merge_and_filter_subtitles(
subtitle_array,
iou_threshold,
text_similarity_threshold
)
# 保存处理结果
output_dir = Path(ocr_json_path).parent
output_filename = Path(ocr_json_path).stem + "_processed.txt"
output_path = output_dir / output_filename
with open(output_path, 'w', encoding='utf-8') as f:
f.write(processed_text)
print(f"处理完成: {output_path}")
print(f"原始字幕数量: {len(subtitle_array)}")
print(f"处理后字幕数量: {len(processed_array)}")
return True
except Exception as e:
print(f"处理文件 {ocr_json_path} 时出错: {str(e)}")
return False
def main():
"""主函数"""
base_dir = "/root/autodl-tmp/video_processed2"
print(f"开始在目录 {base_dir} 中查找OCR JSON文件...")
# 查找所有OCR JSON文件
ocr_files = find_ocr_json_files(base_dir)
if not ocr_files:
print("未找到任何OCR JSON文件")
return
print(f"找到 {len(ocr_files)} 个OCR JSON文件:")
for i, file_path in enumerate(ocr_files, 1):
print(f" {i}. {file_path}")
# 处理参数
iou_threshold = 0.7
text_similarity_threshold = 0.7
print(f"\n开始批量处理...")
print(f"IoU阈值: {iou_threshold}")
print(f"文本相似度阈值: {text_similarity_threshold}")
# 批量处理
success_count = 0
failed_count = 0
for i, ocr_file in enumerate(ocr_files, 1):
print(f"\n进度: {i}/{len(ocr_files)}")
if process_ocr_file(ocr_file, iou_threshold, text_similarity_threshold):
success_count += 1
else:
failed_count += 1
# 输出统计结果
print(f"\n批量处理完成!")
print(f"总文件数: {len(ocr_files)}")
print(f"成功处理: {success_count}")
print(f"处理失败: {failed_count}")
print(f"成功率: {success_count/len(ocr_files)*100:.1f}%")
if __name__ == "__main__":
main()