voidful commited on
Commit
e14f0d8
1 Parent(s): 5ea0944

update README.md

Browse files
Files changed (1) hide show
  1. README.md +155 -0
README.md CHANGED
@@ -20,12 +20,167 @@ model-index:
20
  ---
21
 
22
  # wav2vec2-xlsr-multilingual-56
 
 
 
23
  Fine-tuned [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) on 56 language using the [Common Voice](https://huggingface.co/datasets/common_voice).
24
  When using this model, make sure that your speech input is sampled at 16kHz.
25
 
 
 
26
  ## Env setup:
27
  ```
28
  !pip install torchaudio
29
  !pip install datasets transformers
30
  !pip install asrp
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  ---
21
 
22
  # wav2vec2-xlsr-multilingual-56
23
+
24
+ *56 language, 1 model Multilingual ASR*
25
+
26
  Fine-tuned [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) on 56 language using the [Common Voice](https://huggingface.co/datasets/common_voice).
27
  When using this model, make sure that your speech input is sampled at 16kHz.
28
 
29
+ For more detail: [https://github.com/voidful/wav2vec2-xlsr-multilingual-56](https://github.com/voidful/wav2vec2-xlsr-multilingual-56)
30
+
31
  ## Env setup:
32
  ```
33
  !pip install torchaudio
34
  !pip install datasets transformers
35
  !pip install asrp
36
+ !wget -O lang_ids.pk https://huggingface.co/voidful/wav2vec2-xlsr-multilingual-56/raw/main/lang_ids.pk
37
+ ```
38
+
39
+ ## Usage
40
+ ```
41
+ import torchaudio
42
+ from datasets import load_dataset, load_metric
43
+ from transformers import (
44
+ Wav2Vec2ForCTC,
45
+ Wav2Vec2Processor,
46
+ AutoTokenizer,
47
+ AutoModelWithLMHead
48
+ )
49
+ import torch
50
+ import re
51
+ import sys
52
+ import soundfile as sf
53
+ model_name = "voidful/wav2vec2-xlsr-multilingual-56"
54
+ device = "cuda"
55
+ processor_name = "voidful/wav2vec2-xlsr-multilingual-56"
56
+
57
+ import pickle
58
+ with open("lang_ids.pk", 'rb') as output:
59
+ lang_ids = pickle.load(output)
60
+
61
+ model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
62
+ processor = Wav2Vec2Processor.from_pretrained(processor_name)
63
+
64
+ model.eval()
65
+
66
+ def load_file_to_data(file,sampling_rate=16_000):
67
+ batch = {}
68
+ speech, _ = torchaudio.load(file)
69
+ if sampling_rate != '16_000' or sampling_rate != '16000':
70
+ resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16_000)
71
+ batch["speech"] = resampler.forward(speech.squeeze(0)).numpy()
72
+ batch["sampling_rate"] = resampler.new_freq
73
+ else:
74
+ batch["speech"] = speech.squeeze(0).numpy()
75
+ batch["sampling_rate"] = '16000'
76
+ return batch
77
+
78
+
79
+ def predict(data):
80
+ features = processor(data["speech"], sampling_rate=data["sampling_rate"], padding=True, return_tensors="pt")
81
+ input_values = features.input_values.to(device)
82
+ attention_mask = features.attention_mask.to(device)
83
+ with torch.no_grad():
84
+ logits = model(input_values, attention_mask=attention_mask).logits
85
+ decoded_results = []
86
+ for logit in logits:
87
+ pred_ids = torch.argmax(logit, dim=-1)
88
+ mask = pred_ids.ge(1).unsqueeze(-1).expand(logit.size())
89
+ vocab_size = logit.size()[-1]
90
+ voice_prob = torch.nn.functional.softmax((torch.masked_select(logit, mask).view(-1,vocab_size)),dim=-1)
91
+ comb_pred_ids = torch.argmax(voice_prob, dim=-1)
92
+ decoded_results.append(processor.decode(comb_pred_ids))
93
+
94
+ return decoded_results
95
+
96
+ def predict_lang_specific(data,lang_code):
97
+ features = processor(data["speech"], sampling_rate=data["sampling_rate"], padding=True, return_tensors="pt")
98
+ input_values = features.input_values.to(device)
99
+ attention_mask = features.attention_mask.to(device)
100
+ with torch.no_grad():
101
+ logits = model(input_values, attention_mask=attention_mask).logits
102
+ decoded_results = []
103
+ for logit in logits:
104
+ pred_ids = torch.argmax(logit, dim=-1)
105
+ mask = ~pred_ids.eq(processor.tokenizer.pad_token_id).unsqueeze(-1).expand(logit.size())
106
+ vocab_size = logit.size()[-1]
107
+ voice_prob = torch.nn.functional.softmax((torch.masked_select(logit, mask).view(-1,vocab_size)),dim=-1)
108
+ filtered_input = pred_ids[pred_ids!=processor.tokenizer.pad_token_id].view(1,-1).to(device)
109
+ if len(filtered_input[0]) == 0:
110
+ decoded_results.append("")
111
+ else:
112
+ lang_mask = torch.empty(voice_prob.shape[-1]).fill_(0)
113
+ lang_index = torch.tensor(sorted(lang_ids[lang_code]))
114
+ lang_mask.index_fill_(0, lang_index, 1)
115
+ lang_mask = lang_mask.to(device)
116
+ comb_pred_ids = torch.argmax(lang_mask*voice_prob, dim=-1)
117
+ decoded_results.append(processor.decode(comb_pred_ids))
118
+
119
+ return decoded_results
120
+
121
+
122
+ predict(load_file_to_data('audio file path'))
123
+
124
+ predict_lang_specific(load_file_to_data('audio file path'),'en')
125
+
126
  ```
127
+
128
+ ## Evaluation Result
129
+ | Common Voice Languages | Num. of data | Hour | WER | CER |
130
+ |------------------------|--------------|--------|--------|-------|
131
+ | ar | 21744 | 81.5 | 75.24 | 31.27 |
132
+ | as | 394 | 1.1 | 95.37 | 46.03 |
133
+ | br | 4777 | 7.4 | 93.79 | 41.14 |
134
+ | ca | 301308 | 692.8 | 24.82 | 10.39 |
135
+ | cnh | 1563 | 2.4 | 68.22 | 23.11 |
136
+ | cs | 9773 | 39.5 | 67.89 | 12.57 |
137
+ | cv | 1749 | 5.9 | 95.43 | 34.01 |
138
+ | cy | 11615 | 106.7 | 66.98 | 23.93 |
139
+ | de | 262113 | 822.8 | 27.04 | 6.51 |
140
+ | dv | 4757 | 18.6 | 92.18 | 30.18 |
141
+ | el | 3717 | 11.1 | 94.51 | 58.69 |
142
+ | en | 580501 | 1763.6 | 34.88 | 14.84 |
143
+ | eo | 28574 | 162.3 | 37.79 | 6.23 |
144
+ | es | 176902 | 337.7 | 19.64 | 5.42 |
145
+ | et | 5473 | 35.9 | 86.87 | 20.80 |
146
+ | eu | 12677 | 90.2 | 44.83 | 7.32 |
147
+ | fa | 12806 | 290.6 | 53.85 | 15.09 |
148
+ | fi | 875 | 2.6 | 93.70 | 27.60 |
149
+ | fr | 314745 | 664.1 | 33.19 | 13.94 |
150
+ | fy-NL | 6717 | 27.2 | 72.55 | 26.58 |
151
+ | ga-IE | 1038 | 3.5 | 92.51 | 50.98 |
152
+ | hi | 292 | 2.0 | 90.84 | 57.34 |
153
+ | hsb | 980 | 2.3 | 89.52 | 27.18 |
154
+ | hu | 4782 | 9.3 | 97.11 | 36.74 |
155
+ | ia | 5078 | 10.4 | 52.08 | 11.37 |
156
+ | id | 3965 | 9.9 | 82.48 | 22.82 |
157
+ | it | 70943 | 178.0 | 39.09 | 8.72 |
158
+ | ja | 1308 | 8.2 | 99.21 | 61.91 |
159
+ | ka | 1585 | 4.0 | 90.49 | 18.57 |
160
+ | ky | 3466 | 12.2 | 76.57 | 19.83 |
161
+ | lg | 1634 | 17.1 | 98.95 | 43.84 |
162
+ | lt | 1175 | 3.9 | 92.67 | 26.82 |
163
+ | lv | 4554 | 6.3 | 90.34 | 30.79 |
164
+ | mn | 4020 | 11.6 | 82.70 | 30.15 |
165
+ | mt | 3552 | 7.8 | 84.21 | 22.94 |
166
+ | nl | 14398 | 71.8 | 57.17 | 19.01 |
167
+ | or | 517 | 0.9 | 90.93 | 27.42 |
168
+ | pa-IN | 255 | 0.8 | 88.07 | 42.00 |
169
+ | pl | 12621 | 112.0 | 56.15 | 12.07 |
170
+ | pt | 11106 | 61.3 | 53.27 | 16.33 |
171
+ | rm-sursilv | 2589 | 5.9 | 78.17 | 23.30 |
172
+ | rm-vallader | 931 | 2.3 | 73.64 | 21.70 |
173
+ | ro | 4257 | 8.7 | 83.81 | 21.93 |
174
+ | ru | 23444 | 119.1 | 61.83 | 15.18 |
175
+ | sah | 1847 | 4.4 | 94.36 | 38.47 |
176
+ | sl | 2594 | 6.7 | 84.25 | 20.52 |
177
+ | sv-SE | 4350 | 20.8 | 83.62 | 30.78 |
178
+ | ta | 3788 | 18.4 | 84.23 | 21.60 |
179
+ | th | 4839 | 11.7 | 141.83 | 37.24 |
180
+ | tr | 3478 | 22.3 | 66.79 | 15.55 |
181
+ | tt | 13338 | 26.7 | 86.81 | 33.59 |
182
+ | uk | 7271 | 39.4 | 70.19 | 14.35 |
183
+ | vi | 421 | 1.7 | 96.13 | 66.31 |
184
+ | zh-CN | 27284 | 58.7 | 89.65 | 23.94 |
185
+ | zh-HK | 12678 | 92.1 | 81.67 | 18.82 |
186
+ | zh-TW | 6402 | 56.6 | 85.04 | 29.08 |