whisper-webui / src /whisper /abstractWhisperContainer.py
aadnk's picture
Add inital prompt mode. GITLAB #7
8031785
raw history blame
No virus
4.88 kB
import abc
from typing import List
from src.config import ModelConfig, VadInitialPromptMode
from src.hooks.progressListener import ProgressListener
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
class AbstractWhisperCallback:
@abc.abstractmethod
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
"""
Peform the transcription of the given audio file or data.
Parameters
----------
audio: Union[str, np.ndarray, torch.Tensor]
The audio file to transcribe, or the audio data as a numpy array or torch tensor.
segment_index: int
The target language of the transcription. If not specified, the language will be inferred from the audio content.
task: str
The task - either translate or transcribe.
progress_listener: ProgressListener
A callback to receive progress updates.
"""
raise NotImplementedError()
def _get_initial_prompt(self, initial_prompt: str, initial_prompt_mode: VadInitialPromptMode,
prompt: str, segment_index: int):
if (initial_prompt_mode == VadInitialPromptMode.PREPEND_ALL_SEGMENTS):
return self._concat_prompt(initial_prompt, prompt)
elif (initial_prompt_mode == VadInitialPromptMode.PREPREND_FIRST_SEGMENT):
return self._concat_prompt(initial_prompt, prompt) if segment_index == 0 else prompt
else:
raise ValueError(f"Unknown initial prompt mode {initial_prompt_mode}")
def _concat_prompt(self, prompt1, prompt2):
if (prompt1 is None):
return prompt2
elif (prompt2 is None):
return prompt1
else:
return prompt1 + " " + prompt2
class AbstractWhisperContainer:
def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
download_root: str = None,
cache: ModelCache = None, models: List[ModelConfig] = []):
self.model_name = model_name
self.device = device
self.compute_type = compute_type
self.download_root = download_root
self.cache = cache
# Will be created on demand
self.model = None
# List of known models
self.models = models
def get_model(self):
if self.model is None:
if (self.cache is None):
self.model = self._create_model()
else:
model_key = "WhisperContainer." + self.model_name + ":" + (self.device if self.device else '')
self.model = self.cache.get(model_key, self._create_model)
return self.model
@abc.abstractmethod
def _create_model(self):
raise NotImplementedError()
def ensure_downloaded(self):
pass
@abc.abstractmethod
def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None,
initial_prompt_mode: VadInitialPromptMode = VadInitialPromptMode.PREPREND_FIRST_SEGMENT,
**decodeOptions: dict) -> AbstractWhisperCallback:
"""
Create a WhisperCallback object that can be used to transcript audio files.
Parameters
----------
language: str
The target language of the transcription. If not specified, the language will be inferred from the audio content.
task: str
The task - either translate or transcribe.
initial_prompt: str
The initial prompt to use for the transcription.
initial_prompt_mode: VadInitialPromptMode
The mode to use for the initial prompt. If set to PREPEND_FIRST_SEGMENT, the initial prompt will be prepended to the first segment of audio.
If set to PREPEND_ALL_SEGMENTS, the initial prompt will be prepended to all segments of audio.
decodeOptions: dict
Additional options to pass to the decoder. Must be pickleable.
Returns
-------
A WhisperCallback object.
"""
raise NotImplementedError()
# This is required for multiprocessing
def __getstate__(self):
return {
"model_name": self.model_name,
"device": self.device,
"download_root": self.download_root,
"models": self.models,
"compute_type": self.compute_type
}
def __setstate__(self, state):
self.model_name = state["model_name"]
self.device = state["device"]
self.download_root = state["download_root"]
self.models = state["models"]
self.compute_type = state["compute_type"]
self.model = None
# Depickled objects must use the global cache
self.cache = GLOBAL_MODEL_CACHE