lgris commited on
Commit
2f5c0e2
1 Parent(s): 11ffc3a

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +381 -0
README.md ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: pt
3
+ datasets:
4
+ - common_voice
5
+ - mls
6
+ - cetuc
7
+ - lapsbm
8
+ - voxforge
9
+ - tedx
10
+ - sid
11
+ metrics:
12
+ - wer
13
+ tags:
14
+ - audio
15
+ - speech
16
+ - wav2vec2
17
+ - pt
18
+ - portuguese-speech-corpus
19
+ - automatic-speech-recognition
20
+ - speech
21
+ - PyTorch
22
+ license: apache-2.0
23
+ ---
24
+
25
+ # mls100-xlsr: Wav2vec 2.0 with LaPSBM Dataset
26
+
27
+ This is a the demonstration of a fine-tuned Wav2vec model for Brazilian Portuguese using the [Multilingual Librispeech in Portuguese (MLS)](http://www.openslr.org/94/) dataset.
28
+
29
+ In this notebook the model is tested against other available Brazilian Portuguese datasets.
30
+
31
+ | Dataset | Train | Valid | Test |
32
+ |--------------------------------|-------:|------:|------:|
33
+ | CETUC | | -- | 5.4h |
34
+ | Common Voice | | -- | 9.5h |
35
+ | LaPS BM | | -- | 0.1h |
36
+ | MLS | 161h | -- | 3.7h |
37
+ | Multilingual TEDx (Portuguese) | | -- | 1.8h |
38
+ | SID | | -- | 1.0h |
39
+ | VoxForge | | -- | 0.1h |
40
+ | Total | 161h | -- | 21.6h |
41
+
42
+
43
+ #### Summary
44
+
45
+ | | CETUC | CV | LaPS | MLS | SID | TEDx | VF | AVG |
46
+ |----------------------|---------------|----------------|----------------|----------------|----------------|----------------|----------------|----------------|
47
+ | mls100 (demonstration below) | 0.192 | 0.260 | 0.162 | 0.163 | 0.268 | 0.492 | 0.268 | 0.258 |
48
+ | mls100 + 4-gram (demonstration below) | 0.087 | 0.173 | 0.077 | 0.126 | 0.245 | 0.415 | 0.218 | 0.191 |
49
+
50
+ ## Demonstration
51
+
52
+
53
+ ```python
54
+ MODEL_NAME = "lgris/mls100-xlsr"
55
+ ```
56
+
57
+ ### Imports and dependencies
58
+
59
+
60
+ ```python
61
+ %%capture
62
+ !pip install torch==1.8.2+cu111 torchvision==0.9.2+cu111 torchaudio===0.8.2 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
63
+ !pip install datasets
64
+ !pip install jiwer
65
+ !pip install transformers
66
+ !pip install soundfile
67
+ !pip install pyctcdecode
68
+ !pip install https://github.com/kpu/kenlm/archive/master.zip
69
+ ```
70
+
71
+
72
+ ```python
73
+ import jiwer
74
+ import torchaudio
75
+ from datasets import load_dataset, load_metric
76
+ from transformers import (
77
+ Wav2Vec2ForCTC,
78
+ Wav2Vec2Processor,
79
+ )
80
+ from pyctcdecode import build_ctcdecoder
81
+ import torch
82
+ import re
83
+ import sys
84
+ ```
85
+
86
+ ### Helpers
87
+
88
+
89
+ ```python
90
+ chars_to_ignore_regex = '[\,\?\.\!\;\:\"]' # noqa: W605
91
+
92
+ def map_to_array(batch):
93
+ speech, _ = torchaudio.load(batch["path"])
94
+ batch["speech"] = speech.squeeze(0).numpy()
95
+ batch["sampling_rate"] = 16_000
96
+ batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower().replace("’", "'")
97
+ batch["target"] = batch["sentence"]
98
+ return batch
99
+ ```
100
+
101
+
102
+ ```python
103
+ def calc_metrics(truths, hypos):
104
+ wers = []
105
+ mers = []
106
+ wils = []
107
+ for t, h in zip(truths, hypos):
108
+ try:
109
+ wers.append(jiwer.wer(t, h))
110
+ mers.append(jiwer.mer(t, h))
111
+ wils.append(jiwer.wil(t, h))
112
+ except: # Empty string?
113
+ pass
114
+ wer = sum(wers)/len(wers)
115
+ mer = sum(mers)/len(mers)
116
+ wil = sum(wils)/len(wils)
117
+ return wer, mer, wil
118
+ ```
119
+
120
+
121
+ ```python
122
+ def load_data(dataset):
123
+ data_files = {'test': f'{dataset}/test.csv'}
124
+ dataset = load_dataset('csv', data_files=data_files)["test"]
125
+ return dataset.map(map_to_array)
126
+ ```
127
+
128
+ ### Model
129
+
130
+
131
+ ```python
132
+ class STT:
133
+
134
+ def __init__(self,
135
+ model_name,
136
+ device='cuda' if torch.cuda.is_available() else 'cpu',
137
+ lm=None):
138
+ self.model_name = model_name
139
+ self.model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
140
+ self.processor = Wav2Vec2Processor.from_pretrained(model_name)
141
+ self.vocab_dict = self.processor.tokenizer.get_vocab()
142
+ self.sorted_dict = {
143
+ k.lower(): v for k, v in sorted(self.vocab_dict.items(),
144
+ key=lambda item: item[1])
145
+ }
146
+ self.device = device
147
+ self.lm = lm
148
+ if self.lm:
149
+ self.lm_decoder = build_ctcdecoder(
150
+ list(self.sorted_dict.keys()),
151
+ self.lm
152
+ )
153
+
154
+ def batch_predict(self, batch):
155
+ features = self.processor(batch["speech"],
156
+ sampling_rate=batch["sampling_rate"][0],
157
+ padding=True,
158
+ return_tensors="pt")
159
+ input_values = features.input_values.to(self.device)
160
+ attention_mask = features.attention_mask.to(self.device)
161
+ with torch.no_grad():
162
+ logits = self.model(input_values, attention_mask=attention_mask).logits
163
+ if self.lm:
164
+ logits = logits.cpu().numpy()
165
+ batch["predicted"] = []
166
+ for sample_logits in logits:
167
+ batch["predicted"].append(self.lm_decoder.decode(sample_logits))
168
+ else:
169
+ pred_ids = torch.argmax(logits, dim=-1)
170
+ batch["predicted"] = self.processor.batch_decode(pred_ids)
171
+ return batch
172
+ ```
173
+
174
+ ### Download datasets
175
+
176
+
177
+ ```python
178
+ %%capture
179
+ !gdown --id 1HFECzIizf-bmkQRLiQD0QVqcGtOG5upI
180
+ !mkdir bp_dataset
181
+ !unzip bp_dataset -d bp_dataset/
182
+ ```
183
+
184
+
185
+ ```python
186
+ %cd bp_dataset/
187
+ ```
188
+
189
+ /content/bp_dataset
190
+
191
+
192
+ ### Tests
193
+
194
+
195
+ ```python
196
+ stt = STT(MODEL_NAME)
197
+ ```
198
+
199
+ #### CETUC
200
+
201
+
202
+ ```python
203
+ ds = load_data('cetuc_dataset')
204
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
205
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
206
+ print("CETUC WER:", wer)
207
+ ```
208
+ CETUC WER: 0.192586382955233
209
+
210
+
211
+ #### Common Voice
212
+
213
+
214
+ ```python
215
+ ds = load_data('commonvoice_dataset')
216
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
217
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
218
+ print("CV WER:", wer)
219
+ ```
220
+ CV WER: 0.2604333640312866
221
+
222
+
223
+ #### LaPS
224
+
225
+
226
+ ```python
227
+ ds = load_data('lapsbm_dataset')
228
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
229
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
230
+ print("Laps WER:", wer)
231
+ ```
232
+ Laps WER: 0.16259469696969692
233
+
234
+
235
+ #### MLS
236
+
237
+
238
+ ```python
239
+ ds = load_data('mls_dataset')
240
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
241
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
242
+ print("MLS WER:", wer)
243
+ ```
244
+ MLS WER: 0.16343014413283674
245
+
246
+
247
+ #### SID
248
+
249
+
250
+ ```python
251
+ ds = load_data('sid_dataset')
252
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
253
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
254
+ print("Sid WER:", wer)
255
+ ```
256
+ Sid WER: 0.2682880375992515
257
+
258
+
259
+ #### TEDx
260
+
261
+
262
+ ```python
263
+ ds = load_data('tedx_dataset')
264
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
265
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
266
+ print("TEDx WER:", wer)
267
+ ```
268
+ TEDx WER: 0.49252836581485837
269
+
270
+
271
+ #### VoxForge
272
+
273
+
274
+ ```python
275
+ ds = load_data('voxforge_dataset')
276
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
277
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
278
+ print("VoxForge WER:", wer)
279
+ ```
280
+ VoxForge WER: 0.2686972402597403
281
+
282
+
283
+ ### Tests with LM
284
+
285
+
286
+ ```python
287
+ !rm -rf ~/.cache
288
+ %cd /content/
289
+ # !gdown --id '1d13Onxy9ubmJZORZ8FO2vnsnl36QMiUc' # trained with wikipedia;
290
+ stt = STT(MODEL_NAME, lm='pt-BR-wiki.word.4-gram.arpa')
291
+ # !gdown --id 1dLFldy7eguPtyJj5OAlI4Emnx0BpFywg # trained with bp
292
+ # stt = STT(MODEL_NAME, lm='pt-BR.word.4-gram.arpa')
293
+ %cd bp_dataset/
294
+ ```
295
+
296
+ /content/bp_dataset
297
+
298
+
299
+ #### CETUC
300
+
301
+
302
+ ```python
303
+ ds = load_data('cetuc_dataset')
304
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
305
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
306
+ print("CETUC WER:", wer)
307
+ ```
308
+ CETUC WER: 0.0878818926974661
309
+
310
+
311
+ #### Common Voice
312
+
313
+
314
+ ```python
315
+ ds = load_data('commonvoice_dataset')
316
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
317
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
318
+ print("CV WER:", wer)
319
+ ```
320
+ CV WER: 0.173303354010221
321
+
322
+
323
+ #### LaPS
324
+
325
+
326
+ ```python
327
+ ds = load_data('lapsbm_dataset')
328
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
329
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
330
+ print("Laps WER:", wer)
331
+ ```
332
+ Laps WER: 0.07691919191919189
333
+
334
+
335
+ #### MLS
336
+
337
+
338
+ ```python
339
+ ds = load_data('mls_dataset')
340
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
341
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
342
+ print("MLS WER:", wer)
343
+ ```
344
+ MLS WER: 0.12624377042839321
345
+
346
+
347
+ #### SID
348
+
349
+
350
+ ```python
351
+ ds = load_data('sid_dataset')
352
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
353
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
354
+ print("Sid WER:", wer)
355
+ ```
356
+ Sid WER: 0.24545473435776916
357
+
358
+
359
+ #### TEDx
360
+
361
+
362
+ ```python
363
+ ds = load_data('tedx_dataset')
364
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
365
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
366
+ print("TEDx WER:", wer)
367
+ ```
368
+ TEDx WER: 0.4156272215612955
369
+
370
+
371
+ #### VoxForge
372
+
373
+
374
+ ```python
375
+ ds = load_data('voxforge_dataset')
376
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
377
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
378
+ print("VoxForge WER:", wer)
379
+ ```
380
+ VoxForge WER: 0.21832386363636366
381
+