Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
class WhisperMixin: | |
is_initialized = False | |
def setup_whisper( | |
self, | |
pretrained_model_name_or_path: str = "openai/whisper-base.en", | |
device: str = torch.device("cuda" if torch.cuda.is_available() else "cpu"), | |
): | |
from transformers import WhisperForConditionalGeneration | |
from transformers import WhisperProcessor | |
self.whisper_device = device | |
self.whisper_processor = WhisperProcessor.from_pretrained( | |
pretrained_model_name_or_path | |
) | |
self.whisper_model = WhisperForConditionalGeneration.from_pretrained( | |
pretrained_model_name_or_path | |
).to(self.whisper_device) | |
self.is_initialized = True | |
def get_whisper_features(self) -> torch.Tensor: | |
"""Preprocess audio signal as per the whisper model's training config. | |
Returns | |
------- | |
torch.Tensor | |
The prepinput features of the audio signal. Shape: (1, channels, seq_len) | |
""" | |
import torch | |
if not self.is_initialized: | |
self.setup_whisper() | |
signal = self.to(self.device) | |
raw_speech = list( | |
( | |
signal.clone() | |
.resample(self.whisper_processor.feature_extractor.sampling_rate) | |
.audio_data[:, 0, :] | |
.numpy() | |
) | |
) | |
with torch.inference_mode(): | |
input_features = self.whisper_processor( | |
raw_speech, | |
sampling_rate=self.whisper_processor.feature_extractor.sampling_rate, | |
return_tensors="pt", | |
).input_features | |
return input_features | |
def get_whisper_transcript(self) -> str: | |
"""Get the transcript of the audio signal using the whisper model. | |
Returns | |
------- | |
str | |
The transcript of the audio signal, including special tokens such as <|startoftranscript|> and <|endoftext|>. | |
""" | |
if not self.is_initialized: | |
self.setup_whisper() | |
input_features = self.get_whisper_features() | |
with torch.inference_mode(): | |
input_features = input_features.to(self.whisper_device) | |
generated_ids = self.whisper_model.generate(inputs=input_features) | |
transcription = self.whisper_processor.batch_decode(generated_ids) | |
return transcription[0] | |
def get_whisper_embeddings(self) -> torch.Tensor: | |
"""Get the last hidden state embeddings of the audio signal using the whisper model. | |
Returns | |
------- | |
torch.Tensor | |
The Whisper embeddings of the audio signal. Shape: (1, seq_len, hidden_size) | |
""" | |
import torch | |
if not self.is_initialized: | |
self.setup_whisper() | |
input_features = self.get_whisper_features() | |
encoder = self.whisper_model.get_encoder() | |
with torch.inference_mode(): | |
input_features = input_features.to(self.whisper_device) | |
embeddings = encoder(input_features) | |
return embeddings.last_hidden_state | |