File size: 10,306 Bytes
8d120bf
 
05a2178
3fadc6e
8d120bf
3fadc6e
 
8d120bf
 
 
 
 
05a2178
 
883c794
c52f09b
 
7ce6041
 
4514e2e
7ce6041
533d92e
 
 
6a308c6
 
 
93c4867
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05a2178
 
883c794
 
 
05a2178
f288ceb
71950a8
883c794
05a2178
883c794
533d92e
883c794
fdd892b
 
 
 
 
883c794
533d92e
fdd892b
 
883c794
533d92e
bbbf06e
 
 
fdd892b
f288ceb
bc0cb58
f288ceb
084aa80
 
 
7f502b4
 
084aa80
bc0cb58
 
 
084aa80
bc0cb58
084aa80
7f502b4
 
bc0cb58
bbbf06e
 
 
084aa80
bbbf06e
f288ceb
 
bbbf06e
fdd892b
 
71950a8
fdd892b
883c794
533d92e
fdd892b
883c794
 
6a308c6
fdd892b
 
 
6a308c6
fdd892b
883c794
 
 
71950a8
fdd892b
71950a8
fdd892b
 
883c794
fdd892b
 
 
 
 
3fadc6e
883c794
 
 
 
fdd892b
 
883c794
fdd892b
 
 
3fadc6e
8f5637c
 
 
 
 
 
 
fdd892b
 
3fadc6e
fdd892b
8d120bf
883c794
 
 
 
 
 
 
 
 
 
 
6a308c6
883c794
 
 
 
 
 
3fadc6e
883c794
 
3fadc6e
883c794
 
 
 
7ce6041
883c794
05a2178
 
883c794
 
05a2178
71950a8
 
 
05a2178
084aa80
38cc8a7
71950a8
 
93c4867
883c794
084aa80
883c794
71950a8
 
8d120bf
71950a8
 
 
bc0cb58
b1d4eff
7f502b4
 
3fadc6e
8d120bf
3fadc6e
8d120bf
3fadc6e
7ce6041
d5154e9
05a2178
71950a8
883c794
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
from typing import Iterator

from io import StringIO
import os
import pathlib
import tempfile

# External programs
import whisper
import ffmpeg

# UI
import gradio as gr

from src.download import ExceededMaximumDuration, download_url
from src.utils import slugify, write_srt, write_vtt
from src.vad import VadPeriodicTranscription, VadSileroTranscription

# Limitations (set to -1 to disable)
DEFAULT_INPUT_AUDIO_MAX_DURATION = 600 # seconds

# Whether or not to automatically delete all uploaded files, to save disk space
DELETE_UPLOADED_FILES = True

# Gradio seems to truncate files without keeping the extension, so we need to truncate the file prefix ourself 
MAX_FILE_PREFIX_LENGTH = 17

LANGUAGES = [ 
 "English", "Chinese", "German", "Spanish", "Russian", "Korean", 
 "French", "Japanese", "Portuguese", "Turkish", "Polish", "Catalan", 
 "Dutch", "Arabic", "Swedish", "Italian", "Indonesian", "Hindi", 
 "Finnish", "Vietnamese", "Hebrew", "Ukrainian", "Greek", "Malay", 
 "Czech", "Romanian", "Danish", "Hungarian", "Tamil", "Norwegian", 
 "Thai", "Urdu", "Croatian", "Bulgarian", "Lithuanian", "Latin", 
 "Maori", "Malayalam", "Welsh", "Slovak", "Telugu", "Persian", 
 "Latvian", "Bengali", "Serbian", "Azerbaijani", "Slovenian", 
 "Kannada", "Estonian", "Macedonian", "Breton", "Basque", "Icelandic", 
 "Armenian", "Nepali", "Mongolian", "Bosnian", "Kazakh", "Albanian",
 "Swahili", "Galician", "Marathi", "Punjabi", "Sinhala", "Khmer", 
 "Shona", "Yoruba", "Somali", "Afrikaans", "Occitan", "Georgian", 
 "Belarusian", "Tajik", "Sindhi", "Gujarati", "Amharic", "Yiddish", 
 "Lao", "Uzbek", "Faroese", "Haitian Creole", "Pashto", "Turkmen", 
 "Nynorsk", "Maltese", "Sanskrit", "Luxembourgish", "Myanmar", "Tibetan",
 "Tagalog", "Malagasy", "Assamese", "Tatar", "Hawaiian", "Lingala", 
 "Hausa", "Bashkir", "Javanese", "Sundanese"
]

