asr-wav2vec2-ctc-aishell / custom_interface.py
Adel-Moumen's picture
Update custom_interface.py
9dfcf87 verified
"""Custom Interface for AISHELL-1 CTC inference
An external tokenizer is used so some special tokens
need to be specified during decoding
Authors
* Yingzhi Wang 2022
"""
import torch
from speechbrain.inference.interfaces import Pretrained
class CustomEncoderDecoderASR(Pretrained):
"""A ready-to-use Encoder-Decoder ASR model
The class can be used either to run only the encoder (encode()) to extract
features or to run the entire encoder-decoder model
(transcribe()) to transcribe speech. The given YAML must contains the fields
specified in the *_NEEDED[] lists.
Example
-------
>>> from speechbrain.pretrained import EncoderDecoderASR
>>> tmpdir = getfixture("tmpdir")
>>> asr_model = EncoderDecoderASR.from_hparams(
... source="speechbrain/asr-crdnn-rnnlm-librispeech",
... savedir=tmpdir,
... )
>>> asr_model.transcribe_file("tests/samples/single-mic/example2.flac")
"MY FATHER HAS REVEALED THE CULPRIT'S NAME"
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tokenizer = self.hparams.tokenizer
def transcribe_file(self, path):
"""Transcribes the given audiofile into a sequence of words.
Arguments
---------
path : str
Path to audio file which to transcribe.
Returns
-------
str
The audiofile transcription produced by this ASR system.
"""
waveform = self.load_audio(path)
# Fake a batch:
batch = waveform.unsqueeze(0)
rel_length = torch.tensor([1.0])
predicted_words = self.transcribe_batch(
batch, rel_length
)
return predicted_words[0]
def encode_batch(self, wavs):
"""Encodes the input audio into a sequence of hidden states
The waveforms should already be in the model's desired format.
You can call:
``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)``
to get a correctly converted signal in most cases.
Arguments
---------
wavs : torch.tensor
Batch of waveforms [batch, time, channels] or [batch, time]
depending on the model.
wav_lens : torch.tensor
Lengths of the waveforms relative to the longest one in the
batch, tensor of shape [batch]. The longest one should have
relative length 1.0 and others len(waveform) / max_length.
Used for ignoring padding.
Returns
-------
torch.tensor
The encoded batch
"""
wavs = wavs.float()
wavs = wavs.to(self.device)
outputs = self.mods.wav2vec2(wavs)
outputs = self.mods.enc(outputs)
outputs = self.mods.ctc_lin(outputs)
return outputs
def transcribe_batch(self, wavs, wav_lens):
"""Transcribes the input audio into a sequence of words
The waveforms should already be in the model's desired format.
You can call:
``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)``
to get a correctly converted signal in most cases.
Arguments
---------
wavs : torch.tensor
Batch of waveforms [batch, time, channels] or [batch, time]
depending on the model.
wav_lens : torch.tensor
Lengths of the waveforms relative to the longest one in the
batch, tensor of shape [batch]. The longest one should have
relative length 1.0 and others len(waveform) / max_length.
Used for ignoring padding.
Returns
-------
list
Each waveform in the batch transcribed.
tensor
Each predicted token id.
"""
with torch.no_grad():
wav_lens = wav_lens.to(self.device)
encoder_out = self.encode_batch(wavs)
p_ctc = self.hparams.log_softmax(encoder_out)
sequences = self.hparams.decoder(p_ctc, wav_lens)
predicted_words_list = []
for sequence in sequences:
predicted_tokens = self.tokenizer.convert_ids_to_tokens(
sequence
)
predicted_words = []
for c in predicted_tokens:
if c == "[CLS]":
continue
elif c == "[SEP]" or c == "[PAD]":
break
else:
predicted_words.append(c)
predicted_words_list.append(predicted_words)
return predicted_words_list
def forward(self, wavs, wav_lens):
"""Runs full transcription - note: no gradients through decoding"""
return self.transcribe_batch(wavs, wav_lens)