poonehmousavi's picture
Upload 3 files
092d812
raw
history blame
No virus
4.53 kB
import torch
from speechbrain.pretrained import Pretrained
class WhisperASR(Pretrained):
"""A ready-to-use Whisper ASR model
The class can be used to run only the encoder (encode()) to run the entire encoder-decoder whisper model
(transcribe()) to transcribe speech. The given YAML must contains the fields
specified in the *_NEEDED[] lists.
Example
-------
>>> from speechbrain.pretrained.interfaces import foreign_class
>>> tmpdir = getfixture("tmpdir")
>>> asr_model = foreign_class(source="hf",
... pymodule_file="custom_interface.py",
... classname="WhisperASR",
... hparams_file='hparams.yaml',
... savedir=tmpdir,
... )
>>> asr_model.transcribe_file("tests/samples/example2.wav")
"""
HPARAMS_NEEDED = ['language']
MODULES_NEEDED = ["whisper", "decoder"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tokenizer = self.hparams.whisper.tokenizer
self.tokenizer.set_prefix_tokens(self.hparams.language, "transcribe", False)
self.hparams.decoder.set_decoder_input_tokens(
self.tokenizer.prefix_tokens
)
def transcribe_file(self, path):
"""Transcribes the given audiofile into a sequence of words.
Arguments
---------
path : str
Path to audio file which to transcribe.
Returns
-------
str
The audiofile transcription produced by this ASR system.
"""
waveform = self.load_audio(path)
# Fake a batch:
batch = waveform.unsqueeze(0)
rel_length = torch.tensor([1.0])
predicted_words, predicted_tokens = self.transcribe_batch(
batch, rel_length
)
return predicted_words
def encode_batch(self, wavs, wav_lens):
"""Encodes the input audio into a sequence of hidden states
The waveforms should already be in the model's desired format.
You can call:
``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)``
to get a correctly converted signal in most cases.
Arguments
---------
wavs : torch.tensor
Batch of waveforms [batch, time, channels].
wav_lens : torch.tensor
Lengths of the waveforms relative to the longest one in the
batch, tensor of shape [batch]. The longest one should have
relative length 1.0 and others len(waveform) / max_length.
Used for ignoring padding.
Returns
-------
torch.tensor
The encoded batch
"""
wavs = wavs.float()
wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
encoder_out = self.mods.whisper.forward_encoder(wavs)
return encoder_out
def transcribe_batch(self, wavs, wav_lens):
"""Transcribes the input audio into a sequence of words
The waveforms should already be in the model's desired format.
You can call:
``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)``
to get a correctly converted signal in most cases.
Arguments
---------
wavs : torch.tensor
Batch of waveforms [batch, time, channels].
wav_lens : torch.tensor
Lengths of the waveforms relative to the longest one in the
batch, tensor of shape [batch]. The longest one should have
relative length 1.0 and others len(waveform) / max_length.
Used for ignoring padding.
Returns
-------
list
Each waveform in the batch transcribed.
tensor
Each predicted token id.
"""
with torch.no_grad():
wav_lens = wav_lens.to(self.device)
encoder_out = self.encode_batch(wavs, wav_lens)
predicted_tokens, scores = self.mods.decoder(encoder_out, wav_lens)
predicted_words = self.tokenizer.batch_decode(
predicted_tokens, skip_special_tokens=True)
if self.hparams.normalized_transcripts:
predicted_words = [
self.tokenizer._normalize(text).split(" ")
for text in predicted_words
]
return predicted_words, predicted_tokens
def forward(self, wavs, wav_lens):
"""Runs full transcription - note: no gradients through decoding"""
return self.transcribe_batch(wavs, wav_lens)