496 lines
19 KiB
Python
496 lines
19 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
专业视频字幕OCR提取器
|
||
支持PaddleOCR和EasyOCR两种引擎
|
||
"""
|
||
|
||
import cv2
|
||
import os
|
||
import time
|
||
import json
|
||
import argparse
|
||
from pathlib import Path
|
||
from datetime import datetime
|
||
import logging
|
||
import numpy as np
|
||
from typing import List, Dict, Tuple
|
||
|
||
# 设置日志
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||
logger = logging.getLogger(__name__)
|
||
|
||
def convert_numpy_types(obj):
|
||
"""转换NumPy类型为Python原生类型,用于JSON序列化"""
|
||
if isinstance(obj, np.integer):
|
||
return int(obj)
|
||
elif isinstance(obj, np.floating):
|
||
return float(obj)
|
||
elif isinstance(obj, np.ndarray):
|
||
return obj.tolist()
|
||
elif isinstance(obj, (list, tuple)):
|
||
return [convert_numpy_types(item) for item in obj]
|
||
elif isinstance(obj, dict):
|
||
return {key: convert_numpy_types(value) for key, value in obj.items()}
|
||
else:
|
||
return obj
|
||
|
||
class VideoSubtitleExtractor:
|
||
"""专业视频字幕OCR提取器"""
|
||
|
||
def __init__(self, ocr_engine="paddleocr", language="ch"):
|
||
"""
|
||
初始化OCR提取器
|
||
|
||
Args:
|
||
ocr_engine: OCR引擎 ("paddleocr", "easyocr", "both")
|
||
language: 语言设置 ("ch", "en", "ch_en")
|
||
"""
|
||
self.ocr_engine = ocr_engine
|
||
self.language = language
|
||
self.paddle_ocr = None
|
||
self.easy_ocr = None
|
||
self.load_ocr_engines()
|
||
|
||
def load_ocr_engines(self):
|
||
"""加载OCR引擎"""
|
||
logger.info(f"正在加载OCR引擎: {self.ocr_engine}")
|
||
|
||
if self.ocr_engine in ["paddleocr", "both"]:
|
||
try:
|
||
from paddleocr import PaddleOCR
|
||
|
||
# PaddleOCR语言映射
|
||
paddle_lang_map = {
|
||
"ch": "ch",
|
||
"en": "en",
|
||
"ch_en": "ch" # PaddleOCR支持ch包含中英文
|
||
}
|
||
paddle_lang = paddle_lang_map.get(self.language, "ch")
|
||
|
||
self.paddle_ocr = PaddleOCR(
|
||
use_textline_orientation=True,
|
||
lang=paddle_lang
|
||
)
|
||
logger.info("PaddleOCR加载完成")
|
||
except ImportError:
|
||
logger.error("请安装PaddleOCR: pip install paddleocr")
|
||
raise
|
||
|
||
if self.ocr_engine in ["easyocr", "both"]:
|
||
try:
|
||
import easyocr
|
||
|
||
# EasyOCR语言映射
|
||
if self.language == "ch":
|
||
lang_list = ['ch_sim'] # 简体中文
|
||
elif self.language == "en":
|
||
lang_list = ['en']
|
||
elif self.language == "ch_en":
|
||
lang_list = ['ch_sim', 'en'] # 中英文混合
|
||
else:
|
||
lang_list = ['ch_sim'] # 默认简体中文
|
||
|
||
self.easy_ocr = easyocr.Reader(lang_list)
|
||
logger.info("EasyOCR加载完成")
|
||
except ImportError:
|
||
logger.error("请安装EasyOCR: pip install easyocr")
|
||
raise
|
||
|
||
def extract_subtitles_from_video(self, video_path: str,
|
||
sample_interval: int = 30,
|
||
confidence_threshold: float = 0.5,
|
||
subtitle_region: Tuple = None) -> Dict:
|
||
"""
|
||
从视频中提取字幕
|
||
|
||
Args:
|
||
video_path: 视频文件路径
|
||
sample_interval: 采样间隔(帧数)
|
||
confidence_threshold: 置信度阈值
|
||
subtitle_region: 字幕区域 (x1, y1, x2, y2),None表示全画面
|
||
|
||
Returns:
|
||
Dict: 提取结果
|
||
"""
|
||
video_path = Path(video_path)
|
||
if not video_path.exists():
|
||
raise FileNotFoundError(f"视频文件不存在: {video_path}")
|
||
|
||
logger.info(f"开始提取视频字幕: {video_path}")
|
||
logger.info(f"采样间隔: {sample_interval}帧, 置信度阈值: {confidence_threshold}")
|
||
|
||
start_time = time.time()
|
||
|
||
# 打开视频
|
||
cap = cv2.VideoCapture(str(video_path))
|
||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||
duration = total_frames / fps
|
||
|
||
logger.info(f"视频信息: {total_frames}帧, {fps:.2f}FPS, {duration:.2f}秒")
|
||
|
||
results = {
|
||
"video_path": str(video_path),
|
||
"duration": duration,
|
||
"fps": fps,
|
||
"total_frames": total_frames,
|
||
"sample_interval": sample_interval,
|
||
"confidence_threshold": confidence_threshold,
|
||
"subtitle_region": subtitle_region,
|
||
"subtitles": [],
|
||
"unique_texts": set(),
|
||
"ocr_engine": self.ocr_engine,
|
||
"timestamp": datetime.now().isoformat()
|
||
}
|
||
|
||
frame_count = 0
|
||
processed_frames = 0
|
||
|
||
try:
|
||
while True:
|
||
ret, frame = cap.read()
|
||
if not ret:
|
||
break
|
||
|
||
# 按间隔采样
|
||
if frame_count % sample_interval == 0:
|
||
timestamp = frame_count / fps
|
||
|
||
# 裁剪字幕区域
|
||
if subtitle_region:
|
||
x1, y1, x2, y2 = subtitle_region
|
||
frame_crop = frame[y1:y2, x1:x2]
|
||
else:
|
||
frame_crop = frame
|
||
|
||
# OCR识别
|
||
ocr_results = self._ocr_frame(frame_crop, timestamp, confidence_threshold)
|
||
|
||
if ocr_results:
|
||
results["subtitles"].extend(ocr_results)
|
||
for result in ocr_results:
|
||
results["unique_texts"].add(result["text"])
|
||
|
||
processed_frames += 1
|
||
|
||
# 显示进度
|
||
if processed_frames % 10 == 0:
|
||
progress = (frame_count / total_frames) * 100
|
||
logger.info(f"处理进度: {progress:.1f}% ({processed_frames}帧)")
|
||
|
||
frame_count += 1
|
||
|
||
finally:
|
||
cap.release()
|
||
|
||
# 转换set为list
|
||
results["unique_texts"] = list(results["unique_texts"])
|
||
|
||
# 后处理:去重和合并
|
||
results = self._postprocess_subtitles(results)
|
||
|
||
extract_time = time.time() - start_time
|
||
results["extract_time"] = extract_time
|
||
|
||
logger.info(f"字幕提取完成,耗时: {extract_time:.2f}秒")
|
||
logger.info(f"提取到 {len(results['subtitles'])} 个字幕片段")
|
||
logger.info(f"唯一文本 {len(results['unique_texts'])} 条")
|
||
|
||
return results
|
||
|
||
def _ocr_frame(self, frame: np.ndarray, timestamp: float, confidence_threshold: float) -> List[Dict]:
|
||
"""对单帧进行OCR识别"""
|
||
ocr_results = []
|
||
|
||
# PaddleOCR识别
|
||
if self.paddle_ocr:
|
||
try:
|
||
paddle_result = self.paddle_ocr.ocr(frame)
|
||
if paddle_result and paddle_result[0]:
|
||
for detection in paddle_result[0]:
|
||
if detection:
|
||
# 处理不同版本的返回格式
|
||
try:
|
||
if len(detection) == 2:
|
||
# 旧版本格式: [bbox, (text, confidence)]
|
||
bbox, (text, confidence) = detection
|
||
elif len(detection) == 3:
|
||
# 新版本格式: [bbox, text, confidence]
|
||
bbox, text, confidence = detection
|
||
else:
|
||
# 其他格式,跳过
|
||
continue
|
||
|
||
if confidence >= confidence_threshold and text.strip():
|
||
ocr_results.append({
|
||
"timestamp": timestamp,
|
||
"text": text.strip(),
|
||
"confidence": confidence,
|
||
"bbox": bbox,
|
||
"engine": "PaddleOCR"
|
||
})
|
||
except Exception as parse_error:
|
||
logger.debug(f"解析单个检测结果失败: {parse_error}, 数据: {detection}")
|
||
continue
|
||
except Exception as e:
|
||
logger.warning(f"PaddleOCR识别失败 (时间戳:{timestamp:.2f}): {e}")
|
||
|
||
# EasyOCR识别
|
||
if self.easy_ocr:
|
||
try:
|
||
easy_result = self.easy_ocr.readtext(frame)
|
||
for detection in easy_result:
|
||
bbox, text, confidence = detection
|
||
if confidence >= confidence_threshold and text.strip():
|
||
ocr_results.append({
|
||
"timestamp": timestamp,
|
||
"text": text.strip(),
|
||
"confidence": confidence,
|
||
"bbox": bbox,
|
||
"engine": "EasyOCR"
|
||
})
|
||
except Exception as e:
|
||
logger.warning(f"EasyOCR识别失败 (时间戳:{timestamp:.2f}): {e}")
|
||
|
||
return ocr_results
|
||
|
||
def _postprocess_subtitles(self, results: Dict) -> Dict:
|
||
"""后处理字幕结果"""
|
||
subtitles = results["subtitles"]
|
||
|
||
# 按时间戳排序
|
||
subtitles.sort(key=lambda x: x["timestamp"])
|
||
|
||
# 去重相邻重复文本
|
||
filtered_subtitles = []
|
||
last_text = ""
|
||
last_timestamp = -1
|
||
|
||
for subtitle in subtitles:
|
||
text = subtitle["text"]
|
||
timestamp = subtitle["timestamp"]
|
||
|
||
# 如果文本不同,或者时间间隔较大,则保留
|
||
if text != last_text or (timestamp - last_timestamp) > 5.0:
|
||
filtered_subtitles.append(subtitle)
|
||
last_text = text
|
||
last_timestamp = timestamp
|
||
|
||
results["subtitles"] = filtered_subtitles
|
||
results["filtered_count"] = len(subtitles) - len(filtered_subtitles)
|
||
|
||
# 生成连续文本
|
||
results["continuous_text"] = " ".join([s["text"] for s in filtered_subtitles])
|
||
|
||
# 统计信息
|
||
results["stats"] = {
|
||
"total_detections": len(subtitles),
|
||
"filtered_detections": len(filtered_subtitles),
|
||
"unique_texts": len(results["unique_texts"]),
|
||
"text_length": len(results["continuous_text"]),
|
||
"average_confidence": np.mean([s["confidence"] for s in filtered_subtitles]) if filtered_subtitles else 0
|
||
}
|
||
|
||
return results
|
||
|
||
def extract_subtitles_from_frames(self, frame_dir: str, confidence_threshold: float = 0.5) -> Dict:
|
||
"""从帧图片目录中提取字幕"""
|
||
frame_dir = Path(frame_dir)
|
||
if not frame_dir.exists():
|
||
raise FileNotFoundError(f"帧目录不存在: {frame_dir}")
|
||
|
||
logger.info(f"开始从帧目录提取字幕: {frame_dir}")
|
||
|
||
# 获取所有图片文件
|
||
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp']
|
||
image_files = []
|
||
for ext in image_extensions:
|
||
image_files.extend(frame_dir.glob(f"*{ext}"))
|
||
image_files.extend(frame_dir.glob(f"*{ext.upper()}"))
|
||
|
||
image_files.sort()
|
||
|
||
if not image_files:
|
||
logger.warning(f"在目录 {frame_dir} 中未找到图片文件")
|
||
return {}
|
||
|
||
logger.info(f"找到 {len(image_files)} 个图片文件")
|
||
|
||
results = {
|
||
"frame_dir": str(frame_dir),
|
||
"total_images": len(image_files),
|
||
"confidence_threshold": confidence_threshold,
|
||
"subtitles": [],
|
||
"unique_texts": set(),
|
||
"ocr_engine": self.ocr_engine,
|
||
"timestamp": datetime.now().isoformat()
|
||
}
|
||
|
||
for i, image_file in enumerate(image_files):
|
||
try:
|
||
frame = cv2.imread(str(image_file))
|
||
if frame is not None:
|
||
# 使用文件名作为时间戳(如果可能)
|
||
timestamp = float(i) # 简单使用索引
|
||
|
||
ocr_results = self._ocr_frame(frame, timestamp, confidence_threshold)
|
||
|
||
if ocr_results:
|
||
results["subtitles"].extend(ocr_results)
|
||
for result in ocr_results:
|
||
results["unique_texts"].add(result["text"])
|
||
|
||
# 显示进度
|
||
if (i + 1) % 10 == 0:
|
||
progress = ((i + 1) / len(image_files)) * 100
|
||
logger.info(f"处理进度: {progress:.1f}%")
|
||
|
||
except Exception as e:
|
||
logger.warning(f"处理图片 {image_file} 时出错: {e}")
|
||
|
||
# 转换set为list
|
||
results["unique_texts"] = list(results["unique_texts"])
|
||
|
||
# 后处理
|
||
results = self._postprocess_subtitles(results)
|
||
|
||
logger.info(f"字幕提取完成,提取到 {len(results['subtitles'])} 个字幕片段")
|
||
|
||
return results
|
||
|
||
def save_results(self, results: Dict, output_path: str, format: str = "json"):
|
||
"""保存提取结果"""
|
||
output_path = Path(output_path)
|
||
|
||
if format == "json":
|
||
# 序列化结果(处理set类型和NumPy类型)
|
||
results_copy = results.copy()
|
||
if "unique_texts" in results_copy and isinstance(results_copy["unique_texts"], set):
|
||
results_copy["unique_texts"] = list(results_copy["unique_texts"])
|
||
|
||
# 转换NumPy类型
|
||
results_copy = convert_numpy_types(results_copy)
|
||
|
||
with open(output_path, 'w', encoding='utf-8') as f:
|
||
json.dump(results_copy, f, ensure_ascii=False, indent=2)
|
||
|
||
elif format == "txt":
|
||
with open(output_path, 'w', encoding='utf-8') as f:
|
||
f.write(results.get("continuous_text", ""))
|
||
|
||
elif format == "srt":
|
||
self._save_srt(results, output_path)
|
||
|
||
else:
|
||
raise ValueError(f"不支持的格式: {format}")
|
||
|
||
logger.info(f"结果已保存: {output_path}")
|
||
|
||
def _save_srt(self, results: Dict, output_path: Path):
|
||
"""保存为SRT字幕格式"""
|
||
with open(output_path, 'w', encoding='utf-8') as f:
|
||
for i, subtitle in enumerate(results["subtitles"], 1):
|
||
timestamp = subtitle["timestamp"]
|
||
text = subtitle["text"]
|
||
|
||
# 估算字幕持续时间(3秒)
|
||
start_time = self._format_time(timestamp)
|
||
end_time = self._format_time(timestamp + 3.0)
|
||
|
||
f.write(f"{i}\n")
|
||
f.write(f"{start_time} --> {end_time}\n")
|
||
f.write(f"{text}\n\n")
|
||
|
||
def _format_time(self, seconds):
|
||
"""格式化时间为SRT格式"""
|
||
hours = int(seconds // 3600)
|
||
minutes = int((seconds % 3600) // 60)
|
||
seconds = seconds % 60
|
||
return f"{hours:02d}:{minutes:02d}:{seconds:06.3f}".replace('.', ',')
|
||
|
||
def main():
|
||
"""主函数"""
|
||
parser = argparse.ArgumentParser(description="专业视频字幕OCR提取器")
|
||
parser.add_argument("input", help="输入视频文件路径或帧图片目录")
|
||
parser.add_argument("-e", "--engine", default="paddleocr",
|
||
choices=["paddleocr", "easyocr", "both"],
|
||
help="OCR引擎 (默认: paddleocr)")
|
||
parser.add_argument("-l", "--language", default="ch",
|
||
choices=["ch", "en", "ch_en"],
|
||
help="语言设置 (默认: ch)")
|
||
parser.add_argument("-i", "--interval", type=int, default=30,
|
||
help="帧采样间隔 (默认: 30)")
|
||
parser.add_argument("-c", "--confidence", type=float, default=0.5,
|
||
help="置信度阈值 (默认: 0.5)")
|
||
parser.add_argument("-o", "--output", default="subtitle_results",
|
||
help="输出目录 (默认: subtitle_results)")
|
||
parser.add_argument("-f", "--format", default="json",
|
||
choices=["json", "txt", "srt"],
|
||
help="输出格式 (默认: json)")
|
||
parser.add_argument("--region", nargs=4, type=int, metavar=('X1', 'Y1', 'X2', 'Y2'),
|
||
help="字幕区域坐标 (x1 y1 x2 y2)")
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 创建提取器
|
||
extractor = VideoSubtitleExtractor(
|
||
ocr_engine=args.engine,
|
||
language=args.language
|
||
)
|
||
|
||
input_path = Path(args.input)
|
||
|
||
try:
|
||
if input_path.is_file():
|
||
# 处理视频文件
|
||
logger.info("处理视频文件")
|
||
results = extractor.extract_subtitles_from_video(
|
||
args.input,
|
||
sample_interval=args.interval,
|
||
confidence_threshold=args.confidence,
|
||
subtitle_region=tuple(args.region) if args.region else None
|
||
)
|
||
|
||
elif input_path.is_dir():
|
||
# 处理帧图片目录
|
||
logger.info("处理帧图片目录")
|
||
results = extractor.extract_subtitles_from_frames(
|
||
args.input,
|
||
confidence_threshold=args.confidence
|
||
)
|
||
|
||
else:
|
||
logger.error(f"输入路径无效: {args.input}")
|
||
return
|
||
|
||
# 保存结果
|
||
output_dir = Path(args.output)
|
||
output_dir.mkdir(exist_ok=True)
|
||
|
||
input_name = input_path.stem
|
||
output_file = output_dir / f"{input_name}_subtitles.{args.format}"
|
||
|
||
extractor.save_results(results, output_file, args.format)
|
||
|
||
# 显示结果
|
||
print(f"\n字幕提取完成!")
|
||
print(f"输入: {args.input}")
|
||
print(f"输出: {output_file}")
|
||
print(f"提取时长: {results.get('extract_time', 0):.2f} 秒")
|
||
print(f"字幕片段: {results['stats']['filtered_detections']} 个")
|
||
print(f"唯一文本: {results['stats']['unique_texts']} 条")
|
||
print(f"平均置信度: {results['stats']['average_confidence']:.3f}")
|
||
print(f"文本总长度: {results['stats']['text_length']} 字符")
|
||
|
||
if results.get("continuous_text"):
|
||
print(f"\n提取的连续文本:")
|
||
print("-" * 50)
|
||
print(results["continuous_text"][:500] + "..." if len(results["continuous_text"]) > 500 else results["continuous_text"])
|
||
|
||
except Exception as e:
|
||
logger.error(f"程序执行出错: {str(e)}")
|
||
return
|
||
|
||
if __name__ == "__main__":
|
||
main() |