|
import matplotlib |
|
import torch |
|
import torchaudio |
|
|
|
matplotlib.rcParams["figure.figsize"] = [16.0, 4.8] |
|
|
|
torch.random.manual_seed(0) |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GreedyCTCDecoder(torch.nn.Module): |
|
def __init__(self, labels, blank=0): |
|
super().__init__() |
|
self.labels = labels |
|
self.blank = blank |
|
|
|
def forward(self, emission: torch.Tensor) -> str: |
|
"""Given a sequence emission over labels, get the best path string |
|
Args: |
|
emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`. |
|
|
|
Returns: |
|
str: The resulting transcript |
|
""" |
|
indices = torch.argmax(emission, dim=-1) |
|
indices = torch.unique_consecutive(indices, dim=-1) |
|
indices = [i for i in indices if i != self.blank] |
|
return "".join([self.labels[i] for i in indices]) |
|
|
|
def predict(file): |
|
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H |
|
model = bundle.get_model().to(device) |
|
|
|
waveform, sample_rate = torchaudio.load(file) |
|
waveform = waveform.to(device) |
|
|
|
if sample_rate != bundle.sample_rate: |
|
waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate) |
|
|
|
with torch.inference_mode(): |
|
features, _ = model.extract_features(waveform) |
|
with torch.inference_mode(): |
|
emission, _ = model(waveform) |
|
|
|
decoder = GreedyCTCDecoder(labels=bundle.get_labels()) |
|
transcript = decoder(emission[0]) |
|
return transcript |
|
|
|
|