File size: 15,735 Bytes
31f7bdb
8d120bf
95261ed
8d120bf
05a2178
3fadc6e
8d120bf
3fadc6e
31f7bdb
95261ed
31f7bdb
3fadc6e
8d120bf
 
 
 
05a2178
 
883c794
c52f09b
95261ed
7ce6041
 
4514e2e
7ce6041
533d92e
 
 
6a308c6
 
 
93c4867
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05a2178
 
883c794
31f7bdb
 
95261ed
31f7bdb
 
05a2178
f288ceb
31f7bdb
 
05a2178
48d8572
533d92e
883c794
fdd892b
 
 
 
 
31f7bdb
533d92e
74b1efd
48d8572
74b1efd
 
fdd892b
74b1efd
fdd892b
74b1efd
71950a8
fdd892b
71950a8
fdd892b
 
883c794
fdd892b
 
 
 
 
3fadc6e
95261ed
48d8572
84fa1f8
 
 
 
 
 
74b1efd
95261ed
74b1efd
 
 
d906b98
5bbbb16
95261ed
74b1efd
d906b98
5bbbb16
95261ed
d906b98
 
5bbbb16
95261ed
74b1efd
 
 
5bbbb16
95261ed
 
 
74b1efd
31f7bdb
 
 
 
 
 
 
 
 
74b1efd
 
 
95261ed
31f7bdb
95261ed
 
 
31f7bdb
 
 
 
 
 
 
 
 
 
 
95261ed
84fa1f8
 
 
 
 
 
 
 
5bbbb16
d906b98
 
 
 
5bbbb16
d906b98
 
5bbbb16
d906b98
5bbbb16
d906b98
74b1efd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
883c794
31f7bdb
74b1efd
883c794
 
fdd892b
 
74b1efd
fdd892b
 
 
3fadc6e
8f5637c
 
 
 
 
 
 
fdd892b
 
3fadc6e
fdd892b
8d120bf
883c794
 
 
 
 
 
 
 
 
 
 
6a308c6
883c794
 
 
 
 
 
3fadc6e
883c794
 
3fadc6e
883c794
 
 
 
7ce6041
883c794
05a2178
31f7bdb
 
 
 
 
 
05a2178
31f7bdb
 
 
05a2178
95261ed
 
 
71950a8
 
 
05a2178
f5884f3
38cc8a7
31f7bdb
 
93c4867
883c794
084aa80
74b1efd
31f7bdb
71950a8
8d120bf
71950a8
 
 
31f7bdb
d906b98
48d8572
 
d906b98
3fadc6e
8d120bf
3fadc6e
8d120bf
3fadc6e
7ce6041
31f7bdb
 
 
 
05a2178
71950a8
95261ed
31f7bdb
95261ed
 
 
31f7bdb
 
 
 
95261ed
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
import math
from typing import Iterator
import argparse

from io import StringIO
import os
import pathlib
import tempfile
from src.vadParallel import ParallelContext, ParallelTranscription

from src.whisperContainer import WhisperContainer, WhisperModelCache

# External programs
import ffmpeg

# UI
import gradio as gr

from src.download import ExceededMaximumDuration, download_url
from src.utils import slugify, write_srt, write_vtt
from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription

# Limitations (set to -1 to disable)
DEFAULT_INPUT_AUDIO_MAX_DURATION = 600 # seconds

# Whether or not to automatically delete all uploaded files, to save disk space
DELETE_UPLOADED_FILES = True

# Gradio seems to truncate files without keeping the extension, so we need to truncate the file prefix ourself 
MAX_FILE_PREFIX_LENGTH = 17

