# External programs import os import whisper from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache class WhisperContainer: def __init__(self, model_name: str, device: str = None, download_root: str = None, cache: ModelCache = None): self.model_name = model_name self.device = device self.download_root = download_root self.cache = cache # Will be created on demand self.model = None def get_model(self): if self.model is None: if (self.cache is None): self.model = self._create_model() else: model_key = "WhisperContainer." + self.model_name + ":" + (self.device if self.device else '') self.model = self.cache.get(model_key, self._create_model) return self.model def ensure_downloaded(self): """ Ensure that the model is downloaded. This is useful if you want to ensure that the model is downloaded before passing the container to a subprocess. """ # Warning: Using private API here try: root_dir = self.download_root if root_dir is None: root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper") if self.model_name in whisper._MODELS: whisper._download(whisper._MODELS[self.model_name], root_dir, False) return True except Exception as e: # Given that the API is private, it could change at any time. We don't want to crash the program print("Error pre-downloading model: " + str(e)) return False def _create_model(self): print("Loading whisper model " + self.model_name) return whisper.load_model(self.model_name, device=self.device, download_root=self.download_root) def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict): """ Create a WhisperCallback object that can be used to transcript audio files. Parameters ---------- language: str The target language of the transcription. If not specified, the language will be inferred from the audio content. task: str The task - either translate or transcribe. initial_prompt: str The initial prompt to use for the transcription. decodeOptions: dict Additional options to pass to the decoder. Must be pickleable. Returns ------- A WhisperCallback object. """ return WhisperCallback(self, language=language, task=task, initial_prompt=initial_prompt, **decodeOptions) # This is required for multiprocessing def __getstate__(self): return { "model_name": self.model_name, "device": self.device, "download_root": self.download_root } def __setstate__(self, state): self.model_name = state["model_name"] self.device = state["device"] self.download_root = state["download_root"] self.model = None # Depickled objects must use the global cache self.cache = GLOBAL_MODEL_CACHE class WhisperCallback: def __init__(self, model_container: WhisperContainer, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict): self.model_container = model_container self.language = language self.task = task self.initial_prompt = initial_prompt self.decodeOptions = decodeOptions def invoke(self, audio, segment_index: int, prompt: str, detected_language: str): """ Peform the transcription of the given audio file or data. Parameters ---------- audio: Union[str, np.ndarray, torch.Tensor] The audio file to transcribe, or the audio data as a numpy array or torch tensor. segment_index: int The target language of the transcription. If not specified, the language will be inferred from the audio content. task: str The task - either translate or transcribe. prompt: str The prompt to use for the transcription. detected_language: str The detected language of the audio file. Returns ------- The result of the Whisper call. """ model = self.model_container.get_model() return model.transcribe(audio, \ language=self.language if self.language else detected_language, task=self.task, \ initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \ **self.decodeOptions) def _concat_prompt(self, prompt1, prompt2): if (prompt1 is None): return prompt2 elif (prompt2 is None): return prompt1 else: return prompt1 + " " + prompt2