lgris commited on
Commit
2424d35
1 Parent(s): 2f47e31

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +376 -0
README.md ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: pt
3
+ datasets:
4
+ - common_voice
5
+ - mls
6
+ - cetuc
7
+ - lapsbm
8
+ - voxforge
9
+ - tedx
10
+ - sid
11
+ metrics:
12
+ - wer
13
+ tags:
14
+ - audio
15
+ - speech
16
+ - wav2vec2
17
+ - pt
18
+ - portuguese-speech-corpus
19
+ - automatic-speech-recognition
20
+ - speech
21
+ - PyTorch
22
+ license: apache-2.0
23
+ ---
24
+
25
+ # tedx100-xlsr: Wav2vec 2.0 with TEDx Dataset
26
+
27
+ This is a the demonstration of a fine-tuned Wav2vec model for Brazilian Portuguese using the [TEDx multilingual in Portuguese](http://www.openslr.org/100) dataset.
28
+
29
+ In this notebook the model is tested against other available Brazilian Portuguese datasets.
30
+
31
+ | Dataset | Train | Valid | Test |
32
+ |--------------------------------|-------:|------:|------:|
33
+ | CETUC | | -- | 5.4h |
34
+ | Common Voice | | -- | 9.5h |
35
+ | LaPS BM | | -- | 0.1h |
36
+ | MLS | | -- | 3.7h |
37
+ | Multilingual TEDx (Portuguese) | 148.8h| -- | 1.8h |
38
+ | SID | | -- | 1.0h |
39
+ | VoxForge | | -- | 0.1h |
40
+ | Total |148.8h | -- | 21.6h |
41
+
42
+
43
+ #### Summary
44
+
45
+ | | CETUC | CV | LaPS | MLS | SID | TEDx | VF | AVG |
46
+ |----------------------|---------------|----------------|----------------|----------------|----------------|----------------|----------------|----------------|
47
+ | tedx\_100 (demonstration below) |0.138 | 0.369 | 0.169 | 0.165 | 0.794 | 0.222 | 0.395 | 0.321|
48
+ | tedx\_100 + 4-gram (demonstration below) |0.123 | 0.414 | 0.171 | 0.152 | 0.982 | 0.215 | 0.395 | 0.350|
49
+
50
+ #### Transcription examples
51
+
52
+ | Text | Transcription |
53
+ |------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------|
54
+ | ||
55
+
56
+
57
+ ## Demonstration
58
+
59
+
60
+ ```python
61
+ MODEL_NAME = "lgris/tedx100-xlsr"
62
+ ```
63
+
64
+ ### Imports and dependencies
65
+
66
+
67
+ ```python
68
+ %%capture
69
+ !pip install torch==1.8.2+cu111 torchvision==0.9.2+cu111 torchaudio===0.8.2 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
70
+ !pip install datasets
71
+ !pip install jiwer
72
+ !pip install transformers
73
+ !pip install soundfile
74
+ !pip install pyctcdecode
75
+ !pip install https://github.com/kpu/kenlm/archive/master.zip
76
+ ```
77
+
78
+
79
+ ```python
80
+ import jiwer
81
+ import torchaudio
82
+ from datasets import load_dataset, load_metric
83
+ from transformers import (
84
+ Wav2Vec2ForCTC,
85
+ Wav2Vec2Processor,
86
+ )
87
+ from pyctcdecode import build_ctcdecoder
88
+ import torch
89
+ import re
90
+ import sys
91
+ ```
92
+
93
+ ### Helpers
94
+
95
+
96
+ ```python
97
+ chars_to_ignore_regex = '[\,\?\.\!\;\:\"]' # noqa: W605
98
+
99
+ def map_to_array(batch):
100
+ speech, _ = torchaudio.load(batch["path"])
101
+ batch["speech"] = speech.squeeze(0).numpy()
102
+ batch["sampling_rate"] = 16_000
103
+ batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower().replace("’", "'")
104
+ batch["target"] = batch["sentence"]
105
+ return batch
106
+ ```
107
+
108
+
109
+ ```python
110
+ def calc_metrics(truths, hypos):
111
+ wers = []
112
+ mers = []
113
+ wils = []
114
+ for t, h in zip(truths, hypos):
115
+ try:
116
+ wers.append(jiwer.wer(t, h))
117
+ mers.append(jiwer.mer(t, h))
118
+ wils.append(jiwer.wil(t, h))
119
+ except: # Empty string?
120
+ pass
121
+ wer = sum(wers)/len(wers)
122
+ mer = sum(mers)/len(mers)
123
+ wil = sum(wils)/len(wils)
124
+ return wer, mer, wil
125
+ ```
126
+
127
+
128
+ ```python
129
+ def load_data(dataset):
130
+ data_files = {'test': f'{dataset}/test.csv'}
131
+ dataset = load_dataset('csv', data_files=data_files)["test"]
132
+ return dataset.map(map_to_array)
133
+ ```
134
+
135
+ ### Model
136
+
137
+
138
+ ```python
139
+ class STT:
140
+
141
+ def __init__(self,
142
+ model_name,
143
+ device='cuda' if torch.cuda.is_available() else 'cpu',
144
+ lm=None):
145
+ self.model_name = model_name
146
+ self.model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
147
+ self.processor = Wav2Vec2Processor.from_pretrained(model_name)
148
+ self.vocab_dict = self.processor.tokenizer.get_vocab()
149
+ self.sorted_dict = {
150
+ k.lower(): v for k, v in sorted(self.vocab_dict.items(),
151
+ key=lambda item: item[1])
152
+ }
153
+ self.device = device
154
+ self.lm = lm
155
+ if self.lm:
156
+ self.lm_decoder = build_ctcdecoder(
157
+ list(self.sorted_dict.keys()),
158
+ self.lm
159
+ )
160
+
161
+ def batch_predict(self, batch):
162
+ features = self.processor(batch["speech"],
163
+ sampling_rate=batch["sampling_rate"][0],
164
+ padding=True,
165
+ return_tensors="pt")
166
+ input_values = features.input_values.to(self.device)
167
+ attention_mask = features.attention_mask.to(self.device)
168
+ with torch.no_grad():
169
+ logits = self.model(input_values, attention_mask=attention_mask).logits
170
+ if self.lm:
171
+ logits = logits.cpu().numpy()
172
+ batch["predicted"] = []
173
+ for sample_logits in logits:
174
+ batch["predicted"].append(self.lm_decoder.decode(sample_logits))
175
+ else:
176
+ pred_ids = torch.argmax(logits, dim=-1)
177
+ batch["predicted"] = self.processor.batch_decode(pred_ids)
178
+ return batch
179
+ ```
180
+
181
+ ### Download datasets
182
+
183
+
184
+ ```python
185
+ %%capture
186
+ !gdown --id 1HFECzIizf-bmkQRLiQD0QVqcGtOG5upI
187
+ !mkdir bp_dataset
188
+ !unzip bp_dataset -d bp_dataset/
189
+ ```
190
+
191
+ ### Tests
192
+
193
+
194
+ ```python
195
+ stt = STT(MODEL_NAME)
196
+ ```
197
+
198
+ #### CETUC
199
+
200
+
201
+ ```python
202
+ ds = load_data('cetuc_dataset')
203
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
204
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
205
+ print("CETUC WER:", wer)
206
+ ```
207
+ CETUC WER: 0.13846663354859937
208
+
209
+
210
+ #### Common Voice
211
+
212
+
213
+ ```python
214
+ ds = load_data('commonvoice_dataset')
215
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
216
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
217
+ print("CV WER:", wer)
218
+ ```
219
+ CV WER: 0.36960721735520236
220
+
221
+
222
+ #### LaPS
223
+
224
+
225
+ ```python
226
+ ds = load_data('lapsbm_dataset')
227
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
228
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
229
+ print("Laps WER:", wer)
230
+ ```
231
+ Laps WER: 0.16941287878787875
232
+
233
+
234
+ #### MLS
235
+
236
+
237
+ ```python
238
+ ds = load_data('mls_dataset')
239
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
240
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
241
+ print("MLS WER:", wer)
242
+ ```
243
+ MLS WER: 0.16586103382107384
244
+
245
+
246
+ #### SID
247
+
248
+
249
+ ```python
250
+ ds = load_data('sid_dataset')
251
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
252
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
253
+ print("Sid WER:", wer)
254
+ ```
255
+ Sid WER: 0.7943364822145216
256
+
257
+
258
+ #### TEDx
259
+
260
+
261
+ ```python
262
+ ds = load_data('tedx_dataset')
263
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
264
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
265
+ print("TEDx WER:", wer)
266
+ ```
267
+ TEDx WER: 0.22221476803982182
268
+
269
+
270
+ #### VoxForge
271
+
272
+
273
+ ```python
274
+ ds = load_data('voxforge_dataset')
275
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
276
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
277
+ print("VoxForge WER:", wer)
278
+ ```
279
+ VoxForge WER: 0.39486066017315996
280
+
281
+
282
+ ### Tests with LM
283
+
284
+
285
+ ```python
286
+ # !find -type f -name "*.wav" -delete
287
+ !rm -rf ~/.cache
288
+ !gdown --id 1GJIKseP5ZkTbllQVgOL98R4yYAcIySFP # trained with wikipedia
289
+ stt = STT(MODEL_NAME, lm='pt-BR-wiki.word.4-gram.arpa')
290
+ # !gdown --id 1dLFldy7eguPtyJj5OAlI4Emnx0BpFywg # trained with bp
291
+ # stt = STT(MODEL_NAME, lm='pt-BR.word.4-gram.arpa')
292
+ ```
293
+
294
+ #### CETUC
295
+
296
+
297
+ ```python
298
+ ds = load_data('cetuc_dataset')
299
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
300
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
301
+ print("CETUC WER:", wer)
302
+ ```
303
+ CETUC WER: 0.12338749517028079
304
+
305
+
306
+ #### Common Voice
307
+
308
+
309
+ ```python
310
+ ds = load_data('commonvoice_dataset')
311
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
312
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
313
+ print("CV WER:", wer)
314
+ ```
315
+ CV WER: 0.4146185693398481
316
+
317
+
318
+ #### LaPS
319
+
320
+
321
+ ```python
322
+ ds = load_data('lapsbm_dataset')
323
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
324
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
325
+ print("Laps WER:", wer)
326
+ ```
327
+ Laps WER: 0.17142676767676762
328
+
329
+
330
+ #### MLS
331
+
332
+
333
+ ```python
334
+ ds = load_data('mls_dataset')
335
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
336
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
337
+ print("MLS WER:", wer)
338
+ ```
339
+ MLS WER: 0.15212081808962674
340
+
341
+
342
+ #### SID
343
+
344
+
345
+ ```python
346
+ ds = load_data('sid_dataset')
347
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
348
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
349
+ print("Sid WER:", wer)
350
+ ```
351
+ Sid WER: 0.982518441309493
352
+
353
+
354
+ #### TEDx
355
+
356
+
357
+ ```python
358
+ ds = load_data('tedx_dataset')
359
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
360
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
361
+ print("TEDx WER:", wer)
362
+ ```
363
+ TEDx WER: 0.21567860841157235
364
+
365
+
366
+ #### VoxForge
367
+
368
+
369
+ ```python
370
+ ds = load_data('voxforge_dataset')
371
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
372
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
373
+ print("VoxForge WER:", wer)
374
+ ```
375
+ VoxForge WER: 0.3952218614718614
376
+