LANGUAGES = [ 
 "English", "Chinese", "German", "Spanish", "Russian", "Korean", 
 "French", "Japanese", "Portuguese", "Turkish", "Polish", "Catalan", 
 "Dutch", "Arabic", "Swedish", "Italian", "Indonesian", "Hindi", 
 "Finnish", "Vietnamese", "Hebrew", "Ukrainian", "Greek", "Malay", 
 "Czech", "Romanian", "Danish", "Hungarian", "Tamil", "Norwegian", 
 "Thai", "Urdu", "Croatian", "Bulgarian", "Lithuanian", "Latin", 
 "Maori", "Malayalam", "Welsh", "Slovak", "Telugu", "Persian", 
 "Latvian", "Bengali", "Serbian", "Azerbaijani", "Slovenian", 
 "Kannada", "Estonian", "Macedonian", "Breton", "Basque", "Icelandic", 
 "Armenian", "Nepali", "Mongolian", "Bosnian", "Kazakh", "Albanian",
 "Swahili", "Galician", "Marathi", "Punjabi", "Sinhala", "Khmer", 
 "Shona", "Yoruba", "Somali", "Afrikaans", "Occitan", "Georgian", 
 "Belarusian", "Tajik", "Sindhi", "Gujarati", "Amharic", "Yiddish", 
 "Lao", "Uzbek", "Faroese", "Haitian Creole", "Pashto", "Turkmen", 
 "Nynorsk", "Maltese", "Sanskrit", "Luxembourgish", "Myanmar", "Tibetan",
 "Tagalog", "Malagasy", "Assamese", "Tatar", "Hawaiian", "Lingala", 
 "Hausa", "Bashkir", "Javanese", "Sundanese"
]

