Spaces:
Running
Running
File size: 3,210 Bytes
9fdfe53 501c404 c7bfcf2 9fdfe53 5f8203e 0e9733d 5f8203e 4c4ecfe 5f8203e 7ff9bba 0e9733d 7ff9bba 9fdfe53 c7bfcf2 9fdfe53 7b06fe8 bd79b3c 7b06fe8 9fdfe53 a88b526 501c404 9fdfe53 c7bfcf2 9fdfe53 c7bfcf2 9fdfe53 85f37ae c7bfcf2 85f37ae c7bfcf2 9fdfe53 c7bfcf2 9fdfe53 c7bfcf2 9fdfe53 5f8203e 9fdfe53 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
from modules.whisper.whisper_factory import WhisperFactory
from modules.whisper.data_classes import *
from modules.utils.subtitle_manager import read_file
from modules.utils.paths import WEBUI_DIR
from test_config import *
import requests
import pytest
import gradio as gr
import os
@pytest.mark.parametrize(
"whisper_type,vad_filter,bgm_separation,diarization",
[
(WhisperImpl.WHISPER.value, False, False, False),
(WhisperImpl.FASTER_WHISPER.value, False, False, False),
(WhisperImpl.INSANELY_FAST_WHISPER.value, False, False, False)
]
)
def test_transcribe(
whisper_type: str,
vad_filter: bool,
bgm_separation: bool,
diarization: bool,
):
audio_path_dir = os.path.join(WEBUI_DIR, "tests")
audio_path = os.path.join(audio_path_dir, "jfk.wav")
if not os.path.exists(audio_path):
download_file(TEST_FILE_DOWNLOAD_URL, audio_path_dir)
answer = TEST_ANSWER
if diarization:
answer = "SPEAKER_00|"+TEST_ANSWER
whisper_inferencer = WhisperFactory.create_whisper_inference(
whisper_type=whisper_type,
)
print(
f"""Whisper Device : {whisper_inferencer.device}\n"""
f"""BGM Separation Device: {whisper_inferencer.music_separator.device}\n"""
f"""Diarization Device: {whisper_inferencer.diarizer.device}"""
)
hparams = TranscriptionPipelineParams(
whisper=WhisperParams(
model_size=TEST_WHISPER_MODEL,
compute_type=whisper_inferencer.current_compute_type
),
vad=VadParams(
vad_filter=vad_filter
),
bgm_separation=BGMSeparationParams(
is_separate_bgm=bgm_separation,
enable_offload=True
),
diarization=DiarizationParams(
is_diarize=diarization
),
).to_list()
subtitle_str, file_paths = whisper_inferencer.transcribe_file(
[audio_path],
None,
"SRT",
False,
gr.Progress(),
*hparams,
)
subtitle = read_file(file_paths[0]).split("\n")
assert calculate_wer(answer, subtitle[2].strip().replace(",", "").replace(".", "")) < 0.1
if not is_pytube_detected_bot():
subtitle_str, file_path = whisper_inferencer.transcribe_youtube(
TEST_YOUTUBE_URL,
"SRT",
False,
gr.Progress(),
*hparams,
)
assert isinstance(subtitle_str, str) and subtitle_str
assert os.path.exists(file_path)
subtitle_str, file_path = whisper_inferencer.transcribe_mic(
audio_path,
"SRT",
False,
gr.Progress(),
*hparams,
)
subtitle = read_file(file_path).split("\n")
assert calculate_wer(answer, subtitle[2].strip().replace(",", "").replace(".", "")) < 0.1
def download_file(url, save_dir):
if os.path.exists(TEST_FILE_PATH):
return
if not os.path.exists(save_dir):
os.makedirs(save_dir)
file_name = url.split("/")[-1]
file_path = os.path.join(save_dir, file_name)
response = requests.get(url)
with open(file_path, "wb") as file:
file.write(response.content)
print(f"File downloaded to: {file_path}")
|