Spaces:
Runtime error
Runtime error
# External programs | |
import os | |
import sys | |
from typing import List | |
import whisper | |
from whisper import Whisper | |
from src.config import ModelConfig | |
from src.hooks.whisperProgressHook import ProgressListener, create_progress_listener_handle | |
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache | |
class WhisperContainer: | |
def __init__(self, model_name: str, device: str = None, download_root: str = None, | |
cache: ModelCache = None, models: List[ModelConfig] = []): | |
self.model_name = model_name | |
self.device = device | |
self.download_root = download_root | |
self.cache = cache | |
# Will be created on demand | |
self.model = None | |
# List of known models | |
self.models = models | |
def get_model(self): | |
if self.model is None: | |
if (self.cache is None): | |
self.model = self._create_model() | |
else: | |
model_key = "WhisperContainer." + self.model_name + ":" + (self.device if self.device else '') | |
self.model = self.cache.get(model_key, self._create_model) | |
return self.model | |
def ensure_downloaded(self): | |
""" | |
Ensure that the model is downloaded. This is useful if you want to ensure that the model is downloaded before | |
passing the container to a subprocess. | |
""" | |
# Warning: Using private API here | |
try: | |
root_dir = self.download_root | |
model_config = self.get_model_config() | |
if root_dir is None: | |
root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper") | |
if self.model_name in whisper._MODELS: | |
whisper._download(whisper._MODELS[self.model_name], root_dir, False) | |
else: | |
# If the model is not in the official list, see if it needs to be downloaded | |
model_config.download_url(root_dir) | |
return True | |
except Exception as e: | |
# Given that the API is private, it could change at any time. We don't want to crash the program | |
print("Error pre-downloading model: " + str(e)) | |
return False | |
def get_model_config(self) -> ModelConfig: | |
""" | |
Get the model configuration for the model. | |
""" | |
for model in self.models: | |
if model.name == self.model_name: | |
return model | |
return None | |
def _create_model(self): | |
print("Loading whisper model " + self.model_name) | |
model_config = self.get_model_config() | |
# Note that the model will not be downloaded in the case of an official Whisper model | |
model_path = model_config.download_url(self.download_root) | |
return whisper.load_model(model_path, device=self.device, download_root=self.download_root) | |
def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict): | |
""" | |
Create a WhisperCallback object that can be used to transcript audio files. | |
Parameters | |
---------- | |
language: str | |
The target language of the transcription. If not specified, the language will be inferred from the audio content. | |
task: str | |
The task - either translate or transcribe. | |
initial_prompt: str | |
The initial prompt to use for the transcription. | |
decodeOptions: dict | |
Additional options to pass to the decoder. Must be pickleable. | |
Returns | |
------- | |
A WhisperCallback object. | |
""" | |
return WhisperCallback(self, language=language, task=task, initial_prompt=initial_prompt, **decodeOptions) | |
# This is required for multiprocessing | |
def __getstate__(self): | |
return { "model_name": self.model_name, "device": self.device, "download_root": self.download_root, "models": self.models } | |
def __setstate__(self, state): | |
self.model_name = state["model_name"] | |
self.device = state["device"] | |
self.download_root = state["download_root"] | |
self.models = state["models"] | |
self.model = None | |
# Depickled objects must use the global cache | |
self.cache = GLOBAL_MODEL_CACHE | |
class WhisperCallback: | |
def __init__(self, model_container: WhisperContainer, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict): | |
self.model_container = model_container | |
self.language = language | |
self.task = task | |
self.initial_prompt = initial_prompt | |
self.decodeOptions = decodeOptions | |
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None): | |
""" | |
Peform the transcription of the given audio file or data. | |
Parameters | |
---------- | |
audio: Union[str, np.ndarray, torch.Tensor] | |
The audio file to transcribe, or the audio data as a numpy array or torch tensor. | |
segment_index: int | |
The target language of the transcription. If not specified, the language will be inferred from the audio content. | |
task: str | |
The task - either translate or transcribe. | |
prompt: str | |
The prompt to use for the transcription. | |
detected_language: str | |
The detected language of the audio file. | |
Returns | |
------- | |
The result of the Whisper call. | |
""" | |
model = self.model_container.get_model() | |
if progress_listener is not None: | |
with create_progress_listener_handle(progress_listener): | |
return self._transcribe(model, audio, segment_index, prompt, detected_language) | |
else: | |
return self._transcribe(model, audio, segment_index, prompt, detected_language) | |
def _transcribe(self, model: Whisper, audio, segment_index: int, prompt: str, detected_language: str): | |
return model.transcribe(audio, \ | |
language=self.language if self.language else detected_language, task=self.task, \ | |
initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \ | |
**self.decodeOptions | |
) | |
def _concat_prompt(self, prompt1, prompt2): | |
if (prompt1 is None): | |
return prompt2 | |
elif (prompt2 is None): | |
return prompt1 | |
else: | |
return prompt1 + " " + prompt2 |