#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 使用SenseVoice模型进行高精度中文语音识别 """ import os import time import json import argparse from pathlib import Path from datetime import datetime import logging # 设置日志 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) try: from funasr import AutoModel from funasr.utils.postprocess_utils import rich_transcription_postprocess except ImportError: logger.error("请先安装FunASR: pip install funasr") raise class SenseVoiceTranscriber: """使用SenseVoice模型进行高精度中文语音识别""" def __init__(self, model_dir="/root/autodl-tmp/llm/sensevoice", device="cuda:0"): """ 初始化SenseVoice转录器 Args: model_dir: 模型路径或名称 device: 运行设备 (cuda:0, cpu) """ self.model_dir = model_dir self.device = device self.model = None self.load_model() def load_model(self): """加载SenseVoice模型""" logger.info(f"正在加载SenseVoice模型: {self.model_dir}") start_time = time.time() try: self.model = AutoModel( model=self.model_dir, trust_remote_code=True, device=self.device, disable_update=True ) load_time = time.time() - start_time logger.info(f"SenseVoice模型加载完成,耗时: {load_time:.2f} 秒") except Exception as e: logger.error(f"模型加载失败: {str(e)}") raise def transcribe_audio(self, audio_path, language="auto", use_itn=True): """ 转录音频文件 Args: audio_path: 音频文件路径 language: 语言设置 ("auto", "zh", "en", "yue", "ja", "ko") use_itn: 是否使用逆文本标准化 Returns: dict: 转录结果 """ audio_path = Path(audio_path) if not audio_path.exists(): raise FileNotFoundError(f"音频文件不存在: {audio_path}") logger.info(f"开始转录音频文件: {audio_path}") logger.info(f"语言设置: {language}, 使用ITN: {use_itn}") start_time = time.time() try: # 执行转录 result = self.model.generate( input=str(audio_path), cache={}, language=language, use_itn=use_itn, batch_size_s=60, # 批处理大小(秒) merge_vad=True, # 合并VAD结果 merge_length_s=15 # 合并长度(秒) ) transcribe_time = time.time() - start_time logger.info(f"转录完成,耗时: {transcribe_time:.2f} 秒") # 直接返回原始结果 return { "raw_result": result, "transcribe_time": transcribe_time, "model": "SenseVoice", "timestamp": datetime.now().isoformat(), "file_path": str(audio_path) } except Exception as e: logger.error(f"转录失败: {str(e)}") raise def segment_by_timestamp(self, result): """ 根据时间戳将转录结果分段 Args: result: 原始转录结果 Returns: list: 分段后的结果列表 """ if not isinstance(result, list): return result segmented_results = [] for item in result: if isinstance(item, dict): text = item.get("text", "") timestamp = item.get("timestamp", []) if timestamp and len(timestamp) >= 2: # 有时间戳信息,进行分段 segment = { "start_time": timestamp[0] / 1000.0, # 转换为秒 "end_time": timestamp[1] / 1000.0, "duration": (timestamp[1] - timestamp[0]) / 1000.0, "text": text, "clean_text": self._clean_text(text) } segmented_results.append(segment) else: # 没有时间戳信息,直接添加 segment = { "text": text, "clean_text": self._clean_text(text) } segmented_results.append(segment) return segmented_results def _clean_text(self, text): """清理文本中的特殊标签""" import re # 移除SenseVoice的特殊标签 tags_pattern = r'<\|[^|]+\|>' clean_text = re.sub(tags_pattern, '', text) # 移除多余空格 clean_text = ' '.join(clean_text.split()) return clean_text.strip() def transcribe_multiple_files(self, audio_files, output_dir="sensevoice_kuaishou", **kwargs): """批量转录多个音频文件""" output_path = Path(output_dir) output_path.mkdir(exist_ok=True) results = [] for i, audio_file in enumerate(audio_files, 1): logger.info(f"处理第 {i}/{len(audio_files)} 个文件") try: result = self.transcribe_audio(audio_file, **kwargs) # 保存单个文件的结果 audio_name = Path(audio_file).stem output_file = output_path / f"{audio_name}_sensevoice.json" with open(output_file, 'w', encoding='utf-8') as f: json.dump(result, f, ensure_ascii=False, indent=2) logger.info(f"转录结果已保存: {output_file}") results.append(result) except Exception as e: logger.error(f"处理文件 {audio_file} 时出错: {str(e)}") results.append(None) return results def save_transcript(self, result, output_path, format="json"): """保存转录结果""" output_path = Path(output_path) if format == "json": with open(output_path, 'w', encoding='utf-8') as f: json.dump(result, f, ensure_ascii=False, indent=2) elif format == "txt": # 提取原始文本并清理 raw_result = result.get("raw_result", []) if isinstance(raw_result, list): texts = [] for item in raw_result: if isinstance(item, dict): text = item.get("text", "") if text: clean_text = self._clean_text(text) texts.append(clean_text) full_text = " ".join(texts) else: full_text = str(raw_result) with open(output_path, 'w', encoding='utf-8') as f: f.write(full_text) elif format == "srt": # 使用时间戳分段生成SRT raw_result = result.get("raw_result", []) segmented = self.segment_by_timestamp(raw_result) self._save_srt_from_segments(segmented, output_path) else: raise ValueError(f"不支持的格式: {format}") logger.info(f"转录结果已保存: {output_path}") def _save_srt_from_segments(self, segments, output_path): """从分段结果保存为SRT字幕格式""" with open(output_path, 'w', encoding='utf-8') as f: for i, segment in enumerate(segments, 1): if "start_time" in segment and "end_time" in segment: start_time = self._format_time(segment["start_time"]) end_time = self._format_time(segment["end_time"]) text = segment.get("clean_text", segment.get("text", "")).strip() f.write(f"{i}\n") f.write(f"{start_time} --> {end_time}\n") f.write(f"{text}\n\n") def _format_time(self, seconds): """格式化时间为SRT格式""" hours = int(seconds // 3600) minutes = int((seconds % 3600) // 60) seconds = seconds % 60 return f"{hours:02d}:{minutes:02d}:{seconds:06.3f}".replace('.', ',') def main(): """主函数""" parser = argparse.ArgumentParser(description="使用SenseVoice模型进行高精度中文语音识别") parser.add_argument("input", help="输入音频文件路径或目录") parser.add_argument("-l", "--language", default="auto", choices=["auto", "zh", "en", "yue", "ja", "ko"], help="语言设置 (默认: auto)") parser.add_argument("-o", "--output", default="/root/autodl-tmp/new_sensevoice", help="输出目录 (默认: /root/autodl-tmp/sensevoice_kuaishou)") parser.add_argument("-f", "--format", default="json", choices=["json", "txt", "srt"], help="输出格式 (默认: json)") parser.add_argument("--device", default="cuda:0", help="运行设备 (默认: cuda:0)") parser.add_argument("--no-itn", action="store_true", help="禁用逆文本标准化") args = parser.parse_args() # 创建转录器 transcriber = SenseVoiceTranscriber(device=args.device) input_path = Path(args.input) try: if input_path.is_file(): # 处理单个文件 logger.info("处理单个音频文件") result = transcriber.transcribe_audio( args.input, language=args.language, use_itn=not args.no_itn ) # 保存结果 output_dir = Path(args.output) output_dir.mkdir(exist_ok=True) file_stem = input_path.stem output_file = output_dir / f"{file_stem}_sensevoice.{args.format}" transcriber.save_transcript(result, output_file, args.format) # 显示结果 print(f"\n转录完成!") print(f"原文件: {args.input}") print(f"输出文件: {output_file}") print(f"转录时长: {result['transcribe_time']:.2f} 秒") # 显示原始结果 raw_result = result.get("raw_result", []) print(f"\n原始转录结果:") print("-" * 50) print(raw_result) # 显示分段结果 if isinstance(raw_result, list): segmented = transcriber.segment_by_timestamp(raw_result) print(f"\n分段结果 (共 {len(segmented)} 段):") print("-" * 50) for i, segment in enumerate(segmented, 1): if "start_time" in segment: print(f"段 {i}: [{segment['start_time']:.2f}s - {segment['end_time']:.2f}s]") else: print(f"段 {i}: [无时间戳]") print(f"原文: {segment.get('text', '')}") print(f"清理: {segment.get('clean_text', '')}") print("-" * 30) # 显示完整清理文本 if isinstance(raw_result, list): texts = [] for item in raw_result: if isinstance(item, dict): text = item.get("text", "") if text: clean_text = transcriber._clean_text(text) texts.append(clean_text) full_text = " ".join(texts) print(f"\n完整清理文本:") print("-" * 50) print(full_text) elif input_path.is_dir(): # 处理目录中的所有音频文件 logger.info("处理目录中的音频文件") audio_extensions = ['.wav', '.mp3', '.m4a', '.flac', '.aac', '.ogg'] audio_files = [] for ext in audio_extensions: audio_files.extend(input_path.glob(f"*{ext}")) audio_files.extend(input_path.glob(f"*{ext.upper()}")) if not audio_files: logger.warning(f"在目录 {args.input} 中未找到音频文件") return logger.info(f"找到 {len(audio_files)} 个音频文件") results = transcriber.transcribe_multiple_files( [str(f) for f in audio_files], output_dir=args.output, language=args.language, use_itn=not args.no_itn ) # 统计结果 success_count = sum(1 for r in results if r is not None) print(f"\n批量转录完成!") print(f"总文件数: {len(audio_files)}") print(f"成功转录: {success_count}") print(f"失败: {len(audio_files) - success_count}") print(f"输出目录: {args.output}") else: logger.error(f"输入路径无效: {args.input}") return except Exception as e: logger.error(f"程序执行出错: {str(e)}") return if __name__ == "__main__": main()