337 lines
14 KiB
Python
337 lines
14 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
"""
|
|||
|
|
批量视频字幕提取器
|
|||
|
|
支持批量处理多个视频文件,提取字幕
|
|||
|
|
支持PaddleOCR、EasyOCR和CnOCR三种引擎
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
import sys
|
|||
|
|
import time
|
|||
|
|
import json
|
|||
|
|
import argparse
|
|||
|
|
from pathlib import Path
|
|||
|
|
from datetime import datetime
|
|||
|
|
import logging
|
|||
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|||
|
|
|
|||
|
|
# 添加当前目录到路径以导入OCR模块
|
|||
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|||
|
|
|
|||
|
|
from ocr_subtitle_extractor import VideoSubtitleExtractor
|
|||
|
|
|
|||
|
|
# 设置OCR模型路径环境变量
|
|||
|
|
os.environ['EASYOCR_MODULE_PATH'] = '/root/autodl-tmp/llm/easyocr'
|
|||
|
|
os.environ['CNOCR_HOME'] = '/root/autodl-tmp/llm/cnocr'
|
|||
|
|
|
|||
|
|
# 设置日志
|
|||
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
class BatchSubtitleExtractor:
|
|||
|
|
"""批量视频字幕提取器"""
|
|||
|
|
|
|||
|
|
def __init__(self, ocr_engine="paddleocr", language="ch", max_workers=2):
|
|||
|
|
"""
|
|||
|
|
初始化批量提取器
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
ocr_engine: OCR引擎 ("paddleocr", "easyocr", "cnocr", "all")
|
|||
|
|
language: 语言设置 ("ch", "en", "ch_en")
|
|||
|
|
max_workers: 最大并行工作数
|
|||
|
|
"""
|
|||
|
|
self.ocr_engine = ocr_engine
|
|||
|
|
self.language = language
|
|||
|
|
self.max_workers = max_workers
|
|||
|
|
self.extractor = VideoSubtitleExtractor(ocr_engine=ocr_engine, language=language)
|
|||
|
|
|
|||
|
|
def find_video_files(self, input_dir):
|
|||
|
|
"""查找目录中的所有视频文件"""
|
|||
|
|
video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm']
|
|||
|
|
video_files = []
|
|||
|
|
|
|||
|
|
input_path = Path(input_dir)
|
|||
|
|
|
|||
|
|
if input_path.is_file():
|
|||
|
|
# 单个文件
|
|||
|
|
if input_path.suffix.lower() in video_extensions:
|
|||
|
|
video_files.append(input_path)
|
|||
|
|
elif input_path.is_dir():
|
|||
|
|
# 目录中的所有视频文件
|
|||
|
|
for ext in video_extensions:
|
|||
|
|
video_files.extend(input_path.glob(f"*{ext}"))
|
|||
|
|
video_files.extend(input_path.glob(f"*{ext.upper()}"))
|
|||
|
|
|
|||
|
|
return sorted(video_files)
|
|||
|
|
|
|||
|
|
def extract_single_video(self, video_path, output_dir, **kwargs):
|
|||
|
|
"""
|
|||
|
|
处理单个视频文件
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
video_path: 视频文件路径
|
|||
|
|
output_dir: 输出目录
|
|||
|
|
**kwargs: 其他参数
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
dict: 处理结果
|
|||
|
|
"""
|
|||
|
|
video_path = Path(video_path)
|
|||
|
|
video_name = video_path.stem
|
|||
|
|
|
|||
|
|
logger.info(f"开始处理视频: {video_path}")
|
|||
|
|
start_time = time.time()
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 提取字幕
|
|||
|
|
results = self.extractor.extract_subtitles_from_video(
|
|||
|
|
str(video_path),
|
|||
|
|
sample_interval=kwargs.get('interval', 30),
|
|||
|
|
confidence_threshold=kwargs.get('confidence', 0.5),
|
|||
|
|
subtitle_position=kwargs.get('position', 'bottom')
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 保存结果
|
|||
|
|
output_path = Path(output_dir)
|
|||
|
|
output_path.mkdir(parents=True, exist_ok=True)
|
|||
|
|
|
|||
|
|
for format_type in kwargs.get('formats', ['json']):
|
|||
|
|
output_file = output_path / f"{video_name}_subtitles.{format_type}"
|
|||
|
|
self.extractor.save_results(results, output_file, format_type)
|
|||
|
|
|
|||
|
|
process_time = time.time() - start_time
|
|||
|
|
results['process_time'] = process_time
|
|||
|
|
results['video_path'] = str(video_path)
|
|||
|
|
results['success'] = True
|
|||
|
|
|
|||
|
|
# 统计位置信息
|
|||
|
|
subtitles_with_bbox = [s for s in results['subtitles'] if s.get('bbox')]
|
|||
|
|
bbox_coverage = len(subtitles_with_bbox) / len(results['subtitles']) * 100 if results['subtitles'] else 0
|
|||
|
|
|
|||
|
|
logger.info(f"完成处理视频: {video_path} (耗时: {process_time:.2f}秒)")
|
|||
|
|
logger.info(f" 字幕总数: {len(results['subtitles'])}")
|
|||
|
|
logger.info(f" 有位置信息: {len(subtitles_with_bbox)}")
|
|||
|
|
logger.info(f" 位置信息覆盖率: {bbox_coverage:.1f}%")
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
'video_path': str(video_path),
|
|||
|
|
'success': True,
|
|||
|
|
'process_time': process_time,
|
|||
|
|
'subtitle_count': results['stats']['filtered_detections'],
|
|||
|
|
'text_length': results['stats']['text_length'],
|
|||
|
|
'total_subtitles': len(results['subtitles']),
|
|||
|
|
'subtitles_with_bbox': len(subtitles_with_bbox),
|
|||
|
|
'bbox_coverage': bbox_coverage,
|
|||
|
|
'output_files': [str(output_path / f"{video_name}_subtitles.{fmt}") for fmt in kwargs.get('formats', ['json'])]
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
error_msg = f"处理视频 {video_path} 时出错: {str(e)}"
|
|||
|
|
logger.error(error_msg)
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
'video_path': str(video_path),
|
|||
|
|
'success': False,
|
|||
|
|
'error': error_msg,
|
|||
|
|
'process_time': time.time() - start_time
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def extract_batch(self, input_dir, output_dir, parallel=True, **kwargs):
|
|||
|
|
"""
|
|||
|
|
批量提取字幕
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
input_dir: 输入目录或文件
|
|||
|
|
output_dir: 输出目录
|
|||
|
|
parallel: 是否并行处理
|
|||
|
|
**kwargs: 其他参数
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
dict: 批量处理结果
|
|||
|
|
"""
|
|||
|
|
logger.info(f"开始批量字幕提取")
|
|||
|
|
logger.info(f"输入: {input_dir}")
|
|||
|
|
logger.info(f"输出目录: {output_dir}")
|
|||
|
|
logger.info(f"OCR引擎: {self.ocr_engine}")
|
|||
|
|
logger.info(f"字幕位置: {kwargs.get('position', 'bottom')}")
|
|||
|
|
logger.info(f"并行处理: {parallel}")
|
|||
|
|
|
|||
|
|
start_time = time.time()
|
|||
|
|
|
|||
|
|
# 查找视频文件
|
|||
|
|
video_files = self.find_video_files(input_dir)
|
|||
|
|
|
|||
|
|
if not video_files:
|
|||
|
|
logger.warning(f"在 {input_dir} 中未找到视频文件")
|
|||
|
|
return {
|
|||
|
|
'success': False,
|
|||
|
|
'message': '未找到视频文件',
|
|||
|
|
'total_files': 0,
|
|||
|
|
'results': []
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
logger.info(f"找到 {len(video_files)} 个视频文件")
|
|||
|
|
|
|||
|
|
results = []
|
|||
|
|
|
|||
|
|
if parallel and len(video_files) > 1:
|
|||
|
|
# 并行处理
|
|||
|
|
logger.info(f"使用 {self.max_workers} 个并行工作进程")
|
|||
|
|
|
|||
|
|
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
|||
|
|
# 提交任务
|
|||
|
|
future_to_video = {
|
|||
|
|
executor.submit(self.extract_single_video, video_file, output_dir, **kwargs): video_file
|
|||
|
|
for video_file in video_files
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 收集结果
|
|||
|
|
for future in as_completed(future_to_video):
|
|||
|
|
video_file = future_to_video[future]
|
|||
|
|
try:
|
|||
|
|
result = future.result()
|
|||
|
|
results.append(result)
|
|||
|
|
|
|||
|
|
# 显示进度
|
|||
|
|
progress = len(results) / len(video_files) * 100
|
|||
|
|
logger.info(f"批量处理进度: {progress:.1f}% ({len(results)}/{len(video_files)})")
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"处理视频 {video_file} 时发生异常: {str(e)}")
|
|||
|
|
results.append({
|
|||
|
|
'video_path': str(video_file),
|
|||
|
|
'success': False,
|
|||
|
|
'error': str(e)
|
|||
|
|
})
|
|||
|
|
else:
|
|||
|
|
# 串行处理
|
|||
|
|
for i, video_file in enumerate(video_files, 1):
|
|||
|
|
logger.info(f"处理第 {i}/{len(video_files)} 个视频")
|
|||
|
|
result = self.extract_single_video(video_file, output_dir, **kwargs)
|
|||
|
|
results.append(result)
|
|||
|
|
|
|||
|
|
# 显示进度
|
|||
|
|
progress = i / len(video_files) * 100
|
|||
|
|
logger.info(f"批量处理进度: {progress:.1f}%")
|
|||
|
|
|
|||
|
|
total_time = time.time() - start_time
|
|||
|
|
|
|||
|
|
# 统计结果
|
|||
|
|
success_count = sum(1 for r in results if r['success'])
|
|||
|
|
failed_count = len(results) - success_count
|
|||
|
|
|
|||
|
|
total_subtitles = sum(r.get('subtitle_count', 0) for r in results if r['success'])
|
|||
|
|
total_text_length = sum(r.get('text_length', 0) for r in results if r['success'])
|
|||
|
|
|
|||
|
|
# 统计位置信息
|
|||
|
|
total_subtitles_raw = sum(r.get('total_subtitles', 0) for r in results if r['success'])
|
|||
|
|
total_subtitles_with_bbox = sum(r.get('subtitles_with_bbox', 0) for r in results if r['success'])
|
|||
|
|
overall_bbox_coverage = total_subtitles_with_bbox / total_subtitles_raw * 100 if total_subtitles_raw > 0 else 0
|
|||
|
|
|
|||
|
|
batch_result = {
|
|||
|
|
'success': True,
|
|||
|
|
'total_time': total_time,
|
|||
|
|
'total_files': len(video_files),
|
|||
|
|
'success_count': success_count,
|
|||
|
|
'failed_count': failed_count,
|
|||
|
|
'total_subtitles': total_subtitles,
|
|||
|
|
'total_text_length': total_text_length,
|
|||
|
|
'total_subtitles_raw': total_subtitles_raw,
|
|||
|
|
'total_subtitles_with_bbox': total_subtitles_with_bbox,
|
|||
|
|
'overall_bbox_coverage': overall_bbox_coverage,
|
|||
|
|
'output_directory': output_dir,
|
|||
|
|
'ocr_engine': self.ocr_engine,
|
|||
|
|
'timestamp': datetime.now().isoformat(),
|
|||
|
|
'results': results
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 保存批量处理报告
|
|||
|
|
report_file = Path(output_dir) / "batch_report.json"
|
|||
|
|
with open(report_file, 'w', encoding='utf-8') as f:
|
|||
|
|
json.dump(batch_result, f, ensure_ascii=False, indent=2)
|
|||
|
|
|
|||
|
|
logger.info(f"批量处理完成!")
|
|||
|
|
logger.info(f"总文件数: {len(video_files)}")
|
|||
|
|
logger.info(f"成功: {success_count}, 失败: {failed_count}")
|
|||
|
|
logger.info(f"总耗时: {total_time:.2f} 秒")
|
|||
|
|
logger.info(f"提取字幕: {total_subtitles} 个")
|
|||
|
|
logger.info(f"文本长度: {total_text_length} 字符")
|
|||
|
|
logger.info(f"位置信息统计:")
|
|||
|
|
logger.info(f" 总字幕数: {total_subtitles_raw}")
|
|||
|
|
logger.info(f" 有位置信息: {total_subtitles_with_bbox}")
|
|||
|
|
logger.info(f" 位置信息覆盖率: {overall_bbox_coverage:.1f}%")
|
|||
|
|
logger.info(f"处理报告: {report_file}")
|
|||
|
|
|
|||
|
|
return batch_result
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
"""主函数"""
|
|||
|
|
parser = argparse.ArgumentParser(description="批量视频字幕提取器")
|
|||
|
|
parser.add_argument("input", help="输入视频文件或目录")
|
|||
|
|
parser.add_argument("-e", "--engine", default="cnocr",
|
|||
|
|
choices=["paddleocr", "easyocr", "cnocr", "all"],
|
|||
|
|
help="OCR引擎 (默认: paddleocr)")
|
|||
|
|
parser.add_argument("-l", "--language", default="ch",
|
|||
|
|
choices=["ch", "en", "ch_en"],
|
|||
|
|
help="语言设置 (默认: ch)")
|
|||
|
|
parser.add_argument("-i", "--interval", type=int, default=30,
|
|||
|
|
help="帧采样间隔 (默认: 30)")
|
|||
|
|
parser.add_argument("-c", "--confidence", type=float, default=0.5,
|
|||
|
|
help="置信度阈值 (默认: 0.5)")
|
|||
|
|
parser.add_argument("-o", "--output", default="batch_subtitles",
|
|||
|
|
help="输出目录 (默认: batch_subtitles)")
|
|||
|
|
parser.add_argument("-f", "--formats", nargs='+', default=["json"],
|
|||
|
|
choices=["json", "txt", "srt"],
|
|||
|
|
help="输出格式 (默认: json)")
|
|||
|
|
parser.add_argument("--position", default="full",
|
|||
|
|
choices=["full", "center", "bottom"],
|
|||
|
|
help="字幕区域位置 (full=全屏, center=居中0.5-0.8, bottom=居下0.7-1.0)")
|
|||
|
|
parser.add_argument("--workers", type=int, default=2,
|
|||
|
|
help="并行工作进程数 (默认: 2)")
|
|||
|
|
parser.add_argument("--no-parallel", action="store_true",
|
|||
|
|
help="禁用并行处理")
|
|||
|
|
|
|||
|
|
args = parser.parse_args()
|
|||
|
|
|
|||
|
|
# 创建批量提取器
|
|||
|
|
batch_extractor = BatchSubtitleExtractor(
|
|||
|
|
ocr_engine=args.engine,
|
|||
|
|
language=args.language,
|
|||
|
|
max_workers=args.workers
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 执行批量提取
|
|||
|
|
result = batch_extractor.extract_batch(
|
|||
|
|
input_dir=args.input,
|
|||
|
|
output_dir=args.output,
|
|||
|
|
parallel=not args.no_parallel,
|
|||
|
|
interval=args.interval,
|
|||
|
|
confidence=args.confidence,
|
|||
|
|
formats=args.formats,
|
|||
|
|
position=args.position
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if result['success']:
|
|||
|
|
print(f"\n✅ 批量字幕提取完成!")
|
|||
|
|
print(f"📁 输出目录: {args.output}")
|
|||
|
|
print(f"📊 成功处理: {result['success_count']}/{result['total_files']} 个视频")
|
|||
|
|
if result['failed_count'] > 0:
|
|||
|
|
print(f"❌ 失败: {result['failed_count']} 个")
|
|||
|
|
print(f"⏱️ 总耗时: {result['total_time']:.2f} 秒")
|
|||
|
|
print(f"📝 字幕片段: {result['total_subtitles']} 个")
|
|||
|
|
print(f"📏 文本长度: {result['total_text_length']} 字符")
|
|||
|
|
print(f"📍 位置信息统计:")
|
|||
|
|
print(f" 总字幕数: {result['total_subtitles_raw']}")
|
|||
|
|
print(f" 有位置信息: {result['total_subtitles_with_bbox']}")
|
|||
|
|
print(f" 位置信息覆盖率: {result['overall_bbox_coverage']:.1f}%")
|
|||
|
|
else:
|
|||
|
|
print(f"\n❌ 批量处理失败: {result.get('message', '未知错误')}")
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"批量处理出错: {str(e)}")
|
|||
|
|
print(f"\n❌ 批量处理出错: {str(e)}")
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|