class WhisperTranscriber:
    def __init__(self, input_audio_max_duration: float = DEFAULT_INPUT_AUDIO_MAX_DURATION, vad_process_timeout: float = None, delete_uploaded_files: bool = DELETE_UPLOADED_FILES):
        self.model_cache = WhisperModelCache()
        self.parallel_device_list = None
        self.parallel_context = None
        self.vad_process_timeout = vad_process_timeout

        self.vad_model = None
        self.inputAudioMaxDuration = input_audio_max_duration
        self.deleteUploadedFiles = delete_uploaded_files

    def transcribe_webui(self, modelName, languageName, urlData, uploadFile, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow):
        try:
            source, sourceName = self.__get_source(urlData, uploadFile, microphoneData)
            
            try:
                selectedLanguage = languageName.lower() if len(languageName) > 0 else None
                selectedModel = modelName if modelName is not None else "base"

                model = WhisperContainer(model_name=selectedModel, cache=self.model_cache)

                # Execute whisper
                result = self.transcribe_file(model, source, selectedLanguage, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)

                # Write result
                downloadDirectory = tempfile.mkdtemp()
                
                filePrefix = slugify(sourceName, allow_unicode=True)
                download, text, vtt = self.write_result(result, filePrefix, downloadDirectory)

                return download, text, vtt

            finally:
                # Cleanup source
                if self.deleteUploadedFiles:
                    print("Deleting source file " + source)
                    os.remove(source)
        
        except ExceededMaximumDuration as e:
            return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"

    def transcribe_file(self, model: WhisperContainer, audio_path: str, language: str, task: str = None, vad: str = None, 
                        vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1, **decodeOptions: dict):
        
        initial_prompt = decodeOptions.pop('initial_prompt', None)

        if ('task' in decodeOptions):
            task = decodeOptions.pop('task')

        # Callable for processing an audio file
        whisperCallable = model.create_callback(language, task, initial_prompt, **decodeOptions)

        # The results
        if (vad == 'silero-vad'):
            # Silero VAD where non-speech gaps are transcribed
            process_gaps = self._create_silero_config(NonSpeechStrategy.CREATE_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
            result = self.process_vad(audio_path, whisperCallable, self.vad_model, process_gaps)
        elif (vad == 'silero-vad-skip-gaps'):
            # Silero VAD where non-speech gaps are simply ignored
            skip_gaps = self._create_silero_config(NonSpeechStrategy.SKIP, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
            result = self.process_vad(audio_path, whisperCallable, self.vad_model, skip_gaps)
        elif (vad == 'silero-vad-expand-into-gaps'):
            # Use Silero VAD where speech-segments are expanded into non-speech gaps
            expand_gaps = self._create_silero_config(NonSpeechStrategy.EXPAND_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
            result = self.process_vad(audio_path, whisperCallable, self.vad_model, expand_gaps)
        elif (vad == 'periodic-vad'):
            # Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
            # it may create a break in the middle of a sentence, causing some artifacts.
            periodic_vad = VadPeriodicTranscription()
            period_config = PeriodicTranscriptionConfig(periodic_duration=vadMaxMergeSize, max_prompt_window=vadPromptWindow)
            result = self.process_vad(audio_path, whisperCallable, periodic_vad, period_config)

        else:
            if (self._has_parallel_devices()):
                # Use a simple period transcription instead, as we need to use the parallel context
                periodic_vad = VadPeriodicTranscription()
                period_config = PeriodicTranscriptionConfig(periodic_duration=math.inf, max_prompt_window=1)

                result = self.process_vad(audio_path, whisperCallable, periodic_vad, period_config)
            else:
                # Default VAD
                result = whisperCallable(audio_path, 0, None, None)

        return result

    def process_vad(self, audio_path, whisperCallable, vadModel: AbstractTranscription, vadConfig: TranscriptionConfig):
        if (not self._has_parallel_devices()):
            # No parallel devices, so just run the VAD and Whisper in sequence
            return vadModel.transcribe(audio_path, whisperCallable, vadConfig)

        # Create parallel context if needed
        if (self.parallel_context is None):
            # Create a context wih processes and automatically clear the pool after 1 hour of inactivity
            self.parallel_context = ParallelContext(num_processes=len(self.parallel_device_list), auto_cleanup_timeout_seconds=self.vad_process_timeout)

        parallel_vad = ParallelTranscription()
        return parallel_vad.transcribe_parallel(transcription=vadModel, audio=audio_path, whisperCallable=whisperCallable,  
                                                config=vadConfig, devices=self.parallel_device_list, parallel_context=self.parallel_context) 

    def _has_parallel_devices(self):
        return self.parallel_device_list is not None and len(self.parallel_device_list) > 0

    def _concat_prompt(self, prompt1, prompt2):
        if (prompt1 is None):
            return prompt2
        elif (prompt2 is None):
            return prompt1
        else:
            return prompt1 + " " + prompt2

    def _create_silero_config(self, non_speech_strategy: NonSpeechStrategy, vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1):
        # Use Silero VAD 
        if (self.vad_model is None):
            self.vad_model = VadSileroTranscription()

        config = TranscriptionConfig(non_speech_strategy = non_speech_strategy, 
                max_silent_period=vadMergeWindow, max_merge_size=vadMaxMergeSize, 
                segment_padding_left=vadPadding, segment_padding_right=vadPadding, 
                max_prompt_window=vadPromptWindow)

        return config

    def write_result(self, result: dict, source_name: str, output_dir: str):
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        text = result["text"]
        language = result["language"]
        languageMaxLineWidth = self.__get_max_line_width(language)

        print("Max line width " + str(languageMaxLineWidth))
        vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth)
        srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth)

        output_files = []
        output_files.append(self.__create_file(srt, output_dir, source_name + "-subs.srt"));
        output_files.append(self.__create_file(vtt, output_dir, source_name + "-subs.vtt"));
        output_files.append(self.__create_file(text, output_dir, source_name + "-transcript.txt"));

        return output_files, text, vtt

    def clear_cache(self):
        self.model_cache.clear()
        self.vad_model = None

    def __get_source(self, urlData, uploadFile, microphoneData):
        if urlData:
            # Download from YouTube
            source = download_url(urlData, self.inputAudioMaxDuration)[0]
        else:
            # File input
            source = uploadFile if uploadFile is not None else microphoneData

            if self.inputAudioMaxDuration > 0:
                # Calculate audio length
                audioDuration = ffmpeg.probe(source)["format"]["duration"]
            
                if float(audioDuration) > self.inputAudioMaxDuration:
                    raise ExceededMaximumDuration(videoDuration=audioDuration, maxDuration=self.inputAudioMaxDuration, message="Video is too long")

        file_path = pathlib.Path(source)
        sourceName = file_path.stem[:MAX_FILE_PREFIX_LENGTH] + file_path.suffix

        return source, sourceName

    def __get_max_line_width(self, language: str) -> int:
        if (language and language.lower() in ["japanese", "ja", "chinese", "zh"]):
            # Chinese characters and kana are wider, so limit line length to 40 characters
            return 40
        else:
            # TODO: Add more languages
            # 80 latin characters should fit on a 1080p/720p screen
            return 80

    def __get_subs(self, segments: Iterator[dict], format: str, maxLineWidth: int) -> str:
        segmentStream = StringIO()

        if format == 'vtt':
            write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
        elif format == 'srt':
            write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
        else:
            raise Exception("Unknown format " + format)

        segmentStream.seek(0)
        return segmentStream.read()

    def __create_file(self, text: str, directory: str, fileName: str) -> str:
        # Write the text to a file
        with open(os.path.join(directory, fileName), 'w+', encoding="utf-8") as file:
            file.write(text)

        return file.name

    def close(self):
        self.clear_cache()

        if (self.parallel_context is not None):
            self.parallel_context.close()


def create_ui(input_audio_max_duration, share=False, server_name: str = None, server_port: int = 7860, 
              default_model_name: str = "medium", default_vad: str = None, vad_parallel_devices: str = None, vad_process_timeout: float = None):
    ui = WhisperTranscriber(input_audio_max_duration, vad_process_timeout)

    # Specify a list of devices to use for parallel processing
    ui.parallel_device_list = [ device.strip() for device in vad_parallel_devices.split(",") ] if vad_parallel_devices else None

    ui_description = "Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse " 
    ui_description += " audio and is also a multi-task model that can perform multilingual speech recognition "
    ui_description += " as well as speech translation and language identification. "

    ui_description += "\n\n\n\nFor longer audio files (>10 minutes) not in English, it is recommended that you select Silero VAD (Voice Activity Detector) in the VAD option."

    if input_audio_max_duration > 0:
        ui_description += "\n\n" + "Max audio file length: " + str(input_audio_max_duration) + " s"

    ui_article = "Read the [documentation here](https://huggingface.co/spaces/aadnk/whisper-webui/blob/main/docs/options.md)"

    demo = gr.Interface(fn=ui.transcribe_webui, description=ui_description, article=ui_article, inputs=[
        gr.Dropdown(choices=["tiny", "base", "small", "medium", "large"], value=default_model_name, label="Model"),
        gr.Dropdown(choices=sorted(LANGUAGES), label="Language"),
        gr.Text(label="URL (YouTube, etc.)"),
        gr.Audio(source="upload", type="filepath", label="Upload Audio"), 
        gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
        gr.Dropdown(choices=["transcribe", "translate"], label="Task"),
        gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], value=default_vad, label="VAD"),
        gr.Number(label="VAD - Merge Window (s)", precision=0, value=5),
        gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=30),
        gr.Number(label="VAD - Padding (s)", precision=None, value=1),
        gr.Number(label="VAD - Prompt Window (s)", precision=None, value=3)
    ], outputs=[
        gr.File(label="Download"),
        gr.Text(label="Transcription"), 
        gr.Text(label="Segments")
    ])

    demo.launch(share=share, server_name=server_name, server_port=server_port)
    
    # Clean up
    ui.close()

if __name__ == '__main__':
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--input_audio_max_duration", type=int, default=600, help="Maximum audio file length in seconds, or -1 for no limit.")
    parser.add_argument("--share", type=bool, default=False, help="True to share the app on HuggingFace.")
    parser.add_argument("--server_name", type=str, default=None, help="The host or IP to bind to. If None, bind to localhost.")
    parser.add_argument("--server_port", type=int, default=7860, help="The port to bind to.")
    parser.add_argument("--default_model_name", type=str, default="medium", help="The default model name.")
    parser.add_argument("--default_vad", type=str, default="silero-vad", help="The default VAD.")
    parser.add_argument("--vad_parallel_devices", type=str, default="", help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.")
    parser.add_argument("--vad_process_timeout", type=float, default="1800", help="The number of seconds before inactivate processes are terminated. Use 0 to close processes immediately, or None for no timeout.")

    args = parser.parse_args().__dict__
    create_ui(**args)