401 lines
16 KiB
Python
401 lines
16 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
批量视频字幕提取器
|
||
支持批量处理多个视频文件,提取字幕
|
||
支持PaddleOCR、EasyOCR和CnOCR三种引擎
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import time
|
||
import json
|
||
import argparse
|
||
import re
|
||
from pathlib import Path
|
||
from datetime import datetime
|
||
import logging
|
||
|
||
# 添加当前目录到路径以导入OCR模块
|
||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||
|
||
from ocr_subtitle_extractor import VideoSubtitleExtractor
|
||
|
||
# 设置OCR模型路径环境变量
|
||
os.environ['EASYOCR_MODULE_PATH'] = '/root/autodl-tmp/llm/easyocr'
|
||
os.environ['CNOCR_HOME'] = '/root/autodl-tmp/llm/cnocr'
|
||
|
||
# 设置日志
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class BatchSubtitleExtractor:
|
||
"""批量视频字幕提取器"""
|
||
|
||
def __init__(self, ocr_engine="paddleocr", language="ch"):
|
||
"""
|
||
初始化批量提取器
|
||
|
||
Args:
|
||
ocr_engine: OCR引擎 ("paddleocr", "easyocr", "cnocr", "all")
|
||
language: 语言设置 ("ch", "en", "ch_en")
|
||
"""
|
||
self.ocr_engine = ocr_engine
|
||
self.language = language
|
||
self.extractor = VideoSubtitleExtractor(ocr_engine=ocr_engine, language=language)
|
||
|
||
def _extract_segment_number(self, video_name):
|
||
"""
|
||
从视频名中提取最后一个数字作为片段号
|
||
|
||
Args:
|
||
video_name: 视频文件名(不含扩展名)
|
||
|
||
Returns:
|
||
int: 片段号,如果未找到则返回1
|
||
"""
|
||
# 使用正则表达式匹配最后一个数字
|
||
match = re.search(r'(\d+)(?:_segment)?$', video_name)
|
||
if match:
|
||
return int(match.group(1))
|
||
return 1
|
||
|
||
def _adjust_timestamps(self, results, segment_number):
|
||
"""
|
||
根据片段号调整时间戳
|
||
|
||
Args:
|
||
results: OCR结果字典
|
||
segment_number: 片段号
|
||
|
||
Returns:
|
||
dict: 调整后的结果
|
||
"""
|
||
if segment_number == 1:
|
||
# 片段号为1,不调整时间
|
||
return results
|
||
|
||
# 计算时间偏移量
|
||
time_offset = (segment_number - 1) * 30
|
||
|
||
logger.info(f"片段号: {segment_number}, 时间偏移: +{time_offset}秒")
|
||
|
||
# 调整所有字幕的时间戳
|
||
for subtitle in results['subtitles']:
|
||
if 'timestamp' in subtitle:
|
||
subtitle['timestamp'] += time_offset
|
||
|
||
return results
|
||
|
||
def find_video_files(self, input_dir):
|
||
"""查找目录中的所有视频文件"""
|
||
video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm']
|
||
video_files = []
|
||
|
||
input_path = Path(input_dir)
|
||
|
||
if input_path.is_file():
|
||
# 单个文件
|
||
if input_path.suffix.lower() in video_extensions:
|
||
video_files.append(input_path)
|
||
elif input_path.is_dir():
|
||
# 检查是否是video_processed目录结构
|
||
if self._is_video_processed_structure(input_path):
|
||
# 从video_split目录中查找视频文件
|
||
video_files = self._find_videos_from_processed_structure(input_path)
|
||
else:
|
||
# 普通目录中的所有视频文件
|
||
for ext in video_extensions:
|
||
video_files.extend(input_path.glob(f"*{ext}"))
|
||
video_files.extend(input_path.glob(f"*{ext.upper()}"))
|
||
|
||
return sorted(video_files)
|
||
|
||
def _is_video_processed_structure(self, input_path):
|
||
"""检查是否是video_processed目录结构"""
|
||
# 检查是否有子目录,且子目录中有video_split
|
||
subdirs = [d for d in input_path.iterdir() if d.is_dir()]
|
||
if not subdirs:
|
||
return False
|
||
|
||
# 检查前几个子目录是否包含video_split
|
||
for subdir in subdirs[:3]: # 只检查前3个
|
||
if (subdir / "video_split").exists():
|
||
return True
|
||
return False
|
||
|
||
def _find_videos_from_processed_structure(self, input_path):
|
||
"""从video_processed结构中查找视频文件"""
|
||
video_files = []
|
||
|
||
for video_dir in input_path.iterdir():
|
||
if not video_dir.is_dir():
|
||
continue
|
||
|
||
video_split_dir = video_dir / "video_split"
|
||
if not video_split_dir.exists():
|
||
continue
|
||
|
||
# 从video_split目录中查找视频文件
|
||
video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm']
|
||
for ext in video_extensions:
|
||
video_files.extend(video_split_dir.glob(f"*{ext}"))
|
||
video_files.extend(video_split_dir.glob(f"*{ext.upper()}"))
|
||
|
||
return video_files
|
||
|
||
def extract_single_video(self, video_path, output_dir, **kwargs):
|
||
"""
|
||
处理单个视频文件
|
||
|
||
Args:
|
||
video_path: 视频文件路径
|
||
output_dir: 输出目录
|
||
**kwargs: 其他参数
|
||
|
||
Returns:
|
||
dict: 处理结果
|
||
"""
|
||
video_path = Path(video_path)
|
||
|
||
# 确定视频名称和OCR输出目录
|
||
if self._is_video_processed_structure(Path(output_dir)):
|
||
# 从video_processed结构中提取视频名称
|
||
# video_path格式: /path/to/video_processed/视频名/video_split/视频名_segment_001.mp4
|
||
video_processed_dir = Path(output_dir)
|
||
video_dir_name = video_path.parent.parent.name # 获取视频文件夹名
|
||
video_name = video_path.stem # 获取片段名
|
||
video_ocr_dir = video_processed_dir / video_dir_name / "ocr"
|
||
else:
|
||
# 普通目录结构
|
||
video_name = video_path.stem
|
||
video_ocr_dir = Path(output_dir) / video_name / "ocr"
|
||
|
||
# 提取片段号
|
||
segment_number = self._extract_segment_number(video_name)
|
||
|
||
logger.info(f"开始处理视频: {video_path}")
|
||
logger.info(f"视频名: {video_name}, 片段号: {segment_number}")
|
||
logger.info(f"OCR输出目录: {video_ocr_dir}")
|
||
start_time = time.time()
|
||
|
||
try:
|
||
# 提取字幕
|
||
results = self.extractor.extract_subtitles_from_video(
|
||
str(video_path),
|
||
sample_interval=kwargs.get('interval', 30),
|
||
confidence_threshold=kwargs.get('confidence', 0.5),
|
||
subtitle_position=kwargs.get('position', 'bottom')
|
||
)
|
||
|
||
# 根据片段号调整时间戳
|
||
results = self._adjust_timestamps(results, segment_number)
|
||
|
||
# 创建OCR目录
|
||
video_ocr_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
for format_type in kwargs.get('formats', ['json']):
|
||
output_file = video_ocr_dir / f"{video_name}_subtitles.{format_type}"
|
||
self.extractor.save_results(results, str(output_file), format_type)
|
||
|
||
process_time = time.time() - start_time
|
||
results['process_time'] = process_time
|
||
results['video_path'] = str(video_path)
|
||
results['success'] = True
|
||
results['segment_number'] = segment_number
|
||
|
||
# 统计位置信息
|
||
subtitles_with_bbox = [s for s in results['subtitles'] if s.get('bbox')]
|
||
bbox_coverage = len(subtitles_with_bbox) / len(results['subtitles']) * 100 if results['subtitles'] else 0
|
||
|
||
logger.info(f"完成处理视频: {video_path} (耗时: {process_time:.2f}秒)")
|
||
logger.info(f" 片段号: {segment_number}")
|
||
logger.info(f" 字幕总数: {len(results['subtitles'])}")
|
||
logger.info(f" 有位置信息: {len(subtitles_with_bbox)}")
|
||
logger.info(f" 位置信息覆盖率: {bbox_coverage:.1f}%")
|
||
|
||
return {
|
||
'video_path': str(video_path),
|
||
'success': True,
|
||
'process_time': process_time,
|
||
'segment_number': segment_number,
|
||
'subtitle_count': results['stats']['filtered_detections'],
|
||
'text_length': results['stats']['text_length'],
|
||
'total_subtitles': len(results['subtitles']),
|
||
'subtitles_with_bbox': len(subtitles_with_bbox),
|
||
'bbox_coverage': bbox_coverage,
|
||
'output_files': [str(video_ocr_dir / f"{video_name}_subtitles.{fmt}") for fmt in kwargs.get('formats', ['json'])]
|
||
}
|
||
|
||
except Exception as e:
|
||
error_msg = f"处理视频 {video_path} 时出错: {str(e)}"
|
||
logger.error(error_msg)
|
||
|
||
return {
|
||
'video_path': str(video_path),
|
||
'success': False,
|
||
'error': error_msg,
|
||
'process_time': time.time() - start_time,
|
||
'segment_number': segment_number
|
||
}
|
||
|
||
def extract_batch(self, input_dir, output_dir, **kwargs):
|
||
"""
|
||
批量提取字幕(串行处理)
|
||
|
||
Args:
|
||
input_dir: 输入目录或文件
|
||
output_dir: 输出目录
|
||
**kwargs: 其他参数
|
||
|
||
Returns:
|
||
dict: 批量处理结果
|
||
"""
|
||
logger.info(f"开始批量字幕提取(串行处理)")
|
||
logger.info(f"输入: {input_dir}")
|
||
logger.info(f"输出目录: {output_dir}")
|
||
logger.info(f"OCR引擎: {self.ocr_engine}")
|
||
logger.info(f"字幕位置: {kwargs.get('position', 'bottom')}")
|
||
|
||
start_time = time.time()
|
||
|
||
# 查找视频文件
|
||
video_files = self.find_video_files(input_dir)
|
||
|
||
if not video_files:
|
||
logger.warning(f"在 {input_dir} 中未找到视频文件")
|
||
return {
|
||
'success': False,
|
||
'message': '未找到视频文件',
|
||
'total_files': 0,
|
||
'results': []
|
||
}
|
||
|
||
logger.info(f"找到 {len(video_files)} 个视频文件")
|
||
|
||
results = []
|
||
|
||
# 串行处理所有视频文件
|
||
for i, video_file in enumerate(video_files, 1):
|
||
logger.info(f"处理第 {i}/{len(video_files)} 个视频")
|
||
result = self.extract_single_video(video_file, output_dir, **kwargs)
|
||
results.append(result)
|
||
|
||
# 显示进度
|
||
progress = i / len(video_files) * 100
|
||
logger.info(f"批量处理进度: {progress:.1f}% ({i}/{len(video_files)})")
|
||
|
||
total_time = time.time() - start_time
|
||
|
||
# 统计结果
|
||
success_count = sum(1 for r in results if r['success'])
|
||
failed_count = len(results) - success_count
|
||
|
||
total_subtitles = sum(r.get('subtitle_count', 0) for r in results if r['success'])
|
||
total_text_length = sum(r.get('text_length', 0) for r in results if r['success'])
|
||
|
||
# 统计位置信息
|
||
total_subtitles_raw = sum(r.get('total_subtitles', 0) for r in results if r['success'])
|
||
total_subtitles_with_bbox = sum(r.get('subtitles_with_bbox', 0) for r in results if r['success'])
|
||
overall_bbox_coverage = total_subtitles_with_bbox / total_subtitles_raw * 100 if total_subtitles_raw > 0 else 0
|
||
|
||
batch_result = {
|
||
'success': True,
|
||
'total_time': total_time,
|
||
'total_files': len(video_files),
|
||
'success_count': success_count,
|
||
'failed_count': failed_count,
|
||
'total_subtitles': total_subtitles,
|
||
'total_text_length': total_text_length,
|
||
'total_subtitles_raw': total_subtitles_raw,
|
||
'total_subtitles_with_bbox': total_subtitles_with_bbox,
|
||
'overall_bbox_coverage': overall_bbox_coverage,
|
||
'output_directory': output_dir,
|
||
'ocr_engine': self.ocr_engine,
|
||
'timestamp': datetime.now().isoformat(),
|
||
'results': results
|
||
}
|
||
|
||
# 保存批量处理报告
|
||
report_file = Path(output_dir) / "batch_report.json"
|
||
with open(report_file, 'w', encoding='utf-8') as f:
|
||
json.dump(batch_result, f, ensure_ascii=False, indent=2)
|
||
|
||
logger.info(f"批量处理完成!")
|
||
logger.info(f"总文件数: {len(video_files)}")
|
||
logger.info(f"成功: {success_count}, 失败: {failed_count}")
|
||
logger.info(f"总耗时: {total_time:.2f} 秒")
|
||
logger.info(f"提取字幕: {total_subtitles} 个")
|
||
logger.info(f"文本长度: {total_text_length} 字符")
|
||
logger.info(f"位置信息统计:")
|
||
logger.info(f" 总字幕数: {total_subtitles_raw}")
|
||
logger.info(f" 有位置信息: {total_subtitles_with_bbox}")
|
||
logger.info(f" 位置信息覆盖率: {overall_bbox_coverage:.1f}%")
|
||
logger.info(f"处理报告: {report_file}")
|
||
|
||
return batch_result
|
||
|
||
def main():
|
||
"""主函数"""
|
||
parser = argparse.ArgumentParser(description="批量视频字幕提取器(串行处理)")
|
||
parser.add_argument("input", help="输入视频文件或目录")
|
||
parser.add_argument("-e", "--engine", default="cnocr",
|
||
choices=["paddleocr", "easyocr", "cnocr", "all"],
|
||
help="OCR引擎 (默认: cnocr)")
|
||
parser.add_argument("-l", "--language", default="ch",
|
||
choices=["ch", "en", "ch_en"],
|
||
help="语言设置 (默认: ch)")
|
||
parser.add_argument("-i", "--interval", type=int, default=30,
|
||
help="帧采样间隔 (默认: 30)")
|
||
parser.add_argument("-c", "--confidence", type=float, default=0.5,
|
||
help="置信度阈值 (默认: 0.5)")
|
||
parser.add_argument("-o", "--output", default="/root/autodl-tmp/video_processed",
|
||
help="输出目录 (默认: /root/autodl-tmp/video_processed)")
|
||
parser.add_argument("-f", "--formats", nargs='+', default=["json"],
|
||
choices=["json", "txt", "srt"],
|
||
help="输出格式 (默认: json)")
|
||
parser.add_argument("--position", default="full",
|
||
choices=["full", "center", "bottom"],
|
||
help="字幕区域位置 (full=全屏, center=居中0.5-0.8, bottom=居下0.7-1.0)")
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 创建批量提取器
|
||
batch_extractor = BatchSubtitleExtractor(
|
||
ocr_engine=args.engine,
|
||
language=args.language
|
||
)
|
||
|
||
try:
|
||
# 执行批量提取
|
||
result = batch_extractor.extract_batch(
|
||
input_dir=args.input,
|
||
output_dir=args.output,
|
||
interval=args.interval,
|
||
confidence=args.confidence,
|
||
formats=args.formats,
|
||
position=args.position
|
||
)
|
||
|
||
if result['success']:
|
||
print(f"\n✅ 批量字幕提取完成!")
|
||
print(f"📁 输出目录: {args.output}")
|
||
print(f"📊 成功处理: {result['success_count']}/{result['total_files']} 个视频")
|
||
if result['failed_count'] > 0:
|
||
print(f"❌ 失败: {result['failed_count']} 个")
|
||
print(f"⏱️ 总耗时: {result['total_time']:.2f} 秒")
|
||
print(f"📝 字幕片段: {result['total_subtitles']} 个")
|
||
print(f"📏 文本长度: {result['total_text_length']} 字符")
|
||
print(f"📍 位置信息统计:")
|
||
print(f" 总字幕数: {result['total_subtitles_raw']}")
|
||
print(f" 有位置信息: {result['total_subtitles_with_bbox']}")
|
||
print(f" 位置信息覆盖率: {result['overall_bbox_coverage']:.1f}%")
|
||
else:
|
||
print(f"\n❌ 批量处理失败: {result.get('message', '未知错误')}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"批量处理出错: {str(e)}")
|
||
print(f"\n❌ 批量处理出错: {str(e)}")
|
||
|
||
if __name__ == "__main__":
|
||
main() |