2025-06-09 09:52:32 +08:00
|
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
|
"""
|
|
|
|
|
|
专业视频字幕OCR提取器
|
2025-06-30 10:27:06 +08:00
|
|
|
|
支持PaddleOCR、EasyOCR和CnOCR三种引擎
|
2025-06-09 09:52:32 +08:00
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import cv2
|
|
|
|
|
|
import os
|
|
|
|
|
|
import time
|
|
|
|
|
|
import json
|
|
|
|
|
|
import argparse
|
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
import logging
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
from typing import List, Dict, Tuple
|
|
|
|
|
|
|
|
|
|
|
|
# 设置日志
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
def convert_numpy_types(obj):
|
|
|
|
|
|
"""转换NumPy类型为Python原生类型,用于JSON序列化"""
|
|
|
|
|
|
if isinstance(obj, np.integer):
|
|
|
|
|
|
return int(obj)
|
|
|
|
|
|
elif isinstance(obj, np.floating):
|
|
|
|
|
|
return float(obj)
|
|
|
|
|
|
elif isinstance(obj, np.ndarray):
|
|
|
|
|
|
return obj.tolist()
|
|
|
|
|
|
elif isinstance(obj, (list, tuple)):
|
|
|
|
|
|
return [convert_numpy_types(item) for item in obj]
|
|
|
|
|
|
elif isinstance(obj, dict):
|
|
|
|
|
|
return {key: convert_numpy_types(value) for key, value in obj.items()}
|
|
|
|
|
|
else:
|
|
|
|
|
|
return obj
|
|
|
|
|
|
|
|
|
|
|
|
class VideoSubtitleExtractor:
|
|
|
|
|
|
"""专业视频字幕OCR提取器"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, ocr_engine="paddleocr", language="ch"):
|
|
|
|
|
|
"""
|
|
|
|
|
|
初始化OCR提取器
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
2025-06-30 10:27:06 +08:00
|
|
|
|
ocr_engine: OCR引擎 ("paddleocr", "easyocr", "cnocr", "all")
|
2025-06-09 09:52:32 +08:00
|
|
|
|
language: 语言设置 ("ch", "en", "ch_en")
|
|
|
|
|
|
"""
|
|
|
|
|
|
self.ocr_engine = ocr_engine
|
|
|
|
|
|
self.language = language
|
|
|
|
|
|
self.paddle_ocr = None
|
|
|
|
|
|
self.easy_ocr = None
|
2025-06-30 10:27:06 +08:00
|
|
|
|
self.cn_ocr = None
|
2025-06-09 09:52:32 +08:00
|
|
|
|
self.load_ocr_engines()
|
|
|
|
|
|
|
|
|
|
|
|
def load_ocr_engines(self):
|
|
|
|
|
|
"""加载OCR引擎"""
|
|
|
|
|
|
logger.info(f"正在加载OCR引擎: {self.ocr_engine}")
|
|
|
|
|
|
|
2025-06-30 10:27:06 +08:00
|
|
|
|
if self.ocr_engine in ["paddleocr", "all"]:
|
2025-06-09 09:52:32 +08:00
|
|
|
|
try:
|
|
|
|
|
|
from paddleocr import PaddleOCR
|
|
|
|
|
|
|
|
|
|
|
|
# PaddleOCR语言映射
|
|
|
|
|
|
paddle_lang_map = {
|
|
|
|
|
|
"ch": "ch",
|
|
|
|
|
|
"en": "en",
|
|
|
|
|
|
"ch_en": "ch" # PaddleOCR支持ch包含中英文
|
|
|
|
|
|
}
|
|
|
|
|
|
paddle_lang = paddle_lang_map.get(self.language, "ch")
|
|
|
|
|
|
|
|
|
|
|
|
self.paddle_ocr = PaddleOCR(
|
|
|
|
|
|
use_textline_orientation=True,
|
2025-06-30 10:27:06 +08:00
|
|
|
|
lang=paddle_lang,
|
|
|
|
|
|
show_log=False # 减少日志输出
|
2025-06-09 09:52:32 +08:00
|
|
|
|
)
|
|
|
|
|
|
logger.info("PaddleOCR加载完成")
|
|
|
|
|
|
except ImportError:
|
|
|
|
|
|
logger.error("请安装PaddleOCR: pip install paddleocr")
|
2025-06-30 10:27:06 +08:00
|
|
|
|
if self.ocr_engine == "paddleocr":
|
|
|
|
|
|
raise
|
2025-06-09 09:52:32 +08:00
|
|
|
|
|
2025-06-30 10:27:06 +08:00
|
|
|
|
if self.ocr_engine in ["easyocr", "all"]:
|
2025-06-09 09:52:32 +08:00
|
|
|
|
try:
|
|
|
|
|
|
import easyocr
|
2025-06-30 10:27:06 +08:00
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
|
|
# 设置EasyOCR模型路径
|
|
|
|
|
|
model_storage_directory = '/root/autodl-tmp/llm/easyocr'
|
2025-06-09 09:52:32 +08:00
|
|
|
|
|
|
|
|
|
|
# EasyOCR语言映射
|
|
|
|
|
|
if self.language == "ch":
|
|
|
|
|
|
lang_list = ['ch_sim'] # 简体中文
|
|
|
|
|
|
elif self.language == "en":
|
|
|
|
|
|
lang_list = ['en']
|
|
|
|
|
|
elif self.language == "ch_en":
|
|
|
|
|
|
lang_list = ['ch_sim', 'en'] # 中英文混合
|
|
|
|
|
|
else:
|
|
|
|
|
|
lang_list = ['ch_sim'] # 默认简体中文
|
|
|
|
|
|
|
2025-06-30 10:27:06 +08:00
|
|
|
|
logger.info(f"EasyOCR模型路径: {model_storage_directory}")
|
|
|
|
|
|
self.easy_ocr = easyocr.Reader(
|
|
|
|
|
|
lang_list,
|
|
|
|
|
|
model_storage_directory=model_storage_directory
|
|
|
|
|
|
)
|
2025-06-09 09:52:32 +08:00
|
|
|
|
logger.info("EasyOCR加载完成")
|
|
|
|
|
|
except ImportError:
|
|
|
|
|
|
logger.error("请安装EasyOCR: pip install easyocr")
|
2025-06-30 10:27:06 +08:00
|
|
|
|
if self.ocr_engine == "easyocr":
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
if self.ocr_engine in ["cnocr", "all"]:
|
|
|
|
|
|
try:
|
|
|
|
|
|
from cnocr import CnOcr
|
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
|
|
# 设置CnOCR模型路径
|
|
|
|
|
|
cn_ocr_model_dir = "/root/autodl-tmp/llm/cnocr"
|
|
|
|
|
|
os.makedirs(cn_ocr_model_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
# 设置环境变量,指定CnOCR模型存储路径
|
|
|
|
|
|
os.environ['CNOCR_HOME'] = cn_ocr_model_dir
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"CnOCR模型路径: {cn_ocr_model_dir}")
|
|
|
|
|
|
logger.info("使用CnOCR默认模型配置")
|
|
|
|
|
|
|
|
|
|
|
|
# 使用默认配置,CnOCR会自动选择合适的模型
|
|
|
|
|
|
self.cn_ocr = CnOcr()
|
|
|
|
|
|
logger.info("CnOCR加载完成")
|
|
|
|
|
|
except ImportError:
|
|
|
|
|
|
logger.error("请安装CnOCR: pip install cnocr")
|
|
|
|
|
|
if self.ocr_engine == "cnocr":
|
|
|
|
|
|
raise
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"CnOCR初始化失败: {e}")
|
|
|
|
|
|
if self.ocr_engine == "cnocr":
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
def calculate_subtitle_region(self, frame_width: int, frame_height: int, position: str = "bottom") -> Tuple:
|
|
|
|
|
|
"""
|
|
|
|
|
|
根据视频尺寸和位置选择计算字幕区域
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
frame_width: 视频宽度
|
|
|
|
|
|
frame_height: 视频高度
|
|
|
|
|
|
position: 位置选择 ("full", "center", "bottom")
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
Tuple: (x1, y1, x2, y2) 字幕区域坐标
|
|
|
|
|
|
"""
|
|
|
|
|
|
if position == "full":
|
|
|
|
|
|
# 全屏区域
|
|
|
|
|
|
region = (0, 0, frame_width, frame_height)
|
|
|
|
|
|
elif position == "center":
|
|
|
|
|
|
# 居中区域:高度的0.4-0.6部分
|
|
|
|
|
|
y1 = int(frame_height * 0.5)
|
|
|
|
|
|
y2 = int(frame_height * 0.8)
|
|
|
|
|
|
region = (0, y1, frame_width, y2)
|
|
|
|
|
|
elif position == "bottom":
|
|
|
|
|
|
# 居下区域:高度的0.7-1.0部分
|
|
|
|
|
|
y1 = int(frame_height * 0.5)
|
|
|
|
|
|
y2 = frame_height
|
|
|
|
|
|
region = (0, y1, frame_width, y2)
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 默认全屏
|
|
|
|
|
|
region = (0, 0, frame_width, frame_height)
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"字幕区域 ({position}): {region}, 视频尺寸: {frame_width}x{frame_height}")
|
|
|
|
|
|
return region
|
2025-06-09 09:52:32 +08:00
|
|
|
|
|
|
|
|
|
|
def extract_subtitles_from_video(self, video_path: str,
|
|
|
|
|
|
sample_interval: int = 30,
|
|
|
|
|
|
confidence_threshold: float = 0.5,
|
2025-06-30 10:27:06 +08:00
|
|
|
|
subtitle_position: str = "bottom") -> Dict:
|
2025-06-09 09:52:32 +08:00
|
|
|
|
"""
|
|
|
|
|
|
从视频中提取字幕
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
video_path: 视频文件路径
|
|
|
|
|
|
sample_interval: 采样间隔(帧数)
|
|
|
|
|
|
confidence_threshold: 置信度阈值
|
2025-06-30 10:27:06 +08:00
|
|
|
|
subtitle_position: 字幕位置 ("full", "center", "bottom")
|
2025-06-09 09:52:32 +08:00
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
Dict: 提取结果
|
|
|
|
|
|
"""
|
|
|
|
|
|
video_path = Path(video_path)
|
|
|
|
|
|
if not video_path.exists():
|
|
|
|
|
|
raise FileNotFoundError(f"视频文件不存在: {video_path}")
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"开始提取视频字幕: {video_path}")
|
|
|
|
|
|
logger.info(f"采样间隔: {sample_interval}帧, 置信度阈值: {confidence_threshold}")
|
|
|
|
|
|
|
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
# 打开视频
|
|
|
|
|
|
cap = cv2.VideoCapture(str(video_path))
|
|
|
|
|
|
fps = cap.get(cv2.CAP_PROP_FPS)
|
|
|
|
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
|
|
|
|
duration = total_frames / fps
|
2025-06-30 10:27:06 +08:00
|
|
|
|
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
|
|
|
|
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
2025-06-09 09:52:32 +08:00
|
|
|
|
|
2025-06-30 10:27:06 +08:00
|
|
|
|
# 根据位置计算字幕区域
|
|
|
|
|
|
subtitle_region = self.calculate_subtitle_region(frame_width, frame_height, subtitle_position)
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"视频信息: {total_frames}帧, {fps:.2f}FPS, {duration:.2f}秒, 分辨率: {frame_width}x{frame_height}")
|
2025-06-09 09:52:32 +08:00
|
|
|
|
|
|
|
|
|
|
results = {
|
|
|
|
|
|
"video_path": str(video_path),
|
|
|
|
|
|
"duration": duration,
|
|
|
|
|
|
"fps": fps,
|
|
|
|
|
|
"total_frames": total_frames,
|
2025-06-30 10:27:06 +08:00
|
|
|
|
"frame_width": frame_width,
|
|
|
|
|
|
"frame_height": frame_height,
|
2025-06-09 09:52:32 +08:00
|
|
|
|
"sample_interval": sample_interval,
|
|
|
|
|
|
"confidence_threshold": confidence_threshold,
|
2025-06-30 10:27:06 +08:00
|
|
|
|
"subtitle_position": subtitle_position,
|
2025-06-09 09:52:32 +08:00
|
|
|
|
"subtitle_region": subtitle_region,
|
|
|
|
|
|
"subtitles": [],
|
|
|
|
|
|
"unique_texts": set(),
|
|
|
|
|
|
"ocr_engine": self.ocr_engine,
|
|
|
|
|
|
"timestamp": datetime.now().isoformat()
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
frame_count = 0
|
|
|
|
|
|
processed_frames = 0
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
while True:
|
|
|
|
|
|
ret, frame = cap.read()
|
|
|
|
|
|
if not ret:
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
# 按间隔采样
|
|
|
|
|
|
if frame_count % sample_interval == 0:
|
|
|
|
|
|
timestamp = frame_count / fps
|
|
|
|
|
|
|
|
|
|
|
|
# 裁剪字幕区域
|
|
|
|
|
|
if subtitle_region:
|
|
|
|
|
|
x1, y1, x2, y2 = subtitle_region
|
|
|
|
|
|
frame_crop = frame[y1:y2, x1:x2]
|
|
|
|
|
|
else:
|
|
|
|
|
|
frame_crop = frame
|
|
|
|
|
|
|
|
|
|
|
|
# OCR识别
|
|
|
|
|
|
ocr_results = self._ocr_frame(frame_crop, timestamp, confidence_threshold)
|
|
|
|
|
|
|
|
|
|
|
|
if ocr_results:
|
|
|
|
|
|
results["subtitles"].extend(ocr_results)
|
|
|
|
|
|
for result in ocr_results:
|
|
|
|
|
|
results["unique_texts"].add(result["text"])
|
|
|
|
|
|
|
|
|
|
|
|
processed_frames += 1
|
|
|
|
|
|
|
|
|
|
|
|
# 显示进度
|
|
|
|
|
|
if processed_frames % 10 == 0:
|
|
|
|
|
|
progress = (frame_count / total_frames) * 100
|
|
|
|
|
|
logger.info(f"处理进度: {progress:.1f}% ({processed_frames}帧)")
|
|
|
|
|
|
|
|
|
|
|
|
frame_count += 1
|
|
|
|
|
|
|
|
|
|
|
|
finally:
|
|
|
|
|
|
cap.release()
|
|
|
|
|
|
|
|
|
|
|
|
# 转换set为list
|
|
|
|
|
|
results["unique_texts"] = list(results["unique_texts"])
|
|
|
|
|
|
|
|
|
|
|
|
# 后处理:去重和合并
|
|
|
|
|
|
results = self._postprocess_subtitles(results)
|
|
|
|
|
|
|
|
|
|
|
|
extract_time = time.time() - start_time
|
|
|
|
|
|
results["extract_time"] = extract_time
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"字幕提取完成,耗时: {extract_time:.2f}秒")
|
|
|
|
|
|
logger.info(f"提取到 {len(results['subtitles'])} 个字幕片段")
|
|
|
|
|
|
logger.info(f"唯一文本 {len(results['unique_texts'])} 条")
|
|
|
|
|
|
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
|
|
def _ocr_frame(self, frame: np.ndarray, timestamp: float, confidence_threshold: float) -> List[Dict]:
|
|
|
|
|
|
"""对单帧进行OCR识别"""
|
|
|
|
|
|
ocr_results = []
|
|
|
|
|
|
|
|
|
|
|
|
# PaddleOCR识别
|
|
|
|
|
|
if self.paddle_ocr:
|
|
|
|
|
|
try:
|
|
|
|
|
|
paddle_result = self.paddle_ocr.ocr(frame)
|
|
|
|
|
|
if paddle_result and paddle_result[0]:
|
|
|
|
|
|
for detection in paddle_result[0]:
|
|
|
|
|
|
if detection:
|
|
|
|
|
|
# 处理不同版本的返回格式
|
|
|
|
|
|
try:
|
|
|
|
|
|
if len(detection) == 2:
|
|
|
|
|
|
# 旧版本格式: [bbox, (text, confidence)]
|
|
|
|
|
|
bbox, (text, confidence) = detection
|
|
|
|
|
|
elif len(detection) == 3:
|
|
|
|
|
|
# 新版本格式: [bbox, text, confidence]
|
|
|
|
|
|
bbox, text, confidence = detection
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 其他格式,跳过
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
if confidence >= confidence_threshold and text.strip():
|
|
|
|
|
|
ocr_results.append({
|
|
|
|
|
|
"timestamp": timestamp,
|
|
|
|
|
|
"text": text.strip(),
|
|
|
|
|
|
"confidence": confidence,
|
|
|
|
|
|
"bbox": bbox,
|
|
|
|
|
|
"engine": "PaddleOCR"
|
|
|
|
|
|
})
|
|
|
|
|
|
except Exception as parse_error:
|
|
|
|
|
|
logger.debug(f"解析单个检测结果失败: {parse_error}, 数据: {detection}")
|
|
|
|
|
|
continue
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.warning(f"PaddleOCR识别失败 (时间戳:{timestamp:.2f}): {e}")
|
|
|
|
|
|
|
|
|
|
|
|
# EasyOCR识别
|
|
|
|
|
|
if self.easy_ocr:
|
|
|
|
|
|
try:
|
|
|
|
|
|
easy_result = self.easy_ocr.readtext(frame)
|
|
|
|
|
|
for detection in easy_result:
|
|
|
|
|
|
bbox, text, confidence = detection
|
|
|
|
|
|
if confidence >= confidence_threshold and text.strip():
|
|
|
|
|
|
ocr_results.append({
|
|
|
|
|
|
"timestamp": timestamp,
|
|
|
|
|
|
"text": text.strip(),
|
|
|
|
|
|
"confidence": confidence,
|
|
|
|
|
|
"bbox": bbox,
|
|
|
|
|
|
"engine": "EasyOCR"
|
|
|
|
|
|
})
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.warning(f"EasyOCR识别失败 (时间戳:{timestamp:.2f}): {e}")
|
|
|
|
|
|
|
2025-06-30 10:27:06 +08:00
|
|
|
|
# CnOCR识别
|
|
|
|
|
|
if self.cn_ocr:
|
|
|
|
|
|
try:
|
|
|
|
|
|
cn_result = self.cn_ocr.ocr(frame)
|
|
|
|
|
|
# 检查cn_result是否为列表且不为空
|
|
|
|
|
|
if isinstance(cn_result, list) and len(cn_result) > 0:
|
|
|
|
|
|
for detection in cn_result:
|
|
|
|
|
|
# 确保detection是字典类型
|
|
|
|
|
|
if isinstance(detection, dict):
|
|
|
|
|
|
# CnOCR返回格式: {'text': str, 'score': float, 'position': np.ndarray}
|
|
|
|
|
|
text = detection.get('text', '')
|
|
|
|
|
|
confidence = detection.get('score', 0.0)
|
|
|
|
|
|
position = detection.get('position', None)
|
|
|
|
|
|
|
|
|
|
|
|
# 检查文本和置信度
|
|
|
|
|
|
if (confidence >= confidence_threshold and
|
|
|
|
|
|
isinstance(text, str) and text.strip()):
|
|
|
|
|
|
|
|
|
|
|
|
# 转换position格式为bbox格式
|
|
|
|
|
|
bbox = []
|
|
|
|
|
|
if position is not None:
|
|
|
|
|
|
# position是numpy数组,shape: (4, 2),转换为列表格式
|
|
|
|
|
|
try:
|
|
|
|
|
|
if hasattr(position, 'tolist'):
|
|
|
|
|
|
bbox = position.tolist()
|
|
|
|
|
|
else:
|
|
|
|
|
|
bbox = position
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.debug(f"转换position失败: {e}")
|
|
|
|
|
|
bbox = []
|
|
|
|
|
|
|
|
|
|
|
|
ocr_results.append({
|
|
|
|
|
|
"timestamp": timestamp,
|
|
|
|
|
|
"text": text.strip(),
|
|
|
|
|
|
"confidence": float(confidence),
|
|
|
|
|
|
"bbox": bbox,
|
|
|
|
|
|
"engine": "CnOCR"
|
|
|
|
|
|
})
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.warning(f"CnOCR识别失败 (时间戳:{timestamp:.2f}): {e}")
|
|
|
|
|
|
|
2025-06-09 09:52:32 +08:00
|
|
|
|
return ocr_results
|
|
|
|
|
|
|
|
|
|
|
|
def _postprocess_subtitles(self, results: Dict) -> Dict:
|
|
|
|
|
|
"""后处理字幕结果"""
|
|
|
|
|
|
subtitles = results["subtitles"]
|
|
|
|
|
|
|
|
|
|
|
|
# 按时间戳排序
|
|
|
|
|
|
subtitles.sort(key=lambda x: x["timestamp"])
|
|
|
|
|
|
|
|
|
|
|
|
# 去重相邻重复文本
|
|
|
|
|
|
filtered_subtitles = []
|
|
|
|
|
|
last_text = ""
|
|
|
|
|
|
last_timestamp = -1
|
|
|
|
|
|
|
|
|
|
|
|
for subtitle in subtitles:
|
|
|
|
|
|
text = subtitle["text"]
|
|
|
|
|
|
timestamp = subtitle["timestamp"]
|
|
|
|
|
|
|
|
|
|
|
|
# 如果文本不同,或者时间间隔较大,则保留
|
|
|
|
|
|
if text != last_text or (timestamp - last_timestamp) > 5.0:
|
|
|
|
|
|
filtered_subtitles.append(subtitle)
|
|
|
|
|
|
last_text = text
|
|
|
|
|
|
last_timestamp = timestamp
|
|
|
|
|
|
|
|
|
|
|
|
results["subtitles"] = filtered_subtitles
|
|
|
|
|
|
results["filtered_count"] = len(subtitles) - len(filtered_subtitles)
|
|
|
|
|
|
|
|
|
|
|
|
# 生成连续文本
|
|
|
|
|
|
results["continuous_text"] = " ".join([s["text"] for s in filtered_subtitles])
|
|
|
|
|
|
|
|
|
|
|
|
# 统计信息
|
|
|
|
|
|
results["stats"] = {
|
|
|
|
|
|
"total_detections": len(subtitles),
|
|
|
|
|
|
"filtered_detections": len(filtered_subtitles),
|
|
|
|
|
|
"unique_texts": len(results["unique_texts"]),
|
|
|
|
|
|
"text_length": len(results["continuous_text"]),
|
|
|
|
|
|
"average_confidence": np.mean([s["confidence"] for s in filtered_subtitles]) if filtered_subtitles else 0
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
|
|
def extract_subtitles_from_frames(self, frame_dir: str, confidence_threshold: float = 0.5) -> Dict:
|
|
|
|
|
|
"""从帧图片目录中提取字幕"""
|
|
|
|
|
|
frame_dir = Path(frame_dir)
|
|
|
|
|
|
if not frame_dir.exists():
|
|
|
|
|
|
raise FileNotFoundError(f"帧目录不存在: {frame_dir}")
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"开始从帧目录提取字幕: {frame_dir}")
|
|
|
|
|
|
|
|
|
|
|
|
# 获取所有图片文件
|
|
|
|
|
|
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp']
|
|
|
|
|
|
image_files = []
|
|
|
|
|
|
for ext in image_extensions:
|
|
|
|
|
|
image_files.extend(frame_dir.glob(f"*{ext}"))
|
|
|
|
|
|
image_files.extend(frame_dir.glob(f"*{ext.upper()}"))
|
|
|
|
|
|
|
|
|
|
|
|
image_files.sort()
|
|
|
|
|
|
|
|
|
|
|
|
if not image_files:
|
|
|
|
|
|
logger.warning(f"在目录 {frame_dir} 中未找到图片文件")
|
|
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"找到 {len(image_files)} 个图片文件")
|
|
|
|
|
|
|
|
|
|
|
|
results = {
|
|
|
|
|
|
"frame_dir": str(frame_dir),
|
|
|
|
|
|
"total_images": len(image_files),
|
|
|
|
|
|
"confidence_threshold": confidence_threshold,
|
|
|
|
|
|
"subtitles": [],
|
|
|
|
|
|
"unique_texts": set(),
|
|
|
|
|
|
"ocr_engine": self.ocr_engine,
|
|
|
|
|
|
"timestamp": datetime.now().isoformat()
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for i, image_file in enumerate(image_files):
|
|
|
|
|
|
try:
|
|
|
|
|
|
frame = cv2.imread(str(image_file))
|
|
|
|
|
|
if frame is not None:
|
|
|
|
|
|
# 使用文件名作为时间戳(如果可能)
|
|
|
|
|
|
timestamp = float(i) # 简单使用索引
|
|
|
|
|
|
|
|
|
|
|
|
ocr_results = self._ocr_frame(frame, timestamp, confidence_threshold)
|
|
|
|
|
|
|
|
|
|
|
|
if ocr_results:
|
|
|
|
|
|
results["subtitles"].extend(ocr_results)
|
|
|
|
|
|
for result in ocr_results:
|
|
|
|
|
|
results["unique_texts"].add(result["text"])
|
|
|
|
|
|
|
|
|
|
|
|
# 显示进度
|
|
|
|
|
|
if (i + 1) % 10 == 0:
|
|
|
|
|
|
progress = ((i + 1) / len(image_files)) * 100
|
|
|
|
|
|
logger.info(f"处理进度: {progress:.1f}%")
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.warning(f"处理图片 {image_file} 时出错: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
# 转换set为list
|
|
|
|
|
|
results["unique_texts"] = list(results["unique_texts"])
|
|
|
|
|
|
|
|
|
|
|
|
# 后处理
|
|
|
|
|
|
results = self._postprocess_subtitles(results)
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"字幕提取完成,提取到 {len(results['subtitles'])} 个字幕片段")
|
|
|
|
|
|
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
|
|
def save_results(self, results: Dict, output_path: str, format: str = "json"):
|
|
|
|
|
|
"""保存提取结果"""
|
|
|
|
|
|
output_path = Path(output_path)
|
|
|
|
|
|
|
|
|
|
|
|
if format == "json":
|
|
|
|
|
|
# 序列化结果(处理set类型和NumPy类型)
|
|
|
|
|
|
results_copy = results.copy()
|
|
|
|
|
|
if "unique_texts" in results_copy and isinstance(results_copy["unique_texts"], set):
|
|
|
|
|
|
results_copy["unique_texts"] = list(results_copy["unique_texts"])
|
|
|
|
|
|
|
|
|
|
|
|
# 转换NumPy类型
|
|
|
|
|
|
results_copy = convert_numpy_types(results_copy)
|
|
|
|
|
|
|
|
|
|
|
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|
|
|
|
|
json.dump(results_copy, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
|
|
|
|
|
|
elif format == "txt":
|
|
|
|
|
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|
|
|
|
|
f.write(results.get("continuous_text", ""))
|
|
|
|
|
|
|
|
|
|
|
|
elif format == "srt":
|
|
|
|
|
|
self._save_srt(results, output_path)
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise ValueError(f"不支持的格式: {format}")
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"结果已保存: {output_path}")
|
|
|
|
|
|
|
|
|
|
|
|
def _save_srt(self, results: Dict, output_path: Path):
|
|
|
|
|
|
"""保存为SRT字幕格式"""
|
|
|
|
|
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|
|
|
|
|
for i, subtitle in enumerate(results["subtitles"], 1):
|
|
|
|
|
|
timestamp = subtitle["timestamp"]
|
|
|
|
|
|
text = subtitle["text"]
|
|
|
|
|
|
|
|
|
|
|
|
# 估算字幕持续时间(3秒)
|
|
|
|
|
|
start_time = self._format_time(timestamp)
|
|
|
|
|
|
end_time = self._format_time(timestamp + 3.0)
|
|
|
|
|
|
|
|
|
|
|
|
f.write(f"{i}\n")
|
|
|
|
|
|
f.write(f"{start_time} --> {end_time}\n")
|
|
|
|
|
|
f.write(f"{text}\n\n")
|
|
|
|
|
|
|
|
|
|
|
|
def _format_time(self, seconds):
|
|
|
|
|
|
"""格式化时间为SRT格式"""
|
|
|
|
|
|
hours = int(seconds // 3600)
|
|
|
|
|
|
minutes = int((seconds % 3600) // 60)
|
|
|
|
|
|
seconds = seconds % 60
|
|
|
|
|
|
return f"{hours:02d}:{minutes:02d}:{seconds:06.3f}".replace('.', ',')
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
|
"""主函数"""
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="专业视频字幕OCR提取器")
|
|
|
|
|
|
parser.add_argument("input", help="输入视频文件路径或帧图片目录")
|
2025-06-30 10:27:06 +08:00
|
|
|
|
parser.add_argument("-e", "--engine", default="cnocr",
|
|
|
|
|
|
choices=["paddleocr", "easyocr", "cnocr", "all"],
|
2025-06-09 09:52:32 +08:00
|
|
|
|
help="OCR引擎 (默认: paddleocr)")
|
|
|
|
|
|
parser.add_argument("-l", "--language", default="ch",
|
|
|
|
|
|
choices=["ch", "en", "ch_en"],
|
|
|
|
|
|
help="语言设置 (默认: ch)")
|
|
|
|
|
|
parser.add_argument("-i", "--interval", type=int, default=30,
|
|
|
|
|
|
help="帧采样间隔 (默认: 30)")
|
|
|
|
|
|
parser.add_argument("-c", "--confidence", type=float, default=0.5,
|
|
|
|
|
|
help="置信度阈值 (默认: 0.5)")
|
|
|
|
|
|
parser.add_argument("-o", "--output", default="subtitle_results",
|
|
|
|
|
|
help="输出目录 (默认: subtitle_results)")
|
|
|
|
|
|
parser.add_argument("-f", "--format", default="json",
|
|
|
|
|
|
choices=["json", "txt", "srt"],
|
|
|
|
|
|
help="输出格式 (默认: json)")
|
2025-06-30 10:27:06 +08:00
|
|
|
|
parser.add_argument("--position", default="bottom",
|
|
|
|
|
|
choices=["full", "center", "bottom"],
|
|
|
|
|
|
help="字幕区域位置 (full=全屏, center=居中0.4-0.6, bottom=居下0.7-1.0)")
|
2025-06-09 09:52:32 +08:00
|
|
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
# 创建提取器
|
|
|
|
|
|
extractor = VideoSubtitleExtractor(
|
|
|
|
|
|
ocr_engine=args.engine,
|
|
|
|
|
|
language=args.language
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
input_path = Path(args.input)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
if input_path.is_file():
|
|
|
|
|
|
# 处理视频文件
|
|
|
|
|
|
logger.info("处理视频文件")
|
|
|
|
|
|
results = extractor.extract_subtitles_from_video(
|
|
|
|
|
|
args.input,
|
|
|
|
|
|
sample_interval=args.interval,
|
|
|
|
|
|
confidence_threshold=args.confidence,
|
2025-06-30 10:27:06 +08:00
|
|
|
|
subtitle_position=args.position
|
2025-06-09 09:52:32 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
elif input_path.is_dir():
|
|
|
|
|
|
# 处理帧图片目录
|
|
|
|
|
|
logger.info("处理帧图片目录")
|
|
|
|
|
|
results = extractor.extract_subtitles_from_frames(
|
|
|
|
|
|
args.input,
|
|
|
|
|
|
confidence_threshold=args.confidence
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
logger.error(f"输入路径无效: {args.input}")
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
# 保存结果
|
|
|
|
|
|
output_dir = Path(args.output)
|
|
|
|
|
|
output_dir.mkdir(exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
input_name = input_path.stem
|
|
|
|
|
|
output_file = output_dir / f"{input_name}_subtitles.{args.format}"
|
|
|
|
|
|
|
|
|
|
|
|
extractor.save_results(results, output_file, args.format)
|
|
|
|
|
|
|
|
|
|
|
|
# 显示结果
|
|
|
|
|
|
print(f"\n字幕提取完成!")
|
|
|
|
|
|
print(f"输入: {args.input}")
|
|
|
|
|
|
print(f"输出: {output_file}")
|
|
|
|
|
|
print(f"提取时长: {results.get('extract_time', 0):.2f} 秒")
|
|
|
|
|
|
print(f"字幕片段: {results['stats']['filtered_detections']} 个")
|
|
|
|
|
|
print(f"唯一文本: {results['stats']['unique_texts']} 条")
|
|
|
|
|
|
print(f"平均置信度: {results['stats']['average_confidence']:.3f}")
|
|
|
|
|
|
print(f"文本总长度: {results['stats']['text_length']} 字符")
|
|
|
|
|
|
|
|
|
|
|
|
if results.get("continuous_text"):
|
|
|
|
|
|
print(f"\n提取的连续文本:")
|
|
|
|
|
|
print("-" * 50)
|
|
|
|
|
|
print(results["continuous_text"][:500] + "..." if len(results["continuous_text"]) > 500 else results["continuous_text"])
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"程序执行出错: {str(e)}")
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
main()
|