update README.md
Browse files
README.md
CHANGED
@@ -20,12 +20,167 @@ model-index:
|
|
20 |
---
|
21 |
|
22 |
# wav2vec2-xlsr-multilingual-56
|
|
|
|
|
|
|
23 |
Fine-tuned [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) on 56 language using the [Common Voice](https://huggingface.co/datasets/common_voice).
|
24 |
When using this model, make sure that your speech input is sampled at 16kHz.
|
25 |
|
|
|
|
|
26 |
## Env setup:
|
27 |
```
|
28 |
!pip install torchaudio
|
29 |
!pip install datasets transformers
|
30 |
!pip install asrp
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
---
|
21 |
|
22 |
# wav2vec2-xlsr-multilingual-56
|
23 |
+
|
24 |
+
*56 language, 1 model Multilingual ASR*
|
25 |
+
|
26 |
Fine-tuned [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) on 56 language using the [Common Voice](https://huggingface.co/datasets/common_voice).
|
27 |
When using this model, make sure that your speech input is sampled at 16kHz.
|
28 |
|
29 |
+
For more detail: [https://github.com/voidful/wav2vec2-xlsr-multilingual-56](https://github.com/voidful/wav2vec2-xlsr-multilingual-56)
|
30 |
+
|
31 |
## Env setup:
|
32 |
```
|
33 |
!pip install torchaudio
|
34 |
!pip install datasets transformers
|
35 |
!pip install asrp
|
36 |
+
!wget -O lang_ids.pk https://huggingface.co/voidful/wav2vec2-xlsr-multilingual-56/raw/main/lang_ids.pk
|
37 |
+
```
|
38 |
+
|
39 |
+
## Usage
|
40 |
+
```
|
41 |
+
import torchaudio
|
42 |
+
from datasets import load_dataset, load_metric
|
43 |
+
from transformers import (
|
44 |
+
Wav2Vec2ForCTC,
|
45 |
+
Wav2Vec2Processor,
|
46 |
+
AutoTokenizer,
|
47 |
+
AutoModelWithLMHead
|
48 |
+
)
|
49 |
+
import torch
|
50 |
+
import re
|
51 |
+
import sys
|
52 |
+
import soundfile as sf
|
53 |
+
model_name = "voidful/wav2vec2-xlsr-multilingual-56"
|
54 |
+
device = "cuda"
|
55 |
+
processor_name = "voidful/wav2vec2-xlsr-multilingual-56"
|
56 |
+
|
57 |
+
import pickle
|
58 |
+
with open("lang_ids.pk", 'rb') as output:
|
59 |
+
lang_ids = pickle.load(output)
|
60 |
+
|
61 |
+
model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
|
62 |
+
processor = Wav2Vec2Processor.from_pretrained(processor_name)
|
63 |
+
|
64 |
+
model.eval()
|
65 |
+
|
66 |
+
def load_file_to_data(file,sampling_rate=16_000):
|
67 |
+
batch = {}
|
68 |
+
speech, _ = torchaudio.load(file)
|
69 |
+
if sampling_rate != '16_000' or sampling_rate != '16000':
|
70 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16_000)
|
71 |
+
batch["speech"] = resampler.forward(speech.squeeze(0)).numpy()
|
72 |
+
batch["sampling_rate"] = resampler.new_freq
|
73 |
+
else:
|
74 |
+
batch["speech"] = speech.squeeze(0).numpy()
|
75 |
+
batch["sampling_rate"] = '16000'
|
76 |
+
return batch
|
77 |
+
|
78 |
+
|
79 |
+
def predict(data):
|
80 |
+
features = processor(data["speech"], sampling_rate=data["sampling_rate"], padding=True, return_tensors="pt")
|
81 |
+
input_values = features.input_values.to(device)
|
82 |
+
attention_mask = features.attention_mask.to(device)
|
83 |
+
with torch.no_grad():
|
84 |
+
logits = model(input_values, attention_mask=attention_mask).logits
|
85 |
+
decoded_results = []
|
86 |
+
for logit in logits:
|
87 |
+
pred_ids = torch.argmax(logit, dim=-1)
|
88 |
+
mask = pred_ids.ge(1).unsqueeze(-1).expand(logit.size())
|
89 |
+
vocab_size = logit.size()[-1]
|
90 |
+
voice_prob = torch.nn.functional.softmax((torch.masked_select(logit, mask).view(-1,vocab_size)),dim=-1)
|
91 |
+
comb_pred_ids = torch.argmax(voice_prob, dim=-1)
|
92 |
+
decoded_results.append(processor.decode(comb_pred_ids))
|
93 |
+
|
94 |
+
return decoded_results
|
95 |
+
|
96 |
+
def predict_lang_specific(data,lang_code):
|
97 |
+
features = processor(data["speech"], sampling_rate=data["sampling_rate"], padding=True, return_tensors="pt")
|
98 |
+
input_values = features.input_values.to(device)
|
99 |
+
attention_mask = features.attention_mask.to(device)
|
100 |
+
with torch.no_grad():
|
101 |
+
logits = model(input_values, attention_mask=attention_mask).logits
|
102 |
+
decoded_results = []
|
103 |
+
for logit in logits:
|
104 |
+
pred_ids = torch.argmax(logit, dim=-1)
|
105 |
+
mask = ~pred_ids.eq(processor.tokenizer.pad_token_id).unsqueeze(-1).expand(logit.size())
|
106 |
+
vocab_size = logit.size()[-1]
|
107 |
+
voice_prob = torch.nn.functional.softmax((torch.masked_select(logit, mask).view(-1,vocab_size)),dim=-1)
|
108 |
+
filtered_input = pred_ids[pred_ids!=processor.tokenizer.pad_token_id].view(1,-1).to(device)
|
109 |
+
if len(filtered_input[0]) == 0:
|
110 |
+
decoded_results.append("")
|
111 |
+
else:
|
112 |
+
lang_mask = torch.empty(voice_prob.shape[-1]).fill_(0)
|
113 |
+
lang_index = torch.tensor(sorted(lang_ids[lang_code]))
|
114 |
+
lang_mask.index_fill_(0, lang_index, 1)
|
115 |
+
lang_mask = lang_mask.to(device)
|
116 |
+
comb_pred_ids = torch.argmax(lang_mask*voice_prob, dim=-1)
|
117 |
+
decoded_results.append(processor.decode(comb_pred_ids))
|
118 |
+
|
119 |
+
return decoded_results
|
120 |
+
|
121 |
+
|
122 |
+
predict(load_file_to_data('audio file path'))
|
123 |
+
|
124 |
+
predict_lang_specific(load_file_to_data('audio file path'),'en')
|
125 |
+
|
126 |
```
|
127 |
+
|
128 |
+
## Evaluation Result
|
129 |
+
| Common Voice Languages | Num. of data | Hour | WER | CER |
|
130 |
+
|------------------------|--------------|--------|--------|-------|
|
131 |
+
| ar | 21744 | 81.5 | 75.24 | 31.27 |
|
132 |
+
| as | 394 | 1.1 | 95.37 | 46.03 |
|
133 |
+
| br | 4777 | 7.4 | 93.79 | 41.14 |
|
134 |
+
| ca | 301308 | 692.8 | 24.82 | 10.39 |
|
135 |
+
| cnh | 1563 | 2.4 | 68.22 | 23.11 |
|
136 |
+
| cs | 9773 | 39.5 | 67.89 | 12.57 |
|
137 |
+
| cv | 1749 | 5.9 | 95.43 | 34.01 |
|
138 |
+
| cy | 11615 | 106.7 | 66.98 | 23.93 |
|
139 |
+
| de | 262113 | 822.8 | 27.04 | 6.51 |
|
140 |
+
| dv | 4757 | 18.6 | 92.18 | 30.18 |
|
141 |
+
| el | 3717 | 11.1 | 94.51 | 58.69 |
|
142 |
+
| en | 580501 | 1763.6 | 34.88 | 14.84 |
|
143 |
+
| eo | 28574 | 162.3 | 37.79 | 6.23 |
|
144 |
+
| es | 176902 | 337.7 | 19.64 | 5.42 |
|
145 |
+
| et | 5473 | 35.9 | 86.87 | 20.80 |
|
146 |
+
| eu | 12677 | 90.2 | 44.83 | 7.32 |
|
147 |
+
| fa | 12806 | 290.6 | 53.85 | 15.09 |
|
148 |
+
| fi | 875 | 2.6 | 93.70 | 27.60 |
|
149 |
+
| fr | 314745 | 664.1 | 33.19 | 13.94 |
|
150 |
+
| fy-NL | 6717 | 27.2 | 72.55 | 26.58 |
|
151 |
+
| ga-IE | 1038 | 3.5 | 92.51 | 50.98 |
|
152 |
+
| hi | 292 | 2.0 | 90.84 | 57.34 |
|
153 |
+
| hsb | 980 | 2.3 | 89.52 | 27.18 |
|
154 |
+
| hu | 4782 | 9.3 | 97.11 | 36.74 |
|
155 |
+
| ia | 5078 | 10.4 | 52.08 | 11.37 |
|
156 |
+
| id | 3965 | 9.9 | 82.48 | 22.82 |
|
157 |
+
| it | 70943 | 178.0 | 39.09 | 8.72 |
|
158 |
+
| ja | 1308 | 8.2 | 99.21 | 61.91 |
|
159 |
+
| ka | 1585 | 4.0 | 90.49 | 18.57 |
|
160 |
+
| ky | 3466 | 12.2 | 76.57 | 19.83 |
|
161 |
+
| lg | 1634 | 17.1 | 98.95 | 43.84 |
|
162 |
+
| lt | 1175 | 3.9 | 92.67 | 26.82 |
|
163 |
+
| lv | 4554 | 6.3 | 90.34 | 30.79 |
|
164 |
+
| mn | 4020 | 11.6 | 82.70 | 30.15 |
|
165 |
+
| mt | 3552 | 7.8 | 84.21 | 22.94 |
|
166 |
+
| nl | 14398 | 71.8 | 57.17 | 19.01 |
|
167 |
+
| or | 517 | 0.9 | 90.93 | 27.42 |
|
168 |
+
| pa-IN | 255 | 0.8 | 88.07 | 42.00 |
|
169 |
+
| pl | 12621 | 112.0 | 56.15 | 12.07 |
|
170 |
+
| pt | 11106 | 61.3 | 53.27 | 16.33 |
|
171 |
+
| rm-sursilv | 2589 | 5.9 | 78.17 | 23.30 |
|
172 |
+
| rm-vallader | 931 | 2.3 | 73.64 | 21.70 |
|
173 |
+
| ro | 4257 | 8.7 | 83.81 | 21.93 |
|
174 |
+
| ru | 23444 | 119.1 | 61.83 | 15.18 |
|
175 |
+
| sah | 1847 | 4.4 | 94.36 | 38.47 |
|
176 |
+
| sl | 2594 | 6.7 | 84.25 | 20.52 |
|
177 |
+
| sv-SE | 4350 | 20.8 | 83.62 | 30.78 |
|
178 |
+
| ta | 3788 | 18.4 | 84.23 | 21.60 |
|
179 |
+
| th | 4839 | 11.7 | 141.83 | 37.24 |
|
180 |
+
| tr | 3478 | 22.3 | 66.79 | 15.55 |
|
181 |
+
| tt | 13338 | 26.7 | 86.81 | 33.59 |
|
182 |
+
| uk | 7271 | 39.4 | 70.19 | 14.35 |
|
183 |
+
| vi | 421 | 1.7 | 96.13 | 66.31 |
|
184 |
+
| zh-CN | 27284 | 58.7 | 89.65 | 23.94 |
|
185 |
+
| zh-HK | 12678 | 92.1 | 81.67 | 18.82 |
|
186 |
+
| zh-TW | 6402 | 56.6 | 85.04 | 29.08 |
|