video_template_gen/code/batch_subtitle_extractor.py

401 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
批量视频字幕提取器
支持批量处理多个视频文件,提取字幕
支持PaddleOCR、EasyOCR和CnOCR三种引擎
"""
import os
import sys
import time
import json
import argparse
import re
from pathlib import Path
from datetime import datetime
import logging
# 添加当前目录到路径以导入OCR模块
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from ocr_subtitle_extractor import VideoSubtitleExtractor
# 设置OCR模型路径环境变量
os.environ['EASYOCR_MODULE_PATH'] = '/root/autodl-tmp/llm/easyocr'
os.environ['CNOCR_HOME'] = '/root/autodl-tmp/llm/cnocr'
# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class BatchSubtitleExtractor:
"""批量视频字幕提取器"""
def __init__(self, ocr_engine="paddleocr", language="ch"):
"""
初始化批量提取器
Args:
ocr_engine: OCR引擎 ("paddleocr", "easyocr", "cnocr", "all")
language: 语言设置 ("ch", "en", "ch_en")
"""
self.ocr_engine = ocr_engine
self.language = language
self.extractor = VideoSubtitleExtractor(ocr_engine=ocr_engine, language=language)
def _extract_segment_number(self, video_name):
"""
从视频名中提取最后一个数字作为片段号
Args:
video_name: 视频文件名(不含扩展名)
Returns:
int: 片段号如果未找到则返回1
"""
# 使用正则表达式匹配最后一个数字
match = re.search(r'(\d+)(?:_segment)?$', video_name)
if match:
return int(match.group(1))
return 1
def _adjust_timestamps(self, results, segment_number):
"""
根据片段号调整时间戳
Args:
results: OCR结果字典
segment_number: 片段号
Returns:
dict: 调整后的结果
"""
if segment_number == 1:
# 片段号为1不调整时间
return results
# 计算时间偏移量
time_offset = (segment_number - 1) * 30
logger.info(f"片段号: {segment_number}, 时间偏移: +{time_offset}")
# 调整所有字幕的时间戳
for subtitle in results['subtitles']:
if 'timestamp' in subtitle:
subtitle['timestamp'] += time_offset
return results
def find_video_files(self, input_dir):
"""查找目录中的所有视频文件"""
video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm']
video_files = []
input_path = Path(input_dir)
if input_path.is_file():
# 单个文件
if input_path.suffix.lower() in video_extensions:
video_files.append(input_path)
elif input_path.is_dir():
# 检查是否是video_processed目录结构
if self._is_video_processed_structure(input_path):
# 从video_split目录中查找视频文件
video_files = self._find_videos_from_processed_structure(input_path)
else:
# 普通目录中的所有视频文件
for ext in video_extensions:
video_files.extend(input_path.glob(f"*{ext}"))
video_files.extend(input_path.glob(f"*{ext.upper()}"))
return sorted(video_files)
def _is_video_processed_structure(self, input_path):
"""检查是否是video_processed目录结构"""
# 检查是否有子目录且子目录中有video_split
subdirs = [d for d in input_path.iterdir() if d.is_dir()]
if not subdirs:
return False
# 检查前几个子目录是否包含video_split
for subdir in subdirs[:3]: # 只检查前3个
if (subdir / "video_split").exists():
return True
return False
def _find_videos_from_processed_structure(self, input_path):
"""从video_processed结构中查找视频文件"""
video_files = []
for video_dir in input_path.iterdir():
if not video_dir.is_dir():
continue
video_split_dir = video_dir / "video_split"
if not video_split_dir.exists():
continue
# 从video_split目录中查找视频文件
video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm']
for ext in video_extensions:
video_files.extend(video_split_dir.glob(f"*{ext}"))
video_files.extend(video_split_dir.glob(f"*{ext.upper()}"))
return video_files
def extract_single_video(self, video_path, output_dir, **kwargs):
"""
处理单个视频文件
Args:
video_path: 视频文件路径
output_dir: 输出目录
**kwargs: 其他参数
Returns:
dict: 处理结果
"""
video_path = Path(video_path)
# 确定视频名称和OCR输出目录
if self._is_video_processed_structure(Path(output_dir)):
# 从video_processed结构中提取视频名称
# video_path格式: /path/to/video_processed/视频名/video_split/视频名_segment_001.mp4
video_processed_dir = Path(output_dir)
video_dir_name = video_path.parent.parent.name # 获取视频文件夹名
video_name = video_path.stem # 获取片段名
video_ocr_dir = video_processed_dir / video_dir_name / "ocr"
else:
# 普通目录结构
video_name = video_path.stem
video_ocr_dir = Path(output_dir) / video_name / "ocr"
# 提取片段号
segment_number = self._extract_segment_number(video_name)
logger.info(f"开始处理视频: {video_path}")
logger.info(f"视频名: {video_name}, 片段号: {segment_number}")
logger.info(f"OCR输出目录: {video_ocr_dir}")
start_time = time.time()
try:
# 提取字幕
results = self.extractor.extract_subtitles_from_video(
str(video_path),
sample_interval=kwargs.get('interval', 30),
confidence_threshold=kwargs.get('confidence', 0.5),
subtitle_position=kwargs.get('position', 'bottom')
)
# 根据片段号调整时间戳
results = self._adjust_timestamps(results, segment_number)
# 创建OCR目录
video_ocr_dir.mkdir(parents=True, exist_ok=True)
for format_type in kwargs.get('formats', ['json']):
output_file = video_ocr_dir / f"{video_name}_subtitles.{format_type}"
self.extractor.save_results(results, str(output_file), format_type)
process_time = time.time() - start_time
results['process_time'] = process_time
results['video_path'] = str(video_path)
results['success'] = True
results['segment_number'] = segment_number
# 统计位置信息
subtitles_with_bbox = [s for s in results['subtitles'] if s.get('bbox')]
bbox_coverage = len(subtitles_with_bbox) / len(results['subtitles']) * 100 if results['subtitles'] else 0
logger.info(f"完成处理视频: {video_path} (耗时: {process_time:.2f}秒)")
logger.info(f" 片段号: {segment_number}")
logger.info(f" 字幕总数: {len(results['subtitles'])}")
logger.info(f" 有位置信息: {len(subtitles_with_bbox)}")
logger.info(f" 位置信息覆盖率: {bbox_coverage:.1f}%")
return {
'video_path': str(video_path),
'success': True,
'process_time': process_time,
'segment_number': segment_number,
'subtitle_count': results['stats']['filtered_detections'],
'text_length': results['stats']['text_length'],
'total_subtitles': len(results['subtitles']),
'subtitles_with_bbox': len(subtitles_with_bbox),
'bbox_coverage': bbox_coverage,
'output_files': [str(video_ocr_dir / f"{video_name}_subtitles.{fmt}") for fmt in kwargs.get('formats', ['json'])]
}
except Exception as e:
error_msg = f"处理视频 {video_path} 时出错: {str(e)}"
logger.error(error_msg)
return {
'video_path': str(video_path),
'success': False,
'error': error_msg,
'process_time': time.time() - start_time,
'segment_number': segment_number
}
def extract_batch(self, input_dir, output_dir, **kwargs):
"""
批量提取字幕(串行处理)
Args:
input_dir: 输入目录或文件
output_dir: 输出目录
**kwargs: 其他参数
Returns:
dict: 批量处理结果
"""
logger.info(f"开始批量字幕提取(串行处理)")
logger.info(f"输入: {input_dir}")
logger.info(f"输出目录: {output_dir}")
logger.info(f"OCR引擎: {self.ocr_engine}")
logger.info(f"字幕位置: {kwargs.get('position', 'bottom')}")
start_time = time.time()
# 查找视频文件
video_files = self.find_video_files(input_dir)
if not video_files:
logger.warning(f"{input_dir} 中未找到视频文件")
return {
'success': False,
'message': '未找到视频文件',
'total_files': 0,
'results': []
}
logger.info(f"找到 {len(video_files)} 个视频文件")
results = []
# 串行处理所有视频文件
for i, video_file in enumerate(video_files, 1):
logger.info(f"处理第 {i}/{len(video_files)} 个视频")
result = self.extract_single_video(video_file, output_dir, **kwargs)
results.append(result)
# 显示进度
progress = i / len(video_files) * 100
logger.info(f"批量处理进度: {progress:.1f}% ({i}/{len(video_files)})")
total_time = time.time() - start_time
# 统计结果
success_count = sum(1 for r in results if r['success'])
failed_count = len(results) - success_count
total_subtitles = sum(r.get('subtitle_count', 0) for r in results if r['success'])
total_text_length = sum(r.get('text_length', 0) for r in results if r['success'])
# 统计位置信息
total_subtitles_raw = sum(r.get('total_subtitles', 0) for r in results if r['success'])
total_subtitles_with_bbox = sum(r.get('subtitles_with_bbox', 0) for r in results if r['success'])
overall_bbox_coverage = total_subtitles_with_bbox / total_subtitles_raw * 100 if total_subtitles_raw > 0 else 0
batch_result = {
'success': True,
'total_time': total_time,
'total_files': len(video_files),
'success_count': success_count,
'failed_count': failed_count,
'total_subtitles': total_subtitles,
'total_text_length': total_text_length,
'total_subtitles_raw': total_subtitles_raw,
'total_subtitles_with_bbox': total_subtitles_with_bbox,
'overall_bbox_coverage': overall_bbox_coverage,
'output_directory': output_dir,
'ocr_engine': self.ocr_engine,
'timestamp': datetime.now().isoformat(),
'results': results
}
# 保存批量处理报告
report_file = Path(output_dir) / "batch_report.json"
with open(report_file, 'w', encoding='utf-8') as f:
json.dump(batch_result, f, ensure_ascii=False, indent=2)
logger.info(f"批量处理完成!")
logger.info(f"总文件数: {len(video_files)}")
logger.info(f"成功: {success_count}, 失败: {failed_count}")
logger.info(f"总耗时: {total_time:.2f}")
logger.info(f"提取字幕: {total_subtitles}")
logger.info(f"文本长度: {total_text_length} 字符")
logger.info(f"位置信息统计:")
logger.info(f" 总字幕数: {total_subtitles_raw}")
logger.info(f" 有位置信息: {total_subtitles_with_bbox}")
logger.info(f" 位置信息覆盖率: {overall_bbox_coverage:.1f}%")
logger.info(f"处理报告: {report_file}")
return batch_result
def main():
"""主函数"""
parser = argparse.ArgumentParser(description="批量视频字幕提取器(串行处理)")
parser.add_argument("input", help="输入视频文件或目录")
parser.add_argument("-e", "--engine", default="cnocr",
choices=["paddleocr", "easyocr", "cnocr", "all"],
help="OCR引擎 (默认: cnocr)")
parser.add_argument("-l", "--language", default="ch",
choices=["ch", "en", "ch_en"],
help="语言设置 (默认: ch)")
parser.add_argument("-i", "--interval", type=int, default=30,
help="帧采样间隔 (默认: 30)")
parser.add_argument("-c", "--confidence", type=float, default=0.5,
help="置信度阈值 (默认: 0.5)")
parser.add_argument("-o", "--output", default="/root/autodl-tmp/video_processed",
help="输出目录 (默认: /root/autodl-tmp/video_processed)")
parser.add_argument("-f", "--formats", nargs='+', default=["json"],
choices=["json", "txt", "srt"],
help="输出格式 (默认: json)")
parser.add_argument("--position", default="full",
choices=["full", "center", "bottom"],
help="字幕区域位置 (full=全屏, center=居中0.5-0.8, bottom=居下0.7-1.0)")
args = parser.parse_args()
# 创建批量提取器
batch_extractor = BatchSubtitleExtractor(
ocr_engine=args.engine,
language=args.language
)
try:
# 执行批量提取
result = batch_extractor.extract_batch(
input_dir=args.input,
output_dir=args.output,
interval=args.interval,
confidence=args.confidence,
formats=args.formats,
position=args.position
)
if result['success']:
print(f"\n✅ 批量字幕提取完成!")
print(f"📁 输出目录: {args.output}")
print(f"📊 成功处理: {result['success_count']}/{result['total_files']} 个视频")
if result['failed_count'] > 0:
print(f"❌ 失败: {result['failed_count']}")
print(f"⏱️ 总耗时: {result['total_time']:.2f}")
print(f"📝 字幕片段: {result['total_subtitles']}")
print(f"📏 文本长度: {result['total_text_length']} 字符")
print(f"📍 位置信息统计:")
print(f" 总字幕数: {result['total_subtitles_raw']}")
print(f" 有位置信息: {result['total_subtitles_with_bbox']}")
print(f" 位置信息覆盖率: {result['overall_bbox_coverage']:.1f}%")
else:
print(f"\n❌ 批量处理失败: {result.get('message', '未知错误')}")
except Exception as e:
logger.error(f"批量处理出错: {str(e)}")
print(f"\n❌ 批量处理出错: {str(e)}")
if __name__ == "__main__":
main()