349 lines
13 KiB
Python
349 lines
13 KiB
Python
|
#!/usr/bin/env python3
|
|||
|
# -*- coding: utf-8 -*-
|
|||
|
"""
|
|||
|
使用SenseVoice模型进行高精度中文语音识别
|
|||
|
"""
|
|||
|
|
|||
|
import os
|
|||
|
import time
|
|||
|
import json
|
|||
|
import argparse
|
|||
|
from pathlib import Path
|
|||
|
from datetime import datetime
|
|||
|
import logging
|
|||
|
|
|||
|
# 设置日志
|
|||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|||
|
logger = logging.getLogger(__name__)
|
|||
|
|
|||
|
try:
|
|||
|
from funasr import AutoModel
|
|||
|
from funasr.utils.postprocess_utils import rich_transcription_postprocess
|
|||
|
except ImportError:
|
|||
|
logger.error("请先安装FunASR: pip install funasr")
|
|||
|
raise
|
|||
|
|
|||
|
class SenseVoiceTranscriber:
|
|||
|
"""使用SenseVoice模型进行高精度中文语音识别"""
|
|||
|
|
|||
|
def __init__(self, model_dir="iic/SenseVoiceSmall", device="cuda:0"):
|
|||
|
"""
|
|||
|
初始化SenseVoice转录器
|
|||
|
|
|||
|
Args:
|
|||
|
model_dir: 模型路径或名称
|
|||
|
device: 运行设备 (cuda:0, cpu)
|
|||
|
"""
|
|||
|
self.model_dir = model_dir
|
|||
|
self.device = device
|
|||
|
self.model = None
|
|||
|
self.load_model()
|
|||
|
|
|||
|
def load_model(self):
|
|||
|
"""加载SenseVoice模型"""
|
|||
|
logger.info(f"正在加载SenseVoice模型: {self.model_dir}")
|
|||
|
start_time = time.time()
|
|||
|
|
|||
|
try:
|
|||
|
self.model = AutoModel(
|
|||
|
model=self.model_dir,
|
|||
|
trust_remote_code=True,
|
|||
|
device=self.device
|
|||
|
)
|
|||
|
load_time = time.time() - start_time
|
|||
|
logger.info(f"SenseVoice模型加载完成,耗时: {load_time:.2f} 秒")
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"模型加载失败: {str(e)}")
|
|||
|
raise
|
|||
|
|
|||
|
def transcribe_audio(self, audio_path, language="auto", use_itn=True):
|
|||
|
"""
|
|||
|
转录音频文件
|
|||
|
|
|||
|
Args:
|
|||
|
audio_path: 音频文件路径
|
|||
|
language: 语言设置 ("auto", "zh", "en", "yue", "ja", "ko")
|
|||
|
use_itn: 是否使用逆文本标准化
|
|||
|
|
|||
|
Returns:
|
|||
|
dict: 转录结果
|
|||
|
"""
|
|||
|
audio_path = Path(audio_path)
|
|||
|
|
|||
|
if not audio_path.exists():
|
|||
|
raise FileNotFoundError(f"音频文件不存在: {audio_path}")
|
|||
|
|
|||
|
logger.info(f"开始转录音频文件: {audio_path}")
|
|||
|
logger.info(f"语言设置: {language}, 使用ITN: {use_itn}")
|
|||
|
|
|||
|
start_time = time.time()
|
|||
|
|
|||
|
try:
|
|||
|
# 执行转录
|
|||
|
result = self.model.generate(
|
|||
|
input=str(audio_path),
|
|||
|
cache={},
|
|||
|
language=language,
|
|||
|
use_itn=use_itn,
|
|||
|
batch_size_s=60, # 批处理大小(秒)
|
|||
|
merge_vad=True, # 合并VAD结果
|
|||
|
merge_length_s=15 # 合并长度(秒)
|
|||
|
)
|
|||
|
|
|||
|
transcribe_time = time.time() - start_time
|
|||
|
logger.info(f"转录完成,耗时: {transcribe_time:.2f} 秒")
|
|||
|
|
|||
|
# 处理结果
|
|||
|
processed_result = self._process_result(result, transcribe_time, audio_path)
|
|||
|
|
|||
|
return processed_result
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"转录失败: {str(e)}")
|
|||
|
raise
|
|||
|
|
|||
|
def _process_result(self, result, transcribe_time, audio_path):
|
|||
|
"""处理转录结果"""
|
|||
|
processed = {
|
|||
|
"text": "",
|
|||
|
"segments": [],
|
|||
|
"emotions": [],
|
|||
|
"events": [],
|
|||
|
"transcribe_time": transcribe_time,
|
|||
|
"model": "SenseVoice",
|
|||
|
"timestamp": datetime.now().isoformat(),
|
|||
|
"file_path": str(audio_path)
|
|||
|
}
|
|||
|
|
|||
|
# 处理结果列表
|
|||
|
if isinstance(result, list) and len(result) > 0:
|
|||
|
full_text_parts = []
|
|||
|
|
|||
|
for item in result:
|
|||
|
if isinstance(item, dict):
|
|||
|
# 获取文本内容
|
|||
|
text = item.get("text", "")
|
|||
|
if text:
|
|||
|
full_text_parts.append(text)
|
|||
|
|
|||
|
# 处理时间戳信息
|
|||
|
timestamp = item.get("timestamp", [])
|
|||
|
if timestamp and len(timestamp) >= 2:
|
|||
|
segment = {
|
|||
|
"start": timestamp[0] / 1000.0, # 转换为秒
|
|||
|
"end": timestamp[1] / 1000.0,
|
|||
|
"text": text
|
|||
|
}
|
|||
|
processed["segments"].append(segment)
|
|||
|
|
|||
|
# 处理情感标签
|
|||
|
if "<|HAPPY|>" in text:
|
|||
|
processed["emotions"].append({"emotion": "happy", "text": text})
|
|||
|
elif "<|SAD|>" in text:
|
|||
|
processed["emotions"].append({"emotion": "sad", "text": text})
|
|||
|
elif "<|ANGRY|>" in text:
|
|||
|
processed["emotions"].append({"emotion": "angry", "text": text})
|
|||
|
|
|||
|
# 处理事件标签
|
|||
|
if "<|SPEECH|>" in text:
|
|||
|
processed["events"].append({"event": "speech", "text": text})
|
|||
|
elif "<|Music|>" in text:
|
|||
|
processed["events"].append({"event": "music", "text": text})
|
|||
|
elif "<|BGM|>" in text:
|
|||
|
processed["events"].append({"event": "background_music", "text": text})
|
|||
|
|
|||
|
# 合并完整文本
|
|||
|
processed["text"] = " ".join(full_text_parts)
|
|||
|
|
|||
|
# 清理特殊标签
|
|||
|
processed["clean_text"] = self._clean_text(processed["text"])
|
|||
|
|
|||
|
# 统计信息
|
|||
|
processed["stats"] = {
|
|||
|
"total_segments": len(processed["segments"]),
|
|||
|
"total_duration": max([s.get("end", 0) for s in processed["segments"]] + [0]),
|
|||
|
"text_length": len(processed["clean_text"]),
|
|||
|
"words_count": len(processed["clean_text"]) if processed["clean_text"] else 0,
|
|||
|
"emotions_detected": len(processed["emotions"]),
|
|||
|
"events_detected": len(processed["events"])
|
|||
|
}
|
|||
|
|
|||
|
return processed
|
|||
|
|
|||
|
def _clean_text(self, text):
|
|||
|
"""清理文本中的特殊标签"""
|
|||
|
import re
|
|||
|
# 移除SenseVoice的特殊标签
|
|||
|
tags_pattern = r'<\|[^|]+\|>'
|
|||
|
clean_text = re.sub(tags_pattern, '', text)
|
|||
|
# 移除多余空格
|
|||
|
clean_text = ' '.join(clean_text.split())
|
|||
|
return clean_text.strip()
|
|||
|
|
|||
|
def transcribe_multiple_files(self, audio_files, output_dir="sensevoice_transcripts", **kwargs):
|
|||
|
"""批量转录多个音频文件"""
|
|||
|
output_path = Path(output_dir)
|
|||
|
output_path.mkdir(exist_ok=True)
|
|||
|
|
|||
|
results = []
|
|||
|
|
|||
|
for i, audio_file in enumerate(audio_files, 1):
|
|||
|
logger.info(f"处理第 {i}/{len(audio_files)} 个文件")
|
|||
|
|
|||
|
try:
|
|||
|
result = self.transcribe_audio(audio_file, **kwargs)
|
|||
|
|
|||
|
# 保存单个文件的结果
|
|||
|
audio_name = Path(audio_file).stem
|
|||
|
output_file = output_path / f"{audio_name}_sensevoice.json"
|
|||
|
|
|||
|
with open(output_file, 'w', encoding='utf-8') as f:
|
|||
|
json.dump(result, f, ensure_ascii=False, indent=2)
|
|||
|
|
|||
|
logger.info(f"转录结果已保存: {output_file}")
|
|||
|
results.append(result)
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"处理文件 {audio_file} 时出错: {str(e)}")
|
|||
|
results.append(None)
|
|||
|
|
|||
|
return results
|
|||
|
|
|||
|
def save_transcript(self, result, output_path, format="json"):
|
|||
|
"""保存转录结果"""
|
|||
|
output_path = Path(output_path)
|
|||
|
|
|||
|
if format == "json":
|
|||
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|||
|
json.dump(result, f, ensure_ascii=False, indent=2)
|
|||
|
|
|||
|
elif format == "txt":
|
|||
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|||
|
f.write(result["clean_text"])
|
|||
|
|
|||
|
elif format == "srt":
|
|||
|
self._save_srt(result, output_path)
|
|||
|
|
|||
|
else:
|
|||
|
raise ValueError(f"不支持的格式: {format}")
|
|||
|
|
|||
|
logger.info(f"转录结果已保存: {output_path}")
|
|||
|
|
|||
|
def _save_srt(self, result, output_path):
|
|||
|
"""保存为SRT字幕格式"""
|
|||
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|||
|
for i, segment in enumerate(result["segments"], 1):
|
|||
|
start_time = self._format_time(segment["start"])
|
|||
|
end_time = self._format_time(segment["end"])
|
|||
|
text = segment["text"].strip()
|
|||
|
|
|||
|
f.write(f"{i}\n")
|
|||
|
f.write(f"{start_time} --> {end_time}\n")
|
|||
|
f.write(f"{text}\n\n")
|
|||
|
|
|||
|
def _format_time(self, seconds):
|
|||
|
"""格式化时间为SRT格式"""
|
|||
|
hours = int(seconds // 3600)
|
|||
|
minutes = int((seconds % 3600) // 60)
|
|||
|
seconds = seconds % 60
|
|||
|
return f"{hours:02d}:{minutes:02d}:{seconds:06.3f}".replace('.', ',')
|
|||
|
|
|||
|
def main():
|
|||
|
"""主函数"""
|
|||
|
parser = argparse.ArgumentParser(description="使用SenseVoice模型进行高精度中文语音识别")
|
|||
|
parser.add_argument("input", help="输入音频文件路径或目录")
|
|||
|
parser.add_argument("-l", "--language", default="auto",
|
|||
|
choices=["auto", "zh", "en", "yue", "ja", "ko"],
|
|||
|
help="语言设置 (默认: auto)")
|
|||
|
parser.add_argument("-o", "--output", default="sensevoice_transcripts",
|
|||
|
help="输出目录 (默认: sensevoice_transcripts)")
|
|||
|
parser.add_argument("-f", "--format", default="json",
|
|||
|
choices=["json", "txt", "srt"],
|
|||
|
help="输出格式 (默认: json)")
|
|||
|
parser.add_argument("--device", default="cuda:0",
|
|||
|
help="运行设备 (默认: cuda:0)")
|
|||
|
parser.add_argument("--no-itn", action="store_true",
|
|||
|
help="禁用逆文本标准化")
|
|||
|
|
|||
|
args = parser.parse_args()
|
|||
|
|
|||
|
# 创建转录器
|
|||
|
transcriber = SenseVoiceTranscriber(device=args.device)
|
|||
|
|
|||
|
input_path = Path(args.input)
|
|||
|
|
|||
|
try:
|
|||
|
if input_path.is_file():
|
|||
|
# 处理单个文件
|
|||
|
logger.info("处理单个音频文件")
|
|||
|
result = transcriber.transcribe_audio(
|
|||
|
args.input,
|
|||
|
language=args.language,
|
|||
|
use_itn=not args.no_itn
|
|||
|
)
|
|||
|
|
|||
|
# 保存结果
|
|||
|
output_dir = Path(args.output)
|
|||
|
output_dir.mkdir(exist_ok=True)
|
|||
|
|
|||
|
file_stem = input_path.stem
|
|||
|
output_file = output_dir / f"{file_stem}_sensevoice.{args.format}"
|
|||
|
|
|||
|
transcriber.save_transcript(result, output_file, args.format)
|
|||
|
|
|||
|
# 显示结果
|
|||
|
print(f"\n转录完成!")
|
|||
|
print(f"原文件: {args.input}")
|
|||
|
print(f"输出文件: {output_file}")
|
|||
|
print(f"转录时长: {result['transcribe_time']:.2f} 秒")
|
|||
|
print(f"音频时长: {result['stats']['total_duration']:.2f} 秒")
|
|||
|
print(f"文字长度: {result['stats']['text_length']} 字符")
|
|||
|
print(f"分段数量: {result['stats']['total_segments']} 个")
|
|||
|
print(f"情感检测: {result['stats']['emotions_detected']} 个")
|
|||
|
print(f"事件检测: {result['stats']['events_detected']} 个")
|
|||
|
print(f"\n转录内容:")
|
|||
|
print("-" * 50)
|
|||
|
print(result["clean_text"])
|
|||
|
|
|||
|
elif input_path.is_dir():
|
|||
|
# 处理目录中的所有音频文件
|
|||
|
logger.info("处理目录中的音频文件")
|
|||
|
|
|||
|
audio_extensions = ['.wav', '.mp3', '.m4a', '.flac', '.aac', '.ogg']
|
|||
|
audio_files = []
|
|||
|
|
|||
|
for ext in audio_extensions:
|
|||
|
audio_files.extend(input_path.glob(f"*{ext}"))
|
|||
|
audio_files.extend(input_path.glob(f"*{ext.upper()}"))
|
|||
|
|
|||
|
if not audio_files:
|
|||
|
logger.warning(f"在目录 {args.input} 中未找到音频文件")
|
|||
|
return
|
|||
|
|
|||
|
logger.info(f"找到 {len(audio_files)} 个音频文件")
|
|||
|
|
|||
|
results = transcriber.transcribe_multiple_files(
|
|||
|
[str(f) for f in audio_files],
|
|||
|
output_dir=args.output,
|
|||
|
language=args.language,
|
|||
|
use_itn=not args.no_itn
|
|||
|
)
|
|||
|
|
|||
|
# 统计结果
|
|||
|
success_count = sum(1 for r in results if r is not None)
|
|||
|
print(f"\n批量转录完成!")
|
|||
|
print(f"总文件数: {len(audio_files)}")
|
|||
|
print(f"成功转录: {success_count}")
|
|||
|
print(f"失败: {len(audio_files) - success_count}")
|
|||
|
print(f"输出目录: {args.output}")
|
|||
|
|
|||
|
else:
|
|||
|
logger.error(f"输入路径无效: {args.input}")
|
|||
|
return
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"程序执行出错: {str(e)}")
|
|||
|
return
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
main()
|