|
import abc |
|
from typing import Any, Callable, List |
|
|
|
from src.config import ModelConfig, VadInitialPromptMode |
|
|
|
from src.hooks.progressListener import ProgressListener |
|
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache |
|
from src.prompts.abstractPromptStrategy import AbstractPromptStrategy |
|
|
|
class AbstractWhisperCallback: |
|
def __init__(self): |
|
pass |
|
|
|
@abc.abstractmethod |
|
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None): |
|
""" |
|
Peform the transcription of the given audio file or data. |
|
|
|
Parameters |
|
---------- |
|
audio: Union[str, np.ndarray, torch.Tensor] |
|
The audio file to transcribe, or the audio data as a numpy array or torch tensor. |
|
segment_index: int |
|
The target language of the transcription. If not specified, the language will be inferred from the audio content. |
|
task: str |
|
The task - either translate or transcribe. |
|
progress_listener: ProgressListener |
|
A callback to receive progress updates. |
|
""" |
|
raise NotImplementedError() |
|
|
|
class LambdaWhisperCallback(AbstractWhisperCallback): |
|
def __init__(self, callback_lambda: Callable[[Any, int, str, str, ProgressListener], None]): |
|
super().__init__() |
|
self.callback_lambda = callback_lambda |
|
|
|
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None): |
|
return self.callback_lambda(audio, segment_index, prompt, detected_language, progress_listener) |
|
|
|
class AbstractWhisperContainer: |
|
def __init__(self, model_name: str, device: str = None, compute_type: str = "float16", |
|
download_root: str = None, |
|
cache: ModelCache = None, models: List[ModelConfig] = []): |
|
self.model_name = model_name |
|
self.device = device |
|
self.compute_type = compute_type |
|
self.download_root = download_root |
|
self.cache = cache |
|
|
|
|
|
self.model = None |
|
|
|
|
|
self.models = models |
|
|
|
def get_model(self): |
|
if self.model is None: |
|
|
|
if (self.cache is None): |
|
self.model = self._create_model() |
|
else: |
|
model_key = "WhisperContainer." + self.model_name + ":" + (self.device if self.device else '') |
|
self.model = self.cache.get(model_key, self._create_model) |
|
return self.model |
|
|
|
@abc.abstractmethod |
|
def _create_model(self): |
|
raise NotImplementedError() |
|
|
|
def ensure_downloaded(self): |
|
pass |
|
|
|
@abc.abstractmethod |
|
def create_callback(self, languageCode: str = None, task: str = None, |
|
prompt_strategy: AbstractPromptStrategy = None, |
|
**decodeOptions: dict) -> AbstractWhisperCallback: |
|
""" |
|
Create a WhisperCallback object that can be used to transcript audio files. |
|
|
|
Parameters |
|
---------- |
|
languageCode: str |
|
The target language code of the transcription. If not specified, the language will be inferred from the audio content. |
|
task: str |
|
The task - either translate or transcribe. |
|
prompt_strategy: AbstractPromptStrategy |
|
The prompt strategy to use for the transcription. |
|
decodeOptions: dict |
|
Additional options to pass to the decoder. Must be pickleable. |
|
|
|
Returns |
|
------- |
|
A WhisperCallback object. |
|
""" |
|
raise NotImplementedError() |
|
|
|
|
|
def __getstate__(self): |
|
return { |
|
"model_name": self.model_name, |
|
"device": self.device, |
|
"download_root": self.download_root, |
|
"models": self.models, |
|
"compute_type": self.compute_type |
|
} |
|
|
|
def __setstate__(self, state): |
|
self.model_name = state["model_name"] |
|
self.device = state["device"] |
|
self.download_root = state["download_root"] |
|
self.models = state["models"] |
|
self.compute_type = state["compute_type"] |
|
self.model = None |
|
|
|
self.cache = GLOBAL_MODEL_CACHE |