340 lines
12 KiB
Python
340 lines
12 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
使用Whisper模型进行语音识别
|
||
"""
|
||
|
||
import whisper
|
||
import os
|
||
import time
|
||
from pathlib import Path
|
||
import argparse
|
||
import json
|
||
from datetime import datetime
|
||
import logging
|
||
|
||
# 设置Whisper模型缓存目录
|
||
os.environ['WHISPER_CACHE_DIR'] = '/root/autodl-tmp/llm/whisper'
|
||
|
||
# 设置日志
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class WhisperTranscriber:
|
||
"""使用Whisper模型进行语音转文字"""
|
||
|
||
def __init__(self, model_size="base"):
|
||
"""
|
||
初始化Whisper转录器
|
||
|
||
Args:
|
||
model_size: 模型大小 (tiny, base, small, medium, large, large-v2, large-v3)
|
||
"""
|
||
self.model_size = model_size
|
||
self.model = None
|
||
self.load_model()
|
||
|
||
def load_model(self):
|
||
"""加载Whisper模型"""
|
||
logger.info(f"正在加载Whisper模型: {self.model_size}")
|
||
logger.info(f"模型缓存目录: {os.environ.get('WHISPER_CACHE_DIR', '默认目录')}")
|
||
start_time = time.time()
|
||
|
||
try:
|
||
# 确保缓存目录存在
|
||
cache_dir = Path(os.environ['WHISPER_CACHE_DIR'])
|
||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
self.model = whisper.load_model(self.model_size, download_root=str(cache_dir))
|
||
load_time = time.time() - start_time
|
||
logger.info(f"模型加载完成,耗时: {load_time:.2f} 秒")
|
||
except Exception as e:
|
||
logger.error(f"模型加载失败: {str(e)}")
|
||
raise
|
||
|
||
def transcribe_audio(self, audio_path, language="zh", task="transcribe"):
|
||
"""
|
||
转录音频文件
|
||
|
||
Args:
|
||
audio_path: 音频文件路径
|
||
language: 语言代码 (zh=中文, en=英文, auto=自动检测)
|
||
task: 任务类型 (transcribe=转录, translate=翻译为英文)
|
||
|
||
Returns:
|
||
dict: 转录结果
|
||
"""
|
||
audio_path = Path(audio_path)
|
||
|
||
if not audio_path.exists():
|
||
raise FileNotFoundError(f"音频文件不存在: {audio_path}")
|
||
|
||
logger.info(f"开始转录音频文件: {audio_path}")
|
||
logger.info(f"语言: {language}, 任务: {task}")
|
||
|
||
start_time = time.time()
|
||
|
||
try:
|
||
# 设置转录参数
|
||
options = {
|
||
"task": task,
|
||
"fp16": False, # 使用FP32以提高兼容性
|
||
"verbose": False, # 减少输出
|
||
}
|
||
|
||
# 如果指定了语言,添加到选项中
|
||
if language != "auto":
|
||
options["language"] = language
|
||
|
||
# 执行转录
|
||
try:
|
||
result = self.model.transcribe(str(audio_path), **options)
|
||
except Exception as e:
|
||
# 如果出现PyTorch错误,尝试使用更简单的参数
|
||
logger.warning(f"转录失败,尝试使用简化参数: {str(e)}")
|
||
simple_options = {
|
||
"task": task,
|
||
"fp16": False,
|
||
"verbose": False,
|
||
}
|
||
if language != "auto":
|
||
simple_options["language"] = language
|
||
|
||
try:
|
||
result = self.model.transcribe(str(audio_path), **simple_options)
|
||
except Exception as e2:
|
||
logger.error(f"简化参数也失败: {str(e2)}")
|
||
raise e2
|
||
|
||
transcribe_time = time.time() - start_time
|
||
logger.info(f"转录完成,耗时: {transcribe_time:.2f} 秒")
|
||
|
||
# 处理结果
|
||
processed_result = self._process_result(result, transcribe_time)
|
||
|
||
return processed_result
|
||
|
||
except Exception as e:
|
||
logger.error(f"转录失败: {str(e)}")
|
||
raise
|
||
|
||
def _process_result(self, result, transcribe_time):
|
||
"""处理转录结果"""
|
||
processed = {
|
||
"text": result["text"].strip(),
|
||
"language": result.get("language", "unknown"),
|
||
"segments": [],
|
||
"transcribe_time": transcribe_time,
|
||
"model_size": self.model_size,
|
||
"timestamp": datetime.now().isoformat()
|
||
}
|
||
|
||
# 处理分段信息
|
||
for segment in result.get("segments", []):
|
||
processed_segment = {
|
||
"id": segment.get("id"),
|
||
"start": segment.get("start"),
|
||
"end": segment.get("end"),
|
||
"text": segment.get("text", "").strip(),
|
||
"confidence": segment.get("avg_logprob", 0)
|
||
}
|
||
processed["segments"].append(processed_segment)
|
||
|
||
# 统计信息
|
||
processed["stats"] = {
|
||
"total_segments": len(processed["segments"]),
|
||
"total_duration": max([s.get("end", 0) for s in processed["segments"]] + [0]),
|
||
"text_length": len(processed["text"]),
|
||
"words_count": len(processed["text"].split()) if processed["text"] else 0
|
||
}
|
||
|
||
return processed
|
||
|
||
def transcribe_multiple_files(self, audio_files, output_dir="transcripts", **kwargs):
|
||
"""
|
||
批量转录多个音频文件
|
||
|
||
Args:
|
||
audio_files: 音频文件路径列表
|
||
output_dir: 输出目录
|
||
**kwargs: 传递给transcribe_audio的参数
|
||
|
||
Returns:
|
||
list: 转录结果列表
|
||
"""
|
||
output_path = Path(output_dir)
|
||
output_path.mkdir(exist_ok=True)
|
||
|
||
results = []
|
||
|
||
for i, audio_file in enumerate(audio_files, 1):
|
||
logger.info(f"处理第 {i}/{len(audio_files)} 个文件")
|
||
|
||
try:
|
||
result = self.transcribe_audio(audio_file, **kwargs)
|
||
|
||
# 保存单个文件的结果
|
||
audio_name = Path(audio_file).stem
|
||
output_file = output_path / f"{audio_name}_transcript.json"
|
||
|
||
with open(output_file, 'w', encoding='utf-8') as f:
|
||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||
|
||
logger.info(f"转录结果已保存: {output_file}")
|
||
results.append(result)
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理文件 {audio_file} 时出错: {str(e)}")
|
||
results.append(None)
|
||
|
||
return results
|
||
|
||
def save_transcript(self, result, output_path, format="json"):
|
||
"""
|
||
保存转录结果
|
||
|
||
Args:
|
||
result: 转录结果
|
||
output_path: 输出文件路径
|
||
format: 输出格式 (json, txt, srt)
|
||
"""
|
||
output_path = Path(output_path)
|
||
|
||
if format == "json":
|
||
with open(output_path, 'w', encoding='utf-8') as f:
|
||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||
|
||
elif format == "txt":
|
||
with open(output_path, 'w', encoding='utf-8') as f:
|
||
f.write(result["text"])
|
||
|
||
elif format == "srt":
|
||
self._save_srt(result, output_path)
|
||
|
||
else:
|
||
raise ValueError(f"不支持的格式: {format}")
|
||
|
||
logger.info(f"转录结果已保存: {output_path}")
|
||
|
||
def _save_srt(self, result, output_path):
|
||
"""保存为SRT字幕格式"""
|
||
with open(output_path, 'w', encoding='utf-8') as f:
|
||
for i, segment in enumerate(result["segments"], 1):
|
||
start_time = self._format_time(segment["start"])
|
||
end_time = self._format_time(segment["end"])
|
||
text = segment["text"].strip()
|
||
|
||
f.write(f"{i}\n")
|
||
f.write(f"{start_time} --> {end_time}\n")
|
||
f.write(f"{text}\n\n")
|
||
|
||
def _format_time(self, seconds):
|
||
"""格式化时间为SRT格式"""
|
||
hours = int(seconds // 3600)
|
||
minutes = int((seconds % 3600) // 60)
|
||
seconds = seconds % 60
|
||
return f"{hours:02d}:{minutes:02d}:{seconds:06.3f}".replace('.', ',')
|
||
|
||
def main():
|
||
"""主函数"""
|
||
parser = argparse.ArgumentParser(description="使用Whisper模型进行语音识别")
|
||
parser.add_argument("input", help="输入音频文件路径或目录")
|
||
parser.add_argument("-m", "--model", default="medium",
|
||
choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"],
|
||
help="Whisper模型大小 (默认: base)")
|
||
parser.add_argument("-l", "--language", default="zh",
|
||
help="语言代码 (zh=中文, en=英文, auto=自动检测, 默认: zh)")
|
||
parser.add_argument("-t", "--task", default="transcribe",
|
||
choices=["transcribe", "translate"],
|
||
help="任务类型 (transcribe=转录, translate=翻译为英文, 默认: transcribe)")
|
||
parser.add_argument("-o", "--output", default="transcripts",
|
||
help="输出目录 (默认: transcripts)")
|
||
parser.add_argument("-f", "--format", default="json",
|
||
choices=["json", "txt", "srt"],
|
||
help="输出格式 (默认: json)")
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 创建转录器
|
||
transcriber = WhisperTranscriber(model_size=args.model)
|
||
|
||
input_path = Path(args.input)
|
||
|
||
try:
|
||
if input_path.is_file():
|
||
# 处理单个文件
|
||
logger.info("处理单个音频文件")
|
||
result = transcriber.transcribe_audio(
|
||
args.input,
|
||
language=args.language,
|
||
task=args.task
|
||
)
|
||
|
||
# 保存结果
|
||
output_dir = Path(args.output)
|
||
output_dir.mkdir(exist_ok=True)
|
||
|
||
file_stem = input_path.stem
|
||
output_file = output_dir / f"{file_stem}_transcript.{args.format}"
|
||
|
||
transcriber.save_transcript(result, output_file, args.format)
|
||
|
||
# 显示结果
|
||
print(f"\n转录完成!")
|
||
print(f"原文件: {args.input}")
|
||
print(f"输出文件: {output_file}")
|
||
print(f"识别语言: {result['language']}")
|
||
print(f"转录时长: {result['transcribe_time']:.2f} 秒")
|
||
print(f"音频时长: {result['stats']['total_duration']:.2f} 秒")
|
||
print(f"文字长度: {result['stats']['text_length']} 字符")
|
||
print(f"单词数量: {result['stats']['words_count']} 个")
|
||
print(f"\n转录内容:")
|
||
print("-" * 50)
|
||
print(result["text"])
|
||
|
||
elif input_path.is_dir():
|
||
# 处理目录中的所有音频文件
|
||
logger.info("处理目录中的音频文件")
|
||
|
||
audio_extensions = ['.wav', '.mp3', '.m4a', '.flac', '.aac', '.ogg']
|
||
audio_files = []
|
||
|
||
for ext in audio_extensions:
|
||
audio_files.extend(input_path.glob(f"*{ext}"))
|
||
audio_files.extend(input_path.glob(f"*{ext.upper()}"))
|
||
|
||
if not audio_files:
|
||
logger.warning(f"在目录 {args.input} 中未找到音频文件")
|
||
return
|
||
|
||
logger.info(f"找到 {len(audio_files)} 个音频文件")
|
||
|
||
results = transcriber.transcribe_multiple_files(
|
||
[str(f) for f in audio_files],
|
||
output_dir=args.output,
|
||
language=args.language,
|
||
task=args.task
|
||
)
|
||
|
||
# 统计结果
|
||
success_count = sum(1 for r in results if r is not None)
|
||
print(f"\n批量转录完成!")
|
||
print(f"总文件数: {len(audio_files)}")
|
||
print(f"成功转录: {success_count}")
|
||
print(f"失败: {len(audio_files) - success_count}")
|
||
print(f"输出目录: {args.output}")
|
||
|
||
else:
|
||
logger.error(f"输入路径无效: {args.input}")
|
||
return
|
||
|
||
except Exception as e:
|
||
logger.error(f"程序执行出错: {str(e)}")
|
||
return
|
||
|
||
if __name__ == "__main__":
|
||
main()
|
||
|
||
#base 138m
|
||
#medium 1.42g
|