class WhisperTranscriber:
    def __init__(self, inputAudioMaxDuration: float = DEFAULT_INPUT_AUDIO_MAX_DURATION, deleteUploadedFiles: bool = DELETE_UPLOADED_FILES):
        self.model_cache = dict()

        self.vad_model = None
        self.inputAudioMaxDuration = inputAudioMaxDuration
        self.deleteUploadedFiles = deleteUploadedFiles

    def transcribe_file(self, modelName, languageName, urlData, uploadFile, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding):
        try:
            source, sourceName = self.__get_source(urlData, uploadFile, microphoneData)
            
            try:
                selectedLanguage = languageName.lower() if len(languageName) > 0 else None
                selectedModel = modelName if modelName is not None else "base"

                model = self.model_cache.get(selectedModel, None)
                
                if not model:
                    model = whisper.load_model(selectedModel)
                    self.model_cache[selectedModel] = model

                # Callable for processing an audio file
                whisperCallable = lambda audio : model.transcribe(audio, language=selectedLanguage, task=task)

                # The results
                if (vad == 'silero-vad'):
                    # Use Silero VAD and include gaps
                    if (self.vad_model is None):
                        self.vad_model = VadSileroTranscription()

                    process_gaps = VadSileroTranscription(transcribe_non_speech = True, 
                                    max_silent_period=vadMergeWindow, max_merge_size=vadMaxMergeSize, 
                                    segment_padding_left=vadPadding, segment_padding_right=vadPadding, copy=self.vad_model)
                    result = process_gaps.transcribe(source, whisperCallable)
                elif (vad == 'silero-vad-skip-gaps'):
                    # Use Silero VAD 
                    if (self.vad_model is None):
                        self.vad_model = VadSileroTranscription()
                        
                    skip_gaps = VadSileroTranscription(transcribe_non_speech = False, 
                                    max_silent_period=vadMergeWindow, max_merge_size=vadMaxMergeSize, 
                                    segment_padding_left=vadPadding, segment_padding_right=vadPadding, copy=self.vad_model)
                    result = skip_gaps.transcribe(source, whisperCallable)
                elif (vad == 'periodic-vad'):
                    # Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
                    # it may create a break in the middle of a sentence, causing some artifacts.
                    periodic_vad = VadPeriodicTranscription(periodic_duration=vadMaxMergeSize)
                    result = periodic_vad.transcribe(source, whisperCallable)
                else:
                    # Default VAD
                    result = whisperCallable(source)

                text = result["text"]

                language = result["language"]
                languageMaxLineWidth = self.__get_max_line_width(language)

                print("Max line width " + str(languageMaxLineWidth))
                vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth)
                srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth)

                # Files that can be downloaded
                downloadDirectory = tempfile.mkdtemp()
                filePrefix = slugify(sourceName, allow_unicode=True)

                download = []
                download.append(self.__create_file(srt, downloadDirectory, filePrefix + "-subs.srt"));
                download.append(self.__create_file(vtt, downloadDirectory, filePrefix + "-subs.vtt"));
                download.append(self.__create_file(text, downloadDirectory, filePrefix + "-transcript.txt"));

                return download, text, vtt

            finally:
                # Cleanup source
                if self.deleteUploadedFiles:
                    print("Deleting source file " + source)
                    os.remove(source)
        
        except ExceededMaximumDuration as e:
            return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"

    def clear_cache(self):
        self.model_cache = dict()

    def __get_source(self, urlData, uploadFile, microphoneData):
        if urlData:
            # Download from YouTube
            source = download_url(urlData, self.inputAudioMaxDuration)
        else:
            # File input
            source = uploadFile if uploadFile is not None else microphoneData

            if self.inputAudioMaxDuration > 0:
                # Calculate audio length
                audioDuration = ffmpeg.probe(source)["format"]["duration"]
            
                if float(audioDuration) > self.inputAudioMaxDuration:
                    raise ExceededMaximumDuration(videoDuration=audioDuration, maxDuration=self.inputAudioMaxDuration, message="Video is too long")

        file_path = pathlib.Path(source)
        sourceName = file_path.stem[:MAX_FILE_PREFIX_LENGTH] + file_path.suffix

        return source, sourceName

    def __get_max_line_width(self, language: str) -> int:
        if (language and language.lower() in ["japanese", "ja", "chinese", "zh"]):
            # Chinese characters and kana are wider, so limit line length to 40 characters
            return 40
        else:
            # TODO: Add more languages
            # 80 latin characters should fit on a 1080p/720p screen
            return 80

    def __get_subs(self, segments: Iterator[dict], format: str, maxLineWidth: int) -> str:
        segmentStream = StringIO()

        if format == 'vtt':
            write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
        elif format == 'srt':
            write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
        else:
            raise Exception("Unknown format " + format)

        segmentStream.seek(0)
        return segmentStream.read()

    def __create_file(self, text: str, directory: str, fileName: str) -> str:
        # Write the text to a file
        with open(os.path.join(directory, fileName), 'w+', encoding="utf-8") as file:
            file.write(text)

        return file.name


def create_ui(inputAudioMaxDuration, share=False, server_name: str = None):
    ui = WhisperTranscriber(inputAudioMaxDuration)

    ui_description = "Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse " 
    ui_description += " audio and is also a multi-task model that can perform multilingual speech recognition "
    ui_description += " as well as speech translation and language identification. "

    ui_description += "\n\n\n\nFor longer audio files (>10 minutes), it is recommended that you select Silero VAD (Voice Activity Detector) in the VAD option."

    if inputAudioMaxDuration > 0:
        ui_description += "\n\n" + "Max audio file length: " + str(inputAudioMaxDuration) + " s"

    ui_article = "Read the [documentation here](https://huggingface.co/spaces/aadnk/whisper-webui/blob/main/docs/options.md)"

    demo = gr.Interface(fn=ui.transcribe_file, description=ui_description, article=ui_article, inputs=[
        gr.Dropdown(choices=["tiny", "base", "small", "medium", "large"], value="medium", label="Model"),
        gr.Dropdown(choices=sorted(LANGUAGES), label="Language"),
        gr.Text(label="URL (YouTube, etc.)"),
        gr.Audio(source="upload", type="filepath", label="Upload Audio"), 
        gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
        gr.Dropdown(choices=["transcribe", "translate"], label="Task"),
        gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "periodic-vad"], label="VAD"),
        gr.Number(label="VAD - Merge Window (s)", precision=0, value=5),
        gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=150),
        gr.Number(label="VAD - Padding (s)", precision=None, value=1)
    ], outputs=[
        gr.File(label="Download"),
        gr.Text(label="Transcription"), 
        gr.Text(label="Segments")
    ])

    demo.launch(share=share, server_name=server_name)   

if __name__ == '__main__':
    create_ui(DEFAULT_INPUT_AUDIO_MAX_DURATION)