import torch from speechbrain.pretrained import Pretrained class WhisperASR(Pretrained): """A ready-to-use Whisper ASR model The class can be used to run only the encoder (encode()) to run the entire encoder-decoder whisper model (transcribe()) to transcribe speech. The given YAML must contains the fields specified in the *_NEEDED[] lists. Example ------- >>> from speechbrain.pretrained.interfaces import foreign_class >>> tmpdir = getfixture("tmpdir") >>> asr_model = foreign_class(source="hf", ... pymodule_file="custom_interface.py", ... classname="WhisperASR", ... hparams_file='hparams.yaml', ... savedir=tmpdir, ... ) >>> asr_model.transcribe_file("tests/samples/example2.wav") """ HPARAMS_NEEDED = ['language'] MODULES_NEEDED = ["whisper", "decoder"] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.tokenizer = self.hparams.whisper.tokenizer self.tokenizer.set_prefix_tokens(self.hparams.language, "transcribe", False) self.hparams.decoder.set_decoder_input_tokens( self.tokenizer.prefix_tokens ) def transcribe_file(self, path): """Transcribes the given audiofile into a sequence of words. Arguments --------- path : str Path to audio file which to transcribe. Returns ------- str The audiofile transcription produced by this ASR system. """ waveform = self.load_audio(path) # Fake a batch: batch = waveform.unsqueeze(0) rel_length = torch.tensor([1.0]) predicted_words, predicted_tokens = self.transcribe_batch( batch, rel_length ) return predicted_words def encode_batch(self, wavs, wav_lens): """Encodes the input audio into a sequence of hidden states The waveforms should already be in the model's desired format. You can call: ``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)`` to get a correctly converted signal in most cases. Arguments --------- wavs : torch.tensor Batch of waveforms [batch, time, channels]. wav_lens : torch.tensor Lengths of the waveforms relative to the longest one in the batch, tensor of shape [batch]. The longest one should have relative length 1.0 and others len(waveform) / max_length. Used for ignoring padding. Returns ------- torch.tensor The encoded batch """ wavs = wavs.float() wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) encoder_out = self.mods.whisper.forward_encoder(wavs) return encoder_out def transcribe_batch(self, wavs, wav_lens): """Transcribes the input audio into a sequence of words The waveforms should already be in the model's desired format. You can call: ``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)`` to get a correctly converted signal in most cases. Arguments --------- wavs : torch.tensor Batch of waveforms [batch, time, channels]. wav_lens : torch.tensor Lengths of the waveforms relative to the longest one in the batch, tensor of shape [batch]. The longest one should have relative length 1.0 and others len(waveform) / max_length. Used for ignoring padding. Returns ------- list Each waveform in the batch transcribed. tensor Each predicted token id. """ with torch.no_grad(): wav_lens = wav_lens.to(self.device) encoder_out = self.encode_batch(wavs, wav_lens) predicted_tokens, scores = self.mods.decoder(encoder_out, wav_lens) predicted_words = self.tokenizer.batch_decode( predicted_tokens, skip_special_tokens=True) if self.hparams.normalized_transcripts: predicted_words = [ self.tokenizer._normalize(text).split(" ") for text in predicted_words ] return predicted_words, predicted_tokens def forward(self, wavs, wav_lens): """Runs full transcription - note: no gradients through decoding""" return self.transcribe_batch(wavs, wav_lens)