aadnk commited on
Commit
c0e541b
1 Parent(s): 60f71a4

Support parallel execution of Silero VAD

Browse files
Files changed (7) hide show
  1. app.py +28 -13
  2. cli.py +3 -1
  3. src/modelCache.py +17 -0
  4. src/vad.py +60 -28
  5. src/vadParallel.py +93 -23
  6. src/whisperContainer.py +12 -26
  7. tests/vad_test.py +2 -2
app.py CHANGED
@@ -6,10 +6,9 @@ from io import StringIO
6
  import os
7
  import pathlib
8
  import tempfile
 
9
  from src.vadParallel import ParallelContext, ParallelTranscription
10
 
11
- from src.whisperContainer import WhisperContainer, WhisperModelCache
12
-
13
  # External programs
14
  import ffmpeg
15
 
@@ -19,6 +18,7 @@ import gradio as gr
19
  from src.download import ExceededMaximumDuration, download_url
20
  from src.utils import slugify, write_srt, write_vtt
21
  from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
 
22
 
23
  # Limitations (set to -1 to disable)
24
  DEFAULT_INPUT_AUDIO_MAX_DURATION = 600 # seconds
@@ -50,11 +50,13 @@ LANGUAGES = [
50
  ]
51
 
52
  class WhisperTranscriber:
53
- def __init__(self, input_audio_max_duration: float = DEFAULT_INPUT_AUDIO_MAX_DURATION, vad_process_timeout: float = None, delete_uploaded_files: bool = DELETE_UPLOADED_FILES):
54
- self.model_cache = WhisperModelCache()
55
  self.parallel_device_list = None
56
- self.parallel_context = None
 
57
  self.vad_process_timeout = vad_process_timeout
 
58
 
59
  self.vad_model = None
60
  self.inputAudioMaxDuration = input_audio_max_duration
@@ -142,17 +144,27 @@ class WhisperTranscriber:
142
  # No parallel devices, so just run the VAD and Whisper in sequence
143
  return vadModel.transcribe(audio_path, whisperCallable, vadConfig)
144
 
 
 
 
 
 
 
145
  # Create parallel context if needed
146
- if (self.parallel_context is None):
147
  # Create a context wih processes and automatically clear the pool after 1 hour of inactivity
148
- self.parallel_context = ParallelContext(num_processes=len(self.parallel_device_list), auto_cleanup_timeout_seconds=self.vad_process_timeout)
 
 
 
149
 
150
  parallel_vad = ParallelTranscription()
151
  return parallel_vad.transcribe_parallel(transcription=vadModel, audio=audio_path, whisperCallable=whisperCallable,
152
- config=vadConfig, devices=self.parallel_device_list, parallel_context=self.parallel_context)
 
153
 
154
  def _has_parallel_devices(self):
155
- return self.parallel_device_list is not None and len(self.parallel_device_list) > 0
156
 
157
  def _concat_prompt(self, prompt1, prompt2):
158
  if (prompt1 is None):
@@ -249,13 +261,15 @@ class WhisperTranscriber:
249
  def close(self):
250
  self.clear_cache()
251
 
252
- if (self.parallel_context is not None):
253
- self.parallel_context.close()
 
 
254
 
255
 
256
  def create_ui(input_audio_max_duration, share=False, server_name: str = None, server_port: int = 7860,
257
- default_model_name: str = "medium", default_vad: str = None, vad_parallel_devices: str = None, vad_process_timeout: float = None):
258
- ui = WhisperTranscriber(input_audio_max_duration, vad_process_timeout)
259
 
260
  # Specify a list of devices to use for parallel processing
261
  ui.set_parallel_devices(vad_parallel_devices)
@@ -303,6 +317,7 @@ if __name__ == '__main__':
303
  parser.add_argument("--default_model_name", type=str, default="medium", help="The default model name.")
304
  parser.add_argument("--default_vad", type=str, default="silero-vad", help="The default VAD.")
305
  parser.add_argument("--vad_parallel_devices", type=str, default="", help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.")
 
306
  parser.add_argument("--vad_process_timeout", type=float, default="1800", help="The number of seconds before inactivate processes are terminated. Use 0 to close processes immediately, or None for no timeout.")
307
 
308
  args = parser.parse_args().__dict__
6
  import os
7
  import pathlib
8
  import tempfile
9
+ from src.modelCache import ModelCache
10
  from src.vadParallel import ParallelContext, ParallelTranscription
11
 
 
 
12
  # External programs
13
  import ffmpeg
14
 
18
  from src.download import ExceededMaximumDuration, download_url
19
  from src.utils import slugify, write_srt, write_vtt
20
  from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
21
+ from src.whisperContainer import WhisperContainer
22
 
23
  # Limitations (set to -1 to disable)
24
  DEFAULT_INPUT_AUDIO_MAX_DURATION = 600 # seconds
50
  ]
