hot_video_analyse/code/ocr_subtitle_extractor.py

496 lines
19 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
专业视频字幕OCR提取器
支持PaddleOCR和EasyOCR两种引擎
"""
import cv2
import os
import time
import json
import argparse
from pathlib import Path
from datetime import datetime
import logging
import numpy as np
from typing import List, Dict, Tuple
# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def convert_numpy_types(obj):
"""转换NumPy类型为Python原生类型用于JSON序列化"""
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, (list, tuple)):
return [convert_numpy_types(item) for item in obj]
elif isinstance(obj, dict):
return {key: convert_numpy_types(value) for key, value in obj.items()}
else:
return obj
class VideoSubtitleExtractor:
"""专业视频字幕OCR提取器"""
def __init__(self, ocr_engine="paddleocr", language="ch"):
"""
初始化OCR提取器
Args:
ocr_engine: OCR引擎 ("paddleocr", "easyocr", "both")
language: 语言设置 ("ch", "en", "ch_en")
"""
self.ocr_engine = ocr_engine
self.language = language
self.paddle_ocr = None
self.easy_ocr = None
self.load_ocr_engines()
def load_ocr_engines(self):
"""加载OCR引擎"""
logger.info(f"正在加载OCR引擎: {self.ocr_engine}")
if self.ocr_engine in ["paddleocr", "both"]:
try:
from paddleocr import PaddleOCR
# PaddleOCR语言映射
paddle_lang_map = {
"ch": "ch",
"en": "en",
"ch_en": "ch" # PaddleOCR支持ch包含中英文
}
paddle_lang = paddle_lang_map.get(self.language, "ch")
self.paddle_ocr = PaddleOCR(
use_textline_orientation=True,
lang=paddle_lang
)
logger.info("PaddleOCR加载完成")
except ImportError:
logger.error("请安装PaddleOCR: pip install paddleocr")
raise
if self.ocr_engine in ["easyocr", "both"]:
try:
import easyocr
# EasyOCR语言映射
if self.language == "ch":
lang_list = ['ch_sim'] # 简体中文
elif self.language == "en":
lang_list = ['en']
elif self.language == "ch_en":
lang_list = ['ch_sim', 'en'] # 中英文混合
else:
lang_list = ['ch_sim'] # 默认简体中文
self.easy_ocr = easyocr.Reader(lang_list)
logger.info("EasyOCR加载完成")
except ImportError:
logger.error("请安装EasyOCR: pip install easyocr")
raise
def extract_subtitles_from_video(self, video_path: str,
sample_interval: int = 30,
confidence_threshold: float = 0.5,
subtitle_region: Tuple = None) -> Dict:
"""
从视频中提取字幕
Args:
video_path: 视频文件路径
sample_interval: 采样间隔(帧数)
confidence_threshold: 置信度阈值
subtitle_region: 字幕区域 (x1, y1, x2, y2)None表示全画面
Returns:
Dict: 提取结果
"""
video_path = Path(video_path)
if not video_path.exists():
raise FileNotFoundError(f"视频文件不存在: {video_path}")
logger.info(f"开始提取视频字幕: {video_path}")
logger.info(f"采样间隔: {sample_interval}帧, 置信度阈值: {confidence_threshold}")
start_time = time.time()
# 打开视频
cap = cv2.VideoCapture(str(video_path))
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
duration = total_frames / fps
logger.info(f"视频信息: {total_frames}帧, {fps:.2f}FPS, {duration:.2f}")
results = {
"video_path": str(video_path),
"duration": duration,
"fps": fps,
"total_frames": total_frames,
"sample_interval": sample_interval,
"confidence_threshold": confidence_threshold,
"subtitle_region": subtitle_region,
"subtitles": [],
"unique_texts": set(),
"ocr_engine": self.ocr_engine,
"timestamp": datetime.now().isoformat()
}
frame_count = 0
processed_frames = 0
try:
while True:
ret, frame = cap.read()
if not ret:
break
# 按间隔采样
if frame_count % sample_interval == 0:
timestamp = frame_count / fps
# 裁剪字幕区域
if subtitle_region:
x1, y1, x2, y2 = subtitle_region
frame_crop = frame[y1:y2, x1:x2]
else:
frame_crop = frame
# OCR识别
ocr_results = self._ocr_frame(frame_crop, timestamp, confidence_threshold)
if ocr_results:
results["subtitles"].extend(ocr_results)
for result in ocr_results:
results["unique_texts"].add(result["text"])
processed_frames += 1
# 显示进度
if processed_frames % 10 == 0:
progress = (frame_count / total_frames) * 100
logger.info(f"处理进度: {progress:.1f}% ({processed_frames}帧)")
frame_count += 1
finally:
cap.release()
# 转换set为list
results["unique_texts"] = list(results["unique_texts"])
# 后处理:去重和合并
results = self._postprocess_subtitles(results)
extract_time = time.time() - start_time
results["extract_time"] = extract_time
logger.info(f"字幕提取完成,耗时: {extract_time:.2f}")
logger.info(f"提取到 {len(results['subtitles'])} 个字幕片段")
logger.info(f"唯一文本 {len(results['unique_texts'])}")
return results
def _ocr_frame(self, frame: np.ndarray, timestamp: float, confidence_threshold: float) -> List[Dict]:
"""对单帧进行OCR识别"""
ocr_results = []
# PaddleOCR识别
if self.paddle_ocr:
try:
paddle_result = self.paddle_ocr.ocr(frame)
if paddle_result and paddle_result[0]:
for detection in paddle_result[0]:
if detection:
# 处理不同版本的返回格式
try:
if len(detection) == 2:
# 旧版本格式: [bbox, (text, confidence)]
bbox, (text, confidence) = detection
elif len(detection) == 3:
# 新版本格式: [bbox, text, confidence]
bbox, text, confidence = detection
else:
# 其他格式,跳过
continue
if confidence >= confidence_threshold and text.strip():
ocr_results.append({
"timestamp": timestamp,
"text": text.strip(),
"confidence": confidence,
"bbox": bbox,
"engine": "PaddleOCR"
})
except Exception as parse_error:
logger.debug(f"解析单个检测结果失败: {parse_error}, 数据: {detection}")
continue
except Exception as e:
logger.warning(f"PaddleOCR识别失败 (时间戳:{timestamp:.2f}): {e}")
# EasyOCR识别
if self.easy_ocr:
try:
easy_result = self.easy_ocr.readtext(frame)
for detection in easy_result:
bbox, text, confidence = detection
if confidence >= confidence_threshold and text.strip():
ocr_results.append({
"timestamp": timestamp,
"text": text.strip(),
"confidence": confidence,
"bbox": bbox,
"engine": "EasyOCR"
})
except Exception as e:
logger.warning(f"EasyOCR识别失败 (时间戳:{timestamp:.2f}): {e}")
return ocr_results
def _postprocess_subtitles(self, results: Dict) -> Dict:
"""后处理字幕结果"""
subtitles = results["subtitles"]
# 按时间戳排序
subtitles.sort(key=lambda x: x["timestamp"])
# 去重相邻重复文本
filtered_subtitles = []
last_text = ""
last_timestamp = -1
for subtitle in subtitles:
text = subtitle["text"]
timestamp = subtitle["timestamp"]
# 如果文本不同,或者时间间隔较大,则保留
if text != last_text or (timestamp - last_timestamp) > 5.0:
filtered_subtitles.append(subtitle)
last_text = text
last_timestamp = timestamp
results["subtitles"] = filtered_subtitles
results["filtered_count"] = len(subtitles) - len(filtered_subtitles)
# 生成连续文本
results["continuous_text"] = " ".join([s["text"] for s in filtered_subtitles])
# 统计信息
results["stats"] = {
"total_detections": len(subtitles),
"filtered_detections": len(filtered_subtitles),
"unique_texts": len(results["unique_texts"]),
"text_length": len(results["continuous_text"]),
"average_confidence": np.mean([s["confidence"] for s in filtered_subtitles]) if filtered_subtitles else 0
}
return results
def extract_subtitles_from_frames(self, frame_dir: str, confidence_threshold: float = 0.5) -> Dict:
"""从帧图片目录中提取字幕"""
frame_dir = Path(frame_dir)
if not frame_dir.exists():
raise FileNotFoundError(f"帧目录不存在: {frame_dir}")
logger.info(f"开始从帧目录提取字幕: {frame_dir}")
# 获取所有图片文件
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp']
image_files = []
for ext in image_extensions:
image_files.extend(frame_dir.glob(f"*{ext}"))
image_files.extend(frame_dir.glob(f"*{ext.upper()}"))
image_files.sort()
if not image_files:
logger.warning(f"在目录 {frame_dir} 中未找到图片文件")
return {}
logger.info(f"找到 {len(image_files)} 个图片文件")
results = {
"frame_dir": str(frame_dir),
"total_images": len(image_files),
"confidence_threshold": confidence_threshold,
"subtitles": [],
"unique_texts": set(),
"ocr_engine": self.ocr_engine,
"timestamp": datetime.now().isoformat()
}
for i, image_file in enumerate(image_files):
try:
frame = cv2.imread(str(image_file))
if frame is not None:
# 使用文件名作为时间戳(如果可能)
timestamp = float(i) # 简单使用索引
ocr_results = self._ocr_frame(frame, timestamp, confidence_threshold)
if ocr_results:
results["subtitles"].extend(ocr_results)
for result in ocr_results:
results["unique_texts"].add(result["text"])
# 显示进度
if (i + 1) % 10 == 0:
progress = ((i + 1) / len(image_files)) * 100
logger.info(f"处理进度: {progress:.1f}%")
except Exception as e:
logger.warning(f"处理图片 {image_file} 时出错: {e}")
# 转换set为list
results["unique_texts"] = list(results["unique_texts"])
# 后处理
results = self._postprocess_subtitles(results)
logger.info(f"字幕提取完成,提取到 {len(results['subtitles'])} 个字幕片段")
return results
def save_results(self, results: Dict, output_path: str, format: str = "json"):
"""保存提取结果"""
output_path = Path(output_path)
if format == "json":
# 序列化结果处理set类型和NumPy类型
results_copy = results.copy()
if "unique_texts" in results_copy and isinstance(results_copy["unique_texts"], set):
results_copy["unique_texts"] = list(results_copy["unique_texts"])
# 转换NumPy类型
results_copy = convert_numpy_types(results_copy)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(results_copy, f, ensure_ascii=False, indent=2)
elif format == "txt":
with open(output_path, 'w', encoding='utf-8') as f:
f.write(results.get("continuous_text", ""))
elif format == "srt":
self._save_srt(results, output_path)
else:
raise ValueError(f"不支持的格式: {format}")
logger.info(f"结果已保存: {output_path}")
def _save_srt(self, results: Dict, output_path: Path):
"""保存为SRT字幕格式"""
with open(output_path, 'w', encoding='utf-8') as f:
for i, subtitle in enumerate(results["subtitles"], 1):
timestamp = subtitle["timestamp"]
text = subtitle["text"]
# 估算字幕持续时间3秒
start_time = self._format_time(timestamp)
end_time = self._format_time(timestamp + 3.0)
f.write(f"{i}\n")
f.write(f"{start_time} --> {end_time}\n")
f.write(f"{text}\n\n")
def _format_time(self, seconds):
"""格式化时间为SRT格式"""
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
seconds = seconds % 60
return f"{hours:02d}:{minutes:02d}:{seconds:06.3f}".replace('.', ',')
def main():
"""主函数"""
parser = argparse.ArgumentParser(description="专业视频字幕OCR提取器")
parser.add_argument("input", help="输入视频文件路径或帧图片目录")
parser.add_argument("-e", "--engine", default="paddleocr",
choices=["paddleocr", "easyocr", "both"],
help="OCR引擎 (默认: paddleocr)")
parser.add_argument("-l", "--language", default="ch",
choices=["ch", "en", "ch_en"],
help="语言设置 (默认: ch)")
parser.add_argument("-i", "--interval", type=int, default=30,
help="帧采样间隔 (默认: 30)")
parser.add_argument("-c", "--confidence", type=float, default=0.5,
help="置信度阈值 (默认: 0.5)")
parser.add_argument("-o", "--output", default="subtitle_results",
help="输出目录 (默认: subtitle_results)")
parser.add_argument("-f", "--format", default="json",
choices=["json", "txt", "srt"],
help="输出格式 (默认: json)")
parser.add_argument("--region", nargs=4, type=int, metavar=('X1', 'Y1', 'X2', 'Y2'),
help="字幕区域坐标 (x1 y1 x2 y2)")
args = parser.parse_args()
# 创建提取器
extractor = VideoSubtitleExtractor(
ocr_engine=args.engine,
language=args.language
)
input_path = Path(args.input)
try:
if input_path.is_file():
# 处理视频文件
logger.info("处理视频文件")
results = extractor.extract_subtitles_from_video(
args.input,
sample_interval=args.interval,
confidence_threshold=args.confidence,
subtitle_region=tuple(args.region) if args.region else None
)
elif input_path.is_dir():
# 处理帧图片目录
logger.info("处理帧图片目录")
results = extractor.extract_subtitles_from_frames(
args.input,
confidence_threshold=args.confidence
)
else:
logger.error(f"输入路径无效: {args.input}")
return
# 保存结果
output_dir = Path(args.output)
output_dir.mkdir(exist_ok=True)
input_name = input_path.stem
output_file = output_dir / f"{input_name}_subtitles.{args.format}"
extractor.save_results(results, output_file, args.format)
# 显示结果
print(f"\n字幕提取完成!")
print(f"输入: {args.input}")
print(f"输出: {output_file}")
print(f"提取时长: {results.get('extract_time', 0):.2f}")
print(f"字幕片段: {results['stats']['filtered_detections']}")
print(f"唯一文本: {results['stats']['unique_texts']}")
print(f"平均置信度: {results['stats']['average_confidence']:.3f}")
print(f"文本总长度: {results['stats']['text_length']} 字符")
if results.get("continuous_text"):
print(f"\n提取的连续文本:")
print("-" * 50)
print(results["continuous_text"][:500] + "..." if len(results["continuous_text"]) > 500 else results["continuous_text"])
except Exception as e:
logger.error(f"程序执行出错: {str(e)}")
return
if __name__ == "__main__":
main()