lgris commited on
Commit
b7b88fe
1 Parent(s): 808e8ba

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +369 -0
README.md ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: pt
3
+ datasets:
4
+ - common_voice
5
+ - mls
6
+ - cetuc
7
+ - lapsbm
8
+ - voxforge
9
+ - tedx
10
+ - sid
11
+ metrics:
12
+ - wer
13
+ tags:
14
+ - audio
15
+ - speech
16
+ - wav2vec2
17
+ - pt
18
+ - portuguese-speech-corpus
19
+ - automatic-speech-recognition
20
+ - speech
21
+ - PyTorch
22
+ license: apache-2.0
23
+ ---
24
+
25
+ # commonvoice10-xlsr: Wav2vec 2.0 with Common Voice Dataset
26
+
27
+ This is a the demonstration of a fine-tuned Wav2vec model for Brazilian Portuguese using the [Common Voice 7.0](https://commonvoice.mozilla.org/pt) dataset.
28
+
29
+ In this notebook the model is tested against other available Brazilian Portuguese datasets.
30
+
31
+ | Dataset | Train | Valid | Test |
32
+ |--------------------------------|-------:|------:|------:|
33
+ | CETUC | | -- | 5.4h |
34
+ | Common Voice | 37.8h | -- | 9.5h |
35
+ | LaPS BM | | -- | 0.1h |
36
+ | MLS | | -- | 3.7h |
37
+ | Multilingual TEDx (Portuguese) | | -- | 1.8h |
38
+ | SID | | -- | 1.0h |
39
+ | VoxForge | | -- | 0.1h |
40
+ | Total | | -- | 21.6h |
41
+
42
+
43
+ #### Summary
44
+
45
+ | | CETUC | CV | LaPS | MLS | SID | TEDx | VF | AVG |
46
+ |----------------------|---------------|----------------|----------------|----------------|----------------|----------------|----------------|----------------|
47
+ | commonvoice10 (demonstration below) | 0.133 | 0.189 | 0.165 | 0.189 | 0.247 | 0.474 | 0.251 | 0.235 |
48
+ | commonvoice10 + 4-gram (demonstration below) | 0.060 | 0.117 | 0.088 | 0.136 | 0.181 | 0.394 | 0.227 | 0.171 |
49
+
50
+ ## Demonstration
51
+
52
+
53
+ ```python
54
+ MODEL_NAME = "lgris/commonvoice10-xlsr"
55
+ ```
56
+
57
+ ### Imports and dependencies
58
+
59
+
60
+ ```python
61
+ %%capture
62
+ !pip install torch==1.8.2+cu111 torchvision==0.9.2+cu111 torchaudio===0.8.2 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
63
+ !pip install datasets
64
+ !pip install jiwer
65
+ !pip install transformers
66
+ !pip install soundfile
67
+ !pip install pyctcdecode
68
+ !pip install https://github.com/kpu/kenlm/archive/master.zip
69
+ ```
70
+
71
+
72
+ ```python
73
+ import jiwer
74
+ import torchaudio
75
+ from datasets import load_dataset, load_metric
76
+ from transformers import (
77
+ Wav2Vec2ForCTC,
78
+ Wav2Vec2Processor,
79
+ )
80
+ from pyctcdecode import build_ctcdecoder
81
+ import torch
82
+ import re
83
+ import sys
84
+ ```
85
+
86
+ ### Helpers
87
+
88
+
89
+ ```python
90
+ chars_to_ignore_regex = '[\,\?\.\!\;\:\"]' # noqa: W605
91
+
92
+ def map_to_array(batch):
93
+ speech, _ = torchaudio.load(batch["path"])
94
+ batch["speech"] = speech.squeeze(0).numpy()
95
+ batch["sampling_rate"] = 16_000
96
+ batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower().replace("’", "'")
97
+ batch["target"] = batch["sentence"]
98
+ return batch
99
+ ```
100
+
101
+
102
+ ```python
103
+ def calc_metrics(truths, hypos):
104
+ wers = []
105
+ mers = []
106
+ wils = []
107
+ for t, h in zip(truths, hypos):
108
+ try:
109
+ wers.append(jiwer.wer(t, h))
110
+ mers.append(jiwer.mer(t, h))
111
+ wils.append(jiwer.wil(t, h))
112
+ except: # Empty string?
113
+ pass
114
+ wer = sum(wers)/len(wers)
115
+ mer = sum(mers)/len(mers)
116
+ wil = sum(wils)/len(wils)
117
+ return wer, mer, wil
118
+ ```
119
+
120
+
121
+ ```python
122
+ def load_data(dataset):
123
+ data_files = {'test': f'{dataset}/test.csv'}
124
+ dataset = load_dataset('csv', data_files=data_files)["test"]
125
+ return dataset.map(map_to_array)
126
+ ```
127
+
128
+ ### Model
129
+
130
+
131
+ ```python
132
+ class STT:
133
+
134
+ def __init__(self,
135
+ model_name,
136
+ device='cuda' if torch.cuda.is_available() else 'cpu',
137
+ lm=None):
138
+ self.model_name = model_name
139
+ self.model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
140
+ self.processor = Wav2Vec2Processor.from_pretrained(model_name)
141
+ self.vocab_dict = self.processor.tokenizer.get_vocab()
142
+ self.sorted_dict = {
143
+ k.lower(): v for k, v in sorted(self.vocab_dict.items(),
144
+ key=lambda item: item[1])
145
+ }
146
+ self.device = device
147
+ self.lm = lm
148
+ if self.lm:
149
+ self.lm_decoder = build_ctcdecoder(
150
+ list(self.sorted_dict.keys()),
151
+ self.lm
152
+ )
153
+
154
+ def batch_predict(self, batch):
155
+ features = self.processor(batch["speech"],
156
+ sampling_rate=batch["sampling_rate"][0],
157
+ padding=True,
158
+ return_tensors="pt")
159
+ input_values = features.input_values.to(self.device)
160
+ attention_mask = features.attention_mask.to(self.device)
161
+ with torch.no_grad():
162
+ logits = self.model(input_values, attention_mask=attention_mask).logits
163
+ if self.lm:
164
+ logits = logits.cpu().numpy()
165
+ batch["predicted"] = []
166
+ for sample_logits in logits:
167
+ batch["predicted"].append(self.lm_decoder.decode(sample_logits))
168
+ else:
169
+ pred_ids = torch.argmax(logits, dim=-1)
170
+ batch["predicted"] = self.processor.batch_decode(pred_ids)
171
+ return batch
172
+ ```
173
+
174
+ ### Download datasets
175
+
176
+
177
+ ```python
178
+ %%capture
179
+ !gdown --id 1HFECzIizf-bmkQRLiQD0QVqcGtOG5upI
180
+ !mkdir bp_dataset
181
+ !unzip bp_dataset -d bp_dataset/
182
+ ```
183
+
184
+ ### Tests
185
+
186
+
187
+ ```python
188
+ stt = STT(MODEL_NAME)
189
+ ```
190
+
191
+ #### CETUC
192
+
193
+
194
+ ```python
195
+ ds = load_data('cetuc_dataset')
196
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
197
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
198
+ print("CETUC WER:", wer)
199
+ ```
200
+ CETUC WER: 0.13291846056190185
201
+
202
+
203
+ #### Common Voice
204
+
205
+
206
+ ```python
207
+ ds = load_data('commonvoice_dataset')
208
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
209
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
210
+ print("CV WER:", wer)
211
+ ```
212
+ CV WER: 0.18909733896486755
213
+
214
+
215
+ #### LaPS
216
+
217
+
218
+ ```python
219
+ ds = load_data('lapsbm_dataset')
220
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
221
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
222
+ print("Laps WER:", wer)
223
+ ```
224
+ Laps WER: 0.1655429292929293
225
+
226
+
227
+ #### MLS
228
+
229
+
230
+ ```python
231
+ ds = load_data('mls_dataset')
232
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
233
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
234
+ print("MLS WER:", wer)
235
+ ```
236
+ MLS WER: 0.1894711228284466
237
+
238
+
239
+ #### SID
240
+
241
+
242
+ ```python
243
+ ds = load_data('sid_dataset')
244
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
245
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
246
+ print("Sid WER:", wer)
247
+ ```
248
+ Sid WER: 0.2471983709551264
249
+
250
+
251
+ #### TEDx
252
+
253
+
254
+ ```python
255
+ ds = load_data('tedx_dataset')
256
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
257
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
258
+ print("TEDx WER:", wer)
259
+ ```
260
+ TEDx WER: 0.4739658565194102
261
+
262
+
263
+ #### VoxForge
264
+
265
+
266
+ ```python
267
+ ds = load_data('voxforge_dataset')
268
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
269
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
270
+ print("VoxForge WER:", wer)
271
+ ```
272
+ VoxForge WER: 0.2510294913419914
273
+
274
+
275
+ ### Tests with LM
276
+
277
+
278
+ ```python
279
+ # !find -type f -name "*.wav" -delete
280
+ !rm -rf ~/.cache
281
+ !gdown --id 1GJIKseP5ZkTbllQVgOL98R4yYAcIySFP # trained with wikipedia
282
+ stt = STT(MODEL_NAME, lm='pt-BR-wiki.word.4-gram.arpa')
283
+ # !gdown --id 1dLFldy7eguPtyJj5OAlI4Emnx0BpFywg # trained with bp
284
+ # stt = STT(MODEL_NAME, lm='pt-BR.word.4-gram.arpa')
285
+ ```
286
+
287
+ #### CETUC
288
+
289
+
290
+ ```python
291
+ ds = load_data('cetuc_dataset')
292
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
293
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
294
+ print("CETUC WER:", wer)
295
+ ```
296
+ CETUC WER: 0.060609303416680915
297
+
298
+
299
+ #### Common Voice
300
+
301
+
302
+ ```python
303
+ ds = load_data('commonvoice_dataset')
304
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
305
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
306
+ print("CV WER:", wer)
307
+ ```
308
+ CV WER: 0.11758415681158373
309
+
310
+
311
+ #### LaPS
312
+
313
+
314
+ ```python
315
+ ds = load_data('lapsbm_dataset')
316
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
317
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
318
+ print("Laps WER:", wer)
319
+ ```
320
+ Laps WER: 0.08815340909090909
321
+
322
+
323
+ #### MLS
324
+
325
+
326
+ ```python
327
+ ds = load_data('mls_dataset')
328
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
329
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
330
+ print("MLS WER:", wer)
331
+ ```
332
+ MLS WER: 0.1359966791836458
333
+
334
+
335
+ #### SID
336
+
337
+
338
+ ```python
339
+ ds = load_data('sid_dataset')
340
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
341
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
342
+ print("Sid WER:", wer)
343
+ ```
344
+ Sid WER: 0.1818429601530829
345
+
346
+
347
+ #### TEDx
348
+
349
+
350
+ ```python
351
+ ds = load_data('tedx_dataset')
352
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
353
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
354
+ print("TEDx WER:", wer)
355
+ ```
356
+ TEDx WER: 0.39469326522731385
357
+
358
+
359
+ #### VoxForge
360
+
361
+
362
+ ```python
363
+ ds = load_data('voxforge_dataset')
364
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
365
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
366
+ print("VoxForge WER:", wer)
367
+ ```
368
+ VoxForge WER: 0.22779897186147183
369
+