hot_video_analyse/code/video2audio.py

289 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import sys
import shutil
from pathlib import Path
import numpy as np
import soundfile as sf
from decord import VideoReader, AudioReader, cpu
import argparse
from typing import Optional, Tuple
import logging
# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class Video2AudioExtractor:
"""使用decord库提取视频中的音频"""
def __init__(self, output_dir: str = "output"):
"""
初始化提取器
Args:
output_dir: 输出目录路径
"""
self.output_dir = Path(output_dir)
self.output_dir.mkdir(exist_ok=True)
def extract_audio_from_video(self, video_path: str, audio_format: str = "wav") -> Tuple[str, str]:
"""
从视频文件中提取音频
Args:
video_path: 视频文件路径
audio_format: 音频格式 (wav, flac, ogg)
Returns:
Tuple[str, str]: (视频文件路径, 音频文件路径)
"""
video_path = Path(video_path)
if not video_path.exists():
raise FileNotFoundError(f"视频文件不存在: {video_path}")
logger.info(f"开始处理视频文件: {video_path}")
try:
# 首先获取视频信息
vr = VideoReader(str(video_path), ctx=cpu(0))
logger.info(f"视频信息: 总帧数={len(vr)}, FPS={vr.get_avg_fps():.2f}")
# 尝试使用AudioReader提取音频
audio_data = None
sample_rate = None
try:
# 使用AudioReader读取音频
ar = AudioReader(str(video_path), ctx=cpu(0))
# 获取音频信息
audio_length = len(ar)
original_sample_rate = ar.sample_rate
logger.info(f"音频信息: 长度={audio_length}, 原始采样率={original_sample_rate}")
if audio_length > 0:
# 读取所有音频数据
audio_data = ar[:]
# 确保音频数据是numpy数组
if not isinstance(audio_data, np.ndarray):
audio_data = audio_data.asnumpy()
# 设置采样率
if original_sample_rate > 0:
sample_rate = original_sample_rate
else:
# 如果原始采样率无效,尝试使用指定采样率重新读取
logger.info("原始采样率无效尝试使用44100Hz重新读取")
ar_44k = AudioReader(str(video_path), ctx=cpu(0), sample_rate=44100)
audio_data = ar_44k[:]
if not isinstance(audio_data, np.ndarray):
audio_data = audio_data.asnumpy()
sample_rate = 44100
logger.info(f"成功提取音频: 数据形状={audio_data.shape}, 采样率={sample_rate}")
else:
logger.warning("音频长度为0")
except Exception as e:
logger.warning(f"AudioReader提取失败: {str(e)}")
# 尝试使用不同的参数
try:
logger.info("尝试使用默认参数重新读取音频")
ar_default = AudioReader(str(video_path), sample_rate=44100, mono=False)
audio_data = ar_default[:]
if not isinstance(audio_data, np.ndarray):
audio_data = audio_data.asnumpy()
sample_rate = 44100
logger.info(f"默认参数成功: 数据形状={audio_data.shape}, 采样率={sample_rate}")
except Exception as e2:
logger.warning(f"默认参数也失败: {str(e2)}")
# 如果成功获取音频数据
if audio_data is not None and len(audio_data) > 0:
logger.info(f"最终音频数据形状: {audio_data.shape}")
logger.info(f"音频数据类型: {audio_data.dtype}")
logger.info(f"音频数据范围: [{audio_data.min():.6f}, {audio_data.max():.6f}]")
# 处理音频数据格式
if len(audio_data.shape) == 1:
logger.info("检测到单声道音频")
elif len(audio_data.shape) == 2:
logger.info(f"检测到多声道音频: {audio_data.shape[1]} 个声道")
# 如果是单声道但形状是(1, N),转换为(N,)
if audio_data.shape[0] == 1:
audio_data = audio_data.squeeze(0)
logger.info(f"转换后音频数据形状: {audio_data.shape}")
else:
logger.warning(f"未知的音频数据格式: {audio_data.shape}")
# 确保采样率有效
if sample_rate is None or sample_rate <= 0:
sample_rate = 44100
logger.warning(f"使用默认采样率: {sample_rate} Hz")
else:
logger.info(f"使用采样率: {sample_rate} Hz")
# 生成输出文件名
video_output_path = self.output_dir / video_path.name
audio_output_path = self.output_dir / f"{video_path.stem}.{audio_format}"
# 复制视频文件到输出目录
if video_path != video_output_path:
shutil.copy2(video_path, video_output_path)
logger.info(f"视频文件已复制到: {video_output_path}")
# 保存音频文件
try:
sf.write(str(audio_output_path), audio_data, sample_rate)
logger.info(f"音频文件已保存到: {audio_output_path}")
# 验证保存的文件
if audio_output_path.exists() and audio_output_path.stat().st_size > 0:
logger.info(f"音频文件验证成功,大小: {audio_output_path.stat().st_size} 字节")
return str(video_output_path), str(audio_output_path)
else:
logger.error("音频文件保存失败或文件为空")
return self._copy_video_only(video_path)
except Exception as e:
logger.error(f"保存音频文件时出错: {str(e)}")
return self._copy_video_only(video_path)
else:
logger.warning(f"视频文件 {video_path} 不包含音频轨道或无法提取音频")
return self._copy_video_only(video_path)
except Exception as e:
logger.error(f"处理视频文件时出错: {str(e)}")
# 尝试至少复制视频文件
return self._copy_video_only(video_path)
def _copy_video_only(self, video_path: Path) -> Tuple[str, str]:
"""
仅复制视频文件(当没有音频时)
Args:
video_path: 视频文件路径
Returns:
Tuple[str, str]: (视频文件路径, 空字符串)
"""
video_output_path = self.output_dir / video_path.name
if video_path != video_output_path:
shutil.copy2(video_path, video_output_path)
logger.info(f"视频文件已复制到: {video_output_path}")
return str(video_output_path), ""
def process_multiple_videos(self, video_paths: list, audio_format: str = "wav") -> list:
"""
批量处理多个视频文件
Args:
video_paths: 视频文件路径列表
audio_format: 音频格式
Returns:
list: 处理结果列表,每个元素为 (视频路径, 音频路径) 元组
"""
results = []
for i, video_path in enumerate(video_paths, 1):
logger.info(f"处理第 {i}/{len(video_paths)} 个视频文件")
try:
video_out, audio_out = self.extract_audio_from_video(video_path, audio_format)
results.append((video_out, audio_out))
except Exception as e:
logger.error(f"处理视频 {video_path} 时出错: {str(e)}")
results.append((None, None))
return results
def process_directory(self, input_dir: str, video_extensions: list = None, audio_format: str = "wav") -> list:
"""
处理目录中的所有视频文件
Args:
input_dir: 输入目录路径
video_extensions: 支持的视频文件扩展名列表
audio_format: 音频格式
Returns:
list: 处理结果列表
"""
if video_extensions is None:
video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm']
input_path = Path(input_dir)
if not input_path.exists():
raise FileNotFoundError(f"输入目录不存在: {input_dir}")
# 查找所有视频文件
video_files = []
for ext in video_extensions:
video_files.extend(input_path.glob(f"*{ext}"))
video_files.extend(input_path.glob(f"*{ext.upper()}"))
if not video_files:
logger.warning(f"在目录 {input_dir} 中未找到视频文件")
return []
logger.info(f"找到 {len(video_files)} 个视频文件")
return self.process_multiple_videos([str(f) for f in video_files], audio_format)
def main():
"""主函数"""
parser = argparse.ArgumentParser(description="使用decord从视频中提取音频")
parser.add_argument("input", help="输入视频文件路径或目录路径")
parser.add_argument("-o", "--output", default="output", help="输出目录路径 (默认: output)")
parser.add_argument("-f", "--format", default="wav", choices=["wav", "flac", "ogg"],
help="音频格式 (默认: wav)")
parser.add_argument("--extensions", nargs="+",
default=['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm'],
help="支持的视频文件扩展名")
args = parser.parse_args()
# 创建提取器
extractor = Video2AudioExtractor(args.output)
input_path = Path(args.input)
try:
if input_path.is_file():
# 处理单个文件
logger.info("处理单个视频文件")
video_out, audio_out = extractor.extract_audio_from_video(args.input, args.format)
print(f"\n处理完成!")
print(f"视频文件: {video_out}")
if audio_out:
print(f"音频文件: {audio_out}")
else:
print("该视频文件不包含音频轨道")
elif input_path.is_dir():
# 处理目录
logger.info("处理目录中的所有视频文件")
results = extractor.process_directory(args.input, args.extensions, args.format)
print(f"\n处理完成! 共处理 {len(results)} 个文件")
success_count = sum(1 for r in results if r[0] is not None)
print(f"成功: {success_count}, 失败: {len(results) - success_count}")
print(f"\n所有文件已保存到目录: {extractor.output_dir}")
else:
logger.error(f"输入路径无效: {args.input}")
sys.exit(1)
except Exception as e:
logger.error(f"程序执行出错: {str(e)}")
sys.exit(1)
if __name__ == "__main__":
main()