289 lines
12 KiB
Python
289 lines
12 KiB
Python
import os
|
||
import sys
|
||
import shutil
|
||
from pathlib import Path
|
||
import numpy as np
|
||
import soundfile as sf
|
||
from decord import VideoReader, AudioReader, cpu
|
||
import argparse
|
||
from typing import Optional, Tuple
|
||
import logging
|
||
|
||
# 设置日志
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class Video2AudioExtractor:
|
||
"""使用decord库提取视频中的音频"""
|
||
|
||
def __init__(self, output_dir: str = "output"):
|
||
"""
|
||
初始化提取器
|
||
|
||
Args:
|
||
output_dir: 输出目录路径
|
||
"""
|
||
self.output_dir = Path(output_dir)
|
||
self.output_dir.mkdir(exist_ok=True)
|
||
|
||
def extract_audio_from_video(self, video_path: str, audio_format: str = "wav") -> Tuple[str, str]:
|
||
"""
|
||
从视频文件中提取音频
|
||
|
||
Args:
|
||
video_path: 视频文件路径
|
||
audio_format: 音频格式 (wav, flac, ogg)
|
||
|
||
Returns:
|
||
Tuple[str, str]: (视频文件路径, 音频文件路径)
|
||
"""
|
||
video_path = Path(video_path)
|
||
|
||
if not video_path.exists():
|
||
raise FileNotFoundError(f"视频文件不存在: {video_path}")
|
||
|
||
logger.info(f"开始处理视频文件: {video_path}")
|
||
|
||
try:
|
||
# 首先获取视频信息
|
||
vr = VideoReader(str(video_path), ctx=cpu(0))
|
||
logger.info(f"视频信息: 总帧数={len(vr)}, FPS={vr.get_avg_fps():.2f}")
|
||
|
||
# 尝试使用AudioReader提取音频
|
||
audio_data = None
|
||
sample_rate = None
|
||
|
||
try:
|
||
# 使用AudioReader读取音频
|
||
ar = AudioReader(str(video_path), ctx=cpu(0))
|
||
|
||
# 获取音频信息
|
||
audio_length = len(ar)
|
||
original_sample_rate = ar.sample_rate
|
||
|
||
logger.info(f"音频信息: 长度={audio_length}, 原始采样率={original_sample_rate}")
|
||
|
||
if audio_length > 0:
|
||
# 读取所有音频数据
|
||
audio_data = ar[:]
|
||
|
||
# 确保音频数据是numpy数组
|
||
if not isinstance(audio_data, np.ndarray):
|
||
audio_data = audio_data.asnumpy()
|
||
|
||
# 设置采样率
|
||
if original_sample_rate > 0:
|
||
sample_rate = original_sample_rate
|
||
else:
|
||
# 如果原始采样率无效,尝试使用指定采样率重新读取
|
||
logger.info("原始采样率无效,尝试使用44100Hz重新读取")
|
||
ar_44k = AudioReader(str(video_path), ctx=cpu(0), sample_rate=44100)
|
||
audio_data = ar_44k[:]
|
||
if not isinstance(audio_data, np.ndarray):
|
||
audio_data = audio_data.asnumpy()
|
||
sample_rate = 44100
|
||
|
||
logger.info(f"成功提取音频: 数据形状={audio_data.shape}, 采样率={sample_rate}")
|
||
else:
|
||
logger.warning("音频长度为0")
|
||
|
||
except Exception as e:
|
||
logger.warning(f"AudioReader提取失败: {str(e)}")
|
||
|
||
# 尝试使用不同的参数
|
||
try:
|
||
logger.info("尝试使用默认参数重新读取音频")
|
||
ar_default = AudioReader(str(video_path), sample_rate=44100, mono=False)
|
||
audio_data = ar_default[:]
|
||
if not isinstance(audio_data, np.ndarray):
|
||
audio_data = audio_data.asnumpy()
|
||
sample_rate = 44100
|
||
logger.info(f"默认参数成功: 数据形状={audio_data.shape}, 采样率={sample_rate}")
|
||
except Exception as e2:
|
||
logger.warning(f"默认参数也失败: {str(e2)}")
|
||
|
||
# 如果成功获取音频数据
|
||
if audio_data is not None and len(audio_data) > 0:
|
||
logger.info(f"最终音频数据形状: {audio_data.shape}")
|
||
logger.info(f"音频数据类型: {audio_data.dtype}")
|
||
logger.info(f"音频数据范围: [{audio_data.min():.6f}, {audio_data.max():.6f}]")
|
||
|
||
# 处理音频数据格式
|
||
if len(audio_data.shape) == 1:
|
||
logger.info("检测到单声道音频")
|
||
elif len(audio_data.shape) == 2:
|
||
logger.info(f"检测到多声道音频: {audio_data.shape[1]} 个声道")
|
||
# 如果是单声道但形状是(1, N),转换为(N,)
|
||
if audio_data.shape[0] == 1:
|
||
audio_data = audio_data.squeeze(0)
|
||
logger.info(f"转换后音频数据形状: {audio_data.shape}")
|
||
else:
|
||
logger.warning(f"未知的音频数据格式: {audio_data.shape}")
|
||
|
||
# 确保采样率有效
|
||
if sample_rate is None or sample_rate <= 0:
|
||
sample_rate = 44100
|
||
logger.warning(f"使用默认采样率: {sample_rate} Hz")
|
||
else:
|
||
logger.info(f"使用采样率: {sample_rate} Hz")
|
||
|
||
# 生成输出文件名
|
||
video_output_path = self.output_dir / video_path.name
|
||
audio_output_path = self.output_dir / f"{video_path.stem}.{audio_format}"
|
||
|
||
# 复制视频文件到输出目录
|
||
if video_path != video_output_path:
|
||
shutil.copy2(video_path, video_output_path)
|
||
logger.info(f"视频文件已复制到: {video_output_path}")
|
||
|
||
# 保存音频文件
|
||
try:
|
||
sf.write(str(audio_output_path), audio_data, sample_rate)
|
||
logger.info(f"音频文件已保存到: {audio_output_path}")
|
||
|
||
# 验证保存的文件
|
||
if audio_output_path.exists() and audio_output_path.stat().st_size > 0:
|
||
logger.info(f"音频文件验证成功,大小: {audio_output_path.stat().st_size} 字节")
|
||
return str(video_output_path), str(audio_output_path)
|
||
else:
|
||
logger.error("音频文件保存失败或文件为空")
|
||
return self._copy_video_only(video_path)
|
||
|
||
except Exception as e:
|
||
logger.error(f"保存音频文件时出错: {str(e)}")
|
||
return self._copy_video_only(video_path)
|
||
else:
|
||
logger.warning(f"视频文件 {video_path} 不包含音频轨道或无法提取音频")
|
||
return self._copy_video_only(video_path)
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理视频文件时出错: {str(e)}")
|
||
# 尝试至少复制视频文件
|
||
return self._copy_video_only(video_path)
|
||
|
||
def _copy_video_only(self, video_path: Path) -> Tuple[str, str]:
|
||
"""
|
||
仅复制视频文件(当没有音频时)
|
||
|
||
Args:
|
||
video_path: 视频文件路径
|
||
|
||
Returns:
|
||
Tuple[str, str]: (视频文件路径, 空字符串)
|
||
"""
|
||
video_output_path = self.output_dir / video_path.name
|
||
if video_path != video_output_path:
|
||
shutil.copy2(video_path, video_output_path)
|
||
logger.info(f"视频文件已复制到: {video_output_path}")
|
||
return str(video_output_path), ""
|
||
|
||
def process_multiple_videos(self, video_paths: list, audio_format: str = "wav") -> list:
|
||
"""
|
||
批量处理多个视频文件
|
||
|
||
Args:
|
||
video_paths: 视频文件路径列表
|
||
audio_format: 音频格式
|
||
|
||
Returns:
|
||
list: 处理结果列表,每个元素为 (视频路径, 音频路径) 元组
|
||
"""
|
||
results = []
|
||
|
||
for i, video_path in enumerate(video_paths, 1):
|
||
logger.info(f"处理第 {i}/{len(video_paths)} 个视频文件")
|
||
try:
|
||
video_out, audio_out = self.extract_audio_from_video(video_path, audio_format)
|
||
results.append((video_out, audio_out))
|
||
except Exception as e:
|
||
logger.error(f"处理视频 {video_path} 时出错: {str(e)}")
|
||
results.append((None, None))
|
||
|
||
return results
|
||
|
||
def process_directory(self, input_dir: str, video_extensions: list = None, audio_format: str = "wav") -> list:
|
||
"""
|
||
处理目录中的所有视频文件
|
||
|
||
Args:
|
||
input_dir: 输入目录路径
|
||
video_extensions: 支持的视频文件扩展名列表
|
||
audio_format: 音频格式
|
||
|
||
Returns:
|
||
list: 处理结果列表
|
||
"""
|
||
if video_extensions is None:
|
||
video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm']
|
||
|
||
input_path = Path(input_dir)
|
||
if not input_path.exists():
|
||
raise FileNotFoundError(f"输入目录不存在: {input_dir}")
|
||
|
||
# 查找所有视频文件
|
||
video_files = []
|
||
for ext in video_extensions:
|
||
video_files.extend(input_path.glob(f"*{ext}"))
|
||
video_files.extend(input_path.glob(f"*{ext.upper()}"))
|
||
|
||
if not video_files:
|
||
logger.warning(f"在目录 {input_dir} 中未找到视频文件")
|
||
return []
|
||
|
||
logger.info(f"找到 {len(video_files)} 个视频文件")
|
||
|
||
return self.process_multiple_videos([str(f) for f in video_files], audio_format)
|
||
|
||
def main():
|
||
"""主函数"""
|
||
parser = argparse.ArgumentParser(description="使用decord从视频中提取音频")
|
||
parser.add_argument("input", help="输入视频文件路径或目录路径")
|
||
parser.add_argument("-o", "--output", default="output", help="输出目录路径 (默认: output)")
|
||
parser.add_argument("-f", "--format", default="wav", choices=["wav", "flac", "ogg"],
|
||
help="音频格式 (默认: wav)")
|
||
parser.add_argument("--extensions", nargs="+",
|
||
default=['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm'],
|
||
help="支持的视频文件扩展名")
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 创建提取器
|
||
extractor = Video2AudioExtractor(args.output)
|
||
|
||
input_path = Path(args.input)
|
||
|
||
try:
|
||
if input_path.is_file():
|
||
# 处理单个文件
|
||
logger.info("处理单个视频文件")
|
||
video_out, audio_out = extractor.extract_audio_from_video(args.input, args.format)
|
||
|
||
print(f"\n处理完成!")
|
||
print(f"视频文件: {video_out}")
|
||
if audio_out:
|
||
print(f"音频文件: {audio_out}")
|
||
else:
|
||
print("该视频文件不包含音频轨道")
|
||
|
||
elif input_path.is_dir():
|
||
# 处理目录
|
||
logger.info("处理目录中的所有视频文件")
|
||
results = extractor.process_directory(args.input, args.extensions, args.format)
|
||
|
||
print(f"\n处理完成! 共处理 {len(results)} 个文件")
|
||
success_count = sum(1 for r in results if r[0] is not None)
|
||
print(f"成功: {success_count}, 失败: {len(results) - success_count}")
|
||
|
||
print(f"\n所有文件已保存到目录: {extractor.output_dir}")
|
||
|
||
else:
|
||
logger.error(f"输入路径无效: {args.input}")
|
||
sys.exit(1)
|
||
|
||
except Exception as e:
|
||
logger.error(f"程序执行出错: {str(e)}")
|
||
sys.exit(1)
|
||
|
||
if __name__ == "__main__":
|
||
main()
|