Spaces:
Sleeping
Sleeping
Ensure model is downloaded before spawning sub-processes
Browse files- src/vadParallel.py +4 -0
- src/whisperContainer.py +15 -0
src/vadParallel.py
CHANGED
@@ -96,6 +96,10 @@ class ParallelTranscription(AbstractTranscription):
|
|
96 |
timestamp_segments = transcription.get_transcribe_timestamps(audio, config, 0, total_duration)
|
97 |
merged = transcription.get_merged_timestamps(timestamp_segments, config, total_duration)
|
98 |
|
|
|
|
|
|
|
|
|
99 |
# Split into a list for each device
|
100 |
# TODO: Split by time instead of by number of chunks
|
101 |
merged_split = list(self._split(merged, len(gpu_devices)))
|
|
|
96 |
timestamp_segments = transcription.get_transcribe_timestamps(audio, config, 0, total_duration)
|
97 |
merged = transcription.get_merged_timestamps(timestamp_segments, config, total_duration)
|
98 |
|
99 |
+
# We must make sure the whisper model is downloaded
|
100 |
+
if (len(gpu_devices) > 1):
|
101 |
+
whisperCallable.model_container.ensure_downloaded()
|
102 |
+
|
103 |
# Split into a list for each device
|
104 |
# TODO: Split by time instead of by number of chunks
|
105 |
merged_split = list(self._split(merged, len(gpu_devices)))
|
src/whisperContainer.py
CHANGED
@@ -23,6 +23,21 @@ class WhisperContainer:
|
|
23 |
self.model = self.cache.get(model_key, self._create_model)
|
24 |
return self.model
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
def _create_model(self):
|
27 |
print("Loading whisper model " + self.model_name)
|
28 |
return whisper.load_model(self.model_name, device=self.device, download_root=self.download_root)
|
|
|
23 |
self.model = self.cache.get(model_key, self._create_model)
|
24 |
return self.model
|
25 |
|
26 |
+
def ensure_downloaded(self):
|
27 |
+
"""
|
28 |
+
Ensure that the model is downloaded. This is useful if you want to ensure that the model is downloaded before
|
29 |
+
passing the container to a subprocess.
|
30 |
+
"""
|
31 |
+
# Warning: Using private API here
|
32 |
+
try:
|
33 |
+
if self.model_name in whisper._MODELS:
|
34 |
+
whisper._download(whisper._MODELS[self.model_name], self.download_root, False)
|
35 |
+
return True
|
36 |
+
except Exception as e:
|
37 |
+
# Given that the API is private, it could change at any time. We don't want to crash the program
|
38 |
+
print("Error pre-downloading model: " + str(e))
|
39 |
+
return False
|
40 |
+
|
41 |
def _create_model(self):
|
42 |
print("Loading whisper model " + self.model_name)
|
43 |
return whisper.load_model(self.model_name, device=self.device, download_root=self.download_root)
|