hot_video_analyse/code/sensevoice_transcribe.py

349 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
使用SenseVoice模型进行高精度中文语音识别
"""
import os
import time
import json
import argparse
from pathlib import Path
from datetime import datetime
import logging
# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
try:
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess
except ImportError:
logger.error("请先安装FunASR: pip install funasr")
raise
class SenseVoiceTranscriber:
"""使用SenseVoice模型进行高精度中文语音识别"""
def __init__(self, model_dir="iic/SenseVoiceSmall", device="cuda:0"):
"""
初始化SenseVoice转录器
Args:
model_dir: 模型路径或名称
device: 运行设备 (cuda:0, cpu)
"""
self.model_dir = model_dir
self.device = device
self.model = None
self.load_model()
def load_model(self):
"""加载SenseVoice模型"""
logger.info(f"正在加载SenseVoice模型: {self.model_dir}")
start_time = time.time()
try:
self.model = AutoModel(
model=self.model_dir,
trust_remote_code=True,
device=self.device
)
load_time = time.time() - start_time
logger.info(f"SenseVoice模型加载完成耗时: {load_time:.2f}")
except Exception as e:
logger.error(f"模型加载失败: {str(e)}")
raise
def transcribe_audio(self, audio_path, language="auto", use_itn=True):
"""
转录音频文件
Args:
audio_path: 音频文件路径
language: 语言设置 ("auto", "zh", "en", "yue", "ja", "ko")
use_itn: 是否使用逆文本标准化
Returns:
dict: 转录结果
"""
audio_path = Path(audio_path)
if not audio_path.exists():
raise FileNotFoundError(f"音频文件不存在: {audio_path}")
logger.info(f"开始转录音频文件: {audio_path}")
logger.info(f"语言设置: {language}, 使用ITN: {use_itn}")
start_time = time.time()
try:
# 执行转录
result = self.model.generate(
input=str(audio_path),
cache={},
language=language,
use_itn=use_itn,
batch_size_s=60, # 批处理大小(秒)
merge_vad=True, # 合并VAD结果
merge_length_s=15 # 合并长度(秒)
)
transcribe_time = time.time() - start_time
logger.info(f"转录完成,耗时: {transcribe_time:.2f}")
# 处理结果
processed_result = self._process_result(result, transcribe_time, audio_path)
return processed_result
except Exception as e:
logger.error(f"转录失败: {str(e)}")
raise
def _process_result(self, result, transcribe_time, audio_path):
"""处理转录结果"""
processed = {
"text": "",
"segments": [],
"emotions": [],
"events": [],
"transcribe_time": transcribe_time,
"model": "SenseVoice",
"timestamp": datetime.now().isoformat(),
"file_path": str(audio_path)
}
# 处理结果列表
if isinstance(result, list) and len(result) > 0:
full_text_parts = []
for item in result:
if isinstance(item, dict):
# 获取文本内容
text = item.get("text", "")
if text:
full_text_parts.append(text)
# 处理时间戳信息
timestamp = item.get("timestamp", [])
if timestamp and len(timestamp) >= 2:
segment = {
"start": timestamp[0] / 1000.0, # 转换为秒
"end": timestamp[1] / 1000.0,
"text": text
}
processed["segments"].append(segment)
# 处理情感标签
if "<|HAPPY|>" in text:
processed["emotions"].append({"emotion": "happy", "text": text})
elif "<|SAD|>" in text:
processed["emotions"].append({"emotion": "sad", "text": text})
elif "<|ANGRY|>" in text:
processed["emotions"].append({"emotion": "angry", "text": text})
# 处理事件标签
if "<|SPEECH|>" in text:
processed["events"].append({"event": "speech", "text": text})
elif "<|Music|>" in text:
processed["events"].append({"event": "music", "text": text})
elif "<|BGM|>" in text:
processed["events"].append({"event": "background_music", "text": text})
# 合并完整文本
processed["text"] = " ".join(full_text_parts)
# 清理特殊标签
processed["clean_text"] = self._clean_text(processed["text"])
# 统计信息
processed["stats"] = {
"total_segments": len(processed["segments"]),
"total_duration": max([s.get("end", 0) for s in processed["segments"]] + [0]),
"text_length": len(processed["clean_text"]),
"words_count": len(processed["clean_text"]) if processed["clean_text"] else 0,
"emotions_detected": len(processed["emotions"]),
"events_detected": len(processed["events"])
}
return processed
def _clean_text(self, text):
"""清理文本中的特殊标签"""
import re
# 移除SenseVoice的特殊标签
tags_pattern = r'<\|[^|]+\|>'
clean_text = re.sub(tags_pattern, '', text)
# 移除多余空格
clean_text = ' '.join(clean_text.split())
return clean_text.strip()
def transcribe_multiple_files(self, audio_files, output_dir="sensevoice_transcripts", **kwargs):
"""批量转录多个音频文件"""
output_path = Path(output_dir)
output_path.mkdir(exist_ok=True)
results = []
for i, audio_file in enumerate(audio_files, 1):
logger.info(f"处理第 {i}/{len(audio_files)} 个文件")
try:
result = self.transcribe_audio(audio_file, **kwargs)
# 保存单个文件的结果
audio_name = Path(audio_file).stem
output_file = output_path / f"{audio_name}_sensevoice.json"
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(result, f, ensure_ascii=False, indent=2)
logger.info(f"转录结果已保存: {output_file}")
results.append(result)
except Exception as e:
logger.error(f"处理文件 {audio_file} 时出错: {str(e)}")
results.append(None)
return results
def save_transcript(self, result, output_path, format="json"):
"""保存转录结果"""
output_path = Path(output_path)
if format == "json":
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(result, f, ensure_ascii=False, indent=2)
elif format == "txt":
with open(output_path, 'w', encoding='utf-8') as f:
f.write(result["clean_text"])
elif format == "srt":
self._save_srt(result, output_path)
else:
raise ValueError(f"不支持的格式: {format}")
logger.info(f"转录结果已保存: {output_path}")
def _save_srt(self, result, output_path):
"""保存为SRT字幕格式"""
with open(output_path, 'w', encoding='utf-8') as f:
for i, segment in enumerate(result["segments"], 1):
start_time = self._format_time(segment["start"])
end_time = self._format_time(segment["end"])
text = segment["text"].strip()
f.write(f"{i}\n")
f.write(f"{start_time} --> {end_time}\n")
f.write(f"{text}\n\n")
def _format_time(self, seconds):
"""格式化时间为SRT格式"""
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
seconds = seconds % 60
return f"{hours:02d}:{minutes:02d}:{seconds:06.3f}".replace('.', ',')
def main():
"""主函数"""
parser = argparse.ArgumentParser(description="使用SenseVoice模型进行高精度中文语音识别")
parser.add_argument("input", help="输入音频文件路径或目录")
parser.add_argument("-l", "--language", default="auto",
choices=["auto", "zh", "en", "yue", "ja", "ko"],
help="语言设置 (默认: auto)")
parser.add_argument("-o", "--output", default="sensevoice_transcripts",
help="输出目录 (默认: sensevoice_transcripts)")
parser.add_argument("-f", "--format", default="json",
choices=["json", "txt", "srt"],
help="输出格式 (默认: json)")
parser.add_argument("--device", default="cuda:0",
help="运行设备 (默认: cuda:0)")
parser.add_argument("--no-itn", action="store_true",
help="禁用逆文本标准化")
args = parser.parse_args()
# 创建转录器
transcriber = SenseVoiceTranscriber(device=args.device)
input_path = Path(args.input)
try:
if input_path.is_file():
# 处理单个文件
logger.info("处理单个音频文件")
result = transcriber.transcribe_audio(
args.input,
language=args.language,
use_itn=not args.no_itn
)
# 保存结果
output_dir = Path(args.output)
output_dir.mkdir(exist_ok=True)
file_stem = input_path.stem
output_file = output_dir / f"{file_stem}_sensevoice.{args.format}"
transcriber.save_transcript(result, output_file, args.format)
# 显示结果
print(f"\n转录完成!")
print(f"原文件: {args.input}")
print(f"输出文件: {output_file}")
print(f"转录时长: {result['transcribe_time']:.2f}")
print(f"音频时长: {result['stats']['total_duration']:.2f}")
print(f"文字长度: {result['stats']['text_length']} 字符")
print(f"分段数量: {result['stats']['total_segments']}")
print(f"情感检测: {result['stats']['emotions_detected']}")
print(f"事件检测: {result['stats']['events_detected']}")
print(f"\n转录内容:")
print("-" * 50)
print(result["clean_text"])
elif input_path.is_dir():
# 处理目录中的所有音频文件
logger.info("处理目录中的音频文件")
audio_extensions = ['.wav', '.mp3', '.m4a', '.flac', '.aac', '.ogg']
audio_files = []
for ext in audio_extensions:
audio_files.extend(input_path.glob(f"*{ext}"))
audio_files.extend(input_path.glob(f"*{ext.upper()}"))
if not audio_files:
logger.warning(f"在目录 {args.input} 中未找到音频文件")
return
logger.info(f"找到 {len(audio_files)} 个音频文件")
results = transcriber.transcribe_multiple_files(
[str(f) for f in audio_files],
output_dir=args.output,
language=args.language,
use_itn=not args.no_itn
)
# 统计结果
success_count = sum(1 for r in results if r is not None)
print(f"\n批量转录完成!")
print(f"总文件数: {len(audio_files)}")
print(f"成功转录: {success_count}")
print(f"失败: {len(audio_files) - success_count}")
print(f"输出目录: {args.output}")
else:
logger.error(f"输入路径无效: {args.input}")
return
except Exception as e:
logger.error(f"程序执行出错: {str(e)}")
return
if __name__ == "__main__":
main()