File size: 4,309 Bytes
295de00 052fe7e 74b7d77 8031785 295de00 74b7d77 295de00 74b7d77 052fe7e 74b7d77 295de00 052fe7e 295de00 33ee1bb 295de00 33ee1bb 295de00 1e744c4 74b7d77 8031785 295de00 1e744c4 295de00 74b7d77 295de00 33ee1bb 295de00 33ee1bb 295de00 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import abc
from typing import Any, Callable, List
from src.config import ModelConfig, VadInitialPromptMode
from src.hooks.progressListener import ProgressListener
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
class AbstractWhisperCallback:
def __init__(self):
pass
@abc.abstractmethod
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
"""
Peform the transcription of the given audio file or data.
Parameters
----------
audio: Union[str, np.ndarray, torch.Tensor]
The audio file to transcribe, or the audio data as a numpy array or torch tensor.
segment_index: int
The target language of the transcription. If not specified, the language will be inferred from the audio content.
task: str
The task - either translate or transcribe.
progress_listener: ProgressListener
A callback to receive progress updates.
"""
raise NotImplementedError()
class LambdaWhisperCallback(AbstractWhisperCallback):
def __init__(self, callback_lambda: Callable[[Any, int, str, str, ProgressListener], None]):
super().__init__()
self.callback_lambda = callback_lambda
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
return self.callback_lambda(audio, segment_index, prompt, detected_language, progress_listener)
class AbstractWhisperContainer:
def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
download_root: str = None,
cache: ModelCache = None, models: List[ModelConfig] = []):
self.model_name = model_name
self.device = device
self.compute_type = compute_type
self.download_root = download_root
self.cache = cache
# Will be created on demand
self.model = None
# List of known models
self.models = models
def get_model(self):
if self.model is None:
if (self.cache is None):
self.model = self._create_model()
else:
model_key = "WhisperContainer." + self.model_name + ":" + (self.device if self.device else '')
self.model = self.cache.get(model_key, self._create_model)
return self.model
@abc.abstractmethod
def _create_model(self):
raise NotImplementedError()
def ensure_downloaded(self):
pass
@abc.abstractmethod
def create_callback(self, languageCode: str = None, task: str = None,
prompt_strategy: AbstractPromptStrategy = None,
**decodeOptions: dict) -> AbstractWhisperCallback:
"""
Create a WhisperCallback object that can be used to transcript audio files.
Parameters
----------
languageCode: str
The target language code of the transcription. If not specified, the language will be inferred from the audio content.
task: str
The task - either translate or transcribe.
prompt_strategy: AbstractPromptStrategy
The prompt strategy to use for the transcription.
decodeOptions: dict
Additional options to pass to the decoder. Must be pickleable.
Returns
-------
A WhisperCallback object.
"""
raise NotImplementedError()
# This is required for multiprocessing
def __getstate__(self):
return {
"model_name": self.model_name,
"device": self.device,
"download_root": self.download_root,
"models": self.models,
"compute_type": self.compute_type
}
def __setstate__(self, state):
self.model_name = state["model_name"]
self.device = state["device"]
self.download_root = state["download_root"]
self.models = state["models"]
self.compute_type = state["compute_type"]
self.model = None
# Depickled objects must use the global cache
self.cache = GLOBAL_MODEL_CACHE |