#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 专业视频字幕OCR提取器 支持PaddleOCR和EasyOCR两种引擎 """ import cv2 import os import time import json import argparse from pathlib import Path from datetime import datetime import logging import numpy as np from typing import List, Dict, Tuple # 设置日志 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) def convert_numpy_types(obj): """转换NumPy类型为Python原生类型,用于JSON序列化""" if isinstance(obj, np.integer): return int(obj) elif isinstance(obj, np.floating): return float(obj) elif isinstance(obj, np.ndarray): return obj.tolist() elif isinstance(obj, (list, tuple)): return [convert_numpy_types(item) for item in obj] elif isinstance(obj, dict): return {key: convert_numpy_types(value) for key, value in obj.items()} else: return obj class VideoSubtitleExtractor: """专业视频字幕OCR提取器""" def __init__(self, ocr_engine="paddleocr", language="ch"): """ 初始化OCR提取器 Args: ocr_engine: OCR引擎 ("paddleocr", "easyocr", "both") language: 语言设置 ("ch", "en", "ch_en") """ self.ocr_engine = ocr_engine self.language = language self.paddle_ocr = None self.easy_ocr = None self.load_ocr_engines() def load_ocr_engines(self): """加载OCR引擎""" logger.info(f"正在加载OCR引擎: {self.ocr_engine}") if self.ocr_engine in ["paddleocr", "both"]: try: from paddleocr import PaddleOCR # PaddleOCR语言映射 paddle_lang_map = { "ch": "ch", "en": "en", "ch_en": "ch" # PaddleOCR支持ch包含中英文 } paddle_lang = paddle_lang_map.get(self.language, "ch") self.paddle_ocr = PaddleOCR( use_textline_orientation=True, lang=paddle_lang ) logger.info("PaddleOCR加载完成") except ImportError: logger.error("请安装PaddleOCR: pip install paddleocr") raise if self.ocr_engine in ["easyocr", "both"]: try: import easyocr # EasyOCR语言映射 if self.language == "ch": lang_list = ['ch_sim'] # 简体中文 elif self.language == "en": lang_list = ['en'] elif self.language == "ch_en": lang_list = ['ch_sim', 'en'] # 中英文混合 else: lang_list = ['ch_sim'] # 默认简体中文 self.easy_ocr = easyocr.Reader(lang_list) logger.info("EasyOCR加载完成") except ImportError: logger.error("请安装EasyOCR: pip install easyocr") raise def extract_subtitles_from_video(self, video_path: str, sample_interval: int = 30, confidence_threshold: float = 0.5, subtitle_region: Tuple = None) -> Dict: """ 从视频中提取字幕 Args: video_path: 视频文件路径 sample_interval: 采样间隔(帧数) confidence_threshold: 置信度阈值 subtitle_region: 字幕区域 (x1, y1, x2, y2),None表示全画面 Returns: Dict: 提取结果 """ video_path = Path(video_path) if not video_path.exists(): raise FileNotFoundError(f"视频文件不存在: {video_path}") logger.info(f"开始提取视频字幕: {video_path}") logger.info(f"采样间隔: {sample_interval}帧, 置信度阈值: {confidence_threshold}") start_time = time.time() # 打开视频 cap = cv2.VideoCapture(str(video_path)) fps = cap.get(cv2.CAP_PROP_FPS) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) duration = total_frames / fps logger.info(f"视频信息: {total_frames}帧, {fps:.2f}FPS, {duration:.2f}秒") results = { "video_path": str(video_path), "duration": duration, "fps": fps, "total_frames": total_frames, "sample_interval": sample_interval, "confidence_threshold": confidence_threshold, "subtitle_region": subtitle_region, "subtitles": [], "unique_texts": set(), "ocr_engine": self.ocr_engine, "timestamp": datetime.now().isoformat() } frame_count = 0 processed_frames = 0 try: while True: ret, frame = cap.read() if not ret: break # 按间隔采样 if frame_count % sample_interval == 0: timestamp = frame_count / fps # 裁剪字幕区域 if subtitle_region: x1, y1, x2, y2 = subtitle_region frame_crop = frame[y1:y2, x1:x2] else: frame_crop = frame # OCR识别 ocr_results = self._ocr_frame(frame_crop, timestamp, confidence_threshold) if ocr_results: results["subtitles"].extend(ocr_results) for result in ocr_results: results["unique_texts"].add(result["text"]) processed_frames += 1 # 显示进度 if processed_frames % 10 == 0: progress = (frame_count / total_frames) * 100 logger.info(f"处理进度: {progress:.1f}% ({processed_frames}帧)") frame_count += 1 finally: cap.release() # 转换set为list results["unique_texts"] = list(results["unique_texts"]) # 后处理:去重和合并 results = self._postprocess_subtitles(results) extract_time = time.time() - start_time results["extract_time"] = extract_time logger.info(f"字幕提取完成,耗时: {extract_time:.2f}秒") logger.info(f"提取到 {len(results['subtitles'])} 个字幕片段") logger.info(f"唯一文本 {len(results['unique_texts'])} 条") return results def _ocr_frame(self, frame: np.ndarray, timestamp: float, confidence_threshold: float) -> List[Dict]: """对单帧进行OCR识别""" ocr_results = [] # PaddleOCR识别 if self.paddle_ocr: try: paddle_result = self.paddle_ocr.ocr(frame) if paddle_result and paddle_result[0]: for detection in paddle_result[0]: if detection: # 处理不同版本的返回格式 try: if len(detection) == 2: # 旧版本格式: [bbox, (text, confidence)] bbox, (text, confidence) = detection elif len(detection) == 3: # 新版本格式: [bbox, text, confidence] bbox, text, confidence = detection else: # 其他格式,跳过 continue if confidence >= confidence_threshold and text.strip(): ocr_results.append({ "timestamp": timestamp, "text": text.strip(), "confidence": confidence, "bbox": bbox, "engine": "PaddleOCR" }) except Exception as parse_error: logger.debug(f"解析单个检测结果失败: {parse_error}, 数据: {detection}") continue except Exception as e: logger.warning(f"PaddleOCR识别失败 (时间戳:{timestamp:.2f}): {e}") # EasyOCR识别 if self.easy_ocr: try: easy_result = self.easy_ocr.readtext(frame) for detection in easy_result: bbox, text, confidence = detection if confidence >= confidence_threshold and text.strip(): ocr_results.append({ "timestamp": timestamp, "text": text.strip(), "confidence": confidence, "bbox": bbox, "engine": "EasyOCR" }) except Exception as e: logger.warning(f"EasyOCR识别失败 (时间戳:{timestamp:.2f}): {e}") return ocr_results def _postprocess_subtitles(self, results: Dict) -> Dict: """后处理字幕结果""" subtitles = results["subtitles"] # 按时间戳排序 subtitles.sort(key=lambda x: x["timestamp"]) # 去重相邻重复文本 filtered_subtitles = [] last_text = "" last_timestamp = -1 for subtitle in subtitles: text = subtitle["text"] timestamp = subtitle["timestamp"] # 如果文本不同,或者时间间隔较大,则保留 if text != last_text or (timestamp - last_timestamp) > 5.0: filtered_subtitles.append(subtitle) last_text = text last_timestamp = timestamp results["subtitles"] = filtered_subtitles results["filtered_count"] = len(subtitles) - len(filtered_subtitles) # 生成连续文本 results["continuous_text"] = " ".join([s["text"] for s in filtered_subtitles]) # 统计信息 results["stats"] = { "total_detections": len(subtitles), "filtered_detections": len(filtered_subtitles), "unique_texts": len(results["unique_texts"]), "text_length": len(results["continuous_text"]), "average_confidence": np.mean([s["confidence"] for s in filtered_subtitles]) if filtered_subtitles else 0 } return results def extract_subtitles_from_frames(self, frame_dir: str, confidence_threshold: float = 0.5) -> Dict: """从帧图片目录中提取字幕""" frame_dir = Path(frame_dir) if not frame_dir.exists(): raise FileNotFoundError(f"帧目录不存在: {frame_dir}") logger.info(f"开始从帧目录提取字幕: {frame_dir}") # 获取所有图片文件 image_extensions = ['.jpg', '.jpeg', '.png', '.bmp'] image_files = [] for ext in image_extensions: image_files.extend(frame_dir.glob(f"*{ext}")) image_files.extend(frame_dir.glob(f"*{ext.upper()}")) image_files.sort() if not image_files: logger.warning(f"在目录 {frame_dir} 中未找到图片文件") return {} logger.info(f"找到 {len(image_files)} 个图片文件") results = { "frame_dir": str(frame_dir), "total_images": len(image_files), "confidence_threshold": confidence_threshold, "subtitles": [], "unique_texts": set(), "ocr_engine": self.ocr_engine, "timestamp": datetime.now().isoformat() } for i, image_file in enumerate(image_files): try: frame = cv2.imread(str(image_file)) if frame is not None: # 使用文件名作为时间戳(如果可能) timestamp = float(i) # 简单使用索引 ocr_results = self._ocr_frame(frame, timestamp, confidence_threshold) if ocr_results: results["subtitles"].extend(ocr_results) for result in ocr_results: results["unique_texts"].add(result["text"]) # 显示进度 if (i + 1) % 10 == 0: progress = ((i + 1) / len(image_files)) * 100 logger.info(f"处理进度: {progress:.1f}%") except Exception as e: logger.warning(f"处理图片 {image_file} 时出错: {e}") # 转换set为list results["unique_texts"] = list(results["unique_texts"]) # 后处理 results = self._postprocess_subtitles(results) logger.info(f"字幕提取完成,提取到 {len(results['subtitles'])} 个字幕片段") return results def save_results(self, results: Dict, output_path: str, format: str = "json"): """保存提取结果""" output_path = Path(output_path) if format == "json": # 序列化结果(处理set类型和NumPy类型) results_copy = results.copy() if "unique_texts" in results_copy and isinstance(results_copy["unique_texts"], set): results_copy["unique_texts"] = list(results_copy["unique_texts"]) # 转换NumPy类型 results_copy = convert_numpy_types(results_copy) with open(output_path, 'w', encoding='utf-8') as f: json.dump(results_copy, f, ensure_ascii=False, indent=2) elif format == "txt": with open(output_path, 'w', encoding='utf-8') as f: f.write(results.get("continuous_text", "")) elif format == "srt": self._save_srt(results, output_path) else: raise ValueError(f"不支持的格式: {format}") logger.info(f"结果已保存: {output_path}") def _save_srt(self, results: Dict, output_path: Path): """保存为SRT字幕格式""" with open(output_path, 'w', encoding='utf-8') as f: for i, subtitle in enumerate(results["subtitles"], 1): timestamp = subtitle["timestamp"] text = subtitle["text"] # 估算字幕持续时间(3秒) start_time = self._format_time(timestamp) end_time = self._format_time(timestamp + 3.0) f.write(f"{i}\n") f.write(f"{start_time} --> {end_time}\n") f.write(f"{text}\n\n") def _format_time(self, seconds): """格式化时间为SRT格式""" hours = int(seconds // 3600) minutes = int((seconds % 3600) // 60) seconds = seconds % 60 return f"{hours:02d}:{minutes:02d}:{seconds:06.3f}".replace('.', ',') def main(): """主函数""" parser = argparse.ArgumentParser(description="专业视频字幕OCR提取器") parser.add_argument("input", help="输入视频文件路径或帧图片目录") parser.add_argument("-e", "--engine", default="paddleocr", choices=["paddleocr", "easyocr", "both"], help="OCR引擎 (默认: paddleocr)") parser.add_argument("-l", "--language", default="ch", choices=["ch", "en", "ch_en"], help="语言设置 (默认: ch)") parser.add_argument("-i", "--interval", type=int, default=30, help="帧采样间隔 (默认: 30)") parser.add_argument("-c", "--confidence", type=float, default=0.5, help="置信度阈值 (默认: 0.5)") parser.add_argument("-o", "--output", default="subtitle_results", help="输出目录 (默认: subtitle_results)") parser.add_argument("-f", "--format", default="json", choices=["json", "txt", "srt"], help="输出格式 (默认: json)") parser.add_argument("--region", nargs=4, type=int, metavar=('X1', 'Y1', 'X2', 'Y2'), help="字幕区域坐标 (x1 y1 x2 y2)") args = parser.parse_args() # 创建提取器 extractor = VideoSubtitleExtractor( ocr_engine=args.engine, language=args.language ) input_path = Path(args.input) try: if input_path.is_file(): # 处理视频文件 logger.info("处理视频文件") results = extractor.extract_subtitles_from_video( args.input, sample_interval=args.interval, confidence_threshold=args.confidence, subtitle_region=tuple(args.region) if args.region else None ) elif input_path.is_dir(): # 处理帧图片目录 logger.info("处理帧图片目录") results = extractor.extract_subtitles_from_frames( args.input, confidence_threshold=args.confidence ) else: logger.error(f"输入路径无效: {args.input}") return # 保存结果 output_dir = Path(args.output) output_dir.mkdir(exist_ok=True) input_name = input_path.stem output_file = output_dir / f"{input_name}_subtitles.{args.format}" extractor.save_results(results, output_file, args.format) # 显示结果 print(f"\n字幕提取完成!") print(f"输入: {args.input}") print(f"输出: {output_file}") print(f"提取时长: {results.get('extract_time', 0):.2f} 秒") print(f"字幕片段: {results['stats']['filtered_detections']} 个") print(f"唯一文本: {results['stats']['unique_texts']} 条") print(f"平均置信度: {results['stats']['average_confidence']:.3f}") print(f"文本总长度: {results['stats']['text_length']} 字符") if results.get("continuous_text"): print(f"\n提取的连续文本:") print("-" * 50) print(results["continuous_text"][:500] + "..." if len(results["continuous_text"]) > 500 else results["continuous_text"]) except Exception as e: logger.error(f"程序执行出错: {str(e)}") return if __name__ == "__main__": main()