LAP-DEV commited on
Commit
031b7fd
·
verified ·
1 Parent(s): 904a73a

Upload whisper_base.py

Browse files
Files changed (1) hide show
  1. modules/whisper/whisper_base.py +567 -0
modules/whisper/whisper_base.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import whisper
4
+ import gradio as gr
5
+ import torchaudio
6
+ from abc import ABC, abstractmethod
7
+ from typing import BinaryIO, Union, Tuple, List
8
+ import numpy as np
9
+ from datetime import datetime
10
+ from faster_whisper.vad import VadOptions
11
+ from dataclasses import astuple
12
+
13
+ from modules.uvr.music_separator import MusicSeparator
14
+ from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
15
+ UVR_MODELS_DIR)
16
+ from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
17
+ from modules.utils.youtube_manager import get_ytdata, get_ytaudio
18
+ from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml
19
+ from modules.whisper.whisper_parameter import *
20
+ from modules.diarize.diarizer import Diarizer
21
+ from modules.vad.silero_vad import SileroVAD
22
+
23
+
24
+ class WhisperBase(ABC):
25
+ def __init__(self,
26
+ model_dir: str = WHISPER_MODELS_DIR,
27
+ diarization_model_dir: str = DIARIZATION_MODELS_DIR,
28
+ uvr_model_dir: str = UVR_MODELS_DIR,
29
+ output_dir: str = OUTPUT_DIR,
30
+ ):
31
+ self.model_dir = model_dir
32
+ self.output_dir = output_dir
33
+ os.makedirs(self.output_dir, exist_ok=True)
34
+ os.makedirs(self.model_dir, exist_ok=True)
35
+ self.diarizer = Diarizer(
36
+ model_dir=diarization_model_dir
37
+ )
38
+ self.vad = SileroVAD()
39
+ self.music_separator = MusicSeparator(
40
+ model_dir=uvr_model_dir,
41
+ output_dir=os.path.join(output_dir, "UVR")
42
+ )
43
+
44
+ self.model = None
45
+ self.current_model_size = None
46
+ self.available_models = whisper.available_models()
47
+ self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
48
+ #self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
49
+ self.translatable_models = whisper.available_models()
50
+ self.device = self.get_device()
51
+ self.available_compute_types = ["float16", "float32"]
52
+ self.current_compute_type = "float16" if self.device == "cuda" else "float32"
53
+
54
+ @abstractmethod
55
+ def transcribe(self,
56
+ audio: Union[str, BinaryIO, np.ndarray],
57
+ progress: gr.Progress = gr.Progress(),
58
+ *whisper_params,
59
+ ):
60
+ """Inference whisper model to transcribe"""
61
+ pass
62
+
63
+ @abstractmethod
64
+ def update_model(self,
65
+ model_size: str,
66
+ compute_type: str,
67
+ progress: gr.Progress = gr.Progress()
68
+ ):
69
+ """Initialize whisper model"""
70
+ pass
71
+
72
+ def run(self,
73
+ audio: Union[str, BinaryIO, np.ndarray],
74
+ progress: gr.Progress = gr.Progress(),
75
+ add_timestamp: bool = True,
76
+ *whisper_params,
77
+ ) -> Tuple[List[dict], float]:
78
+ """
79
+ Run transcription with conditional pre-processing and post-processing.
80
+ The VAD will be performed to remove noise from the audio input in pre-processing, if enabled.
81
+ The diarization will be performed in post-processing, if enabled.
82
+
83
+ Parameters
84
+ ----------
85
+ audio: Union[str, BinaryIO, np.ndarray]
86
+ Audio input. This can be file path or binary type.
87
+ progress: gr.Progress
88
+ Indicator to show progress directly in gradio.
89
+ add_timestamp: bool
90
+ Whether to add a timestamp at the end of the filename.
91
+ *whisper_params: tuple
92
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
93
+
94
+ Returns
95
+ ----------
96
+ segments_result: List[dict]
97
+ list of dicts that includes start, end timestamps and transcribed text
98
+ elapsed_time: float
99
+ elapsed time for running
100
+ """
101
+ params = WhisperParameters.as_value(*whisper_params)
102
+
103
+ self.cache_parameters(
104
+ whisper_params=params,
105
+ add_timestamp=add_timestamp
106
+ )
107
+
108
+ if params.lang is None:
109
+ pass
110
+ elif params.lang == "Automatic Detection":
111
+ params.lang = None
112
+ else:
113
+ language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
114
+ params.lang = language_code_dict[params.lang]
115
+
116
+ if params.is_bgm_separate:
117
+ music, audio, _ = self.music_separator.separate(
118
+ audio=audio,
119
+ model_name=params.uvr_model_size,
120
+ device=params.uvr_device,
121
+ segment_size=params.uvr_segment_size,
122
+ save_file=params.uvr_save_file,
123
+ progress=progress
124
+ )
125
+
126
+ if audio.ndim >= 2:
127
+ audio = audio.mean(axis=1)
128
+ if self.music_separator.audio_info is None:
129
+ origin_sample_rate = 16000
130
+ else:
131
+ origin_sample_rate = self.music_separator.audio_info.sample_rate
132
+ audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate)
133
+
134
+ if params.uvr_enable_offload:
135
+ self.music_separator.offload()
136
+
137
+ if params.vad_filter:
138
+ # Explicit value set for float('inf') from gr.Number()
139
+ if params.max_speech_duration_s is None or params.max_speech_duration_s >= 9999:
140
+ params.max_speech_duration_s = float('inf')
141
+
142
+ vad_options = VadOptions(
143
+ threshold=params.threshold,
144
+ min_speech_duration_ms=params.min_speech_duration_ms,
145
+ max_speech_duration_s=params.max_speech_duration_s,
146
+ min_silence_duration_ms=params.min_silence_duration_ms,
147
+ speech_pad_ms=params.speech_pad_ms
148
+ )
149
+
150
+ audio, speech_chunks = self.vad.run(
151
+ audio=audio,
152
+ vad_parameters=vad_options,
153
+ progress=progress
154
+ )
155
+
156
+ result, elapsed_time = self.transcribe(
157
+ audio,
158
+ progress,
159
+ *astuple(params)
160
+ )
161
+
162
+ if params.vad_filter:
163
+ result = self.vad.restore_speech_timestamps(
164
+ segments=result,
165
+ speech_chunks=speech_chunks,
166
+ )
167
+
168
+ if params.is_diarize:
169
+ result, elapsed_time_diarization = self.diarizer.run(
170
+ audio=audio,
171
+ use_auth_token=params.hf_token,
172
+ transcribed_result=result,
173
+ )
174
+ elapsed_time += elapsed_time_diarization
175
+ return result, elapsed_time
176
+
177
+ def transcribe_file(self,
178
+ files: Optional[List] = None,
179
+ input_folder_path: Optional[str] = None,
180
+ file_format: str = "SRT",
181
+ add_timestamp: bool = True,
182
+ progress=gr.Progress(),
183
+ *whisper_params,
184
+ ) -> list:
185
+ """
186
+ Write subtitle file from Files
187
+
188
+ Parameters
189
+ ----------
190
+ files: list
191
+ List of files to transcribe from gr.Files()
192
+ input_folder_path: str
193
+ Input folder path to transcribe from gr.Textbox(). If this is provided, `files` will be ignored and
194
+ this will be used instead.
195
+ file_format: str
196
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
197
+ add_timestamp: bool
198
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
199
+ progress: gr.Progress
200
+ Indicator to show progress directly in gradio.
201
+ *whisper_params: tuple
202
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
203
+
204
+ Returns
205
+ ----------
206
+ result_str:
207
+ Result of transcription to return to gr.Textbox()
208
+ result_file_path:
209
+ Output file path to return to gr.Files()
210
+ """
211
+ try:
212
+ if input_folder_path:
213
+ files = get_media_files(input_folder_path)
214
+ if isinstance(files, str):
215
+ files = [files]
216
+ if files and isinstance(files[0], gr.utils.NamedString):
217
+ files = [file.name for file in files]
218
+
219
+ ## Load model to detect language
220
+ model = whisper.load_model("base")
221
+
222
+ files_info = {}
223
+ files_to_download = {}
224
+ for file in files:
225
+
226
+ ## Detect language
227
+ #params = WhisperParameters.as_value(*whisper_params)
228
+ #model = whisper.load_model(params.model_size)
229
+ mel = whisper.log_mel_spectrogram(whisper.pad_or_trim(whisper.load_audio(file))).to(model.device)
230
+ _, probs = model.detect_language(mel)
231
+ file_language = ""
232
+ for key,value in whisper.tokenizer.LANGUAGES.items():
233
+ if key == str(max(probs, key=probs.get)):
234
+ file_language = value.capitalize()
235
+ break
236
+
237
+ transcribed_segments, time_for_task = self.run(
238
+ file,
239
+ progress,
240
+ add_timestamp,
241
+ *whisper_params,
242
+ )
243
+
244
+ file_name, file_ext = os.path.splitext(os.path.basename(file))
245
+ subtitle, file_path = self.generate_and_write_file(
246
+ file_name=file_name,
247
+ transcribed_segments=transcribed_segments,
248
+ add_timestamp=add_timestamp,
249
+ file_format=file_format,
250
+ output_dir=self.output_dir
251
+ )
252
+ files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path, "lang": file_language, "input": (file_name+file_ext)}
253
+
254
+ ## Add output file as txt
255
+ file_name, file_ext = os.path.splitext(os.path.basename(file))
256
+ subtitle, file_path = self.generate_and_write_file(
257
+ file_name=file_name,
258
+ transcribed_segments=transcribed_segments,
259
+ add_timestamp=add_timestamp,
260
+ file_format="txt",
261
+ output_dir=self.output_dir
262
+ )
263
+ files_to_download[file_name+"_txt"] = {"path": file_path}
264
+
265
+ ## Add output file as srt
266
+ file_name, file_ext = os.path.splitext(os.path.basename(file))
267
+ subtitle, file_path = self.generate_and_write_file(
268
+ file_name=file_name,
269
+ transcribed_segments=transcribed_segments,
270
+ add_timestamp=add_timestamp,
271
+ file_format="srt",
272
+ output_dir=self.output_dir
273
+ )
274
+ files_to_download[file_name+"_srt"] = {"path": file_path}
275
+
276
+ total_result = ''
277
+ total_info = ''
278
+ total_time = 0
279
+ for file_name, info in files_info.items():
280
+ total_result += f'{info["subtitle"]}'
281
+ total_time += info["time_for_task"]
282
+ #total_info += f'{info["lang"]}'
283
+ total_info += f"Input file: {info['input']}\nLanguage prediction: {info['lang']}\n"
284
+
285
+ #result_str = f"Processing of file '{file_name}{file_ext}' done in {self.format_time(total_time)}:\n\n{total_result}"
286
+ total_info += f"\nTranscription duration: {self.format_time(total_time)}"
287
+ result_str = total_result
288
+ result_file_path = [info['path'] for info in files_to_download.values()]
289
+
290
+ return [result_str, result_file_path, total_info]
291
+
292
+ except Exception as e:
293
+ print(f"Error transcribing file: {e}")
294
+ finally:
295
+ self.release_cuda_memory()
296
+
297
+ def transcribe_mic(self,
298
+ mic_audio: str,
299
+ file_format: str = "SRT",
300
+ add_timestamp: bool = True,
301
+ progress=gr.Progress(),
302
+ *whisper_params,
303
+ ) -> list:
304
+ """
305
+ Write subtitle file from microphone
306
+
307
+ Parameters
308
+ ----------
309
+ mic_audio: str
310
+ Audio file path from gr.Microphone()
311
+ file_format: str
312
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
313
+ add_timestamp: bool
314
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
315
+ progress: gr.Progress
316
+ Indicator to show progress directly in gradio.
317
+ *whisper_params: tuple
318
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
319
+
320
+ Returns
321
+ ----------
322
+ result_str:
323
+ Result of transcription to return to gr.Textbox()
324
+ result_file_path:
325
+ Output file path to return to gr.Files()
326
+ """
327
+ try:
328
+ progress(0, desc="Loading Audio..")
329
+ transcribed_segments, time_for_task = self.run(
330
+ mic_audio,
331
+ progress,
332
+ add_timestamp,
333
+ *whisper_params,
334
+ )
335
+ progress(1, desc="Completed!")
336
+
337
+ subtitle, result_file_path = self.generate_and_write_file(
338
+ file_name="Mic",
339
+ transcribed_segments=transcribed_segments,
340
+ add_timestamp=add_timestamp,
341
+ file_format=file_format,
342
+ output_dir=self.output_dir
343
+ )
344
+
345
+ result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
346
+ return [result_str, result_file_path]
347
+ except Exception as e:
348
+ print(f"Error transcribing file: {e}")
349
+ finally:
350
+ self.release_cuda_memory()
351
+
352
+ def transcribe_youtube(self,
353
+ youtube_link: str,
354
+ file_format: str = "SRT",
355
+ add_timestamp: bool = True,
356
+ progress=gr.Progress(),
357
+ *whisper_params,
358
+ ) -> list:
359
+ """
360
+ Write subtitle file from Youtube
361
+
362
+ Parameters
363
+ ----------
364
+ youtube_link: str
365
+ URL of the Youtube video to transcribe from gr.Textbox()
366
+ file_format: str
367
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
368
+ add_timestamp: bool
369
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
370
+ progress: gr.Progress
371
+ Indicator to show progress directly in gradio.
372
+ *whisper_params: tuple
373
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
374
+
375
+ Returns
376
+ ----------
377
+ result_str:
378
+ Result of transcription to return to gr.Textbox()
379
+ result_file_path:
380
+ Output file path to return to gr.Files()
381
+ """
382
+ try:
383
+ progress(0, desc="Loading Audio from Youtube..")
384
+ yt = get_ytdata(youtube_link)
385
+ audio = get_ytaudio(yt)
386
+
387
+ transcribed_segments, time_for_task = self.run(
388
+ audio,
389
+ progress,
390
+ add_timestamp,
391
+ *whisper_params,
392
+ )
393
+
394
+ progress(1, desc="Completed!")
395
+
396
+ file_name = safe_filename(yt.title)
397
+ subtitle, result_file_path = self.generate_and_write_file(
398
+ file_name=file_name,
399
+ transcribed_segments=transcribed_segments,
400
+ add_timestamp=add_timestamp,
401
+ file_format=file_format,
402
+ output_dir=self.output_dir
403
+ )
404
+ result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
405
+
406
+ if os.path.exists(audio):
407
+ os.remove(audio)
408
+
409
+ return [result_str, result_file_path]
410
+
411
+ except Exception as e:
412
+ print(f"Error transcribing file: {e}")
413
+ finally:
414
+ self.release_cuda_memory()
415
+
416
+ @staticmethod
417
+ def generate_and_write_file(file_name: str,
418
+ transcribed_segments: list,
419
+ add_timestamp: bool,
420
+ file_format: str,
421
+ output_dir: str
422
+ ) -> str:
423
+ """
424
+ Writes subtitle file
425
+
426
+ Parameters
427
+ ----------
428
+ file_name: str
429
+ Output file name
430
+ transcribed_segments: list
431
+ Text segments transcribed from audio
432
+ add_timestamp: bool
433
+ Determines whether to add a timestamp to the end of the filename.
434
+ file_format: str
435
+ File format to write. Supported formats: [SRT, WebVTT, txt]
436
+ output_dir: str
437
+ Directory path of the output
438
+
439
+ Returns
440
+ ----------
441
+ content: str
442
+ Result of the transcription
443
+ output_path: str
444
+ output file path
445
+ """
446
+ if add_timestamp:
447
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
448
+ output_path = os.path.join(output_dir, f"{file_name} - {timestamp}")
449
+ else:
450
+ output_path = os.path.join(output_dir, f"{file_name}")
451
+
452
+ file_format = file_format.strip().lower()
453
+ if file_format == "srt":
454
+ content = get_srt(transcribed_segments)
455
+ output_path += '.srt'
456
+
457
+ elif file_format == "webvtt":
458
+ content = get_vtt(transcribed_segments)
459
+ output_path += '.vtt'
460
+
461
+ elif file_format == "txt":
462
+ content = get_txt(transcribed_segments)
463
+ output_path += '.txt'
464
+
465
+ write_file(content, output_path)
466
+ return content, output_path
467
+
468
+ @staticmethod
469
+ def format_time(elapsed_time: float) -> str:
470
+ """
471
+ Get {hours} {minutes} {seconds} time format string
472
+
473
+ Parameters
474
+ ----------
475
+ elapsed_time: str
476
+ Elapsed time for transcription
477
+
478
+ Returns
479
+ ----------
480
+ Time format string
481
+ """
482
+ hours, rem = divmod(elapsed_time, 3600)
483
+ minutes, seconds = divmod(rem, 60)
484
+
485
+ time_str = ""
486
+ if hours:
487
+ time_str += f"{hours} hours "
488
+ if minutes:
489
+ time_str += f"{minutes} minutes "
490
+ seconds = round(seconds)
491
+ time_str += f"{seconds} seconds"
492
+
493
+ return time_str.strip()
494
+
495
+ @staticmethod
496
+ def get_device():
497
+ if torch.cuda.is_available():
498
+ return "cuda"
499
+ elif torch.backends.mps.is_available():
500
+ if not WhisperBase.is_sparse_api_supported():
501
+ # Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886
502
+ return "cpu"
503
+ return "mps"
504
+ else:
505
+ return "cpu"
506
+
507
+ @staticmethod
508
+ def is_sparse_api_supported():
509
+ if not torch.backends.mps.is_available():
510
+ return False
511
+
512
+ try:
513
+ device = torch.device("mps")
514
+ sparse_tensor = torch.sparse_coo_tensor(
515
+ indices=torch.tensor([[0, 1], [2, 3]]),
516
+ values=torch.tensor([1, 2]),
517
+ size=(4, 4),
518
+ device=device
519
+ )
520
+ return True
521
+ except RuntimeError:
522
+ return False
523
+
524
+ @staticmethod
525
+ def release_cuda_memory():
526
+ """Release memory"""
527
+ if torch.cuda.is_available():
528
+ torch.cuda.empty_cache()
529
+ torch.cuda.reset_max_memory_allocated()
530
+
531
+ @staticmethod
532
+ def remove_input_files(file_paths: List[str]):
533
+ """Remove gradio cached files"""
534
+ if not file_paths:
535
+ return
536
+
537
+ for file_path in file_paths:
538
+ if file_path and os.path.exists(file_path):
539
+ os.remove(file_path)
540
+
541
+ @staticmethod
542
+ def cache_parameters(
543
+ whisper_params: WhisperValues,
544
+ add_timestamp: bool
545
+ ):
546
+ """cache parameters to the yaml file"""
547
+ cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
548
+ cached_whisper_param = whisper_params.to_yaml()
549
+ cached_yaml = {**cached_params, **cached_whisper_param}
550
+ cached_yaml["whisper"]["add_timestamp"] = add_timestamp
551
+
552
+ save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
553
+
554
+ @staticmethod
555
+ def resample_audio(audio: Union[str, np.ndarray],
556
+ new_sample_rate: int = 16000,
557
+ original_sample_rate: Optional[int] = None,) -> np.ndarray:
558
+ """Resamples audio to 16k sample rate, standard on Whisper model"""
559
+ if isinstance(audio, str):
560
+ audio, original_sample_rate = torchaudio.load(audio)
561
+ else:
562
+ if original_sample_rate is None:
563
+ raise ValueError("original_sample_rate must be provided when audio is numpy array.")
564
+ audio = torch.from_numpy(audio)
565
+ resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=new_sample_rate)
566
+ resampled_audio = resampler(audio).numpy()
567
+ return resampled_audio