Spaces:
No application file
No application file
Upload whisperContainer.py
Browse files- whisperContainer.py +127 -0
whisperContainer.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# External programs
|
2 |
+
import os
|
3 |
+
import whisper
|
4 |
+
|
5 |
+
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
|
6 |
+
|
7 |
+
class WhisperContainer:
|
8 |
+
def __init__(self, model_name: str, device: str = None, download_root: str = None, cache: ModelCache = None):
|
9 |
+
self.model_name = model_name
|
10 |
+
self.device = device
|
11 |
+
self.download_root = download_root
|
12 |
+
self.cache = cache
|
13 |
+
|
14 |
+
# Will be created on demand
|
15 |
+
self.model = None
|
16 |
+
|
17 |
+
def get_model(self):
|
18 |
+
if self.model is None:
|
19 |
+
|
20 |
+
if (self.cache is None):
|
21 |
+
self.model = self._create_model()
|
22 |
+
else:
|
23 |
+
model_key = "WhisperContainer." + self.model_name + ":" + (self.device if self.device else '')
|
24 |
+
self.model = self.cache.get(model_key, self._create_model)
|
25 |
+
return self.model
|
26 |
+
|
27 |
+
def ensure_downloaded(self):
|
28 |
+
"""
|
29 |
+
Ensure that the model is downloaded. This is useful if you want to ensure that the model is downloaded before
|
30 |
+
passing the container to a subprocess.
|
31 |
+
"""
|
32 |
+
# Warning: Using private API here
|
33 |
+
try:
|
34 |
+
root_dir = self.download_root
|
35 |
+
|
36 |
+
if root_dir is None:
|
37 |
+
root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
|
38 |
+
|
39 |
+
if self.model_name in whisper._MODELS:
|
40 |
+
whisper._download(whisper._MODELS[self.model_name], root_dir, False)
|
41 |
+
return True
|
42 |
+
except Exception as e:
|
43 |
+
# Given that the API is private, it could change at any time. We don't want to crash the program
|
44 |
+
print("Error pre-downloading model: " + str(e))
|
45 |
+
return False
|
46 |
+
|
47 |
+
def _create_model(self):
|
48 |
+
print("Loading whisper model " + self.model_name)
|
49 |
+
return whisper.load_model(self.model_name, device=self.device, download_root=self.download_root)
|
50 |
+
|
51 |
+
def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
|
52 |
+
"""
|
53 |
+
Create a WhisperCallback object that can be used to transcript audio files.
|
54 |
+
|
55 |
+
Parameters
|
56 |
+
----------
|
57 |
+
language: str
|
58 |
+
The target language of the transcription. If not specified, the language will be inferred from the audio content.
|
59 |
+
task: str
|
60 |
+
The task - either translate or transcribe.
|
61 |
+
initial_prompt: str
|
62 |
+
The initial prompt to use for the transcription.
|
63 |
+
decodeOptions: dict
|
64 |
+
Additional options to pass to the decoder. Must be pickleable.
|
65 |
+
|
66 |
+
Returns
|
67 |
+
-------
|
68 |
+
A WhisperCallback object.
|
69 |
+
"""
|
70 |
+
return WhisperCallback(self, language=language, task=task, initial_prompt=initial_prompt, **decodeOptions)
|
71 |
+
|
72 |
+
# This is required for multiprocessing
|
73 |
+
def __getstate__(self):
|
74 |
+
return { "model_name": self.model_name, "device": self.device, "download_root": self.download_root }
|
75 |
+
|
76 |
+
def __setstate__(self, state):
|
77 |
+
self.model_name = state["model_name"]
|
78 |
+
self.device = state["device"]
|
79 |
+
self.download_root = state["download_root"]
|
80 |
+
self.model = None
|
81 |
+
# Depickled objects must use the global cache
|
82 |
+
self.cache = GLOBAL_MODEL_CACHE
|
83 |
+
|
84 |
+
|
85 |
+
class WhisperCallback:
|
86 |
+
def __init__(self, model_container: WhisperContainer, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
|
87 |
+
self.model_container = model_container
|
88 |
+
self.language = language
|
89 |
+
self.task = task
|
90 |
+
self.initial_prompt = initial_prompt
|
91 |
+
self.decodeOptions = decodeOptions
|
92 |
+
|
93 |
+
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str):
|
94 |
+
"""
|
95 |
+
Peform the transcription of the given audio file or data.
|
96 |
+
|
97 |
+
Parameters
|
98 |
+
----------
|
99 |
+
audio: Union[str, np.ndarray, torch.Tensor]
|
100 |
+
The audio file to transcribe, or the audio data as a numpy array or torch tensor.
|
101 |
+
segment_index: int
|
102 |
+
The target language of the transcription. If not specified, the language will be inferred from the audio content.
|
103 |
+
task: str
|
104 |
+
The task - either translate or transcribe.
|
105 |
+
prompt: str
|
106 |
+
The prompt to use for the transcription.
|
107 |
+
detected_language: str
|
108 |
+
The detected language of the audio file.
|
109 |
+
|
110 |
+
Returns
|
111 |
+
-------
|
112 |
+
The result of the Whisper call.
|
113 |
+
"""
|
114 |
+
model = self.model_container.get_model()
|
115 |
+
|
116 |
+
return model.transcribe(audio, \
|
117 |
+
language=self.language if self.language else detected_language, task=self.task, \
|
118 |
+
initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \
|
119 |
+
**self.decodeOptions)
|
120 |
+
|
121 |
+
def _concat_prompt(self, prompt1, prompt2):
|
122 |
+
if (prompt1 is None):
|
123 |
+
return prompt2
|
124 |
+
elif (prompt2 is None):
|
125 |
+
return prompt1
|
126 |
+
else:
|
127 |
+
return prompt1 + " " + prompt2
|