aadnk commited on
Commit
7c5d37e
1 Parent(s): 82aecd5

Ensure model is downloaded before spawning sub-processes

Browse files
Files changed (2) hide show
  1. src/vadParallel.py +4 -0
  2. src/whisperContainer.py +15 -0
src/vadParallel.py CHANGED
@@ -96,6 +96,10 @@ class ParallelTranscription(AbstractTranscription):
96
  timestamp_segments = transcription.get_transcribe_timestamps(audio, config, 0, total_duration)
97
  merged = transcription.get_merged_timestamps(timestamp_segments, config, total_duration)
98
 
 
 
 
 
99
  # Split into a list for each device
100
  # TODO: Split by time instead of by number of chunks
101
  merged_split = list(self._split(merged, len(gpu_devices)))
 
96
  timestamp_segments = transcription.get_transcribe_timestamps(audio, config, 0, total_duration)
97
  merged = transcription.get_merged_timestamps(timestamp_segments, config, total_duration)
98
 
99
+ # We must make sure the whisper model is downloaded
100
+ if (len(gpu_devices) > 1):
101
+ whisperCallable.model_container.ensure_downloaded()
102
+
103
  # Split into a list for each device
104
  # TODO: Split by time instead of by number of chunks
105
  merged_split = list(self._split(merged, len(gpu_devices)))
src/whisperContainer.py CHANGED
@@ -23,6 +23,21 @@ class WhisperContainer:
23
  self.model = self.cache.get(model_key, self._create_model)
24
  return self.model
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def _create_model(self):
27
  print("Loading whisper model " + self.model_name)
28
  return whisper.load_model(self.model_name, device=self.device, download_root=self.download_root)
 
23
  self.model = self.cache.get(model_key, self._create_model)
24
  return self.model
25
 
26
+ def ensure_downloaded(self):
27
+ """
28
+ Ensure that the model is downloaded. This is useful if you want to ensure that the model is downloaded before
29
+ passing the container to a subprocess.
30
+ """
31
+ # Warning: Using private API here
32
+ try:
33
+ if self.model_name in whisper._MODELS:
34
+ whisper._download(whisper._MODELS[self.model_name], self.download_root, False)
35
+ return True
36
+ except Exception as e:
37
+ # Given that the API is private, it could change at any time. We don't want to crash the program
38
+ print("Error pre-downloading model: " + str(e))
39
+ return False
40
+
41
  def _create_model(self):
42
  print("Loading whisper model " + self.model_name)
43
  return whisper.load_model(self.model_name, device=self.device, download_root=self.download_root)