51
 
52
  class WhisperTranscriber:
53
+ def __init__(self, input_audio_max_duration: float = DEFAULT_INPUT_AUDIO_MAX_DURATION, vad_process_timeout: float = None, vad_cpu_cores: int = 1, delete_uploaded_files: bool = DELETE_UPLOADED_FILES):
54
+ self.model_cache = ModelCache()
55
  self.parallel_device_list = None
56
+ self.gpu_parallel_context = None
57
+ self.cpu_parallel_context = None
58
  self.vad_process_timeout = vad_process_timeout
59
+ self.vad_cpu_cores = vad_cpu_cores
60
 
61
  self.vad_model = None
62
  self.inputAudioMaxDuration = input_audio_max_duration
144
  # No parallel devices, so just run the VAD and Whisper in sequence
145
  return vadModel.transcribe(audio_path, whisperCallable, vadConfig)
146
 
147
+ gpu_devices = self.parallel_device_list
148
+
149
+ if (gpu_devices is None or len(gpu_devices) == 0):
150
+ # No GPU devices specified, pass the current environment variable to the first GPU process. This may be NULL.
151
+ gpu_devices = [os.environ.get("CUDA_VISIBLE_DEVICES", None)]
152
+
153
  # Create parallel context if needed
154
+ if (self.gpu_parallel_context is None):
155
  # Create a context wih processes and automatically clear the pool after 1 hour of inactivity
156
+ self.gpu_parallel_context = ParallelContext(num_processes=len(gpu_devices), auto_cleanup_timeout_seconds=self.vad_process_timeout)
157
+ # We also need a CPU context for the VAD
158
+ if (self.cpu_parallel_context is None):
159
+ self.cpu_parallel_context = ParallelContext(num_processes=self.vad_cpu_cores, auto_cleanup_timeout_seconds=self.vad_process_timeout)
160
 
161
  parallel_vad = ParallelTranscription()
162
  return parallel_vad.transcribe_parallel(transcription=vadModel, audio=audio_path, whisperCallable=whisperCallable,
163
+ config=vadConfig, cpu_device_count=self.vad_cpu_cores, gpu_devices=gpu_devices,
164
+ cpu_parallel_context=self.cpu_parallel_context, gpu_parallel_context=self.gpu_parallel_context)
165
 
166
  def _has_parallel_devices(self):
167
+ return (self.parallel_device_list is not None and len(self.parallel_device_list) > 0) or self.vad_cpu_cores > 1
168
 
169
  def _concat_prompt(self, prompt1, prompt2):
170
  if (prompt1 is None):
261
  def close(self):
262
  self.clear_cache()
263
 
264
+ if (self.gpu_parallel_context is not None):
265
+ self.gpu_parallel_context.close()
266
+ if (self.cpu_parallel_context is not None):
267
+ self.cpu_parallel_context.close()
268
 
269
 
270
  def create_ui(input_audio_max_duration, share=False, server_name: str = None, server_port: int = 7860,
271
+ default_model_name: str = "medium", default_vad: str = None, vad_parallel_devices: str = None, vad_process_timeout: float = None, vad_cpu_cores: int = 1):
272
+ ui = WhisperTranscriber(input_audio_max_duration, vad_process_timeout, vad_cpu_cores)
273
 
274
  # Specify a list of devices to use for parallel processing
275
  ui.set_parallel_devices(vad_parallel_devices)
317
  parser.add_argument("--default_model_name", type=str, default="medium", help="The default model name.")
318
  parser.add_argument("--default_vad", type=str, default="silero-vad", help="The default VAD.")
