#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 使用SenseVoice模型进行高精度中文语音识别 """ import os import time import json import argparse from pathlib import Path from datetime import datetime import logging # 设置日志 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) try: from funasr import AutoModel from funasr.utils.postprocess_utils import rich_transcription_postprocess except ImportError: logger.error("请先安装FunASR: pip install funasr") raise class SenseVoiceTranscriber: """使用SenseVoice模型进行高精度中文语音识别""" def __init__(self, model_dir="iic/SenseVoiceSmall", device="cuda:0"): """ 初始化SenseVoice转录器 Args: model_dir: 模型路径或名称 device: 运行设备 (cuda:0, cpu) """ self.model_dir = model_dir self.device = device self.model = None self.load_model() def load_model(self): """加载SenseVoice模型""" logger.info(f"正在加载SenseVoice模型: {self.model_dir}") start_time = time.time() try: self.model = AutoModel( model=self.model_dir, trust_remote_code=True, device=self.device ) load_time = time.time() - start_time logger.info(f"SenseVoice模型加载完成,耗时: {load_time:.2f} 秒") except Exception as e: logger.error(f"模型加载失败: {str(e)}") raise def transcribe_audio(self, audio_path, language="auto", use_itn=True): """ 转录音频文件 Args: audio_path: 音频文件路径 language: 语言设置 ("auto", "zh", "en", "yue", "ja", "ko") use_itn: 是否使用逆文本标准化 Returns: dict: 转录结果 """ audio_path = Path(audio_path) if not audio_path.exists(): raise FileNotFoundError(f"音频文件不存在: {audio_path}") logger.info(f"开始转录音频文件: {audio_path}") logger.info(f"语言设置: {language}, 使用ITN: {use_itn}") start_time = time.time() try: # 执行转录 result = self.model.generate( input=str(audio_path), cache={}, language=language, use_itn=use_itn, batch_size_s=60, # 批处理大小(秒) merge_vad=True, # 合并VAD结果 merge_length_s=15 # 合并长度(秒) ) transcribe_time = time.time() - start_time logger.info(f"转录完成,耗时: {transcribe_time:.2f} 秒") # 处理结果 processed_result = self._process_result(result, transcribe_time, audio_path) return processed_result except Exception as e: logger.error(f"转录失败: {str(e)}") raise def _process_result(self, result, transcribe_time, audio_path): """处理转录结果""" processed = { "text": "", "segments": [], "emotions": [], "events": [], "transcribe_time": transcribe_time, "model": "SenseVoice", "timestamp": datetime.now().isoformat(), "file_path": str(audio_path) } # 处理结果列表 if isinstance(result, list) and len(result) > 0: full_text_parts = [] for item in result: if isinstance(item, dict): # 获取文本内容 text = item.get("text", "") if text: full_text_parts.append(text) # 处理时间戳信息 timestamp = item.get("timestamp", []) if timestamp and len(timestamp) >= 2: segment = { "start": timestamp[0] / 1000.0, # 转换为秒 "end": timestamp[1] / 1000.0, "text": text } processed["segments"].append(segment) # 处理情感标签 if "<|HAPPY|>" in text: processed["emotions"].append({"emotion": "happy", "text": text}) elif "<|SAD|>" in text: processed["emotions"].append({"emotion": "sad", "text": text}) elif "<|ANGRY|>" in text: processed["emotions"].append({"emotion": "angry", "text": text}) # 处理事件标签 if "<|SPEECH|>" in text: processed["events"].append({"event": "speech", "text": text}) elif "<|Music|>" in text: processed["events"].append({"event": "music", "text": text}) elif "<|BGM|>" in text: processed["events"].append({"event": "background_music", "text": text}) # 合并完整文本 processed["text"] = " ".join(full_text_parts) # 清理特殊标签 processed["clean_text"] = self._clean_text(processed["text"]) # 统计信息 processed["stats"] = { "total_segments": len(processed["segments"]), "total_duration": max([s.get("end", 0) for s in processed["segments"]] + [0]), "text_length": len(processed["clean_text"]), "words_count": len(processed["clean_text"]) if processed["clean_text"] else 0, "emotions_detected": len(processed["emotions"]), "events_detected": len(processed["events"]) } return processed def _clean_text(self, text): """清理文本中的特殊标签""" import re # 移除SenseVoice的特殊标签 tags_pattern = r'<\|[^|]+\|>' clean_text = re.sub(tags_pattern, '', text) # 移除多余空格 clean_text = ' '.join(clean_text.split()) return clean_text.strip() def transcribe_multiple_files(self, audio_files, output_dir="sensevoice_transcripts", **kwargs): """批量转录多个音频文件""" output_path = Path(output_dir) output_path.mkdir(exist_ok=True) results = [] for i, audio_file in enumerate(audio_files, 1): logger.info(f"处理第 {i}/{len(audio_files)} 个文件") try: result = self.transcribe_audio(audio_file, **kwargs) # 保存单个文件的结果 audio_name = Path(audio_file).stem output_file = output_path / f"{audio_name}_sensevoice.json" with open(output_file, 'w', encoding='utf-8') as f: json.dump(result, f, ensure_ascii=False, indent=2) logger.info(f"转录结果已保存: {output_file}") results.append(result) except Exception as e: logger.error(f"处理文件 {audio_file} 时出错: {str(e)}") results.append(None) return results def save_transcript(self, result, output_path, format="json"): """保存转录结果""" output_path = Path(output_path) if format == "json": with open(output_path, 'w', encoding='utf-8') as f: json.dump(result, f, ensure_ascii=False, indent=2) elif format == "txt": with open(output_path, 'w', encoding='utf-8') as f: f.write(result["clean_text"]) elif format == "srt": self._save_srt(result, output_path) else: raise ValueError(f"不支持的格式: {format}") logger.info(f"转录结果已保存: {output_path}") def _save_srt(self, result, output_path): """保存为SRT字幕格式""" with open(output_path, 'w', encoding='utf-8') as f: for i, segment in enumerate(result["segments"], 1): start_time = self._format_time(segment["start"]) end_time = self._format_time(segment["end"]) text = segment["text"].strip() f.write(f"{i}\n") f.write(f"{start_time} --> {end_time}\n") f.write(f"{text}\n\n") def _format_time(self, seconds): """格式化时间为SRT格式""" hours = int(seconds // 3600) minutes = int((seconds % 3600) // 60) seconds = seconds % 60 return f"{hours:02d}:{minutes:02d}:{seconds:06.3f}".replace('.', ',') def main(): """主函数""" parser = argparse.ArgumentParser(description="使用SenseVoice模型进行高精度中文语音识别") parser.add_argument("input", help="输入音频文件路径或目录") parser.add_argument("-l", "--language", default="auto", choices=["auto", "zh", "en", "yue", "ja", "ko"], help="语言设置 (默认: auto)") parser.add_argument("-o", "--output", default="sensevoice_transcripts", help="输出目录 (默认: sensevoice_transcripts)") parser.add_argument("-f", "--format", default="json", choices=["json", "txt", "srt"], help="输出格式 (默认: json)") parser.add_argument("--device", default="cuda:0", help="运行设备 (默认: cuda:0)") parser.add_argument("--no-itn", action="store_true", help="禁用逆文本标准化") args = parser.parse_args() # 创建转录器 transcriber = SenseVoiceTranscriber(device=args.device) input_path = Path(args.input) try: if input_path.is_file(): # 处理单个文件 logger.info("处理单个音频文件") result = transcriber.transcribe_audio( args.input, language=args.language, use_itn=not args.no_itn ) # 保存结果 output_dir = Path(args.output) output_dir.mkdir(exist_ok=True) file_stem = input_path.stem output_file = output_dir / f"{file_stem}_sensevoice.{args.format}" transcriber.save_transcript(result, output_file, args.format) # 显示结果 print(f"\n转录完成!") print(f"原文件: {args.input}") print(f"输出文件: {output_file}") print(f"转录时长: {result['transcribe_time']:.2f} 秒") print(f"音频时长: {result['stats']['total_duration']:.2f} 秒") print(f"文字长度: {result['stats']['text_length']} 字符") print(f"分段数量: {result['stats']['total_segments']} 个") print(f"情感检测: {result['stats']['emotions_detected']} 个") print(f"事件检测: {result['stats']['events_detected']} 个") print(f"\n转录内容:") print("-" * 50) print(result["clean_text"]) elif input_path.is_dir(): # 处理目录中的所有音频文件 logger.info("处理目录中的音频文件") audio_extensions = ['.wav', '.mp3', '.m4a', '.flac', '.aac', '.ogg'] audio_files = [] for ext in audio_extensions: audio_files.extend(input_path.glob(f"*{ext}")) audio_files.extend(input_path.glob(f"*{ext.upper()}")) if not audio_files: logger.warning(f"在目录 {args.input} 中未找到音频文件") return logger.info(f"找到 {len(audio_files)} 个音频文件") results = transcriber.transcribe_multiple_files( [str(f) for f in audio_files], output_dir=args.output, language=args.language, use_itn=not args.no_itn ) # 统计结果 success_count = sum(1 for r in results if r is not None) print(f"\n批量转录完成!") print(f"总文件数: {len(audio_files)}") print(f"成功转录: {success_count}") print(f"失败: {len(audio_files) - success_count}") print(f"输出目录: {args.output}") else: logger.error(f"输入路径无效: {args.input}") return except Exception as e: logger.error(f"程序执行出错: {str(e)}") return if __name__ == "__main__": main()