#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 专业视频字幕OCR提取器 支持PaddleOCR、EasyOCR和CnOCR三种引擎 """ import cv2 import os import time import json import argparse from pathlib import Path from datetime import datetime import logging import numpy as np from typing import List, Dict, Tuple # 设置日志 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) def convert_numpy_types(obj): """转换NumPy类型为Python原生类型,用于JSON序列化""" if isinstance(obj, np.integer): return int(obj) elif isinstance(obj, np.floating): return float(obj) elif isinstance(obj, np.ndarray): return obj.tolist() elif isinstance(obj, (list, tuple)): return [convert_numpy_types(item) for item in obj] elif isinstance(obj, dict): return {key: convert_numpy_types(value) for key, value in obj.items()} else: return obj class VideoSubtitleExtractor: """专业视频字幕OCR提取器""" def __init__(self, ocr_engine="paddleocr", language="ch"): """ 初始化OCR提取器 Args: ocr_engine: OCR引擎 ("paddleocr", "easyocr", "cnocr", "all") language: 语言设置 ("ch", "en", "ch_en") """ self.ocr_engine = ocr_engine self.language = language self.paddle_ocr = None self.easy_ocr = None self.cn_ocr = None self.load_ocr_engines() def load_ocr_engines(self): """加载OCR引擎""" logger.info(f"正在加载OCR引擎: {self.ocr_engine}") if self.ocr_engine in ["paddleocr", "all"]: try: from paddleocr import PaddleOCR # PaddleOCR语言映射 paddle_lang_map = { "ch": "ch", "en": "en", "ch_en": "ch" # PaddleOCR支持ch包含中英文 } paddle_lang = paddle_lang_map.get(self.language, "ch") self.paddle_ocr = PaddleOCR( use_textline_orientation=True, lang=paddle_lang, show_log=False # 减少日志输出 ) logger.info("PaddleOCR加载完成") except ImportError: logger.error("请安装PaddleOCR: pip install paddleocr") if self.ocr_engine == "paddleocr": raise if self.ocr_engine in ["easyocr", "all"]: try: import easyocr import os # 设置EasyOCR模型路径 model_storage_directory = '/root/autodl-tmp/llm/easyocr' # EasyOCR语言映射 if self.language == "ch": lang_list = ['ch_sim'] # 简体中文 elif self.language == "en": lang_list = ['en'] elif self.language == "ch_en": lang_list = ['ch_sim', 'en'] # 中英文混合 else: lang_list = ['ch_sim'] # 默认简体中文 logger.info(f"EasyOCR模型路径: {model_storage_directory}") self.easy_ocr = easyocr.Reader( lang_list, model_storage_directory=model_storage_directory ) logger.info("EasyOCR加载完成") except ImportError: logger.error("请安装EasyOCR: pip install easyocr") if self.ocr_engine == "easyocr": raise if self.ocr_engine in ["cnocr", "all"]: try: from cnocr import CnOcr import os # 设置CnOCR模型路径 cn_ocr_model_dir = "/root/autodl-tmp/llm/cnocr" os.makedirs(cn_ocr_model_dir, exist_ok=True) # 设置环境变量,指定CnOCR模型存储路径 os.environ['CNOCR_HOME'] = cn_ocr_model_dir logger.info(f"CnOCR模型路径: {cn_ocr_model_dir}") logger.info("使用CnOCR默认模型配置") # 使用默认配置,CnOCR会自动选择合适的模型 self.cn_ocr = CnOcr() logger.info("CnOCR加载完成") except ImportError: logger.error("请安装CnOCR: pip install cnocr") if self.ocr_engine == "cnocr": raise except Exception as e: logger.error(f"CnOCR初始化失败: {e}") if self.ocr_engine == "cnocr": raise def calculate_subtitle_region(self, frame_width: int, frame_height: int, position: str = "bottom") -> Tuple: """ 根据视频尺寸和位置选择计算字幕区域 Args: frame_width: 视频宽度 frame_height: 视频高度 position: 位置选择 ("full", "center", "bottom") Returns: Tuple: (x1, y1, x2, y2) 字幕区域坐标 """ if position == "full": # 全屏区域 region = (0, 0, frame_width, frame_height) elif position == "center": # 居中区域:高度的0.4-0.6部分 y1 = int(frame_height * 0.5) y2 = int(frame_height * 0.8) region = (0, y1, frame_width, y2) elif position == "bottom": # 居下区域:高度的0.7-1.0部分 y1 = int(frame_height * 0.5) y2 = frame_height region = (0, y1, frame_width, y2) else: # 默认全屏 region = (0, 0, frame_width, frame_height) logger.info(f"字幕区域 ({position}): {region}, 视频尺寸: {frame_width}x{frame_height}") return region def extract_subtitles_from_video(self, video_path: str, sample_interval: int = 30, confidence_threshold: float = 0.5, subtitle_position: str = "bottom") -> Dict: """ 从视频中提取字幕 Args: video_path: 视频文件路径 sample_interval: 采样间隔(帧数) confidence_threshold: 置信度阈值 subtitle_position: 字幕位置 ("full", "center", "bottom") Returns: Dict: 提取结果 """ video_path = Path(video_path) if not video_path.exists(): raise FileNotFoundError(f"视频文件不存在: {video_path}") logger.info(f"开始提取视频字幕: {video_path}") logger.info(f"采样间隔: {sample_interval}帧, 置信度阈值: {confidence_threshold}") start_time = time.time() # 打开视频 cap = cv2.VideoCapture(str(video_path)) fps = cap.get(cv2.CAP_PROP_FPS) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) duration = total_frames / fps frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 根据位置计算字幕区域 subtitle_region = self.calculate_subtitle_region(frame_width, frame_height, subtitle_position) logger.info(f"视频信息: {total_frames}帧, {fps:.2f}FPS, {duration:.2f}秒, 分辨率: {frame_width}x{frame_height}") results = { "video_path": str(video_path), "duration": duration, "fps": fps, "total_frames": total_frames, "frame_width": frame_width, "frame_height": frame_height, "sample_interval": sample_interval, "confidence_threshold": confidence_threshold, "subtitle_position": subtitle_position, "subtitle_region": subtitle_region, "subtitles": [], "unique_texts": set(), "ocr_engine": self.ocr_engine, "timestamp": datetime.now().isoformat() } frame_count = 0 processed_frames = 0 try: while True: ret, frame = cap.read() if not ret: break # 按间隔采样 if frame_count % sample_interval == 0: timestamp = frame_count / fps # 裁剪字幕区域 if subtitle_region: x1, y1, x2, y2 = subtitle_region frame_crop = frame[y1:y2, x1:x2] else: frame_crop = frame # OCR识别 ocr_results = self._ocr_frame(frame_crop, timestamp, confidence_threshold) if ocr_results: results["subtitles"].extend(ocr_results) for result in ocr_results: results["unique_texts"].add(result["text"]) processed_frames += 1 # 显示进度 if processed_frames % 10 == 0: progress = (frame_count / total_frames) * 100 logger.info(f"处理进度: {progress:.1f}% ({processed_frames}帧)") frame_count += 1 finally: cap.release() # 转换set为list results["unique_texts"] = list(results["unique_texts"]) # 后处理:去重和合并 results = self._postprocess_subtitles(results) extract_time = time.time() - start_time results["extract_time"] = extract_time logger.info(f"字幕提取完成,耗时: {extract_time:.2f}秒") logger.info(f"提取到 {len(results['subtitles'])} 个字幕片段") logger.info(f"唯一文本 {len(results['unique_texts'])} 条") return results def _ocr_frame(self, frame: np.ndarray, timestamp: float, confidence_threshold: float) -> List[Dict]: """对单帧进行OCR识别""" ocr_results = [] # PaddleOCR识别 if self.paddle_ocr: try: paddle_result = self.paddle_ocr.ocr(frame) if paddle_result and paddle_result[0]: for detection in paddle_result[0]: if detection: # 处理不同版本的返回格式 try: if len(detection) == 2: # 旧版本格式: [bbox, (text, confidence)] bbox, (text, confidence) = detection elif len(detection) == 3: # 新版本格式: [bbox, text, confidence] bbox, text, confidence = detection else: # 其他格式,跳过 continue if confidence >= confidence_threshold and text.strip(): ocr_results.append({ "timestamp": timestamp, "text": text.strip(), "confidence": confidence, "bbox": bbox, "engine": "PaddleOCR" }) except Exception as parse_error: logger.debug(f"解析单个检测结果失败: {parse_error}, 数据: {detection}") continue except Exception as e: logger.warning(f"PaddleOCR识别失败 (时间戳:{timestamp:.2f}): {e}") # EasyOCR识别 if self.easy_ocr: try: easy_result = self.easy_ocr.readtext(frame) for detection in easy_result: bbox, text, confidence = detection if confidence >= confidence_threshold and text.strip(): ocr_results.append({ "timestamp": timestamp, "text": text.strip(), "confidence": confidence, "bbox": bbox, "engine": "EasyOCR" }) except Exception as e: logger.warning(f"EasyOCR识别失败 (时间戳:{timestamp:.2f}): {e}") # CnOCR识别 if self.cn_ocr: try: cn_result = self.cn_ocr.ocr(frame) # 检查cn_result是否为列表且不为空 if isinstance(cn_result, list) and len(cn_result) > 0: for detection in cn_result: # 确保detection是字典类型 if isinstance(detection, dict): # CnOCR返回格式: {'text': str, 'score': float, 'position': np.ndarray} text = detection.get('text', '') confidence = detection.get('score', 0.0) position = detection.get('position', None) # 检查文本和置信度 if (confidence >= confidence_threshold and isinstance(text, str) and text.strip()): # 转换position格式为bbox格式 bbox = [] if position is not None: # position是numpy数组,shape: (4, 2),转换为列表格式 try: if hasattr(position, 'tolist'): bbox = position.tolist() else: bbox = position except Exception as e: logger.debug(f"转换position失败: {e}") bbox = [] ocr_results.append({ "timestamp": timestamp, "text": text.strip(), "confidence": float(confidence), "bbox": bbox, "engine": "CnOCR" }) except Exception as e: logger.warning(f"CnOCR识别失败 (时间戳:{timestamp:.2f}): {e}") return ocr_results def _postprocess_subtitles(self, results: Dict) -> Dict: """后处理字幕结果""" subtitles = results["subtitles"] # 按时间戳排序 subtitles.sort(key=lambda x: x["timestamp"]) # 去重相邻重复文本 filtered_subtitles = [] last_text = "" last_timestamp = -1 for subtitle in subtitles: text = subtitle["text"] timestamp = subtitle["timestamp"] # 如果文本不同,或者时间间隔较大,则保留 if text != last_text or (timestamp - last_timestamp) > 5.0: filtered_subtitles.append(subtitle) last_text = text last_timestamp = timestamp results["subtitles"] = filtered_subtitles results["filtered_count"] = len(subtitles) - len(filtered_subtitles) # 生成连续文本 results["continuous_text"] = " ".join([s["text"] for s in filtered_subtitles]) # 统计信息 results["stats"] = { "total_detections": len(subtitles), "filtered_detections": len(filtered_subtitles), "unique_texts": len(results["unique_texts"]), "text_length": len(results["continuous_text"]), "average_confidence": np.mean([s["confidence"] for s in filtered_subtitles]) if filtered_subtitles else 0 } return results def extract_subtitles_from_frames(self, frame_dir: str, confidence_threshold: float = 0.5) -> Dict: """从帧图片目录中提取字幕""" frame_dir = Path(frame_dir) if not frame_dir.exists(): raise FileNotFoundError(f"帧目录不存在: {frame_dir}") logger.info(f"开始从帧目录提取字幕: {frame_dir}") # 获取所有图片文件 image_extensions = ['.jpg', '.jpeg', '.png', '.bmp'] image_files = [] for ext in image_extensions: image_files.extend(frame_dir.glob(f"*{ext}")) image_files.extend(frame_dir.glob(f"*{ext.upper()}")) image_files.sort() if not image_files: logger.warning(f"在目录 {frame_dir} 中未找到图片文件") return {} logger.info(f"找到 {len(image_files)} 个图片文件") results = { "frame_dir": str(frame_dir), "total_images": len(image_files), "confidence_threshold": confidence_threshold, "subtitles": [], "unique_texts": set(), "ocr_engine": self.ocr_engine, "timestamp": datetime.now().isoformat() } for i, image_file in enumerate(image_files): try: frame = cv2.imread(str(image_file)) if frame is not None: # 使用文件名作为时间戳(如果可能) timestamp = float(i) # 简单使用索引 ocr_results = self._ocr_frame(frame, timestamp, confidence_threshold) if ocr_results: results["subtitles"].extend(ocr_results) for result in ocr_results: results["unique_texts"].add(result["text"]) # 显示进度 if (i + 1) % 10 == 0: progress = ((i + 1) / len(image_files)) * 100 logger.info(f"处理进度: {progress:.1f}%") except Exception as e: logger.warning(f"处理图片 {image_file} 时出错: {e}") # 转换set为list results["unique_texts"] = list(results["unique_texts"]) # 后处理 results = self._postprocess_subtitles(results) logger.info(f"字幕提取完成,提取到 {len(results['subtitles'])} 个字幕片段") return results def save_results(self, results: Dict, output_path: str, format: str = "json"): """保存提取结果""" output_path = Path(output_path) if format == "json": # 序列化结果(处理set类型和NumPy类型) results_copy = results.copy() if "unique_texts" in results_copy and isinstance(results_copy["unique_texts"], set): results_copy["unique_texts"] = list(results_copy["unique_texts"]) # 转换NumPy类型 results_copy = convert_numpy_types(results_copy) with open(output_path, 'w', encoding='utf-8') as f: json.dump(results_copy, f, ensure_ascii=False, indent=2) elif format == "txt": with open(output_path, 'w', encoding='utf-8') as f: f.write(results.get("continuous_text", "")) elif format == "srt": self._save_srt(results, output_path) else: raise ValueError(f"不支持的格式: {format}") logger.info(f"结果已保存: {output_path}") def _save_srt(self, results: Dict, output_path: Path): """保存为SRT字幕格式""" with open(output_path, 'w', encoding='utf-8') as f: for i, subtitle in enumerate(results["subtitles"], 1): timestamp = subtitle["timestamp"] text = subtitle["text"] # 估算字幕持续时间(3秒) start_time = self._format_time(timestamp) end_time = self._format_time(timestamp + 3.0) f.write(f"{i}\n") f.write(f"{start_time} --> {end_time}\n") f.write(f"{text}\n\n") def _format_time(self, seconds): """格式化时间为SRT格式""" hours = int(seconds // 3600) minutes = int((seconds % 3600) // 60) seconds = seconds % 60 return f"{hours:02d}:{minutes:02d}:{seconds:06.3f}".replace('.', ',') def main(): """主函数""" parser = argparse.ArgumentParser(description="专业视频字幕OCR提取器") parser.add_argument("input", help="输入视频文件路径或帧图片目录") parser.add_argument("-e", "--engine", default="cnocr", choices=["paddleocr", "easyocr", "cnocr", "all"], help="OCR引擎 (默认: paddleocr)") parser.add_argument("-l", "--language", default="ch", choices=["ch", "en", "ch_en"], help="语言设置 (默认: ch)") parser.add_argument("-i", "--interval", type=int, default=30, help="帧采样间隔 (默认: 30)") parser.add_argument("-c", "--confidence", type=float, default=0.5, help="置信度阈值 (默认: 0.5)") parser.add_argument("-o", "--output", default="subtitle_results", help="输出目录 (默认: subtitle_results)") parser.add_argument("-f", "--format", default="json", choices=["json", "txt", "srt"], help="输出格式 (默认: json)") parser.add_argument("--position", default="bottom", choices=["full", "center", "bottom"], help="字幕区域位置 (full=全屏, center=居中0.4-0.6, bottom=居下0.7-1.0)") args = parser.parse_args() # 创建提取器 extractor = VideoSubtitleExtractor( ocr_engine=args.engine, language=args.language ) input_path = Path(args.input) try: if input_path.is_file(): # 处理视频文件 logger.info("处理视频文件") results = extractor.extract_subtitles_from_video( args.input, sample_interval=args.interval, confidence_threshold=args.confidence, subtitle_position=args.position ) elif input_path.is_dir(): # 处理帧图片目录 logger.info("处理帧图片目录") results = extractor.extract_subtitles_from_frames( args.input, confidence_threshold=args.confidence ) else: logger.error(f"输入路径无效: {args.input}") return # 保存结果 output_dir = Path(args.output) output_dir.mkdir(exist_ok=True) input_name = input_path.stem output_file = output_dir / f"{input_name}_subtitles.{args.format}" extractor.save_results(results, output_file, args.format) # 显示结果 print(f"\n字幕提取完成!") print(f"输入: {args.input}") print(f"输出: {output_file}") print(f"提取时长: {results.get('extract_time', 0):.2f} 秒") print(f"字幕片段: {results['stats']['filtered_detections']} 个") print(f"唯一文本: {results['stats']['unique_texts']} 条") print(f"平均置信度: {results['stats']['average_confidence']:.3f}") print(f"文本总长度: {results['stats']['text_length']} 字符") if results.get("continuous_text"): print(f"\n提取的连续文本:") print("-" * 50) print(results["continuous_text"][:500] + "..." if len(results["continuous_text"]) > 500 else results["continuous_text"]) except Exception as e: logger.error(f"程序执行出错: {str(e)}") return if __name__ == "__main__": main()