319
  parser.add_argument("--vad_parallel_devices", type=str, default="", help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.")
320
+ parser.add_argument("--vad_cpu_cores", type=int, default=1, help="The number of CPU cores to use for VAD pre-processing.")
321
  parser.add_argument("--vad_process_timeout", type=float, default="1800", help="The number of seconds before inactivate processes are terminated. Use 0 to close processes immediately, or None for no timeout.")
322
 
323
  args = parser.parse_args().__dict__
cli.py CHANGED
@@ -32,6 +32,7 @@ def cli():
32
  parser.add_argument("--vad_max_merge_size", type=optional_float, default=30, help="The maximum size (in seconds) of a voice segment")
33
  parser.add_argument("--vad_padding", type=optional_float, default=1, help="The padding (in seconds) to add to each voice segment")
34
  parser.add_argument("--vad_prompt_window", type=optional_float, default=3, help="The window size of the prompt to pass to Whisper")
 
35
  parser.add_argument("--vad_parallel_devices", type=str, default="", help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.")
36
 
37
  parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
@@ -73,8 +74,9 @@ def cli():
73
  vad_max_merge_size = args.pop("vad_max_merge_size")
74
  vad_padding = args.pop("vad_padding")
75
  vad_prompt_window = args.pop("vad_prompt_window")
 
76
 
77
- model = WhisperContainer(model_name, device=device, download_root=model_dir)
78
  transcriber = WhisperTranscriber(delete_uploaded_files=False)
79
  transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
80
 
32
  parser.add_argument("--vad_max_merge_size", type=optional_float, default=30, help="The maximum size (in seconds) of a voice segment")
33
  parser.add_argument("--vad_padding", type=optional_float, default=1, help="The padding (in seconds) to add to each voice segment")
34
  parser.add_argument("--vad_prompt_window", type=optional_float, default=3, help="The window size of the prompt to pass to Whisper")
35
+ parser.add_argument("--vad_cpu_cores", type=int, default=1, help="The number of CPU cores to use for VAD pre-processing.")
36
  parser.add_argument("--vad_parallel_devices", type=str, default="", help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.")
37
 
38
  parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
74
  vad_max_merge_size = args.pop("vad_max_merge_size")
75
  vad_padding = args.pop("vad_padding")
76
  vad_prompt_window = args.pop("vad_prompt_window")
77
+ vad_cpu_cores = args.pop("vad_cpu_cores")
78
 
79
+ model = WhisperContainer(model_name, device=device, download_root=model_dir, vad_cpu_cores=vad_cpu_cores)
80
  transcriber = WhisperTranscriber(delete_uploaded_files=False)
81
  transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
82
 
src/modelCache.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class ModelCache:
2
+ def __init__(self):
3
+ self._cache = dict()
4
+
5
+ def get(self, model_key: str, model_factory):
6
+ result = self._cache.get(model_key)
7
+
8
+ if result is None:
9
+ result = model_factory()
10
+ self._cache[model_key] = result
11
+ return result
12
+
13
+ def clear(self):
14
+ self._cache.clear()
15
+
16
+ # A global cache of models. This is mainly used by the daemon processes to avoid loading the same model multiple times.
17
+ GLOBAL_MODEL_CACHE = ModelCache()
src/vad.py CHANGED
@@ -1,9 +1,11 @@
1
  from abc import ABC, abstractmethod
2
  from collections import Counter, deque
 
3
 
4
  from typing import Any, Deque, Iterator, List, Dict
5
 
6
  from pprint import pprint
 
7
 
8
  from src.segments import merge_timestamps
9
  from src.whisperContainer import WhisperCallback
@@ -76,7 +78,7 @@ class AbstractTranscription(ABC):
76
  return load_audio(str, self.sampling_rate, start_time, duration)
77
 
78
  @abstractmethod
79
- def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig):
80
  """
81
  Get the start and end timestamps of the sections that should be transcribed by this VAD method.
82
 
@@ -93,10 +95,10 @@ class AbstractTranscription(ABC):
93
  """
94
  return
95
 
96
- def get_merged_timestamps(self, audio: str, config: TranscriptionConfig):
97
  """
98
  Get the start and end timestamps of the sections that should be transcribed by this VAD method,
99
- after merging the segments using the specified configuration.
100
 
101
  Parameters
102
  ----------
@@ -109,21 +111,17 @@ class AbstractTranscription(ABC):
109
  -------
110
  A list of start and end timestamps, in fractional seconds.
111
  """
112
- seconds_timestamps = self.get_transcribe_timestamps(audio, config)
113
-
114
- merged = merge_timestamps(seconds_timestamps, config.max_silent_period, config.max_merge_size,
115
  config.segment_padding_left, config.segment_padding_right)
116
 
117
  if config.non_speech_strategy != NonSpeechStrategy.SKIP:
118
- max_audio_duration = get_audio_duration(audio)
119
-
120
  # Expand segments to include the gaps between them
121
  if (config.non_speech_strategy == NonSpeechStrategy.CREATE_SEGMENT):
122
  # When we have a prompt window, we create speech segments betwen each segment if we exceed the merge size
123
- merged = self.fill_gaps(merged, total_duration=max_audio_duration, max_expand_size=config.max_merge_size)
124
  elif config.non_speech_strategy == NonSpeechStrategy.EXPAND_SEGMENT:
125
  # With no prompt window, it is better to just expand the segments (this effectively passes the prompt to the next segment)
126
- merged = self.expand_gaps(merged, total_duration=max_audio_duration)
127
  else:
128
  raise Exception("Unknown non-speech strategy: " + str(config.non_speech_strategy))
129
 
@@ -147,8 +145,11 @@ class AbstractTranscription(ABC):
147
  A list of start and end timestamps, in fractional seconds.
148
  """
149
 
 
 
 
150
  # Get speech timestamps from full audio file
151
- merged = self.get_merged_timestamps(audio, config)
152
 
153
  # A deque of transcribed segments that is passed to the next segment as a prompt
154
  prompt_window = deque()
@@ -392,22 +393,41 @@ class AbstractTranscription(ABC):
392
 
393
 
394
  class VadSileroTranscription(AbstractTranscription):
395
- def __init__(self, sampling_rate: int = 16000):
396
  super().__init__(sampling_rate=sampling_rate)
397
-
398
- self.model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
399
- (self.get_speech_timestamps, _, _, _, _) = utils
400
-
401
-
402
- def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig):
403
- audio_duration = get_audio_duration(audio)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
  result = []
405
 
 
 
 
406
  # Divide procesisng of audio into chunks
407
- chunk_start = 0.0
408
 
409
- while (chunk_start < audio_duration):
410
- chunk_duration = min(audio_duration - chunk_start, VAD_MAX_PROCESSING_CHUNK)
411
 
412
  print("Processing VAD in chunk from {} to {}".format(format_timestamp(chunk_start), format_timestamp(chunk_start + chunk_duration)))
413
  wav = self.get_audio_segment(audio, str(chunk_start), str(chunk_duration))
@@ -421,23 +441,35 @@ class VadSileroTranscription(AbstractTranscription):
421
  result.extend(adjusted)
422
  chunk_start += chunk_duration
423
 
 
 
 
424
  return result
425
 
 
 
 
 
 
 
 
 
 
 
 
426
  # A very simple VAD that just marks every N seconds as speech
427
  class VadPeriodicTranscription(AbstractTranscription):
428
  def __init__(self, sampling_rate: int = 16000):
429
  super().__init__(sampling_rate=sampling_rate)
430
 
431
- def get_transcribe_timestamps(self, audio: str, config: PeriodicTranscriptionConfig):
432
- # Get duration in seconds
433
- audio_duration = get_audio_duration(audio)
434
  result = []
435
 
436
  # Generate a timestamp every N seconds
437
- start_timestamp = 0
438
 
439
- while (start_timestamp < audio_duration):
440
- end_timestamp = min(start_timestamp + config.periodic_duration, audio_duration)
441
  segment_duration = end_timestamp - start_timestamp
442
 
443
  # Minimum duration is 1 second
1
  from abc import ABC, abstractmethod
2
  from collections import Counter, deque
3
+ import time
4
 
5
  from typing import Any, Deque, Iterator, List, Dict
6
 
7
  from pprint import pprint
8
+ from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
9
 
10
  from src.segments import merge_timestamps
11
  from src.whisperContainer import WhisperCallback
78
  return load_audio(str, self.sampling_rate, start_time, duration)
79
 
80
  @abstractmethod
81
+ def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig, start_time: float, end_time: float):
82
  """
83
  Get the start and end timestamps of the sections that should be transcribed by this VAD method.
84
 
95
  """
96
  return
97
 
98
+ def get_merged_timestamps(self, timestamps: List[Dict[str, Any]], config: TranscriptionConfig, total_duration: float):
99
  """
100
  Get the start and end timestamps of the sections that should be transcribed by this VAD method,
101
+ after merging the given segments using the specified configuration.
102
 
103
  Parameters
104
  ----------
111
  -------
112
  A list of start and end timestamps, in fractional seconds.
113
  """
114
+ merged = merge_timestamps(timestamps, config.max_silent_period, config.max_merge_size,
 
 
115
  config.segment_padding_left, config.segment_padding_right)
116
 
117
  if config.non_speech_strategy != NonSpeechStrategy.SKIP:
 
 
118
  # Expand segments to include the gaps between them
119
  if (config.non_speech_strategy == NonSpeechStrategy.CREATE_SEGMENT):
120
  # When we have a prompt window, we create speech segments betwen each segment if we exceed the merge size
121
+ merged = self.fill_gaps(merged, total_duration=total_duration, max_expand_size=config.max_merge_size)
122
  elif config.non_speech_strategy == NonSpeechStrategy.EXPAND_SEGMENT:
123
  # With no prompt window, it is better to just expand the segments (this effectively passes the prompt to the next segment)
124
+ merged = self.expand_gaps(merged, total_duration=total_duration)
125
  else:
126
  raise Exception("Unknown non-speech strategy: " + str(config.non_speech_strategy))
127
 
145
  A list of start and end timestamps, in fractional seconds.
146
  """
147
 
148
+ max_audio_duration = get_audio_duration(audio)
149
+ timestamp_segments = self.get_transcribe_timestamps(audio, config, 0, max_audio_duration)
150
+
151
  # Get speech timestamps from full audio file
152
+ merged = self.get_merged_timestamps(timestamp_segments, config, max_audio_duration)
153
 
154
  # A deque of transcribed segments that is passed to the next segment as a prompt
155
  prompt_window = deque()
393
 
394
 
395
  class VadSileroTranscription(AbstractTranscription):
396
+ def __init__(self, sampling_rate: int = 16000, cache: ModelCache = None):
397
  super().__init__(sampling_rate=sampling_rate)
398
+ self.model = None
399
+ self.cache = cache
400
+ self._initialize_model()
401
+
402
+ def _initialize_model(self):
403
+ if (self.cache is not None):
404
+ model_key = "VadSileroTranscription"
405
+ self.model, self.get_speech_timestamps = self.cache.get(model_key, self._create_model)
406
+ print("Loaded Silerio model from cache.")
407
+ else:
408
+ self.model, self.get_speech_timestamps = self._create_model()
409
+ print("Created Silerio model")
410
+
411
+ def _create_model(self):
412
+ model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
413
+
414
+ # Silero does not benefit from multi-threading
415
+ torch.set_num_threads(1) # JIT
416
+ (get_speech_timestamps, _, _, _, _) = utils
417
+
418
+ return model, get_speech_timestamps
419
+
420
+ def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig, start_time: float, end_time: float):
421
  result = []
422
 
423
+ print("Getting timestamps from audio file: {}, start: {}, duration: {}".format(audio, start_time, end_time))
424
+ perf_start_time = time.perf_counter()
425
+
426
  # Divide procesisng of audio into chunks
427
+ chunk_start = start_time
428
 
429
+ while (chunk_start < end_time):
430
+ chunk_duration = min(end_time - chunk_start, VAD_MAX_PROCESSING_CHUNK)
431
 
432
  print("Processing VAD in chunk from {} to {}".format(format_timestamp(chunk_start), format_timestamp(chunk_start + chunk_duration)))
433
  wav = self.get_audio_segment(audio, str(chunk_start), str(chunk_duration))
441
  result.extend(adjusted)
442
  chunk_start += chunk_duration
443
 
444
+ perf_end_time = time.perf_counter()
445
+ print("VAD processing took {} seconds".format(perf_end_time - perf_start_time))
446
+
447
  return result
448
 
449
+ def __getstate__(self):
450
+ # We only need the sampling rate
451
+ return { 'sampling_rate': self.sampling_rate }
452
+
453
+ def __setstate__(self, state):
454
+ self.sampling_rate = state['sampling_rate']
455
+ self.model = None
456
+ # Use the global cache
457
+ self.cache = GLOBAL_MODEL_CACHE
458
+ self._initialize_model()
459
+
460
  # A very simple VAD that just marks every N seconds as speech
461
  class VadPeriodicTranscription(AbstractTranscription):
462
  def __init__(self, sampling_rate: int = 16000):
463
  super().__init__(sampling_rate=sampling_rate)
464
 
465
+ def get_transcribe_timestamps(self, audio: str, config: PeriodicTranscriptionConfig, start_time: float, end_time: float):
 
 
466
  result = []
467
 
468
  # Generate a timestamp every N seconds
469
+ start_timestamp = start_time
470
 
471
+ while (start_timestamp < end_time):
472
+ end_timestamp = min(start_timestamp + config.periodic_duration, end_time)
473
  segment_duration = end_timestamp - start_timestamp
474
 
475
  # Minimum duration is 1 second
src/vadParallel.py CHANGED
@@ -1,12 +1,12 @@
1
  import multiprocessing
2
  import threading
3
  import time
4
- from src.vad import AbstractTranscription, TranscriptionConfig
5
  from src.whisperContainer import WhisperCallback
6
 
7
  from multiprocessing import Pool
8
 
9
- from typing import List
10
  import os
11
 
12
 
@@ -76,19 +76,28 @@ class ParallelTranscriptionConfig(TranscriptionConfig):
76
  super().__init__(copy.non_speech_strategy, copy.segment_padding_left, copy.segment_padding_right, copy.max_silent_period, copy.max_merge_size, copy.max_prompt_window, initial_segment_index)
77
  self.device_id = device_id
78
  self.override_timestamps = override_timestamps
79
-
80
  class ParallelTranscription(AbstractTranscription):
 
 
 
 
81
  def __init__(self, sampling_rate: int = 16000):
82
  super().__init__(sampling_rate=sampling_rate)
83
 
84
-
85
- def transcribe_parallel(self, transcription: AbstractTranscription, audio: str, whisperCallable: WhisperCallback, config: TranscriptionConfig, devices: List[str], parallel_context: ParallelContext = None):
 
 
86
  # First, get the timestamps for the original audio
87
- merged = transcription.get_merged_timestamps(audio, config)
 
 
 
88
 
89
  # Split into a list for each device
90
  # TODO: Split by time instead of by number of chunks
91
- merged_split = list(self._split(merged, len(devices)))
92
 
93
  # Parameters that will be passed to the transcribe function
94
  parameters = []
@@ -96,15 +105,15 @@ class ParallelTranscription(AbstractTranscription):
96
 
97
  for i in range(len(merged_split)):
98
  device_segment_list = list(merged_split[i])
99
- device_id = devices[i]
100
 
101
  if (len(device_segment_list) <= 0):
102
  continue
103
 
104
- print("Device " + device_id + " (index " + str(i) + ") has " + str(len(device_segment_list)) + " segments")
105
 
106
  # Create a new config with the given device ID
107
- device_config = ParallelTranscriptionConfig(devices[i], device_segment_list, segment_index, config)
108
  segment_index += len(device_segment_list)
109
 
110
  parameters.append([audio, whisperCallable, device_config]);
@@ -119,12 +128,12 @@ class ParallelTranscription(AbstractTranscription):
119
 
120
  # Spawn a separate process for each device
121
  try:
122
- if (parallel_context is None):
123
- parallel_context = ParallelContext(len(devices))
124
  created_context = True
125
 
126
  # Get a pool of processes
127
- pool = parallel_context.get_pool()
128
 
129
  # Run the transcription in parallel
130
  results = pool.starmap(self.transcribe, parameters)
@@ -140,29 +149,90 @@ class ParallelTranscription(AbstractTranscription):
140
 
141
  finally:
142
  # Return the pool to the context
143
- if (parallel_context is not None):
144
- parallel_context.return_pool(pool)
145
  # Always close the context if we created it
146
  if (created_context):
147
- parallel_context.close()
148
 
149
  return merged
150
 
151
- def get_transcribe_timestamps(self, audio: str, config: ParallelTranscriptionConfig):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  return []
153
 
154
- def get_merged_timestamps(self, audio: str, config: ParallelTranscriptionConfig):
155
  # Override timestamps that will be processed
156
  if (config.override_timestamps is not None):
157
  print("Using override timestamps of size " + str(len(config.override_timestamps)))
158
  return config.override_timestamps
159
- return super().get_merged_timestamps(audio, config)
160
 
161
  def transcribe(self, audio: str, whisperCallable: WhisperCallback, config: ParallelTranscriptionConfig):
162
- # Override device ID
163
- if (config.device_id is not None):
164
- print("Using device " + config.device_id)
165
- os.environ["CUDA_VISIBLE_DEVICES"] = config.device_id
 
 
 
 
 
 
166
  return super().transcribe(audio, whisperCallable, config)
167
 
168
  def _split(self, a, n):
1
  import multiprocessing
2
  import threading
3
  import time
4
+ from src.vad import AbstractTranscription, TranscriptionConfig, get_audio_duration
5
  from src.whisperContainer import WhisperCallback
6
 
7
  from multiprocessing import Pool
8
 
9
+ from typing import Any, Dict, List
10
  import os
11
 
12
 
76
  super().__init__(copy.non_speech_strategy, copy.segment_padding_left, copy.segment_padding_right, copy.max_silent_period, copy.max_merge_size, copy.max_prompt_window, initial_segment_index)
77
  self.device_id = device_id
78
  self.override_timestamps = override_timestamps
79
+
80
  class ParallelTranscription(AbstractTranscription):
81
+ # Silero VAD typically takes about 3 seconds per minute, so there's no need to split the chunks
82
+ # into smaller segments than 2 minute (min 6 seconds per CPU core)
83
+ MIN_CPU_CHUNK_SIZE_SECONDS = 2 * 60
84
+
85
  def __init__(self, sampling_rate: int = 16000):
86
  super().__init__(sampling_rate=sampling_rate)
87
 
88
+ def transcribe_parallel(self, transcription: AbstractTranscription, audio: str, whisperCallable: WhisperCallback, config: TranscriptionConfig,
89
+ cpu_device_count: int, gpu_devices: List[str], cpu_parallel_context: ParallelContext = None, gpu_parallel_context: ParallelContext = None):
90
+ total_duration = get_audio_duration(audio)
91
+
92
  # First, get the timestamps for the original audio
93
+ if (cpu_device_count > 1):
94
+ merged = self._get_merged_timestamps_parallel(transcription, audio, config, total_duration, cpu_device_count, cpu_parallel_context)
95
+ else:
96
+ merged = transcription.get_merged_timestamps(audio, config, total_duration)
97
 
98
  # Split into a list for each device
99
  # TODO: Split by time instead of by number of chunks
100
+ merged_split = list(self._split(merged, len(gpu_devices)))
101
 
102
  # Parameters that will be passed to the transcribe function
103
  parameters = []
105
 
106
  for i in range(len(merged_split)):
107
  device_segment_list = list(merged_split[i])
108
+ device_id = gpu_devices[i]
109
 
110
  if (len(device_segment_list) <= 0):
111
  continue
112
 
113
+ print("Device " + str(device_id) + " (index " + str(i) + ") has " + str(len(device_segment_list)) + " segments")
114
 
115
  # Create a new config with the given device ID
116
+ device_config = ParallelTranscriptionConfig(device_id, device_segment_list, segment_index, config)
117
  segment_index += len(device_segment_list)
118
 
119
  parameters.append([audio, whisperCallable, device_config]);
128
 
129
  # Spawn a separate process for each device
130
  try:
131
+ if (gpu_parallel_context is None):
132
+ gpu_parallel_context = ParallelContext(len(gpu_devices))
133
  created_context = True
134
 
135
  # Get a pool of processes
136
+ pool = gpu_parallel_context.get_pool()
137
 
138
  # Run the transcription in parallel
139
  results = pool.starmap(self.transcribe, parameters)
149
 
150
  finally:
151
  # Return the pool to the context
152
+ if (gpu_parallel_context is not None):
153
+ gpu_parallel_context.return_pool(pool)
154
  # Always close the context if we created it
155
  if (created_context):
156
+ gpu_parallel_context.close()
157
 
158
  return merged
159
 
160
+ def _get_merged_timestamps_parallel(self, transcription: AbstractTranscription, audio: str, config: TranscriptionConfig, total_duration: float,
161
+ cpu_device_count: int, cpu_parallel_context: ParallelContext = None):
162
+ parameters = []
163
+
164
+ chunk_size = max(total_duration / cpu_device_count, self.MIN_CPU_CHUNK_SIZE_SECONDS)
165
+ chunk_start = 0
166
+ cpu_device_id = 0
167
+
168
+ perf_start_time = time.perf_counter()
169
+
170
+ # Create chunks that will be processed on the CPU
171
+ while (chunk_start < total_duration):
172
+ chunk_end = min(chunk_start + chunk_size, total_duration)
173
+
174
+ print("Parallel VAD: Executing chunk from " + str(chunk_start) + " to " +
175
+ str(chunk_end) + " on CPU device " + str(cpu_device_id))
176
+ parameters.append([audio, config, chunk_start, chunk_end]);
177
+
178
+ cpu_device_id += 1
179
+ chunk_start = chunk_end
180
+
181
+ created_context = False
182
+
183
+ # Spawn a separate process for each device
184
+ try:
185
+ if (cpu_parallel_context is None):
186
+ cpu_parallel_context = ParallelContext(cpu_device_count)
187
+ created_context = True
188
+
189
+ # Get a pool of processes
190
+ pool = cpu_parallel_context.get_pool()
191
+
192
+ # Run the transcription in parallel. Note that transcription must be picklable.
193
+ results = pool.starmap(transcription.get_transcribe_timestamps, parameters)
194
+
195
+ timestamps = []
196
+
197
+ # Flatten the results
198
+ for result in results:
199
+ timestamps.extend(result)
200
+
201
+ merged = transcription.get_merged_timestamps(timestamps, config, total_duration)
202
+
203
+ perf_end_time = time.perf_counter()
204
+ print("Parallel VAD processing took {} seconds".format(perf_end_time - perf_start_time))
205
+ return merged
206
+
207
+ finally:
208
+ # Return the pool to the context
209
+ if (cpu_parallel_context is not None):
210
+ cpu_parallel_context.return_pool(pool)
211
+ # Always close the context if we created it
212
+ if (created_context):
213
+ cpu_parallel_context.close()
214
+
215
+ def get_transcribe_timestamps(self, audio: str, config: ParallelTranscriptionConfig, start_time: float, duration: float):
216
  return []
217
 
218
+ def get_merged_timestamps(self, timestamps: List[Dict[str, Any]], config: ParallelTranscriptionConfig, total_duration: float):
219
  # Override timestamps that will be processed
220
  if (config.override_timestamps is not None):
221
  print("Using override timestamps of size " + str(len(config.override_timestamps)))
222
  return config.override_timestamps
223
+ return super().get_merged_timestamps(timestamps, config, total_duration)
224
 
225
  def transcribe(self, audio: str, whisperCallable: WhisperCallback, config: ParallelTranscriptionConfig):
226
+ # Override device ID the first time
227
+ if (os.environ.get("INITIALIZED", None) is None):
228
+ os.environ["INITIALIZED"] = "1"
229
+
230
+ # Note that this may be None if the user didn't specify a device. In that case, Whisper will
231
+ # just use the default GPU device.
232
+ if (config.device_id is not None):
233
+ print("Using device " + config.device_id)
234
+ os.environ["CUDA_VISIBLE_DEVICES"] = config.device_id
235
+
236
  return super().transcribe(audio, whisperCallable, config)
237
 
238
  def _split(self, a, n):
src/whisperContainer.py CHANGED
@@ -1,29 +1,10 @@
1
  # External programs
2
  import whisper
3
 
4
- class WhisperModelCache:
5
- def __init__(self):
6
- self._cache = dict()
7
-
8
- def get(self, model_name, device: str = None):
9
- key = model_name + ":" + (device if device else '')
10
-
11
- result = self._cache.get(key)
12
-
13
- if result is None:
14
- print("Loading whisper model " + model_name)
15
- result = whisper.load_model(name=model_name, device=device)
16
- self._cache[key] = result
17
- return result
18
-
19
- def clear(self):
20
- self._cache.clear()
21
-
22
- # A global cache of models. This is mainly used by the daemon processes to avoid loading the same model multiple times.
23
- GLOBAL_WHISPER_MODEL_CACHE = WhisperModelCache()
24
 
25
  class WhisperContainer:
26
- def __init__(self, model_name: str, device: str = None, download_root: str = None, cache: WhisperModelCache = None):
27
  self.model_name = model_name
28
  self.device = device
29
  self.download_root = download_root
@@ -36,12 +17,16 @@ class WhisperContainer:
36
  if self.model is None:
37
 
38
  if (self.cache is None):
39
- print("Loading whisper model " + self.model_name)
40
- self.model = whisper.load_model(self.model_name, device=self.device, download_root=self.download_root)
41
  else:
42
- self.model = self.cache.get(self.model_name, device=self.device)
 
43
  return self.model
44
 
 
 
 
 
45
  def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
46
  """
47
  Create a WhisperCallback object that can be used to transcript audio files.
@@ -65,14 +50,15 @@ class WhisperContainer:
65
 
66
  # This is required for multiprocessing
67
  def __getstate__(self):
68
- return { "model_name": self.model_name, "device": self.device }
69
 
70
  def __setstate__(self, state):
71
  self.model_name = state["model_name"]
72
  self.device = state["device"]
 
73
  self.model = None
74
  # Depickled objects must use the global cache
75
- self.cache = GLOBAL_WHISPER_MODEL_CACHE
76
 
77
 
78
  class WhisperCallback:
1
  # External programs
2
  import whisper
3
 
4
+ from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  class WhisperContainer:
7
+ def __init__(self, model_name: str, device: str = None, download_root: str = None, cache: ModelCache = None):
8
  self.model_name = model_name
9
  self.device = device
10
  self.download_root = download_root
17
  if self.model is None:
18
 
19
  if (self.cache is None):
20
+ self.model = self._create_model()
 
21
  else:
22
+ model_key = "WhisperContainer." + self.model_name + ":" + (self.device if self.device else '')
23
+ self.model = self.cache.get(model_key, self._create_model)
24
  return self.model
25
 
26
+ def _create_model(self):
27
+ print("Loading whisper model " + self.model_name)
28
+ return whisper.load_model(self.model_name, device=self.device, download_root=self.download_root)
29
+
30
  def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
31
  """
32
  Create a WhisperCallback object that can be used to transcript audio files.
50
 
51
  # This is required for multiprocessing
52
  def __getstate__(self):
53
+ return { "model_name": self.model_name, "device": self.device, "download_root": self.download_root }
54
 
55
  def __setstate__(self, state):
56
  self.model_name = state["model_name"]
57
  self.device = state["device"]
58
+ self.download_root = state["download_root"]
59
  self.model = None
60
  # Depickled objects must use the global cache
61
+ self.cache = GLOBAL_MODEL_CACHE
62
 
63
 
64
  class WhisperCallback:
tests/vad_test.py CHANGED
@@ -5,7 +5,7 @@ import sys
5
 
6
  sys.path.append('../whisper-webui')
7
 
8
- from src.vad import AbstractTranscription, VadSileroTranscription
9
 
10
  class TestVad(unittest.TestCase):
11
  def __init__(self, *args, **kwargs):
@@ -55,7 +55,7 @@ class MockVadTranscription(AbstractTranscription):
55
  # For mocking, this just returns a simple numppy array
56
  return np.array([start_time_seconds, duration_seconds], dtype=np.float64)
57
 
58
- def get_transcribe_timestamps(self, audio: str):
59
  result = []
60
 
61
  result.append( { 'start': 30, 'end': 60 } )
5
 
6
  sys.path.append('../whisper-webui')
7
 
8
+ from src.vad import AbstractTranscription, TranscriptionConfig, VadSileroTranscription
9
 
10
  class TestVad(unittest.TestCase):
11
  def __init__(self, *args, **kwargs):
55
  # For mocking, this just returns a simple numppy array
56
  return np.array([start_time_seconds, duration_seconds], dtype=np.float64)
57
 
58
+ def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig, start_time: float, duration: float):
59
  result = []
60
 
61
  result.append( { 'start': 30, 'end': 60 } )