hot_video_analyse/code/sensevoice_transcribe.py

370 lines
14 KiB
Python
Raw Permalink Normal View History

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
使用SenseVoice模型进行高精度中文语音识别
"""
import os
import time
import json
import argparse
from pathlib import Path
from datetime import datetime
import logging
# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
try:
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess
except ImportError:
logger.error("请先安装FunASR: pip install funasr")
raise
class SenseVoiceTranscriber:
"""使用SenseVoice模型进行高精度中文语音识别"""
def __init__(self, model_dir="/root/autodl-tmp/llm/sensevoice", device="cuda:0"):
"""
初始化SenseVoice转录器
Args:
model_dir: 模型路径或名称
device: 运行设备 (cuda:0, cpu)
"""
self.model_dir = model_dir
self.device = device
self.model = None
self.load_model()
def load_model(self):
"""加载SenseVoice模型"""
logger.info(f"正在加载SenseVoice模型: {self.model_dir}")
start_time = time.time()
try:
self.model = AutoModel(
model=self.model_dir,
trust_remote_code=True,
device=self.device,
disable_update=True
)
load_time = time.time() - start_time
logger.info(f"SenseVoice模型加载完成耗时: {load_time:.2f}")
except Exception as e:
logger.error(f"模型加载失败: {str(e)}")
raise
def transcribe_audio(self, audio_path, language="auto", use_itn=True):
"""
转录音频文件
Args:
audio_path: 音频文件路径
language: 语言设置 ("auto", "zh", "en", "yue", "ja", "ko")
use_itn: 是否使用逆文本标准化
Returns:
dict: 转录结果
"""
audio_path = Path(audio_path)
if not audio_path.exists():
raise FileNotFoundError(f"音频文件不存在: {audio_path}")
logger.info(f"开始转录音频文件: {audio_path}")
logger.info(f"语言设置: {language}, 使用ITN: {use_itn}")
start_time = time.time()
try:
# 执行转录
result = self.model.generate(
input=str(audio_path),
cache={},
language=language,
use_itn=use_itn,
batch_size_s=60, # 批处理大小(秒)
merge_vad=True, # 合并VAD结果
merge_length_s=15 # 合并长度(秒)
)
transcribe_time = time.time() - start_time
logger.info(f"转录完成,耗时: {transcribe_time:.2f}")
# 直接返回原始结果
return {
"raw_result": result,
"transcribe_time": transcribe_time,
"model": "SenseVoice",
"timestamp": datetime.now().isoformat(),
"file_path": str(audio_path)
}
except Exception as e:
logger.error(f"转录失败: {str(e)}")
raise
def segment_by_timestamp(self, result):
"""
根据时间戳将转录结果分段
Args:
result: 原始转录结果
Returns:
list: 分段后的结果列表
"""
if not isinstance(result, list):
return result
segmented_results = []
for item in result:
if isinstance(item, dict):
text = item.get("text", "")
timestamp = item.get("timestamp", [])
if timestamp and len(timestamp) >= 2:
# 有时间戳信息,进行分段
segment = {
"start_time": timestamp[0] / 1000.0, # 转换为秒
"end_time": timestamp[1] / 1000.0,
"duration": (timestamp[1] - timestamp[0]) / 1000.0,
"text": text,
"clean_text": self._clean_text(text)
}
segmented_results.append(segment)
else:
# 没有时间戳信息,直接添加
segment = {
"text": text,
"clean_text": self._clean_text(text)
}
segmented_results.append(segment)
return segmented_results
def _clean_text(self, text):
"""清理文本中的特殊标签"""
import re
# 移除SenseVoice的特殊标签
tags_pattern = r'<\|[^|]+\|>'
clean_text = re.sub(tags_pattern, '', text)
# 移除多余空格
clean_text = ' '.join(clean_text.split())
return clean_text.strip()
def transcribe_multiple_files(self, audio_files, output_dir="sensevoice_kuaishou", **kwargs):
"""批量转录多个音频文件"""
output_path = Path(output_dir)
output_path.mkdir(exist_ok=True)
results = []
for i, audio_file in enumerate(audio_files, 1):
logger.info(f"处理第 {i}/{len(audio_files)} 个文件")
try:
result = self.transcribe_audio(audio_file, **kwargs)
# 保存单个文件的结果
audio_name = Path(audio_file).stem
output_file = output_path / f"{audio_name}_sensevoice.json"
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(result, f, ensure_ascii=False, indent=2)
logger.info(f"转录结果已保存: {output_file}")
results.append(result)
except Exception as e:
logger.error(f"处理文件 {audio_file} 时出错: {str(e)}")
results.append(None)
return results
def save_transcript(self, result, output_path, format="json"):
"""保存转录结果"""
output_path = Path(output_path)
if format == "json":
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(result, f, ensure_ascii=False, indent=2)
elif format == "txt":
# 提取原始文本并清理
raw_result = result.get("raw_result", [])
if isinstance(raw_result, list):
texts = []
for item in raw_result:
if isinstance(item, dict):
text = item.get("text", "")
if text:
clean_text = self._clean_text(text)
texts.append(clean_text)
full_text = " ".join(texts)
else:
full_text = str(raw_result)
with open(output_path, 'w', encoding='utf-8') as f:
f.write(full_text)
elif format == "srt":
# 使用时间戳分段生成SRT
raw_result = result.get("raw_result", [])
segmented = self.segment_by_timestamp(raw_result)
self._save_srt_from_segments(segmented, output_path)
else:
raise ValueError(f"不支持的格式: {format}")
logger.info(f"转录结果已保存: {output_path}")
def _save_srt_from_segments(self, segments, output_path):
"""从分段结果保存为SRT字幕格式"""
with open(output_path, 'w', encoding='utf-8') as f:
for i, segment in enumerate(segments, 1):
if "start_time" in segment and "end_time" in segment:
start_time = self._format_time(segment["start_time"])
end_time = self._format_time(segment["end_time"])
text = segment.get("clean_text", segment.get("text", "")).strip()
f.write(f"{i}\n")
f.write(f"{start_time} --> {end_time}\n")
f.write(f"{text}\n\n")
def _format_time(self, seconds):
"""格式化时间为SRT格式"""
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
seconds = seconds % 60
return f"{hours:02d}:{minutes:02d}:{seconds:06.3f}".replace('.', ',')
def main():
"""主函数"""
parser = argparse.ArgumentParser(description="使用SenseVoice模型进行高精度中文语音识别")
parser.add_argument("input", help="输入音频文件路径或目录")
parser.add_argument("-l", "--language", default="auto",
choices=["auto", "zh", "en", "yue", "ja", "ko"],
help="语言设置 (默认: auto)")
parser.add_argument("-o", "--output", default="/root/autodl-tmp/new_sensevoice",
help="输出目录 (默认: /root/autodl-tmp/sensevoice_kuaishou)")
parser.add_argument("-f", "--format", default="json",
choices=["json", "txt", "srt"],
help="输出格式 (默认: json)")
parser.add_argument("--device", default="cuda:0",
help="运行设备 (默认: cuda:0)")
parser.add_argument("--no-itn", action="store_true",
help="禁用逆文本标准化")
args = parser.parse_args()
# 创建转录器
transcriber = SenseVoiceTranscriber(device=args.device)
input_path = Path(args.input)
try:
if input_path.is_file():
# 处理单个文件
logger.info("处理单个音频文件")
result = transcriber.transcribe_audio(
args.input,
language=args.language,
use_itn=not args.no_itn
)
# 保存结果
output_dir = Path(args.output)
output_dir.mkdir(exist_ok=True)
file_stem = input_path.stem
output_file = output_dir / f"{file_stem}_sensevoice.{args.format}"
transcriber.save_transcript(result, output_file, args.format)
# 显示结果
print(f"\n转录完成!")
print(f"原文件: {args.input}")
print(f"输出文件: {output_file}")
print(f"转录时长: {result['transcribe_time']:.2f}")
# 显示原始结果
raw_result = result.get("raw_result", [])
print(f"\n原始转录结果:")
print("-" * 50)
print(raw_result)
# 显示分段结果
if isinstance(raw_result, list):
segmented = transcriber.segment_by_timestamp(raw_result)
print(f"\n分段结果 (共 {len(segmented)} 段):")
print("-" * 50)
for i, segment in enumerate(segmented, 1):
if "start_time" in segment:
print(f"{i}: [{segment['start_time']:.2f}s - {segment['end_time']:.2f}s]")
else:
print(f"{i}: [无时间戳]")
print(f"原文: {segment.get('text', '')}")
print(f"清理: {segment.get('clean_text', '')}")
print("-" * 30)
# 显示完整清理文本
if isinstance(raw_result, list):
texts = []
for item in raw_result:
if isinstance(item, dict):
text = item.get("text", "")
if text:
clean_text = transcriber._clean_text(text)
texts.append(clean_text)
full_text = " ".join(texts)
print(f"\n完整清理文本:")
print("-" * 50)
print(full_text)
elif input_path.is_dir():
# 处理目录中的所有音频文件
logger.info("处理目录中的音频文件")
audio_extensions = ['.wav', '.mp3', '.m4a', '.flac', '.aac', '.ogg']
audio_files = []
for ext in audio_extensions:
audio_files.extend(input_path.glob(f"*{ext}"))
audio_files.extend(input_path.glob(f"*{ext.upper()}"))
if not audio_files:
logger.warning(f"在目录 {args.input} 中未找到音频文件")
return
logger.info(f"找到 {len(audio_files)} 个音频文件")
results = transcriber.transcribe_multiple_files(
[str(f) for f in audio_files],
output_dir=args.output,
language=args.language,
use_itn=not args.no_itn
)
# 统计结果
success_count = sum(1 for r in results if r is not None)
print(f"\n批量转录完成!")
print(f"总文件数: {len(audio_files)}")
print(f"成功转录: {success_count}")
print(f"失败: {len(audio_files) - success_count}")
print(f"输出目录: {args.output}")
else:
logger.error(f"输入路径无效: {args.input}")
return
except Exception as e:
logger.error(f"程序执行出错: {str(e)}")
return
if __name__ == "__main__":
main()