video_template_gen/code/whisper_audio_transcribe.py

340 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
使用Whisper模型进行语音识别
"""
import whisper
import os
import time
from pathlib import Path
import argparse
import json
from datetime import datetime
import logging
# 设置Whisper模型缓存目录
os.environ['WHISPER_CACHE_DIR'] = '/root/autodl-tmp/llm/whisper'
# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class WhisperTranscriber:
"""使用Whisper模型进行语音转文字"""
def __init__(self, model_size="base"):
"""
初始化Whisper转录器
Args:
model_size: 模型大小 (tiny, base, small, medium, large, large-v2, large-v3)
"""
self.model_size = model_size
self.model = None
self.load_model()
def load_model(self):
"""加载Whisper模型"""
logger.info(f"正在加载Whisper模型: {self.model_size}")
logger.info(f"模型缓存目录: {os.environ.get('WHISPER_CACHE_DIR', '默认目录')}")
start_time = time.time()
try:
# 确保缓存目录存在
cache_dir = Path(os.environ['WHISPER_CACHE_DIR'])
cache_dir.mkdir(parents=True, exist_ok=True)
self.model = whisper.load_model(self.model_size, download_root=str(cache_dir))
load_time = time.time() - start_time
logger.info(f"模型加载完成,耗时: {load_time:.2f}")
except Exception as e:
logger.error(f"模型加载失败: {str(e)}")
raise
def transcribe_audio(self, audio_path, language="zh", task="transcribe"):
"""
转录音频文件
Args:
audio_path: 音频文件路径
language: 语言代码 (zh=中文, en=英文, auto=自动检测)
task: 任务类型 (transcribe=转录, translate=翻译为英文)
Returns:
dict: 转录结果
"""
audio_path = Path(audio_path)
if not audio_path.exists():
raise FileNotFoundError(f"音频文件不存在: {audio_path}")
logger.info(f"开始转录音频文件: {audio_path}")
logger.info(f"语言: {language}, 任务: {task}")
start_time = time.time()
try:
# 设置转录参数
options = {
"task": task,
"fp16": False, # 使用FP32以提高兼容性
"verbose": False, # 减少输出
}
# 如果指定了语言,添加到选项中
if language != "auto":
options["language"] = language
# 执行转录
try:
result = self.model.transcribe(str(audio_path), **options)
except Exception as e:
# 如果出现PyTorch错误尝试使用更简单的参数
logger.warning(f"转录失败,尝试使用简化参数: {str(e)}")
simple_options = {
"task": task,
"fp16": False,
"verbose": False,
}
if language != "auto":
simple_options["language"] = language
try:
result = self.model.transcribe(str(audio_path), **simple_options)
except Exception as e2:
logger.error(f"简化参数也失败: {str(e2)}")
raise e2
transcribe_time = time.time() - start_time
logger.info(f"转录完成,耗时: {transcribe_time:.2f}")
# 处理结果
processed_result = self._process_result(result, transcribe_time)
return processed_result
except Exception as e:
logger.error(f"转录失败: {str(e)}")
raise
def _process_result(self, result, transcribe_time):
"""处理转录结果"""
processed = {
"text": result["text"].strip(),
"language": result.get("language", "unknown"),
"segments": [],
"transcribe_time": transcribe_time,
"model_size": self.model_size,
"timestamp": datetime.now().isoformat()
}
# 处理分段信息
for segment in result.get("segments", []):
processed_segment = {
"id": segment.get("id"),
"start": segment.get("start"),
"end": segment.get("end"),
"text": segment.get("text", "").strip(),
"confidence": segment.get("avg_logprob", 0)
}
processed["segments"].append(processed_segment)
# 统计信息
processed["stats"] = {
"total_segments": len(processed["segments"]),
"total_duration": max([s.get("end", 0) for s in processed["segments"]] + [0]),
"text_length": len(processed["text"]),
"words_count": len(processed["text"].split()) if processed["text"] else 0
}
return processed
def transcribe_multiple_files(self, audio_files, output_dir="transcripts", **kwargs):
"""
批量转录多个音频文件
Args:
audio_files: 音频文件路径列表
output_dir: 输出目录
**kwargs: 传递给transcribe_audio的参数
Returns:
list: 转录结果列表
"""
output_path = Path(output_dir)
output_path.mkdir(exist_ok=True)
results = []
for i, audio_file in enumerate(audio_files, 1):
logger.info(f"处理第 {i}/{len(audio_files)} 个文件")
try:
result = self.transcribe_audio(audio_file, **kwargs)
# 保存单个文件的结果
audio_name = Path(audio_file).stem
output_file = output_path / f"{audio_name}_transcript.json"
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(result, f, ensure_ascii=False, indent=2)
logger.info(f"转录结果已保存: {output_file}")
results.append(result)
except Exception as e:
logger.error(f"处理文件 {audio_file} 时出错: {str(e)}")
results.append(None)
return results
def save_transcript(self, result, output_path, format="json"):
"""
保存转录结果
Args:
result: 转录结果
output_path: 输出文件路径
format: 输出格式 (json, txt, srt)
"""
output_path = Path(output_path)
if format == "json":
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(result, f, ensure_ascii=False, indent=2)
elif format == "txt":
with open(output_path, 'w', encoding='utf-8') as f:
f.write(result["text"])
elif format == "srt":
self._save_srt(result, output_path)
else:
raise ValueError(f"不支持的格式: {format}")
logger.info(f"转录结果已保存: {output_path}")
def _save_srt(self, result, output_path):
"""保存为SRT字幕格式"""
with open(output_path, 'w', encoding='utf-8') as f:
for i, segment in enumerate(result["segments"], 1):
start_time = self._format_time(segment["start"])
end_time = self._format_time(segment["end"])
text = segment["text"].strip()
f.write(f"{i}\n")
f.write(f"{start_time} --> {end_time}\n")
f.write(f"{text}\n\n")
def _format_time(self, seconds):
"""格式化时间为SRT格式"""
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
seconds = seconds % 60
return f"{hours:02d}:{minutes:02d}:{seconds:06.3f}".replace('.', ',')
def main():
"""主函数"""
parser = argparse.ArgumentParser(description="使用Whisper模型进行语音识别")
parser.add_argument("input", help="输入音频文件路径或目录")
parser.add_argument("-m", "--model", default="medium",
choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"],
help="Whisper模型大小 (默认: base)")
parser.add_argument("-l", "--language", default="zh",
help="语言代码 (zh=中文, en=英文, auto=自动检测, 默认: zh)")
parser.add_argument("-t", "--task", default="transcribe",
choices=["transcribe", "translate"],
help="任务类型 (transcribe=转录, translate=翻译为英文, 默认: transcribe)")
parser.add_argument("-o", "--output", default="transcripts",
help="输出目录 (默认: transcripts)")
parser.add_argument("-f", "--format", default="json",
choices=["json", "txt", "srt"],
help="输出格式 (默认: json)")
args = parser.parse_args()
# 创建转录器
transcriber = WhisperTranscriber(model_size=args.model)
input_path = Path(args.input)
try:
if input_path.is_file():
# 处理单个文件
logger.info("处理单个音频文件")
result = transcriber.transcribe_audio(
args.input,
language=args.language,
task=args.task
)
# 保存结果
output_dir = Path(args.output)
output_dir.mkdir(exist_ok=True)
file_stem = input_path.stem
output_file = output_dir / f"{file_stem}_transcript.{args.format}"
transcriber.save_transcript(result, output_file, args.format)
# 显示结果
print(f"\n转录完成!")
print(f"原文件: {args.input}")
print(f"输出文件: {output_file}")
print(f"识别语言: {result['language']}")
print(f"转录时长: {result['transcribe_time']:.2f}")
print(f"音频时长: {result['stats']['total_duration']:.2f}")
print(f"文字长度: {result['stats']['text_length']} 字符")
print(f"单词数量: {result['stats']['words_count']}")
print(f"\n转录内容:")
print("-" * 50)
print(result["text"])
elif input_path.is_dir():
# 处理目录中的所有音频文件
logger.info("处理目录中的音频文件")
audio_extensions = ['.wav', '.mp3', '.m4a', '.flac', '.aac', '.ogg']
audio_files = []
for ext in audio_extensions:
audio_files.extend(input_path.glob(f"*{ext}"))
audio_files.extend(input_path.glob(f"*{ext.upper()}"))
if not audio_files:
logger.warning(f"在目录 {args.input} 中未找到音频文件")
return
logger.info(f"找到 {len(audio_files)} 个音频文件")
results = transcriber.transcribe_multiple_files(
[str(f) for f in audio_files],
output_dir=args.output,
language=args.language,
task=args.task
)
# 统计结果
success_count = sum(1 for r in results if r is not None)
print(f"\n批量转录完成!")
print(f"总文件数: {len(audio_files)}")
print(f"成功转录: {success_count}")
print(f"失败: {len(audio_files) - success_count}")
print(f"输出目录: {args.output}")
else:
logger.error(f"输入路径无效: {args.input}")
return
except Exception as e:
logger.error(f"程序执行出错: {str(e)}")
return
if __name__ == "__main__":
main()
#base 138m
#medium 1.42g