Transcribe_V0.2 / src /whisper /abstractWhisperContainer.py
HgMenon's picture
Upload 37 files
07915a1
raw
history blame contribute delete
No virus
4.88 kB
import abc
from typing import List
from src.config import ModelConfig, VadInitialPromptMode
from src.hooks.progressListener import ProgressListener
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
class AbstractWhisperCallback:
@abc.abstractmethod
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
"""
Peform the transcription of the given audio file or data.
Parameters
----------
audio: Union[str, np.ndarray, torch.Tensor]
The audio file to transcribe, or the audio data as a numpy array or torch tensor.
segment_index: int
The target language of the transcription. If not specified, the language will be inferred from the audio content.
task: str
The task - either translate or transcribe.
progress_listener: ProgressListener
A callback to receive progress updates.
"""
raise NotImplementedError()
def _get_initial_prompt(self, initial_prompt: str, initial_prompt_mode: VadInitialPromptMode,
prompt: str, segment_index: int):
if (initial_prompt_mode == VadInitialPromptMode.PREPEND_ALL_SEGMENTS):
return self._concat_prompt(initial_prompt, prompt)
elif (initial_prompt_mode == VadInitialPromptMode.PREPREND_FIRST_SEGMENT):
return self._concat_prompt(initial_prompt, prompt) if segment_index == 0 else prompt
else:
raise ValueError(f"Unknown initial prompt mode {initial_prompt_mode}")
def _concat_prompt(self, prompt1, prompt2):
if (prompt1 is None):
return prompt2
elif (prompt2 is None):
return prompt1
else:
return prompt1 + " " + prompt2
class AbstractWhisperContainer:
def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
download_root: str = None,
cache: ModelCache = None, models: List[ModelConfig] = []):
self.model_name = model_name
self.device = device
self.compute_type = compute_type
self.download_root = download_root
self.cache = cache
# Will be created on demand
self.model = None
# List of known models
self.models = models
def get_model(self):
if self.model is None:
if (self.cache is None):
self.model = self._create_model()
else:
model_key = "WhisperContainer." + self.model_name + ":" + (self.device if self.device else '')
self.model = self.cache.get(model_key, self._create_model)
return self.model
@abc.abstractmethod
def _create_model(self):
raise NotImplementedError()
def ensure_downloaded(self):
pass
@abc.abstractmethod
def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None,
initial_prompt_mode: VadInitialPromptMode = VadInitialPromptMode.PREPREND_FIRST_SEGMENT,
**decodeOptions: dict) -> AbstractWhisperCallback:
"""
Create a WhisperCallback object that can be used to transcript audio files.
Parameters
----------
language: str
The target language of the transcription. If not specified, the language will be inferred from the audio content.
task: str
The task - either translate or transcribe.
initial_prompt: str
The initial prompt to use for the transcription.
initial_prompt_mode: VadInitialPromptMode
The mode to use for the initial prompt. If set to PREPEND_FIRST_SEGMENT, the initial prompt will be prepended to the first segment of audio.
If set to PREPEND_ALL_SEGMENTS, the initial prompt will be prepended to all segments of audio.
decodeOptions: dict
Additional options to pass to the decoder. Must be pickleable.
Returns
-------
A WhisperCallback object.
"""
raise NotImplementedError()
# This is required for multiprocessing
def __getstate__(self):
return {
"model_name": self.model_name,
"device": self.device,
"download_root": self.download_root,
"models": self.models,
"compute_type": self.compute_type
}
def __setstate__(self, state):
self.model_name = state["model_name"]
self.device = state["device"]
self.download_root = state["download_root"]
self.models = state["models"]
self.compute_type = state["compute_type"]
self.model = None
# Depickled objects must use the global cache
self.cache = GLOBAL_MODEL_CACHE