File size: 3,210 Bytes
9fdfe53
501c404
c7bfcf2
9fdfe53
 
 
 
 
 
 
 
 
5f8203e
0e9733d
5f8203e
4c4ecfe
 
 
5f8203e
 
7ff9bba
 
 
 
0e9733d
7ff9bba
9fdfe53
 
 
 
 
c7bfcf2
 
 
 
9fdfe53
 
 
7b06fe8
bd79b3c
 
7b06fe8
 
9fdfe53
a88b526
501c404
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9fdfe53
c7bfcf2
9fdfe53
 
 
 
 
 
 
c7bfcf2
 
9fdfe53
85f37ae
c7bfcf2
85f37ae
 
 
 
 
 
 
c7bfcf2
9fdfe53
c7bfcf2
9fdfe53
 
 
 
 
 
c7bfcf2
 
9fdfe53
 
 
5f8203e
 
 
9fdfe53
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from modules.whisper.whisper_factory import WhisperFactory
from modules.whisper.data_classes import *
from modules.utils.subtitle_manager import read_file
from modules.utils.paths import WEBUI_DIR
from test_config import *

import requests
import pytest
import gradio as gr
import os


@pytest.mark.parametrize(
    "whisper_type,vad_filter,bgm_separation,diarization",
    [
        (WhisperImpl.WHISPER.value, False, False, False),
        (WhisperImpl.FASTER_WHISPER.value, False, False, False),
        (WhisperImpl.INSANELY_FAST_WHISPER.value, False, False, False)
    ]
)
def test_transcribe(
    whisper_type: str,
    vad_filter: bool,
    bgm_separation: bool,
    diarization: bool,
):
    audio_path_dir = os.path.join(WEBUI_DIR, "tests")
    audio_path = os.path.join(audio_path_dir, "jfk.wav")
    if not os.path.exists(audio_path):
        download_file(TEST_FILE_DOWNLOAD_URL, audio_path_dir)

    answer = TEST_ANSWER
    if diarization:
        answer = "SPEAKER_00|"+TEST_ANSWER

    whisper_inferencer = WhisperFactory.create_whisper_inference(
        whisper_type=whisper_type,
    )
    print(
        f"""Whisper Device : {whisper_inferencer.device}\n"""
        f"""BGM Separation Device: {whisper_inferencer.music_separator.device}\n"""
        f"""Diarization Device: {whisper_inferencer.diarizer.device}"""
    )

    hparams = TranscriptionPipelineParams(
        whisper=WhisperParams(
            model_size=TEST_WHISPER_MODEL,
            compute_type=whisper_inferencer.current_compute_type
        ),
        vad=VadParams(
            vad_filter=vad_filter
        ),
        bgm_separation=BGMSeparationParams(
            is_separate_bgm=bgm_separation,
            enable_offload=True
        ),
        diarization=DiarizationParams(
            is_diarize=diarization
        ),
    ).to_list()

    subtitle_str, file_paths = whisper_inferencer.transcribe_file(
        [audio_path],
        None,
        "SRT",
        False,
        gr.Progress(),
        *hparams,
    )
    subtitle = read_file(file_paths[0]).split("\n")
    assert calculate_wer(answer, subtitle[2].strip().replace(",", "").replace(".", "")) < 0.1

    if not is_pytube_detected_bot():
        subtitle_str, file_path = whisper_inferencer.transcribe_youtube(
            TEST_YOUTUBE_URL,
            "SRT",
            False,
            gr.Progress(),
            *hparams,
        )
        assert isinstance(subtitle_str, str) and subtitle_str
        assert os.path.exists(file_path)

    subtitle_str, file_path = whisper_inferencer.transcribe_mic(
        audio_path,
        "SRT",
        False,
        gr.Progress(),
        *hparams,
    )
    subtitle = read_file(file_path).split("\n")
    assert calculate_wer(answer, subtitle[2].strip().replace(",", "").replace(".", "")) < 0.1


def download_file(url, save_dir):
    if os.path.exists(TEST_FILE_PATH):
        return

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    file_name = url.split("/")[-1]
    file_path = os.path.join(save_dir, file_name)

    response = requests.get(url)

    with open(file_path, "wb") as file:
        file.write(response.content)

    print(f"File downloaded to: {file_path}")