lgris commited on
Commit
3bc89fd
1 Parent(s): c23ab5b

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +370 -0
README.md ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: pt
3
+ datasets:
4
+ - common_voice
5
+ - mls
6
+ - cetuc
7
+ - lapsbm
8
+ - voxforge
9
+ - tedx
10
+ - sid
11
+ metrics:
12
+ - wer
13
+ tags:
14
+ - audio
15
+ - speech
16
+ - wav2vec2
17
+ - pt
18
+ - portuguese-speech-corpus
19
+ - automatic-speech-recognition
20
+ - speech
21
+ - PyTorch
22
+ license: apache-2.0
23
+ ---
24
+
25
+ # commonvoice100-xlsr: Wav2vec 2.0 with Common Voice Dataset
26
+
27
+ This is a the demonstration of a fine-tuned Wav2vec model for Brazilian Portuguese using the [Common Voice 7.0](https://commonvoice.mozilla.org/pt) dataset.
28
+
29
+ In this notebook the model is tested against other available Brazilian Portuguese datasets.
30
+
31
+ | Dataset | Train | Valid | Test |
32
+ |--------------------------------|-------:|------:|------:|
33
+ | CETUC | | -- | 5.4h |
34
+ | Common Voice | 37.8h | -- | 9.5h |
35
+ | LaPS BM | | -- | 0.1h |
36
+ | MLS | | -- | 3.7h |
37
+ | Multilingual TEDx (Portuguese) | | -- | 1.8h |
38
+ | SID | | -- | 1.0h |
39
+ | VoxForge | | -- | 0.1h |
40
+ | Total | | -- | 21.6h |
41
+
42
+
43
+ #### Summary
44
+
45
+ | | CETUC | CV | LaPS | MLS | SID | TEDx | VF | AVG |
46
+ |----------------------|---------------|----------------|----------------|----------------|----------------|----------------|----------------|----------------|
47
+ | commonvoice\_100 (demonstration below) |0.088 | 0.126 | 0.121 | 0.173 | 0.177 | 0.424 | 0.145 | 0.179 |
48
+ | commonvoice\_100 + 4-gram (demonstration below) |0.057 | 0.095 | 0.076 | 0.138 | 0.146 | 0.382 | 0.130 | 0.146|
49
+
50
+ ## Demonstration
51
+
52
+
53
+ ```python
54
+ MODEL_NAME = "lgris/commonvoice100-xlsr"
55
+ ```
56
+
57
+ ### Imports and dependencies
58
+
59
+
60
+ ```python
61
+ %%capture
62
+ !pip install torch==1.8.2+cu111 torchvision==0.9.2+cu111 torchaudio===0.8.2 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
63
+ !pip install datasets
64
+ !pip install jiwer
65
+ !pip install transformers
66
+ !pip install soundfile
67
+ !pip install pyctcdecode
68
+ !pip install https://github.com/kpu/kenlm/archive/master.zip
69
+ ```
70
+
71
+
72
+ ```python
73
+ import jiwer
74
+ import torchaudio
75
+ from datasets import load_dataset, load_metric
76
+ from transformers import (
77
+ Wav2Vec2ForCTC,
78
+ Wav2Vec2Processor,
79
+ )
80
+ from pyctcdecode import build_ctcdecoder
81
+ import torch
82
+ import re
83
+ import sys
84
+ ```
85
+
86
+ ### Helpers
87
+
88
+
89
+ ```python
90
+ chars_to_ignore_regex = '[\,\?\.\!\;\:\"]' # noqa: W605
91
+
92
+ def map_to_array(batch):
93
+ speech, _ = torchaudio.load(batch["path"])
94
+ batch["speech"] = speech.squeeze(0).numpy()
95
+ batch["sampling_rate"] = 16_000
96
+ batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower().replace("’", "'")
97
+ batch["target"] = batch["sentence"]
98
+ return batch
99
+ ```
100
+
101
+
102
+ ```python
103
+ def calc_metrics(truths, hypos):
104
+ wers = []
105
+ mers = []
106
+ wils = []
107
+ for t, h in zip(truths, hypos):
108
+ try:
109
+ wers.append(jiwer.wer(t, h))
110
+ mers.append(jiwer.mer(t, h))
111
+ wils.append(jiwer.wil(t, h))
112
+ except: # Empty string?
113
+ pass
114
+ wer = sum(wers)/len(wers)
115
+ mer = sum(mers)/len(mers)
116
+ wil = sum(wils)/len(wils)
117
+ return wer, mer, wil
118
+ ```
119
+
120
+
121
+ ```python
122
+ def load_data(dataset):
123
+ data_files = {'test': f'{dataset}/test.csv'}
124
+ dataset = load_dataset('csv', data_files=data_files)["test"]
125
+ return dataset.map(map_to_array)
126
+ ```
127
+
128
+ ### Model
129
+
130
+
131
+ ```python
132
+ class STT:
133
+
134
+ def __init__(self,
135
+ model_name,
136
+ device='cuda' if torch.cuda.is_available() else 'cpu',
137
+ lm=None):
138
+ self.model_name = model_name
139
+ self.model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
140
+ self.processor = Wav2Vec2Processor.from_pretrained(model_name)
141
+ self.vocab_dict = self.processor.tokenizer.get_vocab()
142
+ self.sorted_dict = {
143
+ k.lower(): v for k, v in sorted(self.vocab_dict.items(),
144
+ key=lambda item: item[1])
145
+ }
146
+ self.device = device
147
+ self.lm = lm
148
+ if self.lm:
149
+ self.lm_decoder = build_ctcdecoder(
150
+ list(self.sorted_dict.keys()),
151
+ self.lm
152
+ )
153
+
154
+ def batch_predict(self, batch):
155
+ features = self.processor(batch["speech"],
156
+ sampling_rate=batch["sampling_rate"][0],
157
+ padding=True,
158
+ return_tensors="pt")
159
+ input_values = features.input_values.to(self.device)
160
+ attention_mask = features.attention_mask.to(self.device)
161
+ with torch.no_grad():
162
+ logits = self.model(input_values, attention_mask=attention_mask).logits
163
+ if self.lm:
164
+ logits = logits.cpu().numpy()
165
+ batch["predicted"] = []
166
+ for sample_logits in logits:
167
+ batch["predicted"].append(self.lm_decoder.decode(sample_logits))
168
+ else:
169
+ pred_ids = torch.argmax(logits, dim=-1)
170
+ batch["predicted"] = self.processor.batch_decode(pred_ids)
171
+ return batch
172
+ ```
173
+
174
+ ### Download datasets
175
+
176
+
177
+ ```python
178
+ %%capture
179
+ !gdown --id 1HFECzIizf-bmkQRLiQD0QVqcGtOG5upI
180
+ !mkdir bp_dataset
181
+ !unzip bp_dataset -d bp_dataset/
182
+ ```
183
+
184
+ ### Tests
185
+
186
+
187
+ ```python
188
+ stt = STT(MODEL_NAME)
189
+ ```
190
+
191
+ #### CETUC
192
+
193
+
194
+ ```python
195
+ ds = load_data('cetuc_dataset')
196
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
197
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
198
+ print("CETUC WER:", wer)
199
+ ```
200
+ CETUC WER: 0.08868880057404624
201
+
202
+
203
+ #### Common Voice
204
+
205
+
206
+ ```python
207
+ ds = load_data('commonvoice_dataset')
208
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
209
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
210
+ print("CV WER:", wer)
211
+ ```
212
+ CV WER: 0.12601035333655114
213
+
214
+
215
+ #### LaPS
216
+
217
+
218
+ ```python
219
+ ds = load_data('lapsbm_dataset')
220
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
221
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
222
+ print("Laps WER:", wer)
223
+ ```
224
+ Laps WER: 0.12149621212121209
225
+
226
+
227
+ #### MLS
228
+
229
+
230
+ ```python
231
+ ds = load_data('mls_dataset')
232
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
233
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
234
+ print("MLS WER:", wer)
235
+ ```
236
+ MLS WER: 0.173594387890256
237
+
238
+
239
+ #### SID
240
+
241
+
242
+ ```python
243
+ ds = load_data('sid_dataset')
244
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
245
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
246
+ print("Sid WER:", wer)
247
+ ```
248
+ Sid WER: 0.1775290775992294
249
+
250
+
251
+ #### TEDx
252
+
253
+
254
+ ```python
255
+ ds = load_data('tedx_dataset')
256
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
257
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
258
+ print("TEDx WER:", wer)
259
+ ```
260
+ TEDx WER: 0.4245704568241374
261
+
262
+
263
+ #### VoxForge
264
+
265
+
266
+ ```python
267
+ ds = load_data('voxforge_dataset')
268
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
269
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
270
+ print("VoxForge WER:", wer)
271
+ ```
272
+ VoxForge WER: 0.14541801948051947
273
+
274
+
275
+ ### Tests with LM
276
+
277
+
278
+ ```python
279
+ # !find -type f -name "*.wav" -delete
280
+ !rm -rf ~/.cache
281
+ !gdown --id 1GJIKseP5ZkTbllQVgOL98R4yYAcIySFP # trained with wikipedia
282
+ stt = STT(MODEL_NAME, lm='pt-BR-wiki.word.4-gram.arpa')
283
+ # !gdown --id 1dLFldy7eguPtyJj5OAlI4Emnx0BpFywg # trained with bp
284
+ # stt = STT(MODEL_NAME, lm='pt-BR.word.4-gram.arpa')
285
+ ```
286
+
287
+
288
+ #### CETUC
289
+
290
+
291
+ ```python
292
+ ds = load_data('cetuc_dataset')
293
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
294
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
295
+ print("CETUC WER:", wer)
296
+ ```
297
+ CETUC WER: 0.05764220069547976
298
+
299
+
300
+ #### Common Voice
301
+
302
+
303
+ ```python
304
+ ds = load_data('commonvoice_dataset')
305
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
306
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
307
+ print("CV WER:", wer)
308
+ ```
309
+ CV WER: 0.09569130510737103
310
+
311
+
312
+ #### LaPS
313
+
314
+
315
+ ```python
316
+ ds = load_data('lapsbm_dataset')
317
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
318
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
319
+ print("Laps WER:", wer)
320
+ ```
321
+ Laps WER: 0.07688131313131312
322
+
323
+
324
+ #### MLS
325
+
326
+
327
+ ```python
328
+ ds = load_data('mls_dataset')
329
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
330
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
331
+ print("MLS WER:", wer)
332
+ ```
333
+ MLS WER: 0.13814768877494732
334
+
335
+
336
+ #### SID
337
+
338
+
339
+ ```python
340
+ ds = load_data('sid_dataset')
341
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
342
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
343
+ print("Sid WER:", wer)
344
+ ```
345
+ Sid WER: 0.14652459944499036
346
+
347
+
348
+ #### TEDx
349
+
350
+
351
+ ```python
352
+ ds = load_data('tedx_dataset')
353
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
354
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
355
+ print("TEDx WER:", wer)
356
+ ```
357
+ TEDx WER: 0.38196090002435623
358
+
359
+
360
+ #### VoxForge
361
+
362
+
363
+ ```python
364
+ ds = load_data('voxforge_dataset')
365
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
366
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
367
+ print("VoxForge WER:", wer)
368
+ ```
369
+ VoxForge WER: 0.13054112554112554
370
+