# External programs import abc import os import sys from typing import List from urllib.parse import urlparse import torch import urllib3 from src.hooks.progressListener import ProgressListener import whisper from whisper import Whisper from src.config import ModelConfig, VadInitialPromptMode from src.hooks.whisperProgressHook import create_progress_listener_handle from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache from src.prompts.abstractPromptStrategy import AbstractPromptStrategy from src.utils import download_file from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer class WhisperContainer(AbstractWhisperContainer): def __init__(self, model_name: str, device: str = None, compute_type: str = "float16", download_root: str = None, cache: ModelCache = None, models: List[ModelConfig] = []): if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" super().__init__(model_name, device, compute_type, download_root, cache, models) def ensure_downloaded(self): """ Ensure that the model is downloaded. This is useful if you want to ensure that the model is downloaded before passing the container to a subprocess. """ # Warning: Using private API here try: root_dir = self.download_root model_config = self._get_model_config() if root_dir is None: root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper") if self.model_name in whisper._MODELS: whisper._download(whisper._MODELS[self.model_name], root_dir, False) else: # If the model is not in the official list, see if it needs to be downloaded model_config.download_url(root_dir) return True except Exception as e: # Given that the API is private, it could change at any time. We don't want to crash the program print("Error pre-downloading model: " + str(e)) return False def _get_model_config(self) -> ModelConfig: """ Get the model configuration for the model. """ for model in self.models: if model.name == self.model_name: return model return None def _create_model(self): print("Loading whisper model " + self.model_name) model_config = self._get_model_config() # Note that the model will not be downloaded in the case of an official Whisper model model_path = self._get_model_path(model_config, self.download_root) return whisper.load_model(model_path, device=self.device, download_root=self.download_root) def create_callback(self, language: str = None, task: str = None, prompt_strategy: AbstractPromptStrategy = None, **decodeOptions: dict) -> AbstractWhisperCallback: """ Create a WhisperCallback object that can be used to transcript audio files. Parameters ---------- language: str The target language of the transcription. If not specified, the language will be inferred from the audio content. task: str The task - either translate or transcribe. prompt_strategy: AbstractPromptStrategy The prompt strategy to use. If not specified, the prompt from Whisper will be used. decodeOptions: dict Additional options to pass to the decoder. Must be pickleable. Returns ------- A WhisperCallback object. """ return WhisperCallback(self, language=language, task=task, prompt_strategy=prompt_strategy, **decodeOptions) def _get_model_path(self, model_config: ModelConfig, root_dir: str = None): from src.conversion.hf_converter import convert_hf_whisper """ Download the model. Parameters ---------- model_config: ModelConfig The model configuration. """ # See if path is already set if model_config.path is not None: return model_config.path if root_dir is None: root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper") model_type = model_config.type.lower() if model_config.type is not None else "whisper" if model_type in ["huggingface", "hf"]: model_config.path = model_config.url destination_target = os.path.join(root_dir, model_config.name + ".pt") # Convert from HuggingFace format to Whisper format if os.path.exists(destination_target): print(f"File {destination_target} already exists, skipping conversion") else: print("Saving HuggingFace model in Whisper format to " + destination_target) convert_hf_whisper(model_config.url, destination_target) model_config.path = destination_target elif model_type in ["whisper", "w"]: model_config.path = model_config.url # See if URL is just a file if model_config.url in whisper._MODELS: # No need to download anything - Whisper will handle it model_config.path = model_config.url elif model_config.url.startswith("file://"): # Get file path model_config.path = urlparse(model_config.url).path # See if it is an URL elif model_config.url.startswith("http://") or model_config.url.startswith("https://"): # Extension (or file name) extension = os.path.splitext(model_config.url)[-1] download_target = os.path.join(root_dir, model_config.name + extension) if os.path.exists(download_target) and not os.path.isfile(download_target): raise RuntimeError(f"{download_target} exists and is not a regular file") if not os.path.isfile(download_target): download_file(model_config.url, download_target) else: print(f"File {download_target} already exists, skipping download") model_config.path = download_target # Must be a local file else: model_config.path = model_config.url else: raise ValueError(f"Unknown model type {model_type}") return model_config.path class WhisperCallback(AbstractWhisperCallback): def __init__(self, model_container: WhisperContainer, language: str = None, task: str = None, prompt_strategy: AbstractPromptStrategy = None, **decodeOptions: dict): self.model_container = model_container self.language = language self.task = task self.prompt_strategy = prompt_strategy self.decodeOptions = decodeOptions def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None): """ Peform the transcription of the given audio file or data. Parameters ---------- audio: Union[str, np.ndarray, torch.Tensor] The audio file to transcribe, or the audio data as a numpy array or torch tensor. segment_index: int The target language of the transcription. If not specified, the language will be inferred from the audio content. task: str The task - either translate or transcribe. progress_listener: ProgressListener A callback to receive progress updates. """ model = self.model_container.get_model() if progress_listener is not None: with create_progress_listener_handle(progress_listener): return self._transcribe(model, audio, segment_index, prompt, detected_language) else: return self._transcribe(model, audio, segment_index, prompt, detected_language) def _transcribe(self, model: Whisper, audio, segment_index: int, prompt: str, detected_language: str): decodeOptions = self.decodeOptions.copy() # Add fp16 if self.model_container.compute_type in ["fp16", "float16"]: decodeOptions["fp16"] = True initial_prompt = self.prompt_strategy.get_segment_prompt(segment_index, prompt, detected_language) \ if self.prompt_strategy else prompt result = model.transcribe(audio, \ language=self.language if self.language else detected_language, task=self.task, \ initial_prompt=initial_prompt, \ **decodeOptions ) # If we have a prompt strategy, we need to increment the current prompt if self.prompt_strategy: self.prompt_strategy.on_segment_finished(segment_index, prompt, detected_language, result) return result