#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 使用Whisper模型进行语音识别 """ import whisper import os import time from pathlib import Path import argparse import json from datetime import datetime import logging # 设置日志 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) class WhisperTranscriber: """使用Whisper模型进行语音转文字""" def __init__(self, model_size="base"): """ 初始化Whisper转录器 Args: model_size: 模型大小 (tiny, base, small, medium, large, large-v2, large-v3) """ self.model_size = model_size self.model = None self.load_model() def load_model(self): """加载Whisper模型""" logger.info(f"正在加载Whisper模型: {self.model_size}") start_time = time.time() try: self.model = whisper.load_model(self.model_size) load_time = time.time() - start_time logger.info(f"模型加载完成,耗时: {load_time:.2f} 秒") except Exception as e: logger.error(f"模型加载失败: {str(e)}") raise def transcribe_audio(self, audio_path, language="zh", task="transcribe"): """ 转录音频文件 Args: audio_path: 音频文件路径 language: 语言代码 (zh=中文, en=英文, auto=自动检测) task: 任务类型 (transcribe=转录, translate=翻译为英文) Returns: dict: 转录结果 """ audio_path = Path(audio_path) if not audio_path.exists(): raise FileNotFoundError(f"音频文件不存在: {audio_path}") logger.info(f"开始转录音频文件: {audio_path}") logger.info(f"语言: {language}, 任务: {task}") start_time = time.time() try: # 设置转录参数 options = { "task": task, "fp16": False, # 使用FP32以提高兼容性 } # 如果指定了语言,添加到选项中 if language != "auto": options["language"] = language # 执行转录 result = self.model.transcribe(str(audio_path), **options) transcribe_time = time.time() - start_time logger.info(f"转录完成,耗时: {transcribe_time:.2f} 秒") # 处理结果 processed_result = self._process_result(result, transcribe_time) return processed_result except Exception as e: logger.error(f"转录失败: {str(e)}") raise def _process_result(self, result, transcribe_time): """处理转录结果""" processed = { "text": result["text"].strip(), "language": result.get("language", "unknown"), "segments": [], "transcribe_time": transcribe_time, "model_size": self.model_size, "timestamp": datetime.now().isoformat() } # 处理分段信息 for segment in result.get("segments", []): processed_segment = { "id": segment.get("id"), "start": segment.get("start"), "end": segment.get("end"), "text": segment.get("text", "").strip(), "confidence": segment.get("avg_logprob", 0) } processed["segments"].append(processed_segment) # 统计信息 processed["stats"] = { "total_segments": len(processed["segments"]), "total_duration": max([s.get("end", 0) for s in processed["segments"]] + [0]), "text_length": len(processed["text"]), "words_count": len(processed["text"].split()) if processed["text"] else 0 } return processed def transcribe_multiple_files(self, audio_files, output_dir="transcripts", **kwargs): """ 批量转录多个音频文件 Args: audio_files: 音频文件路径列表 output_dir: 输出目录 **kwargs: 传递给transcribe_audio的参数 Returns: list: 转录结果列表 """ output_path = Path(output_dir) output_path.mkdir(exist_ok=True) results = [] for i, audio_file in enumerate(audio_files, 1): logger.info(f"处理第 {i}/{len(audio_files)} 个文件") try: result = self.transcribe_audio(audio_file, **kwargs) # 保存单个文件的结果 audio_name = Path(audio_file).stem output_file = output_path / f"{audio_name}_transcript.json" with open(output_file, 'w', encoding='utf-8') as f: json.dump(result, f, ensure_ascii=False, indent=2) logger.info(f"转录结果已保存: {output_file}") results.append(result) except Exception as e: logger.error(f"处理文件 {audio_file} 时出错: {str(e)}") results.append(None) return results def save_transcript(self, result, output_path, format="json"): """ 保存转录结果 Args: result: 转录结果 output_path: 输出文件路径 format: 输出格式 (json, txt, srt) """ output_path = Path(output_path) if format == "json": with open(output_path, 'w', encoding='utf-8') as f: json.dump(result, f, ensure_ascii=False, indent=2) elif format == "txt": with open(output_path, 'w', encoding='utf-8') as f: f.write(result["text"]) elif format == "srt": self._save_srt(result, output_path) else: raise ValueError(f"不支持的格式: {format}") logger.info(f"转录结果已保存: {output_path}") def _save_srt(self, result, output_path): """保存为SRT字幕格式""" with open(output_path, 'w', encoding='utf-8') as f: for i, segment in enumerate(result["segments"], 1): start_time = self._format_time(segment["start"]) end_time = self._format_time(segment["end"]) text = segment["text"].strip() f.write(f"{i}\n") f.write(f"{start_time} --> {end_time}\n") f.write(f"{text}\n\n") def _format_time(self, seconds): """格式化时间为SRT格式""" hours = int(seconds // 3600) minutes = int((seconds % 3600) // 60) seconds = seconds % 60 return f"{hours:02d}:{minutes:02d}:{seconds:06.3f}".replace('.', ',') def main(): """主函数""" parser = argparse.ArgumentParser(description="使用Whisper模型进行语音识别") parser.add_argument("input", help="输入音频文件路径或目录") parser.add_argument("-m", "--model", default="base", choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"], help="Whisper模型大小 (默认: base)") parser.add_argument("-l", "--language", default="zh", help="语言代码 (zh=中文, en=英文, auto=自动检测, 默认: zh)") parser.add_argument("-t", "--task", default="transcribe", choices=["transcribe", "translate"], help="任务类型 (transcribe=转录, translate=翻译为英文, 默认: transcribe)") parser.add_argument("-o", "--output", default="transcripts", help="输出目录 (默认: transcripts)") parser.add_argument("-f", "--format", default="json", choices=["json", "txt", "srt"], help="输出格式 (默认: json)") args = parser.parse_args() # 创建转录器 transcriber = WhisperTranscriber(model_size=args.model) input_path = Path(args.input) try: if input_path.is_file(): # 处理单个文件 logger.info("处理单个音频文件") result = transcriber.transcribe_audio( args.input, language=args.language, task=args.task ) # 保存结果 output_dir = Path(args.output) output_dir.mkdir(exist_ok=True) file_stem = input_path.stem output_file = output_dir / f"{file_stem}_transcript.{args.format}" transcriber.save_transcript(result, output_file, args.format) # 显示结果 print(f"\n转录完成!") print(f"原文件: {args.input}") print(f"输出文件: {output_file}") print(f"识别语言: {result['language']}") print(f"转录时长: {result['transcribe_time']:.2f} 秒") print(f"音频时长: {result['stats']['total_duration']:.2f} 秒") print(f"文字长度: {result['stats']['text_length']} 字符") print(f"单词数量: {result['stats']['words_count']} 个") print(f"\n转录内容:") print("-" * 50) print(result["text"]) elif input_path.is_dir(): # 处理目录中的所有音频文件 logger.info("处理目录中的音频文件") audio_extensions = ['.wav', '.mp3', '.m4a', '.flac', '.aac', '.ogg'] audio_files = [] for ext in audio_extensions: audio_files.extend(input_path.glob(f"*{ext}")) audio_files.extend(input_path.glob(f"*{ext.upper()}")) if not audio_files: logger.warning(f"在目录 {args.input} 中未找到音频文件") return logger.info(f"找到 {len(audio_files)} 个音频文件") results = transcriber.transcribe_multiple_files( [str(f) for f in audio_files], output_dir=args.output, language=args.language, task=args.task ) # 统计结果 success_count = sum(1 for r in results if r is not None) print(f"\n批量转录完成!") print(f"总文件数: {len(audio_files)}") print(f"成功转录: {success_count}") print(f"失败: {len(audio_files) - success_count}") print(f"输出目录: {args.output}") else: logger.error(f"输入路径无效: {args.input}") return except Exception as e: logger.error(f"程序执行出错: {str(e)}") return if __name__ == "__main__": main()