173 lines
4.7 KiB
Python
173 lines
4.7 KiB
Python
|
#!/usr/bin/env python3
|
|||
|
# -*- coding: utf-8 -*-
|
|||
|
"""
|
|||
|
专业提取工具测试脚本
|
|||
|
演示如何使用SenseVoice和OCR工具
|
|||
|
"""
|
|||
|
|
|||
|
import os
|
|||
|
import time
|
|||
|
from pathlib import Path
|
|||
|
|
|||
|
def test_sensevoice():
|
|||
|
"""测试SenseVoice语音识别"""
|
|||
|
print("🎤 测试SenseVoice语音识别...")
|
|||
|
|
|||
|
# 检查音频文件
|
|||
|
audio_dir = Path("../video2audio")
|
|||
|
audio_files = list(audio_dir.glob("*.wav"))
|
|||
|
|
|||
|
if not audio_files:
|
|||
|
print("❌ 未找到音频文件,请先运行video2audio.py提取音频")
|
|||
|
return
|
|||
|
|
|||
|
# 选择第一个音频文件测试
|
|||
|
test_audio = audio_files[0]
|
|||
|
print(f"📁 测试文件: {test_audio}")
|
|||
|
|
|||
|
# 构建命令
|
|||
|
cmd = f"""python sensevoice_transcribe.py "{test_audio}" \
|
|||
|
--language zh \
|
|||
|
--output sensevoice_test_results \
|
|||
|
--format json"""
|
|||
|
|
|||
|
print(f"🚀 执行命令: {cmd}")
|
|||
|
|
|||
|
try:
|
|||
|
os.system(cmd)
|
|||
|
print("✅ SenseVoice测试完成")
|
|||
|
except Exception as e:
|
|||
|
print(f"❌ SenseVoice测试失败: {e}")
|
|||
|
|
|||
|
def test_ocr_extractor():
|
|||
|
"""测试OCR字幕提取"""
|
|||
|
print("\n📝 测试OCR字幕提取...")
|
|||
|
|
|||
|
# 检查视频文件
|
|||
|
video_dir = Path("../video2audio")
|
|||
|
video_files = list(video_dir.glob("*.mp4"))
|
|||
|
|
|||
|
if not video_files:
|
|||
|
print("❌ 未找到视频文件")
|
|||
|
return
|
|||
|
|
|||
|
# 选择第一个视频文件测试
|
|||
|
test_video = video_files[0]
|
|||
|
print(f"📁 测试文件: {test_video}")
|
|||
|
|
|||
|
# 构建命令
|
|||
|
cmd = f"""python ocr_subtitle_extractor.py "{test_video}" \
|
|||
|
--engine paddleocr \
|
|||
|
--language ch \
|
|||
|
--confidence 0.5 \
|
|||
|
--output ocr_test_results \
|
|||
|
--format json"""
|
|||
|
|
|||
|
print(f"🚀 执行命令: {cmd}")
|
|||
|
|
|||
|
try:
|
|||
|
os.system(cmd)
|
|||
|
print("✅ OCR测试完成")
|
|||
|
except Exception as e:
|
|||
|
print(f"❌ OCR测试失败: {e}")
|
|||
|
|
|||
|
def test_whisper_comparison():
|
|||
|
"""测试Whisper对比"""
|
|||
|
print("\n🎯 测试Whisper对比...")
|
|||
|
|
|||
|
audio_dir = Path("../video2audio")
|
|||
|
audio_files = list(audio_dir.glob("*.wav"))
|
|||
|
|
|||
|
if not audio_files:
|
|||
|
print("❌ 未找到音频文件")
|
|||
|
return
|
|||
|
|
|||
|
test_audio = audio_files[0]
|
|||
|
|
|||
|
# 使用Whisper
|
|||
|
cmd = f"""python whisper_audio_transcribe.py "{audio_dir}" \
|
|||
|
--model base \
|
|||
|
--output whisper_test_results"""
|
|||
|
|
|||
|
print(f"🚀 执行Whisper命令: {cmd}")
|
|||
|
|
|||
|
try:
|
|||
|
os.system(cmd)
|
|||
|
print("✅ Whisper测试完成")
|
|||
|
except Exception as e:
|
|||
|
print(f"❌ Whisper测试失败: {e}")
|
|||
|
|
|||
|
def compare_results():
|
|||
|
"""对比结果"""
|
|||
|
print("\n📊 对比测试结果...")
|
|||
|
|
|||
|
# 检查输出目录
|
|||
|
sensevoice_dir = Path("sensevoice_test_results")
|
|||
|
ocr_dir = Path("ocr_test_results")
|
|||
|
whisper_dir = Path("whisper_test_results")
|
|||
|
|
|||
|
print("📂 输出目录检查:")
|
|||
|
print(f" SenseVoice: {'✅' if sensevoice_dir.exists() else '❌'} {sensevoice_dir}")
|
|||
|
print(f" OCR提取器: {'✅' if ocr_dir.exists() else '❌'} {ocr_dir}")
|
|||
|
print(f" Whisper: {'✅' if whisper_dir.exists() else '❌'} {whisper_dir}")
|
|||
|
|
|||
|
# 显示文件列表
|
|||
|
for name, dir_path in [("SenseVoice", sensevoice_dir), ("OCR", ocr_dir), ("Whisper", whisper_dir)]:
|
|||
|
if dir_path.exists():
|
|||
|
files = list(dir_path.glob("*"))
|
|||
|
print(f"\n{name} 输出文件 ({len(files)}个):")
|
|||
|
for file in files[:3]: # 只显示前3个
|
|||
|
file_size = file.stat().st_size if file.is_file() else 0
|
|||
|
print(f" 📄 {file.name} ({file_size} bytes)")
|
|||
|
|
|||
|
def main():
|
|||
|
"""主函数"""
|
|||
|
print("🔧 专业提取工具测试套件")
|
|||
|
print("=" * 50)
|
|||
|
|
|||
|
start_time = time.time()
|
|||
|
|
|||
|
# 检查依赖
|
|||
|
print("\n📋 检查依赖安装...")
|
|||
|
|
|||
|
dependencies = [
|
|||
|
("funasr", "SenseVoice语音识别"),
|
|||
|
("paddleocr", "PaddleOCR字幕识别"),
|
|||
|
("cv2", "OpenCV图像处理"),
|
|||
|
("whisper", "Whisper语音识别")
|
|||
|
]
|
|||
|
|
|||
|
for module, desc in dependencies:
|
|||
|
try:
|
|||
|
__import__(module)
|
|||
|
print(f" ✅ {desc} - 已安装")
|
|||
|
except ImportError:
|
|||
|
print(f" ❌ {desc} - 未安装")
|
|||
|
|
|||
|
print("\n" + "=" * 50)
|
|||
|
|
|||
|
# 运行测试
|
|||
|
try:
|
|||
|
# 测试语音识别
|
|||
|
test_sensevoice()
|
|||
|
|
|||
|
# 测试OCR提取
|
|||
|
test_ocr_extractor()
|
|||
|
|
|||
|
# 测试Whisper对比
|
|||
|
test_whisper_comparison()
|
|||
|
|
|||
|
# 对比结果
|
|||
|
compare_results()
|
|||
|
|
|||
|
except KeyboardInterrupt:
|
|||
|
print("\n⏹️ 测试被用户中断")
|
|||
|
except Exception as e:
|
|||
|
print(f"\n❌ 测试过程中出错: {e}")
|
|||
|
|
|||
|
total_time = time.time() - start_time
|
|||
|
print(f"\n⏱️ 总测试时间: {total_time:.2f} 秒")
|
|||
|
print("🎉 测试完成!")
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
main()
|