EzAudio / audiotools /core /whisper.py
OpenSound's picture
Upload 33 files
71de706 verified
raw
history blame
3.1 kB
import torch
class WhisperMixin:
is_initialized = False
def setup_whisper(
self,
pretrained_model_name_or_path: str = "openai/whisper-base.en",
device: str = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
):
from transformers import WhisperForConditionalGeneration
from transformers import WhisperProcessor
self.whisper_device = device
self.whisper_processor = WhisperProcessor.from_pretrained(
pretrained_model_name_or_path
)
self.whisper_model = WhisperForConditionalGeneration.from_pretrained(
pretrained_model_name_or_path
).to(self.whisper_device)
self.is_initialized = True
def get_whisper_features(self) -> torch.Tensor:
"""Preprocess audio signal as per the whisper model's training config.
Returns
-------
torch.Tensor
The prepinput features of the audio signal. Shape: (1, channels, seq_len)
"""
import torch
if not self.is_initialized:
self.setup_whisper()
signal = self.to(self.device)
raw_speech = list(
(
signal.clone()
.resample(self.whisper_processor.feature_extractor.sampling_rate)
.audio_data[:, 0, :]
.numpy()
)
)
with torch.inference_mode():
input_features = self.whisper_processor(
raw_speech,
sampling_rate=self.whisper_processor.feature_extractor.sampling_rate,
return_tensors="pt",
).input_features
return input_features
def get_whisper_transcript(self) -> str:
"""Get the transcript of the audio signal using the whisper model.
Returns
-------
str
The transcript of the audio signal, including special tokens such as <|startoftranscript|> and <|endoftext|>.
"""
if not self.is_initialized:
self.setup_whisper()
input_features = self.get_whisper_features()
with torch.inference_mode():
input_features = input_features.to(self.whisper_device)
generated_ids = self.whisper_model.generate(inputs=input_features)
transcription = self.whisper_processor.batch_decode(generated_ids)
return transcription[0]
def get_whisper_embeddings(self) -> torch.Tensor:
"""Get the last hidden state embeddings of the audio signal using the whisper model.
Returns
-------
torch.Tensor
The Whisper embeddings of the audio signal. Shape: (1, seq_len, hidden_size)
"""
import torch
if not self.is_initialized:
self.setup_whisper()
input_features = self.get_whisper_features()
encoder = self.whisper_model.get_encoder()
with torch.inference_mode():
input_features = input_features.to(self.whisper_device)
embeddings = encoder(input_features)
return embeddings.last_hidden_state