#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 批量视频字幕提取器 支持批量处理多个视频文件,提取字幕 支持PaddleOCR、EasyOCR和CnOCR三种引擎 """ import os import sys import time import json import argparse from pathlib import Path from datetime import datetime import logging from concurrent.futures import ThreadPoolExecutor, as_completed # 添加当前目录到路径以导入OCR模块 sys.path.append(os.path.dirname(os.path.abspath(__file__))) from ocr_subtitle_extractor import VideoSubtitleExtractor # 设置OCR模型路径环境变量 os.environ['EASYOCR_MODULE_PATH'] = '/root/autodl-tmp/llm/easyocr' os.environ['CNOCR_HOME'] = '/root/autodl-tmp/llm/cnocr' # 设置日志 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) class BatchSubtitleExtractor: """批量视频字幕提取器""" def __init__(self, ocr_engine="paddleocr", language="ch", max_workers=2): """ 初始化批量提取器 Args: ocr_engine: OCR引擎 ("paddleocr", "easyocr", "cnocr", "all") language: 语言设置 ("ch", "en", "ch_en") max_workers: 最大并行工作数 """ self.ocr_engine = ocr_engine self.language = language self.max_workers = max_workers self.extractor = VideoSubtitleExtractor(ocr_engine=ocr_engine, language=language) def find_video_files(self, input_dir): """查找目录中的所有视频文件""" video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm'] video_files = [] input_path = Path(input_dir) if input_path.is_file(): # 单个文件 if input_path.suffix.lower() in video_extensions: video_files.append(input_path) elif input_path.is_dir(): # 目录中的所有视频文件 for ext in video_extensions: video_files.extend(input_path.glob(f"*{ext}")) video_files.extend(input_path.glob(f"*{ext.upper()}")) return sorted(video_files) def extract_single_video(self, video_path, output_dir, **kwargs): """ 处理单个视频文件 Args: video_path: 视频文件路径 output_dir: 输出目录 **kwargs: 其他参数 Returns: dict: 处理结果 """ video_path = Path(video_path) video_name = video_path.stem logger.info(f"开始处理视频: {video_path}") start_time = time.time() try: # 提取字幕 results = self.extractor.extract_subtitles_from_video( str(video_path), sample_interval=kwargs.get('interval', 30), confidence_threshold=kwargs.get('confidence', 0.5), subtitle_position=kwargs.get('position', 'bottom') ) # 保存结果 output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) for format_type in kwargs.get('formats', ['json']): output_file = output_path / f"{video_name}_subtitles.{format_type}" self.extractor.save_results(results, output_file, format_type) process_time = time.time() - start_time results['process_time'] = process_time results['video_path'] = str(video_path) results['success'] = True # 统计位置信息 subtitles_with_bbox = [s for s in results['subtitles'] if s.get('bbox')] bbox_coverage = len(subtitles_with_bbox) / len(results['subtitles']) * 100 if results['subtitles'] else 0 logger.info(f"完成处理视频: {video_path} (耗时: {process_time:.2f}秒)") logger.info(f" 字幕总数: {len(results['subtitles'])}") logger.info(f" 有位置信息: {len(subtitles_with_bbox)}") logger.info(f" 位置信息覆盖率: {bbox_coverage:.1f}%") return { 'video_path': str(video_path), 'success': True, 'process_time': process_time, 'subtitle_count': results['stats']['filtered_detections'], 'text_length': results['stats']['text_length'], 'total_subtitles': len(results['subtitles']), 'subtitles_with_bbox': len(subtitles_with_bbox), 'bbox_coverage': bbox_coverage, 'output_files': [str(output_path / f"{video_name}_subtitles.{fmt}") for fmt in kwargs.get('formats', ['json'])] } except Exception as e: error_msg = f"处理视频 {video_path} 时出错: {str(e)}" logger.error(error_msg) return { 'video_path': str(video_path), 'success': False, 'error': error_msg, 'process_time': time.time() - start_time } def extract_batch(self, input_dir, output_dir, parallel=True, **kwargs): """ 批量提取字幕 Args: input_dir: 输入目录或文件 output_dir: 输出目录 parallel: 是否并行处理 **kwargs: 其他参数 Returns: dict: 批量处理结果 """ logger.info(f"开始批量字幕提取") logger.info(f"输入: {input_dir}") logger.info(f"输出目录: {output_dir}") logger.info(f"OCR引擎: {self.ocr_engine}") logger.info(f"字幕位置: {kwargs.get('position', 'bottom')}") logger.info(f"并行处理: {parallel}") start_time = time.time() # 查找视频文件 video_files = self.find_video_files(input_dir) if not video_files: logger.warning(f"在 {input_dir} 中未找到视频文件") return { 'success': False, 'message': '未找到视频文件', 'total_files': 0, 'results': [] } logger.info(f"找到 {len(video_files)} 个视频文件") results = [] if parallel and len(video_files) > 1: # 并行处理 logger.info(f"使用 {self.max_workers} 个并行工作进程") with ThreadPoolExecutor(max_workers=self.max_workers) as executor: # 提交任务 future_to_video = { executor.submit(self.extract_single_video, video_file, output_dir, **kwargs): video_file for video_file in video_files } # 收集结果 for future in as_completed(future_to_video): video_file = future_to_video[future] try: result = future.result() results.append(result) # 显示进度 progress = len(results) / len(video_files) * 100 logger.info(f"批量处理进度: {progress:.1f}% ({len(results)}/{len(video_files)})") except Exception as e: logger.error(f"处理视频 {video_file} 时发生异常: {str(e)}") results.append({ 'video_path': str(video_file), 'success': False, 'error': str(e) }) else: # 串行处理 for i, video_file in enumerate(video_files, 1): logger.info(f"处理第 {i}/{len(video_files)} 个视频") result = self.extract_single_video(video_file, output_dir, **kwargs) results.append(result) # 显示进度 progress = i / len(video_files) * 100 logger.info(f"批量处理进度: {progress:.1f}%") total_time = time.time() - start_time # 统计结果 success_count = sum(1 for r in results if r['success']) failed_count = len(results) - success_count total_subtitles = sum(r.get('subtitle_count', 0) for r in results if r['success']) total_text_length = sum(r.get('text_length', 0) for r in results if r['success']) # 统计位置信息 total_subtitles_raw = sum(r.get('total_subtitles', 0) for r in results if r['success']) total_subtitles_with_bbox = sum(r.get('subtitles_with_bbox', 0) for r in results if r['success']) overall_bbox_coverage = total_subtitles_with_bbox / total_subtitles_raw * 100 if total_subtitles_raw > 0 else 0 batch_result = { 'success': True, 'total_time': total_time, 'total_files': len(video_files), 'success_count': success_count, 'failed_count': failed_count, 'total_subtitles': total_subtitles, 'total_text_length': total_text_length, 'total_subtitles_raw': total_subtitles_raw, 'total_subtitles_with_bbox': total_subtitles_with_bbox, 'overall_bbox_coverage': overall_bbox_coverage, 'output_directory': output_dir, 'ocr_engine': self.ocr_engine, 'timestamp': datetime.now().isoformat(), 'results': results } # 保存批量处理报告 report_file = Path(output_dir) / "batch_report.json" with open(report_file, 'w', encoding='utf-8') as f: json.dump(batch_result, f, ensure_ascii=False, indent=2) logger.info(f"批量处理完成!") logger.info(f"总文件数: {len(video_files)}") logger.info(f"成功: {success_count}, 失败: {failed_count}") logger.info(f"总耗时: {total_time:.2f} 秒") logger.info(f"提取字幕: {total_subtitles} 个") logger.info(f"文本长度: {total_text_length} 字符") logger.info(f"位置信息统计:") logger.info(f" 总字幕数: {total_subtitles_raw}") logger.info(f" 有位置信息: {total_subtitles_with_bbox}") logger.info(f" 位置信息覆盖率: {overall_bbox_coverage:.1f}%") logger.info(f"处理报告: {report_file}") return batch_result def main(): """主函数""" parser = argparse.ArgumentParser(description="批量视频字幕提取器") parser.add_argument("input", help="输入视频文件或目录") parser.add_argument("-e", "--engine", default="cnocr", choices=["paddleocr", "easyocr", "cnocr", "all"], help="OCR引擎 (默认: paddleocr)") parser.add_argument("-l", "--language", default="ch", choices=["ch", "en", "ch_en"], help="语言设置 (默认: ch)") parser.add_argument("-i", "--interval", type=int, default=30, help="帧采样间隔 (默认: 30)") parser.add_argument("-c", "--confidence", type=float, default=0.5, help="置信度阈值 (默认: 0.5)") parser.add_argument("-o", "--output", default="batch_subtitles", help="输出目录 (默认: batch_subtitles)") parser.add_argument("-f", "--formats", nargs='+', default=["json"], choices=["json", "txt", "srt"], help="输出格式 (默认: json)") parser.add_argument("--position", default="full", choices=["full", "center", "bottom"], help="字幕区域位置 (full=全屏, center=居中0.5-0.8, bottom=居下0.7-1.0)") parser.add_argument("--workers", type=int, default=2, help="并行工作进程数 (默认: 2)") parser.add_argument("--no-parallel", action="store_true", help="禁用并行处理") args = parser.parse_args() # 创建批量提取器 batch_extractor = BatchSubtitleExtractor( ocr_engine=args.engine, language=args.language, max_workers=args.workers ) try: # 执行批量提取 result = batch_extractor.extract_batch( input_dir=args.input, output_dir=args.output, parallel=not args.no_parallel, interval=args.interval, confidence=args.confidence, formats=args.formats, position=args.position ) if result['success']: print(f"\n✅ 批量字幕提取完成!") print(f"📁 输出目录: {args.output}") print(f"📊 成功处理: {result['success_count']}/{result['total_files']} 个视频") if result['failed_count'] > 0: print(f"❌ 失败: {result['failed_count']} 个") print(f"⏱️ 总耗时: {result['total_time']:.2f} 秒") print(f"📝 字幕片段: {result['total_subtitles']} 个") print(f"📏 文本长度: {result['total_text_length']} 字符") print(f"📍 位置信息统计:") print(f" 总字幕数: {result['total_subtitles_raw']}") print(f" 有位置信息: {result['total_subtitles_with_bbox']}") print(f" 位置信息覆盖率: {result['overall_bbox_coverage']:.1f}%") else: print(f"\n❌ 批量处理失败: {result.get('message', '未知错误')}") except Exception as e: logger.error(f"批量处理出错: {str(e)}") print(f"\n❌ 批量处理出错: {str(e)}") if __name__ == "__main__": main()