wav2vec2-xlsr-multilingual-56

56 language, 1 model Multilingual ASR

Fine-tuned facebook/wav2vec2-large-xlsr-53 on 56 language using the Common Voice.
When using this model, make sure that your speech input is sampled at 16kHz.

For more detail: https://github.com/voidful/wav2vec2-xlsr-multilingual-56

Env setup:

!pip install torchaudio
!pip install datasets transformers
!pip install asrp
!wget -O lang_ids.pk https://huggingface.co/voidful/wav2vec2-xlsr-multilingual-56/raw/main/lang_ids.pk

Usage

import torchaudio
from datasets import load_dataset, load_metric
from transformers import (
    Wav2Vec2ForCTC,
    Wav2Vec2Processor,
    AutoTokenizer, 
    AutoModelWithLMHead 
)
import torch
import re
import sys
import soundfile as sf
model_name = "voidful/wav2vec2-xlsr-multilingual-56"
device = "cuda"
processor_name = "voidful/wav2vec2-xlsr-multilingual-56"

import pickle
with open("lang_ids.pk", 'rb') as output:
    lang_ids = pickle.load(output)
    
model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
processor = Wav2Vec2Processor.from_pretrained(processor_name)

model.eval()

def load_file_to_data(file,sampling_rate=16_000):
    batch = {}
    speech, _ = torchaudio.load(file)
    if sampling_rate != '16_000' or sampling_rate != '16000':
        resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16_000)
        batch["speech"] = resampler.forward(speech.squeeze(0)).numpy()
        batch["sampling_rate"] = resampler.new_freq
    else:
        batch["speech"] = speech.squeeze(0).numpy()
        batch["sampling_rate"] = '16000'
    return batch


def predict(data):
    features = processor(data["speech"], sampling_rate=data["sampling_rate"], padding=True, return_tensors="pt")
    input_values = features.input_values.to(device)
    attention_mask = features.attention_mask.to(device)
    with torch.no_grad():
        logits = model(input_values, attention_mask=attention_mask).logits
        decoded_results = []
        for logit in logits:
            pred_ids = torch.argmax(logit, dim=-1)
            mask = pred_ids.ge(1).unsqueeze(-1).expand(logit.size())
            vocab_size = logit.size()[-1]
            voice_prob = torch.nn.functional.softmax((torch.masked_select(logit, mask).view(-1,vocab_size)),dim=-1)
            comb_pred_ids = torch.argmax(voice_prob, dim=-1)
            decoded_results.append(processor.decode(comb_pred_ids))

    return decoded_results

def predict_lang_specific(data,lang_code):
    features = processor(data["speech"], sampling_rate=data["sampling_rate"], padding=True, return_tensors="pt")
    input_values = features.input_values.to(device)
    attention_mask = features.attention_mask.to(device)
    with torch.no_grad():
        logits = model(input_values, attention_mask=attention_mask).logits
        decoded_results = []
        for logit in logits:
            pred_ids = torch.argmax(logit, dim=-1)
            mask = ~pred_ids.eq(processor.tokenizer.pad_token_id).unsqueeze(-1).expand(logit.size())
            vocab_size = logit.size()[-1]
            voice_prob = torch.nn.functional.softmax((torch.masked_select(logit, mask).view(-1,vocab_size)),dim=-1)
            filtered_input = pred_ids[pred_ids!=processor.tokenizer.pad_token_id].view(1,-1).to(device)
            if len(filtered_input[0]) == 0:
                decoded_results.append("")
            else:
                lang_mask = torch.empty(voice_prob.shape[-1]).fill_(0)
                lang_index = torch.tensor(sorted(lang_ids[lang_code]))
                lang_mask.index_fill_(0, lang_index, 1)
                lang_mask = lang_mask.to(device)
                comb_pred_ids = torch.argmax(lang_mask*voice_prob, dim=-1)
                decoded_results.append(processor.decode(comb_pred_ids))
                
    return decoded_results


predict(load_file_to_data('audio file path',sampling_rate=16_000)) # beware of the audio file sampling rate

predict_lang_specific(load_file_to_data('audio file path',sampling_rate=16_000),'en') # beware of the audio file sampling rate

Result

Common Voice Languages Num. of data Hour WER CER
ar 21744 81.5 75.29 31.23
as 394 1.1 95.37 46.05
br 4777 7.4 93.79 41.16
ca 301308 692.8 24.80 10.39
cnh 1563 2.4 68.11 23.10
cs 9773 39.5 67.86 12.57
cv 1749 5.9 95.43 34.03
cy 11615 106.7 67.03 23.97
de 262113 822.8 27.03 6.50
dv 4757 18.6 92.16 30.15
el 3717 11.1 94.48 58.67
en 580501 1763.6 34.87 14.84
eo 28574 162.3 37.77 6.23
es 176902 337.7 19.63 5.41
et 5473 35.9 86.87 20.79
eu 12677 90.2 44.80 7.32
fa 12806 290.6 53.81 15.09
fi 875 2.6 93.78 27.57
fr 314745 664.1 33.16 13.94
fy-NL 6717 27.2 72.54 26.58
ga-IE 1038 3.5 92.57 51.02
hi 292 2.0 90.95 57.43
hsb 980 2.3 89.44 27.19
hu 4782 9.3 97.15 36.75
ia 5078 10.4 52.00 11.35
id 3965 9.9 82.50 22.82
it 70943 178.0 39.09 8.72
ja 1308 8.2 99.21 62.06
ka 1585 4.0 90.53 18.57
ky 3466 12.2 76.53 19.80
lg 1634 17.1 98.95 43.84
lt 1175 3.9 92.61 26.81
lv 4554 6.3 90.34 30.81
mn 4020 11.6 82.68 30.14
mt 3552 7.8 84.18 22.96
nl 14398 71.8 57.18 19.01
or 517 0.9 90.93 27.34
pa-IN 255 0.8 87.95 42.03
pl 12621 112.0 56.14 12.06
pt 11106 61.3 53.24 16.32
rm-sursilv 2589 5.9 78.17 23.31
rm-vallader 931 2.3 73.67 21.76
ro 4257 8.7 83.84 21.95
ru 23444 119.1 61.83 15.18
sah 1847 4.4 94.38 38.46
sl 2594 6.7 84.21 20.54
sv-SE 4350 20.8 83.68 30.79
ta 3788 18.4 84.19 21.60
th 4839 11.7 141.87 37.16
tr 3478 22.3 66.77 15.55
tt 13338 26.7 86.80 33.57
uk 7271 39.4 70.23 14.34
vi 421 1.7 96.06 66.25
zh-CN 27284 58.7 89.67 23.96
zh-HK 12678 92.1 81.77 18.82
zh-TW 6402 56.6 85.08 29.07
Downloads last month
352
Hosted inference API
Automatic Speech Recognition
or
This model can be loaded on the Inference API on-demand.
Evaluation results

Model card error

This model's model-index metadata is invalid: Schema validation error. properties must have property 'metrics'