hot_video_analyse/code/batch_subtitle_extractor.py

337 lines
14 KiB
Python
Raw Permalink Normal View History

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
批量视频字幕提取器
支持批量处理多个视频文件提取字幕
支持PaddleOCREasyOCR和CnOCR三种引擎
"""
import os
import sys
import time
import json
import argparse
from pathlib import Path
from datetime import datetime
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
# 添加当前目录到路径以导入OCR模块
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from ocr_subtitle_extractor import VideoSubtitleExtractor
# 设置OCR模型路径环境变量
os.environ['EASYOCR_MODULE_PATH'] = '/root/autodl-tmp/llm/easyocr'
os.environ['CNOCR_HOME'] = '/root/autodl-tmp/llm/cnocr'
# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class BatchSubtitleExtractor:
"""批量视频字幕提取器"""
def __init__(self, ocr_engine="paddleocr", language="ch", max_workers=2):
"""
初始化批量提取器
Args:
ocr_engine: OCR引擎 ("paddleocr", "easyocr", "cnocr", "all")
language: 语言设置 ("ch", "en", "ch_en")
max_workers: 最大并行工作数
"""
self.ocr_engine = ocr_engine
self.language = language
self.max_workers = max_workers
self.extractor = VideoSubtitleExtractor(ocr_engine=ocr_engine, language=language)
def find_video_files(self, input_dir):
"""查找目录中的所有视频文件"""
video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm']
video_files = []
input_path = Path(input_dir)
if input_path.is_file():
# 单个文件
if input_path.suffix.lower() in video_extensions:
video_files.append(input_path)
elif input_path.is_dir():
# 目录中的所有视频文件
for ext in video_extensions:
video_files.extend(input_path.glob(f"*{ext}"))
video_files.extend(input_path.glob(f"*{ext.upper()}"))
return sorted(video_files)
def extract_single_video(self, video_path, output_dir, **kwargs):
"""
处理单个视频文件
Args:
video_path: 视频文件路径
output_dir: 输出目录
**kwargs: 其他参数
Returns:
dict: 处理结果
"""
video_path = Path(video_path)
video_name = video_path.stem
logger.info(f"开始处理视频: {video_path}")
start_time = time.time()
try:
# 提取字幕
results = self.extractor.extract_subtitles_from_video(
str(video_path),
sample_interval=kwargs.get('interval', 30),
confidence_threshold=kwargs.get('confidence', 0.5),
subtitle_position=kwargs.get('position', 'bottom')
)
# 保存结果
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
for format_type in kwargs.get('formats', ['json']):
output_file = output_path / f"{video_name}_subtitles.{format_type}"
self.extractor.save_results(results, output_file, format_type)
process_time = time.time() - start_time
results['process_time'] = process_time
results['video_path'] = str(video_path)
results['success'] = True
# 统计位置信息
subtitles_with_bbox = [s for s in results['subtitles'] if s.get('bbox')]
bbox_coverage = len(subtitles_with_bbox) / len(results['subtitles']) * 100 if results['subtitles'] else 0
logger.info(f"完成处理视频: {video_path} (耗时: {process_time:.2f}秒)")
logger.info(f" 字幕总数: {len(results['subtitles'])}")
logger.info(f" 有位置信息: {len(subtitles_with_bbox)}")
logger.info(f" 位置信息覆盖率: {bbox_coverage:.1f}%")
return {
'video_path': str(video_path),
'success': True,
'process_time': process_time,
'subtitle_count': results['stats']['filtered_detections'],
'text_length': results['stats']['text_length'],
'total_subtitles': len(results['subtitles']),
'subtitles_with_bbox': len(subtitles_with_bbox),
'bbox_coverage': bbox_coverage,
'output_files': [str(output_path / f"{video_name}_subtitles.{fmt}") for fmt in kwargs.get('formats', ['json'])]
}
except Exception as e:
error_msg = f"处理视频 {video_path} 时出错: {str(e)}"
logger.error(error_msg)
return {
'video_path': str(video_path),
'success': False,
'error': error_msg,
'process_time': time.time() - start_time
}
def extract_batch(self, input_dir, output_dir, parallel=True, **kwargs):
"""
批量提取字幕
Args:
input_dir: 输入目录或文件
output_dir: 输出目录
parallel: 是否并行处理
**kwargs: 其他参数
Returns:
dict: 批量处理结果
"""
logger.info(f"开始批量字幕提取")
logger.info(f"输入: {input_dir}")
logger.info(f"输出目录: {output_dir}")
logger.info(f"OCR引擎: {self.ocr_engine}")
logger.info(f"字幕位置: {kwargs.get('position', 'bottom')}")
logger.info(f"并行处理: {parallel}")
start_time = time.time()
# 查找视频文件
video_files = self.find_video_files(input_dir)
if not video_files:
logger.warning(f"{input_dir} 中未找到视频文件")
return {
'success': False,
'message': '未找到视频文件',
'total_files': 0,
'results': []
}
logger.info(f"找到 {len(video_files)} 个视频文件")
results = []
if parallel and len(video_files) > 1:
# 并行处理
logger.info(f"使用 {self.max_workers} 个并行工作进程")
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
# 提交任务
future_to_video = {
executor.submit(self.extract_single_video, video_file, output_dir, **kwargs): video_file
for video_file in video_files
}
# 收集结果
for future in as_completed(future_to_video):
video_file = future_to_video[future]
try:
result = future.result()
results.append(result)
# 显示进度
progress = len(results) / len(video_files) * 100
logger.info(f"批量处理进度: {progress:.1f}% ({len(results)}/{len(video_files)})")
except Exception as e:
logger.error(f"处理视频 {video_file} 时发生异常: {str(e)}")
results.append({
'video_path': str(video_file),
'success': False,
'error': str(e)
})
else:
# 串行处理
for i, video_file in enumerate(video_files, 1):
logger.info(f"处理第 {i}/{len(video_files)} 个视频")
result = self.extract_single_video(video_file, output_dir, **kwargs)
results.append(result)
# 显示进度
progress = i / len(video_files) * 100
logger.info(f"批量处理进度: {progress:.1f}%")
total_time = time.time() - start_time
# 统计结果
success_count = sum(1 for r in results if r['success'])
failed_count = len(results) - success_count
total_subtitles = sum(r.get('subtitle_count', 0) for r in results if r['success'])
total_text_length = sum(r.get('text_length', 0) for r in results if r['success'])
# 统计位置信息
total_subtitles_raw = sum(r.get('total_subtitles', 0) for r in results if r['success'])
total_subtitles_with_bbox = sum(r.get('subtitles_with_bbox', 0) for r in results if r['success'])
overall_bbox_coverage = total_subtitles_with_bbox / total_subtitles_raw * 100 if total_subtitles_raw > 0 else 0
batch_result = {
'success': True,
'total_time': total_time,
'total_files': len(video_files),
'success_count': success_count,
'failed_count': failed_count,
'total_subtitles': total_subtitles,
'total_text_length': total_text_length,
'total_subtitles_raw': total_subtitles_raw,
'total_subtitles_with_bbox': total_subtitles_with_bbox,
'overall_bbox_coverage': overall_bbox_coverage,
'output_directory': output_dir,
'ocr_engine': self.ocr_engine,
'timestamp': datetime.now().isoformat(),
'results': results
}
# 保存批量处理报告
report_file = Path(output_dir) / "batch_report.json"
with open(report_file, 'w', encoding='utf-8') as f:
json.dump(batch_result, f, ensure_ascii=False, indent=2)
logger.info(f"批量处理完成!")
logger.info(f"总文件数: {len(video_files)}")
logger.info(f"成功: {success_count}, 失败: {failed_count}")
logger.info(f"总耗时: {total_time:.2f}")
logger.info(f"提取字幕: {total_subtitles}")
logger.info(f"文本长度: {total_text_length} 字符")
logger.info(f"位置信息统计:")
logger.info(f" 总字幕数: {total_subtitles_raw}")
logger.info(f" 有位置信息: {total_subtitles_with_bbox}")
logger.info(f" 位置信息覆盖率: {overall_bbox_coverage:.1f}%")
logger.info(f"处理报告: {report_file}")
return batch_result
def main():
"""主函数"""
parser = argparse.ArgumentParser(description="批量视频字幕提取器")
parser.add_argument("input", help="输入视频文件或目录")
parser.add_argument("-e", "--engine", default="cnocr",
choices=["paddleocr", "easyocr", "cnocr", "all"],
help="OCR引擎 (默认: paddleocr)")
parser.add_argument("-l", "--language", default="ch",
choices=["ch", "en", "ch_en"],
help="语言设置 (默认: ch)")
parser.add_argument("-i", "--interval", type=int, default=30,
help="帧采样间隔 (默认: 30)")
parser.add_argument("-c", "--confidence", type=float, default=0.5,
help="置信度阈值 (默认: 0.5)")
parser.add_argument("-o", "--output", default="batch_subtitles",
help="输出目录 (默认: batch_subtitles)")
parser.add_argument("-f", "--formats", nargs='+', default=["json"],
choices=["json", "txt", "srt"],
help="输出格式 (默认: json)")
parser.add_argument("--position", default="full",
choices=["full", "center", "bottom"],
help="字幕区域位置 (full=全屏, center=居中0.5-0.8, bottom=居下0.7-1.0)")
parser.add_argument("--workers", type=int, default=2,
help="并行工作进程数 (默认: 2)")
parser.add_argument("--no-parallel", action="store_true",
help="禁用并行处理")
args = parser.parse_args()
# 创建批量提取器
batch_extractor = BatchSubtitleExtractor(
ocr_engine=args.engine,
language=args.language,
max_workers=args.workers
)
try:
# 执行批量提取
result = batch_extractor.extract_batch(
input_dir=args.input,
output_dir=args.output,
parallel=not args.no_parallel,
interval=args.interval,
confidence=args.confidence,
formats=args.formats,
position=args.position
)
if result['success']:
print(f"\n✅ 批量字幕提取完成!")
print(f"📁 输出目录: {args.output}")
print(f"📊 成功处理: {result['success_count']}/{result['total_files']} 个视频")
if result['failed_count'] > 0:
print(f"❌ 失败: {result['failed_count']}")
print(f"⏱️ 总耗时: {result['total_time']:.2f}")
print(f"📝 字幕片段: {result['total_subtitles']}")
print(f"📏 文本长度: {result['total_text_length']} 字符")
print(f"📍 位置信息统计:")
print(f" 总字幕数: {result['total_subtitles_raw']}")
print(f" 有位置信息: {result['total_subtitles_with_bbox']}")
print(f" 位置信息覆盖率: {result['overall_bbox_coverage']:.1f}%")
else:
print(f"\n❌ 批量处理失败: {result.get('message', '未知错误')}")
except Exception as e:
logger.error(f"批量处理出错: {str(e)}")
print(f"\n❌ 批量处理出错: {str(e)}")
if __name__ == "__main__":
main()