496 lines
19 KiB
Python
496 lines
19 KiB
Python
|
#!/usr/bin/env python3
|
|||
|
# -*- coding: utf-8 -*-
|
|||
|
"""
|
|||
|
专业视频字幕OCR提取器
|
|||
|
支持PaddleOCR和EasyOCR两种引擎
|
|||
|
"""
|
|||
|
|
|||
|
import cv2
|
|||
|
import os
|
|||
|
import time
|
|||
|
import json
|
|||
|
import argparse
|
|||
|
from pathlib import Path
|
|||
|
from datetime import datetime
|
|||
|
import logging
|
|||
|
import numpy as np
|
|||
|
from typing import List, Dict, Tuple
|
|||
|
|
|||
|
# 设置日志
|
|||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|||
|
logger = logging.getLogger(__name__)
|
|||
|
|
|||
|
def convert_numpy_types(obj):
|
|||
|
"""转换NumPy类型为Python原生类型,用于JSON序列化"""
|
|||
|
if isinstance(obj, np.integer):
|
|||
|
return int(obj)
|
|||
|
elif isinstance(obj, np.floating):
|
|||
|
return float(obj)
|
|||
|
elif isinstance(obj, np.ndarray):
|
|||
|
return obj.tolist()
|
|||
|
elif isinstance(obj, (list, tuple)):
|
|||
|
return [convert_numpy_types(item) for item in obj]
|
|||
|
elif isinstance(obj, dict):
|
|||
|
return {key: convert_numpy_types(value) for key, value in obj.items()}
|
|||
|
else:
|
|||
|
return obj
|
|||
|
|
|||
|
class VideoSubtitleExtractor:
|
|||
|
"""专业视频字幕OCR提取器"""
|
|||
|
|
|||
|
def __init__(self, ocr_engine="paddleocr", language="ch"):
|
|||
|
"""
|
|||
|
初始化OCR提取器
|
|||
|
|
|||
|
Args:
|
|||
|
ocr_engine: OCR引擎 ("paddleocr", "easyocr", "both")
|
|||
|
language: 语言设置 ("ch", "en", "ch_en")
|
|||
|
"""
|
|||
|
self.ocr_engine = ocr_engine
|
|||
|
self.language = language
|
|||
|
self.paddle_ocr = None
|
|||
|
self.easy_ocr = None
|
|||
|
self.load_ocr_engines()
|
|||
|
|
|||
|
def load_ocr_engines(self):
|
|||
|
"""加载OCR引擎"""
|
|||
|
logger.info(f"正在加载OCR引擎: {self.ocr_engine}")
|
|||
|
|
|||
|
if self.ocr_engine in ["paddleocr", "both"]:
|
|||
|
try:
|
|||
|
from paddleocr import PaddleOCR
|
|||
|
|
|||
|
# PaddleOCR语言映射
|
|||
|
paddle_lang_map = {
|
|||
|
"ch": "ch",
|
|||
|
"en": "en",
|
|||
|
"ch_en": "ch" # PaddleOCR支持ch包含中英文
|
|||
|
}
|
|||
|
paddle_lang = paddle_lang_map.get(self.language, "ch")
|
|||
|
|
|||
|
self.paddle_ocr = PaddleOCR(
|
|||
|
use_textline_orientation=True,
|
|||
|
lang=paddle_lang
|
|||
|
)
|
|||
|
logger.info("PaddleOCR加载完成")
|
|||
|
except ImportError:
|
|||
|
logger.error("请安装PaddleOCR: pip install paddleocr")
|
|||
|
raise
|
|||
|
|
|||
|
if self.ocr_engine in ["easyocr", "both"]:
|
|||
|
try:
|
|||
|
import easyocr
|
|||
|
|
|||
|
# EasyOCR语言映射
|
|||
|
if self.language == "ch":
|
|||
|
lang_list = ['ch_sim'] # 简体中文
|
|||
|
elif self.language == "en":
|
|||
|
lang_list = ['en']
|
|||
|
elif self.language == "ch_en":
|
|||
|
lang_list = ['ch_sim', 'en'] # 中英文混合
|
|||
|
else:
|
|||
|
lang_list = ['ch_sim'] # 默认简体中文
|
|||
|
|
|||
|
self.easy_ocr = easyocr.Reader(lang_list)
|
|||
|
logger.info("EasyOCR加载完成")
|
|||
|
except ImportError:
|
|||
|
logger.error("请安装EasyOCR: pip install easyocr")
|
|||
|
raise
|
|||
|
|
|||
|
def extract_subtitles_from_video(self, video_path: str,
|
|||
|
sample_interval: int = 30,
|
|||
|
confidence_threshold: float = 0.5,
|
|||
|
subtitle_region: Tuple = None) -> Dict:
|
|||
|
"""
|
|||
|
从视频中提取字幕
|
|||
|
|
|||
|
Args:
|
|||
|
video_path: 视频文件路径
|
|||
|
sample_interval: 采样间隔(帧数)
|
|||
|
confidence_threshold: 置信度阈值
|
|||
|
subtitle_region: 字幕区域 (x1, y1, x2, y2),None表示全画面
|
|||
|
|
|||
|
Returns:
|
|||
|
Dict: 提取结果
|
|||
|
"""
|
|||
|
video_path = Path(video_path)
|
|||
|
if not video_path.exists():
|
|||
|
raise FileNotFoundError(f"视频文件不存在: {video_path}")
|
|||
|
|
|||
|
logger.info(f"开始提取视频字幕: {video_path}")
|
|||
|
logger.info(f"采样间隔: {sample_interval}帧, 置信度阈值: {confidence_threshold}")
|
|||
|
|
|||
|
start_time = time.time()
|
|||
|
|
|||
|
# 打开视频
|
|||
|
cap = cv2.VideoCapture(str(video_path))
|
|||
|
fps = cap.get(cv2.CAP_PROP_FPS)
|
|||
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|||
|
duration = total_frames / fps
|
|||
|
|
|||
|
logger.info(f"视频信息: {total_frames}帧, {fps:.2f}FPS, {duration:.2f}秒")
|
|||
|
|
|||
|
results = {
|
|||
|
"video_path": str(video_path),
|
|||
|
"duration": duration,
|
|||
|
"fps": fps,
|
|||
|
"total_frames": total_frames,
|
|||
|
"sample_interval": sample_interval,
|
|||
|
"confidence_threshold": confidence_threshold,
|
|||
|
"subtitle_region": subtitle_region,
|
|||
|
"subtitles": [],
|
|||
|
"unique_texts": set(),
|
|||
|
"ocr_engine": self.ocr_engine,
|
|||
|
"timestamp": datetime.now().isoformat()
|
|||
|
}
|
|||
|
|
|||
|
frame_count = 0
|
|||
|
processed_frames = 0
|
|||
|
|
|||
|
try:
|
|||
|
while True:
|
|||
|
ret, frame = cap.read()
|
|||
|
if not ret:
|
|||
|
break
|
|||
|
|
|||
|
# 按间隔采样
|
|||
|
if frame_count % sample_interval == 0:
|
|||
|
timestamp = frame_count / fps
|
|||
|
|
|||
|
# 裁剪字幕区域
|
|||
|
if subtitle_region:
|
|||
|
x1, y1, x2, y2 = subtitle_region
|
|||
|
frame_crop = frame[y1:y2, x1:x2]
|
|||
|
else:
|
|||
|
frame_crop = frame
|
|||
|
|
|||
|
# OCR识别
|
|||
|
ocr_results = self._ocr_frame(frame_crop, timestamp, confidence_threshold)
|
|||
|
|
|||
|
if ocr_results:
|
|||
|
results["subtitles"].extend(ocr_results)
|
|||
|
for result in ocr_results:
|
|||
|
results["unique_texts"].add(result["text"])
|
|||
|
|
|||
|
processed_frames += 1
|
|||
|
|
|||
|
# 显示进度
|
|||
|
if processed_frames % 10 == 0:
|
|||
|
progress = (frame_count / total_frames) * 100
|
|||
|
logger.info(f"处理进度: {progress:.1f}% ({processed_frames}帧)")
|
|||
|
|
|||
|
frame_count += 1
|
|||
|
|
|||
|
finally:
|
|||
|
cap.release()
|
|||
|
|
|||
|
# 转换set为list
|
|||
|
results["unique_texts"] = list(results["unique_texts"])
|
|||
|
|
|||
|
# 后处理:去重和合并
|
|||
|
results = self._postprocess_subtitles(results)
|
|||
|
|
|||
|
extract_time = time.time() - start_time
|
|||
|
results["extract_time"] = extract_time
|
|||
|
|
|||
|
logger.info(f"字幕提取完成,耗时: {extract_time:.2f}秒")
|
|||
|
logger.info(f"提取到 {len(results['subtitles'])} 个字幕片段")
|
|||
|
logger.info(f"唯一文本 {len(results['unique_texts'])} 条")
|
|||
|
|
|||
|
return results
|
|||
|
|
|||
|
def _ocr_frame(self, frame: np.ndarray, timestamp: float, confidence_threshold: float) -> List[Dict]:
|
|||
|
"""对单帧进行OCR识别"""
|
|||
|
ocr_results = []
|
|||
|
|
|||
|
# PaddleOCR识别
|
|||
|
if self.paddle_ocr:
|
|||
|
try:
|
|||
|
paddle_result = self.paddle_ocr.ocr(frame)
|
|||
|
if paddle_result and paddle_result[0]:
|
|||
|
for detection in paddle_result[0]:
|
|||
|
if detection:
|
|||
|
# 处理不同版本的返回格式
|
|||
|
try:
|
|||
|
if len(detection) == 2:
|
|||
|
# 旧版本格式: [bbox, (text, confidence)]
|
|||
|
bbox, (text, confidence) = detection
|
|||
|
elif len(detection) == 3:
|
|||
|
# 新版本格式: [bbox, text, confidence]
|
|||
|
bbox, text, confidence = detection
|
|||
|
else:
|
|||
|
# 其他格式,跳过
|
|||
|
continue
|
|||
|
|
|||
|
if confidence >= confidence_threshold and text.strip():
|
|||
|
ocr_results.append({
|
|||
|
"timestamp": timestamp,
|
|||
|
"text": text.strip(),
|
|||
|
"confidence": confidence,
|
|||
|
"bbox": bbox,
|
|||
|
"engine": "PaddleOCR"
|
|||
|
})
|
|||
|
except Exception as parse_error:
|
|||
|
logger.debug(f"解析单个检测结果失败: {parse_error}, 数据: {detection}")
|
|||
|
continue
|
|||
|
except Exception as e:
|
|||
|
logger.warning(f"PaddleOCR识别失败 (时间戳:{timestamp:.2f}): {e}")
|
|||
|
|
|||
|
# EasyOCR识别
|
|||
|
if self.easy_ocr:
|
|||
|
try:
|
|||
|
easy_result = self.easy_ocr.readtext(frame)
|
|||
|
for detection in easy_result:
|
|||
|
bbox, text, confidence = detection
|
|||
|
if confidence >= confidence_threshold and text.strip():
|
|||
|
ocr_results.append({
|
|||
|
"timestamp": timestamp,
|
|||
|
"text": text.strip(),
|
|||
|
"confidence": confidence,
|
|||
|
"bbox": bbox,
|
|||
|
"engine": "EasyOCR"
|
|||
|
})
|
|||
|
except Exception as e:
|
|||
|
logger.warning(f"EasyOCR识别失败 (时间戳:{timestamp:.2f}): {e}")
|
|||
|
|
|||
|
return ocr_results
|
|||
|
|
|||
|
def _postprocess_subtitles(self, results: Dict) -> Dict:
|
|||
|
"""后处理字幕结果"""
|
|||
|
subtitles = results["subtitles"]
|
|||
|
|
|||
|
# 按时间戳排序
|
|||
|
subtitles.sort(key=lambda x: x["timestamp"])
|
|||
|
|
|||
|
# 去重相邻重复文本
|
|||
|
filtered_subtitles = []
|
|||
|
last_text = ""
|
|||
|
last_timestamp = -1
|
|||
|
|
|||
|
for subtitle in subtitles:
|
|||
|
text = subtitle["text"]
|
|||
|
timestamp = subtitle["timestamp"]
|
|||
|
|
|||
|
# 如果文本不同,或者时间间隔较大,则保留
|
|||
|
if text != last_text or (timestamp - last_timestamp) > 5.0:
|
|||
|
filtered_subtitles.append(subtitle)
|
|||
|
last_text = text
|
|||
|
last_timestamp = timestamp
|
|||
|
|
|||
|
results["subtitles"] = filtered_subtitles
|
|||
|
results["filtered_count"] = len(subtitles) - len(filtered_subtitles)
|
|||
|
|
|||
|
# 生成连续文本
|
|||
|
results["continuous_text"] = " ".join([s["text"] for s in filtered_subtitles])
|
|||
|
|
|||
|
# 统计信息
|
|||
|
results["stats"] = {
|
|||
|
"total_detections": len(subtitles),
|
|||
|
"filtered_detections": len(filtered_subtitles),
|
|||
|
"unique_texts": len(results["unique_texts"]),
|
|||
|
"text_length": len(results["continuous_text"]),
|
|||
|
"average_confidence": np.mean([s["confidence"] for s in filtered_subtitles]) if filtered_subtitles else 0
|
|||
|
}
|
|||
|
|
|||
|
return results
|
|||
|
|
|||
|
def extract_subtitles_from_frames(self, frame_dir: str, confidence_threshold: float = 0.5) -> Dict:
|
|||
|
"""从帧图片目录中提取字幕"""
|
|||
|
frame_dir = Path(frame_dir)
|
|||
|
if not frame_dir.exists():
|
|||
|
raise FileNotFoundError(f"帧目录不存在: {frame_dir}")
|
|||
|
|
|||
|
logger.info(f"开始从帧目录提取字幕: {frame_dir}")
|
|||
|
|
|||
|
# 获取所有图片文件
|
|||
|
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp']
|
|||
|
image_files = []
|
|||
|
for ext in image_extensions:
|
|||
|
image_files.extend(frame_dir.glob(f"*{ext}"))
|
|||
|
image_files.extend(frame_dir.glob(f"*{ext.upper()}"))
|
|||
|
|
|||
|
image_files.sort()
|
|||
|
|
|||
|
if not image_files:
|
|||
|
logger.warning(f"在目录 {frame_dir} 中未找到图片文件")
|
|||
|
return {}
|
|||
|
|
|||
|
logger.info(f"找到 {len(image_files)} 个图片文件")
|
|||
|
|
|||
|
results = {
|
|||
|
"frame_dir": str(frame_dir),
|
|||
|
"total_images": len(image_files),
|
|||
|
"confidence_threshold": confidence_threshold,
|
|||
|
"subtitles": [],
|
|||
|
"unique_texts": set(),
|
|||
|
"ocr_engine": self.ocr_engine,
|
|||
|
"timestamp": datetime.now().isoformat()
|
|||
|
}
|
|||
|
|
|||
|
for i, image_file in enumerate(image_files):
|
|||
|
try:
|
|||
|
frame = cv2.imread(str(image_file))
|
|||
|
if frame is not None:
|
|||
|
# 使用文件名作为时间戳(如果可能)
|
|||
|
timestamp = float(i) # 简单使用索引
|
|||
|
|
|||
|
ocr_results = self._ocr_frame(frame, timestamp, confidence_threshold)
|
|||
|
|
|||
|
if ocr_results:
|
|||
|
results["subtitles"].extend(ocr_results)
|
|||
|
for result in ocr_results:
|
|||
|
results["unique_texts"].add(result["text"])
|
|||
|
|
|||
|
# 显示进度
|
|||
|
if (i + 1) % 10 == 0:
|
|||
|
progress = ((i + 1) / len(image_files)) * 100
|
|||
|
logger.info(f"处理进度: {progress:.1f}%")
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
logger.warning(f"处理图片 {image_file} 时出错: {e}")
|
|||
|
|
|||
|
# 转换set为list
|
|||
|
results["unique_texts"] = list(results["unique_texts"])
|
|||
|
|
|||
|
# 后处理
|
|||
|
results = self._postprocess_subtitles(results)
|
|||
|
|
|||
|
logger.info(f"字幕提取完成,提取到 {len(results['subtitles'])} 个字幕片段")
|
|||
|
|
|||
|
return results
|
|||
|
|
|||
|
def save_results(self, results: Dict, output_path: str, format: str = "json"):
|
|||
|
"""保存提取结果"""
|
|||
|
output_path = Path(output_path)
|
|||
|
|
|||
|
if format == "json":
|
|||
|
# 序列化结果(处理set类型和NumPy类型)
|
|||
|
results_copy = results.copy()
|
|||
|
if "unique_texts" in results_copy and isinstance(results_copy["unique_texts"], set):
|
|||
|
results_copy["unique_texts"] = list(results_copy["unique_texts"])
|
|||
|
|
|||
|
# 转换NumPy类型
|
|||
|
results_copy = convert_numpy_types(results_copy)
|
|||
|
|
|||
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|||
|
json.dump(results_copy, f, ensure_ascii=False, indent=2)
|
|||
|
|
|||
|
elif format == "txt":
|
|||
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|||
|
f.write(results.get("continuous_text", ""))
|
|||
|
|
|||
|
elif format == "srt":
|
|||
|
self._save_srt(results, output_path)
|
|||
|
|
|||
|
else:
|
|||
|
raise ValueError(f"不支持的格式: {format}")
|
|||
|
|
|||
|
logger.info(f"结果已保存: {output_path}")
|
|||
|
|
|||
|
def _save_srt(self, results: Dict, output_path: Path):
|
|||
|
"""保存为SRT字幕格式"""
|
|||
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|||
|
for i, subtitle in enumerate(results["subtitles"], 1):
|
|||
|
timestamp = subtitle["timestamp"]
|
|||
|
text = subtitle["text"]
|
|||
|
|
|||
|
# 估算字幕持续时间(3秒)
|
|||
|
start_time = self._format_time(timestamp)
|
|||
|
end_time = self._format_time(timestamp + 3.0)
|
|||
|
|
|||
|
f.write(f"{i}\n")
|
|||
|
f.write(f"{start_time} --> {end_time}\n")
|
|||
|
f.write(f"{text}\n\n")
|
|||
|
|
|||
|
def _format_time(self, seconds):
|
|||
|
"""格式化时间为SRT格式"""
|
|||
|
hours = int(seconds // 3600)
|
|||
|
minutes = int((seconds % 3600) // 60)
|
|||
|
seconds = seconds % 60
|
|||
|
return f"{hours:02d}:{minutes:02d}:{seconds:06.3f}".replace('.', ',')
|
|||
|
|
|||
|
def main():
|
|||
|
"""主函数"""
|
|||
|
parser = argparse.ArgumentParser(description="专业视频字幕OCR提取器")
|
|||
|
parser.add_argument("input", help="输入视频文件路径或帧图片目录")
|
|||
|
parser.add_argument("-e", "--engine", default="paddleocr",
|
|||
|
choices=["paddleocr", "easyocr", "both"],
|
|||
|
help="OCR引擎 (默认: paddleocr)")
|
|||
|
parser.add_argument("-l", "--language", default="ch",
|
|||
|
choices=["ch", "en", "ch_en"],
|
|||
|
help="语言设置 (默认: ch)")
|
|||
|
parser.add_argument("-i", "--interval", type=int, default=30,
|
|||
|
help="帧采样间隔 (默认: 30)")
|
|||
|
parser.add_argument("-c", "--confidence", type=float, default=0.5,
|
|||
|
help="置信度阈值 (默认: 0.5)")
|
|||
|
parser.add_argument("-o", "--output", default="subtitle_results",
|
|||
|
help="输出目录 (默认: subtitle_results)")
|
|||
|
parser.add_argument("-f", "--format", default="json",
|
|||
|
choices=["json", "txt", "srt"],
|
|||
|
help="输出格式 (默认: json)")
|
|||
|
parser.add_argument("--region", nargs=4, type=int, metavar=('X1', 'Y1', 'X2', 'Y2'),
|
|||
|
help="字幕区域坐标 (x1 y1 x2 y2)")
|
|||
|
|
|||
|
args = parser.parse_args()
|
|||
|
|
|||
|
# 创建提取器
|
|||
|
extractor = VideoSubtitleExtractor(
|
|||
|
ocr_engine=args.engine,
|
|||
|
language=args.language
|
|||
|
)
|
|||
|
|
|||
|
input_path = Path(args.input)
|
|||
|
|
|||
|
try:
|
|||
|
if input_path.is_file():
|
|||
|
# 处理视频文件
|
|||
|
logger.info("处理视频文件")
|
|||
|
results = extractor.extract_subtitles_from_video(
|
|||
|
args.input,
|
|||
|
sample_interval=args.interval,
|
|||
|
confidence_threshold=args.confidence,
|
|||
|
subtitle_region=tuple(args.region) if args.region else None
|
|||
|
)
|
|||
|
|
|||
|
elif input_path.is_dir():
|
|||
|
# 处理帧图片目录
|
|||
|
logger.info("处理帧图片目录")
|
|||
|
results = extractor.extract_subtitles_from_frames(
|
|||
|
args.input,
|
|||
|
confidence_threshold=args.confidence
|
|||
|
)
|
|||
|
|
|||
|
else:
|
|||
|
logger.error(f"输入路径无效: {args.input}")
|
|||
|
return
|
|||
|
|
|||
|
# 保存结果
|
|||
|
output_dir = Path(args.output)
|
|||
|
output_dir.mkdir(exist_ok=True)
|
|||
|
|
|||
|
input_name = input_path.stem
|
|||
|
output_file = output_dir / f"{input_name}_subtitles.{args.format}"
|
|||
|
|
|||
|
extractor.save_results(results, output_file, args.format)
|
|||
|
|
|||
|
# 显示结果
|
|||
|
print(f"\n字幕提取完成!")
|
|||
|
print(f"输入: {args.input}")
|
|||
|
print(f"输出: {output_file}")
|
|||
|
print(f"提取时长: {results.get('extract_time', 0):.2f} 秒")
|
|||
|
print(f"字幕片段: {results['stats']['filtered_detections']} 个")
|
|||
|
print(f"唯一文本: {results['stats']['unique_texts']} 条")
|
|||
|
print(f"平均置信度: {results['stats']['average_confidence']:.3f}")
|
|||
|
print(f"文本总长度: {results['stats']['text_length']} 字符")
|
|||
|
|
|||
|
if results.get("continuous_text"):
|
|||
|
print(f"\n提取的连续文本:")
|
|||
|
print("-" * 50)
|
|||
|
print(results["continuous_text"][:500] + "..." if len(results["continuous_text"]) > 500 else results["continuous_text"])
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"程序执行出错: {str(e)}")
|
|||
|
return
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
main()
|