pragyaa commited on
Commit
cdeec9e
1 Parent(s): 207320e

Upload whisperContainer.py

Browse files
Files changed (1) hide show
  1. whisperContainer.py +127 -0
whisperContainer.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # External programs
2
+ import os
3
+ import whisper
4
+
5
+ from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
6
+
7
+ class WhisperContainer:
8
+ def __init__(self, model_name: str, device: str = None, download_root: str = None, cache: ModelCache = None):
9
+ self.model_name = model_name
10
+ self.device = device
11
+ self.download_root = download_root
12
+ self.cache = cache
13
+
14
+ # Will be created on demand
15
+ self.model = None
16
+
17
+ def get_model(self):
18
+ if self.model is None:
19
+
20
+ if (self.cache is None):
21
+ self.model = self._create_model()
22
+ else:
23
+ model_key = "WhisperContainer." + self.model_name + ":" + (self.device if self.device else '')
24
+ self.model = self.cache.get(model_key, self._create_model)
25
+ return self.model
26
+
27
+ def ensure_downloaded(self):
28
+ """
29
+ Ensure that the model is downloaded. This is useful if you want to ensure that the model is downloaded before
30
+ passing the container to a subprocess.
31
+ """
32
+ # Warning: Using private API here
33
+ try:
34
+ root_dir = self.download_root
35
+
36
+ if root_dir is None:
37
+ root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
38
+
39
+ if self.model_name in whisper._MODELS:
40
+ whisper._download(whisper._MODELS[self.model_name], root_dir, False)
41
+ return True
42
+ except Exception as e:
43
+ # Given that the API is private, it could change at any time. We don't want to crash the program
44
+ print("Error pre-downloading model: " + str(e))
45
+ return False
46
+
47
+ def _create_model(self):
48
+ print("Loading whisper model " + self.model_name)
49
+ return whisper.load_model(self.model_name, device=self.device, download_root=self.download_root)
50
+
51
+ def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
52
+ """
53
+ Create a WhisperCallback object that can be used to transcript audio files.
54
+
55
+ Parameters
56
+ ----------
57
+ language: str
58
+ The target language of the transcription. If not specified, the language will be inferred from the audio content.
59
+ task: str
60
+ The task - either translate or transcribe.
61
+ initial_prompt: str
62
+ The initial prompt to use for the transcription.
63
+ decodeOptions: dict
64
+ Additional options to pass to the decoder. Must be pickleable.
65
+
66
+ Returns
67
+ -------
68
+ A WhisperCallback object.
69
+ """
70
+ return WhisperCallback(self, language=language, task=task, initial_prompt=initial_prompt, **decodeOptions)
71
+
72
+ # This is required for multiprocessing
73
+ def __getstate__(self):
74
+ return { "model_name": self.model_name, "device": self.device, "download_root": self.download_root }
75
+
76
+ def __setstate__(self, state):
77
+ self.model_name = state["model_name"]
78
+ self.device = state["device"]
79
+ self.download_root = state["download_root"]
80
+ self.model = None
81
+ # Depickled objects must use the global cache
82
+ self.cache = GLOBAL_MODEL_CACHE
83
+
84
+
85
+ class WhisperCallback:
86
+ def __init__(self, model_container: WhisperContainer, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
87
+ self.model_container = model_container
88
+ self.language = language
89
+ self.task = task
90
+ self.initial_prompt = initial_prompt
91
+ self.decodeOptions = decodeOptions
92
+
93
+ def invoke(self, audio, segment_index: int, prompt: str, detected_language: str):
94
+ """
95
+ Peform the transcription of the given audio file or data.
96
+
97
+ Parameters
98
+ ----------
99
+ audio: Union[str, np.ndarray, torch.Tensor]
100
+ The audio file to transcribe, or the audio data as a numpy array or torch tensor.
101
+ segment_index: int
102
+ The target language of the transcription. If not specified, the language will be inferred from the audio content.
103
+ task: str
104
+ The task - either translate or transcribe.
105
+ prompt: str
106
+ The prompt to use for the transcription.
107
+ detected_language: str
108
+ The detected language of the audio file.
109
+
110
+ Returns
111
+ -------
112
+ The result of the Whisper call.
113
+ """
114
+ model = self.model_container.get_model()
115
+
116
+ return model.transcribe(audio, \
117
+ language=self.language if self.language else detected_language, task=self.task, \
118
+ initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \
119
+ **self.decodeOptions)
120
+
121
+ def _concat_prompt(self, prompt1, prompt2):
122
+ if (prompt1 is None):
123
+ return prompt2
124
+ elif (prompt2 is None):
125
+ return prompt1
126
+ else:
127
+ return prompt1 + " " + prompt2