hot_video_analyse/code/sensevoice_transcribe.py

370 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
使用SenseVoice模型进行高精度中文语音识别
"""
import os
import time
import json
import argparse
from pathlib import Path
from datetime import datetime
import logging
# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
try:
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess
except ImportError:
logger.error("请先安装FunASR: pip install funasr")
raise
class SenseVoiceTranscriber:
"""使用SenseVoice模型进行高精度中文语音识别"""
def __init__(self, model_dir="/root/autodl-tmp/llm/sensevoice", device="cuda:0"):
"""
初始化SenseVoice转录器
Args:
model_dir: 模型路径或名称
device: 运行设备 (cuda:0, cpu)
"""
self.model_dir = model_dir
self.device = device
self.model = None
self.load_model()
def load_model(self):
"""加载SenseVoice模型"""
logger.info(f"正在加载SenseVoice模型: {self.model_dir}")
start_time = time.time()
try:
self.model = AutoModel(
model=self.model_dir,
trust_remote_code=True,
device=self.device,
disable_update=True
)
load_time = time.time() - start_time
logger.info(f"SenseVoice模型加载完成耗时: {load_time:.2f}")
except Exception as e:
logger.error(f"模型加载失败: {str(e)}")
raise
def transcribe_audio(self, audio_path, language="auto", use_itn=True):
"""
转录音频文件
Args:
audio_path: 音频文件路径
language: 语言设置 ("auto", "zh", "en", "yue", "ja", "ko")
use_itn: 是否使用逆文本标准化
Returns:
dict: 转录结果
"""
audio_path = Path(audio_path)
if not audio_path.exists():
raise FileNotFoundError(f"音频文件不存在: {audio_path}")
logger.info(f"开始转录音频文件: {audio_path}")
logger.info(f"语言设置: {language}, 使用ITN: {use_itn}")
start_time = time.time()
try:
# 执行转录
result = self.model.generate(
input=str(audio_path),
cache={},
language=language,
use_itn=use_itn,
batch_size_s=60, # 批处理大小(秒)
merge_vad=True, # 合并VAD结果
merge_length_s=15 # 合并长度(秒)
)
transcribe_time = time.time() - start_time
logger.info(f"转录完成,耗时: {transcribe_time:.2f}")
# 直接返回原始结果
return {
"raw_result": result,
"transcribe_time": transcribe_time,
"model": "SenseVoice",
"timestamp": datetime.now().isoformat(),
"file_path": str(audio_path)
}
except Exception as e:
logger.error(f"转录失败: {str(e)}")
raise
def segment_by_timestamp(self, result):
"""
根据时间戳将转录结果分段
Args:
result: 原始转录结果
Returns:
list: 分段后的结果列表
"""
if not isinstance(result, list):
return result
segmented_results = []
for item in result:
if isinstance(item, dict):
text = item.get("text", "")
timestamp = item.get("timestamp", [])
if timestamp and len(timestamp) >= 2:
# 有时间戳信息,进行分段
segment = {
"start_time": timestamp[0] / 1000.0, # 转换为秒
"end_time": timestamp[1] / 1000.0,
"duration": (timestamp[1] - timestamp[0]) / 1000.0,
"text": text,
"clean_text": self._clean_text(text)
}
segmented_results.append(segment)
else:
# 没有时间戳信息,直接添加
segment = {
"text": text,
"clean_text": self._clean_text(text)
}
segmented_results.append(segment)
return segmented_results
def _clean_text(self, text):
"""清理文本中的特殊标签"""
import re
# 移除SenseVoice的特殊标签
tags_pattern = r'<\|[^|]+\|>'
clean_text = re.sub(tags_pattern, '', text)
# 移除多余空格
clean_text = ' '.join(clean_text.split())
return clean_text.strip()
def transcribe_multiple_files(self, audio_files, output_dir="sensevoice_kuaishou", **kwargs):
"""批量转录多个音频文件"""
output_path = Path(output_dir)
output_path.mkdir(exist_ok=True)
results = []
for i, audio_file in enumerate(audio_files, 1):
logger.info(f"处理第 {i}/{len(audio_files)} 个文件")
try:
result = self.transcribe_audio(audio_file, **kwargs)
# 保存单个文件的结果
audio_name = Path(audio_file).stem
output_file = output_path / f"{audio_name}_sensevoice.json"
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(result, f, ensure_ascii=False, indent=2)
logger.info(f"转录结果已保存: {output_file}")
results.append(result)
except Exception as e:
logger.error(f"处理文件 {audio_file} 时出错: {str(e)}")
results.append(None)
return results
def save_transcript(self, result, output_path, format="json"):
"""保存转录结果"""
output_path = Path(output_path)
if format == "json":
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(result, f, ensure_ascii=False, indent=2)
elif format == "txt":
# 提取原始文本并清理
raw_result = result.get("raw_result", [])
if isinstance(raw_result, list):
texts = []
for item in raw_result:
if isinstance(item, dict):
text = item.get("text", "")
if text:
clean_text = self._clean_text(text)
texts.append(clean_text)
full_text = " ".join(texts)
else:
full_text = str(raw_result)
with open(output_path, 'w', encoding='utf-8') as f:
f.write(full_text)
elif format == "srt":
# 使用时间戳分段生成SRT
raw_result = result.get("raw_result", [])
segmented = self.segment_by_timestamp(raw_result)
self._save_srt_from_segments(segmented, output_path)
else:
raise ValueError(f"不支持的格式: {format}")
logger.info(f"转录结果已保存: {output_path}")
def _save_srt_from_segments(self, segments, output_path):
"""从分段结果保存为SRT字幕格式"""
with open(output_path, 'w', encoding='utf-8') as f:
for i, segment in enumerate(segments, 1):
if "start_time" in segment and "end_time" in segment:
start_time = self._format_time(segment["start_time"])
end_time = self._format_time(segment["end_time"])
text = segment.get("clean_text", segment.get("text", "")).strip()
f.write(f"{i}\n")
f.write(f"{start_time} --> {end_time}\n")
f.write(f"{text}\n\n")
def _format_time(self, seconds):
"""格式化时间为SRT格式"""
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
seconds = seconds % 60
return f"{hours:02d}:{minutes:02d}:{seconds:06.3f}".replace('.', ',')
def main():
"""主函数"""
parser = argparse.ArgumentParser(description="使用SenseVoice模型进行高精度中文语音识别")
parser.add_argument("input", help="输入音频文件路径或目录")
parser.add_argument("-l", "--language", default="auto",
choices=["auto", "zh", "en", "yue", "ja", "ko"],
help="语言设置 (默认: auto)")
parser.add_argument("-o", "--output", default="/root/autodl-tmp/new_sensevoice",
help="输出目录 (默认: /root/autodl-tmp/sensevoice_kuaishou)")
parser.add_argument("-f", "--format", default="json",
choices=["json", "txt", "srt"],
help="输出格式 (默认: json)")
parser.add_argument("--device", default="cuda:0",
help="运行设备 (默认: cuda:0)")
parser.add_argument("--no-itn", action="store_true",
help="禁用逆文本标准化")
args = parser.parse_args()
# 创建转录器
transcriber = SenseVoiceTranscriber(device=args.device)
input_path = Path(args.input)
try:
if input_path.is_file():
# 处理单个文件
logger.info("处理单个音频文件")
result = transcriber.transcribe_audio(
args.input,
language=args.language,
use_itn=not args.no_itn
)
# 保存结果
output_dir = Path(args.output)
output_dir.mkdir(exist_ok=True)
file_stem = input_path.stem
output_file = output_dir / f"{file_stem}_sensevoice.{args.format}"
transcriber.save_transcript(result, output_file, args.format)
# 显示结果
print(f"\n转录完成!")
print(f"原文件: {args.input}")
print(f"输出文件: {output_file}")
print(f"转录时长: {result['transcribe_time']:.2f}")
# 显示原始结果
raw_result = result.get("raw_result", [])
print(f"\n原始转录结果:")
print("-" * 50)
print(raw_result)
# 显示分段结果
if isinstance(raw_result, list):
segmented = transcriber.segment_by_timestamp(raw_result)
print(f"\n分段结果 (共 {len(segmented)} 段):")
print("-" * 50)
for i, segment in enumerate(segmented, 1):
if "start_time" in segment:
print(f"{i}: [{segment['start_time']:.2f}s - {segment['end_time']:.2f}s]")
else:
print(f"{i}: [无时间戳]")
print(f"原文: {segment.get('text', '')}")
print(f"清理: {segment.get('clean_text', '')}")
print("-" * 30)
# 显示完整清理文本
if isinstance(raw_result, list):
texts = []
for item in raw_result:
if isinstance(item, dict):
text = item.get("text", "")
if text:
clean_text = transcriber._clean_text(text)
texts.append(clean_text)
full_text = " ".join(texts)
print(f"\n完整清理文本:")
print("-" * 50)
print(full_text)
elif input_path.is_dir():
# 处理目录中的所有音频文件
logger.info("处理目录中的音频文件")
audio_extensions = ['.wav', '.mp3', '.m4a', '.flac', '.aac', '.ogg']
audio_files = []
for ext in audio_extensions:
audio_files.extend(input_path.glob(f"*{ext}"))
audio_files.extend(input_path.glob(f"*{ext.upper()}"))
if not audio_files:
logger.warning(f"在目录 {args.input} 中未找到音频文件")
return
logger.info(f"找到 {len(audio_files)} 个音频文件")
results = transcriber.transcribe_multiple_files(
[str(f) for f in audio_files],
output_dir=args.output,
language=args.language,
use_itn=not args.no_itn
)
# 统计结果
success_count = sum(1 for r in results if r is not None)
print(f"\n批量转录完成!")
print(f"总文件数: {len(audio_files)}")
print(f"成功转录: {success_count}")
print(f"失败: {len(audio_files) - success_count}")
print(f"输出目录: {args.output}")
else:
logger.error(f"输入路径无效: {args.input}")
return
except Exception as e:
logger.error(f"程序执行出错: {str(e)}")
return
if __name__ == "__main__":
main()