lgris commited on
Commit
dcc0594
1 Parent(s): 5ce3758

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +370 -0
README.md ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: pt
3
+ datasets:
4
+ - common_voice
5
+ - mls
6
+ - cetuc
7
+ - lapsbm
8
+ - voxforge
9
+ - tedx
10
+ - sid
11
+ metrics:
12
+ - wer
13
+ tags:
14
+ - audio
15
+ - speech
16
+ - wav2vec2
17
+ - pt
18
+ - portuguese-speech-corpus
19
+ - automatic-speech-recognition
20
+ - speech
21
+ - PyTorch
22
+ license: apache-2.0
23
+ ---
24
+
25
+ # sid10-xlsr: Wav2vec 2.0 with Sidney Dataset
26
+
27
+ This is a the demonstration of a fine-tuned Wav2vec model for Brazilian Portuguese using the [Sidney](https://igormq.github.io/datasets/) dataset.
28
+
29
+ In this notebook the model is tested against other available Brazilian Portuguese datasets.
30
+
31
+ | Dataset | Train | Valid | Test |
32
+ |--------------------------------|-------:|------:|------:|
33
+ | CETUC | | -- | 5.4h |
34
+ | Common Voice | | -- | 9.5h |
35
+ | LaPS BM | | -- | 0.1h |
36
+ | MLS | | -- | 3.7h |
37
+ | Multilingual TEDx (Portuguese) | | -- | 1.8h |
38
+ | SID | 7.2h | -- | 1.0h |
39
+ | VoxForge | | -- | 0.1h |
40
+ | Total | 7.2h| -- | 21.6h |
41
+
42
+
43
+ #### Summary
44
+
45
+ | | CETUC | CV | LaPS | MLS | SID | TEDx | VF | AVG |
46
+ |----------------------|---------------|----------------|----------------|----------------|----------------|----------------|----------------|----------------|
47
+ | sid\_10 (demonstration below) |0.186 | 0.327 | 0.207 | 0.505 | 0.124 | 0.835 | 0.472 | 0.379|
48
+ | sid\_10 + 4-gram (demonstration below) |0.096 | 0.223 | 0.115 | 0.432 | 0.101 | 0.791 | 0.348 | 0.301|
49
+
50
+ ## Demonstration
51
+
52
+
53
+ ```python
54
+ MODEL_NAME = "lgris/sid10-xlsr"
55
+ ```
56
+
57
+ ### Imports and dependencies
58
+
59
+
60
+ ```python
61
+ %%capture
62
+ !pip install torch==1.8.2+cu111 torchvision==0.9.2+cu111 torchaudio===0.8.2 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
63
+ !pip install datasets
64
+ !pip install jiwer
65
+ !pip install transformers
66
+ !pip install soundfile
67
+ !pip install pyctcdecode
68
+ !pip install https://github.com/kpu/kenlm/archive/master.zip
69
+ ```
70
+
71
+
72
+ ```python
73
+ import jiwer
74
+ import torchaudio
75
+ from datasets import load_dataset, load_metric
76
+ from transformers import (
77
+ Wav2Vec2ForCTC,
78
+ Wav2Vec2Processor,
79
+ )
80
+ from pyctcdecode import build_ctcdecoder
81
+ import torch
82
+ import re
83
+ import sys
84
+ ```
85
+
86
+ ### Helpers
87
+
88
+
89
+ ```python
90
+ chars_to_ignore_regex = '[\,\?\.\!\;\:\"]' # noqa: W605
91
+
92
+ def map_to_array(batch):
93
+ speech, _ = torchaudio.load(batch["path"])
94
+ batch["speech"] = speech.squeeze(0).numpy()
95
+ batch["sampling_rate"] = 16_000
96
+ batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower().replace("’", "'")
97
+ batch["target"] = batch["sentence"]
98
+ return batch
99
+ ```
100
+
101
+
102
+ ```python
103
+ def calc_metrics(truths, hypos):
104
+ wers = []
105
+ mers = []
106
+ wils = []
107
+ for t, h in zip(truths, hypos):
108
+ try:
109
+ wers.append(jiwer.wer(t, h))
110
+ mers.append(jiwer.mer(t, h))
111
+ wils.append(jiwer.wil(t, h))
112
+ except: # Empty string?
113
+ pass
114
+ wer = sum(wers)/len(wers)
115
+ mer = sum(mers)/len(mers)
116
+ wil = sum(wils)/len(wils)
117
+ return wer, mer, wil
118
+ ```
119
+
120
+
121
+ ```python
122
+ def load_data(dataset):
123
+ data_files = {'test': f'{dataset}/test.csv'}
124
+ dataset = load_dataset('csv', data_files=data_files)["test"]
125
+ return dataset.map(map_to_array)
126
+ ```
127
+
128
+ ### Model
129
+
130
+
131
+ ```python
132
+ class STT:
133
+
134
+ def __init__(self,
135
+ model_name,
136
+ device='cuda' if torch.cuda.is_available() else 'cpu',
137
+ lm=None):
138
+ self.model_name = model_name
139
+ self.model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
140
+ self.processor = Wav2Vec2Processor.from_pretrained(model_name)
141
+ self.vocab_dict = self.processor.tokenizer.get_vocab()
142
+ self.sorted_dict = {
143
+ k.lower(): v for k, v in sorted(self.vocab_dict.items(),
144
+ key=lambda item: item[1])
145
+ }
146
+ self.device = device
147
+ self.lm = lm
148
+ if self.lm:
149
+ self.lm_decoder = build_ctcdecoder(
150
+ list(self.sorted_dict.keys()),
151
+ self.lm
152
+ )
153
+
154
+ def batch_predict(self, batch):
155
+ features = self.processor(batch["speech"],
156
+ sampling_rate=batch["sampling_rate"][0],
157
+ padding=True,
158
+ return_tensors="pt")
159
+ input_values = features.input_values.to(self.device)
160
+ attention_mask = features.attention_mask.to(self.device)
161
+ with torch.no_grad():
162
+ logits = self.model(input_values, attention_mask=attention_mask).logits
163
+ if self.lm:
164
+ logits = logits.cpu().numpy()
165
+ batch["predicted"] = []
166
+ for sample_logits in logits:
167
+ batch["predicted"].append(self.lm_decoder.decode(sample_logits))
168
+ else:
169
+ pred_ids = torch.argmax(logits, dim=-1)
170
+ batch["predicted"] = self.processor.batch_decode(pred_ids)
171
+ return batch
172
+ ```
173
+
174
+ ### Download datasets
175
+
176
+
177
+ ```python
178
+ %%capture
179
+ !gdown --id 1HFECzIizf-bmkQRLiQD0QVqcGtOG5upI
180
+ !mkdir bp_dataset
181
+ !unzip bp_dataset -d bp_dataset/
182
+ ```
183
+
184
+ ### Tests
185
+
186
+
187
+ ```python
188
+ stt = STT(MODEL_NAME)
189
+ ```
190
+
191
+
192
+ #### CETUC
193
+
194
+
195
+ ```python
196
+ ds = load_data('cetuc_dataset')
197
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
198
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
199
+ print("CETUC WER:", wer)
200
+ ```
201
+ CETUC WER: 0.18623689076557778
202
+
203
+
204
+ #### Common Voice
205
+
206
+
207
+ ```python
208
+ ds = load_data('commonvoice_dataset')
209
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
210
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
211
+ print("CV WER:", wer)
212
+ ```
213
+ CV WER: 0.3279775395502392
214
+
215
+
216
+ #### LaPS
217
+
218
+
219
+ ```python
220
+ ds = load_data('lapsbm_dataset')
221
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
222
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
223
+ print("Laps WER:", wer)
224
+ ```
225
+ Laps WER: 0.20780303030303032
226
+
227
+
228
+ #### MLS
229
+
230
+
231
+ ```python
232
+ ds = load_data('mls_dataset')
233
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
234
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
235
+ print("MLS WER:", wer)
236
+ ```
237
+ MLS WER: 0.5056711598536057
238
+
239
+
240
+ #### SID
241
+
242
+
243
+ ```python
244
+ ds = load_data('sid_dataset')
245
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
246
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
247
+ print("Sid WER:", wer)
248
+ ```
249
+ Sid WER: 0.1247776617710105
250
+
251
+
252
+ #### TEDx
253
+
254
+
255
+ ```python
256
+ ds = load_data('tedx_dataset')
257
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
258
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
259
+ print("TEDx WER:", wer)
260
+ ```
261
+ TEDx WER: 0.8350609256842175
262
+
263
+
264
+ #### VoxForge
265
+
266
+
267
+ ```python
268
+ ds = load_data('voxforge_dataset')
269
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
270
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
271
+ print("VoxForge WER:", wer)
272
+ ```
273
+ VoxForge WER: 0.47242153679653687
274
+
275
+
276
+ ### Tests with LM
277
+
278
+
279
+ ```python
280
+ # !find -type f -name "*.wav" -delete
281
+ !rm -rf ~/.cache
282
+ !gdown --id 1GJIKseP5ZkTbllQVgOL98R4yYAcIySFP # trained with wikipedia
283
+ stt = STT(MODEL_NAME, lm='pt-BR-wiki.word.4-gram.arpa')
284
+ # !gdown --id 1dLFldy7eguPtyJj5OAlI4Emnx0BpFywg # trained with bp
285
+ # stt = STT(MODEL_NAME, lm='pt-BR.word.4-gram.arpa')
286
+ ```
287
+
288
+ #### CETUC
289
+
290
+
291
+ ```python
292
+ ds = load_data('cetuc_dataset')
293
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
294
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
295
+ print("CETUC WER:", wer)
296
+ ```
297
+ CETUC WER: 0.09677271347353278
298
+
299
+
300
+ #### Common Voice
301
+
302
+
303
+ ```python
304
+ ds = load_data('commonvoice_dataset')
305
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
306
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
307
+ print("CV WER:", wer)
308
+ ```
309
+ CV WER: 0.22363215674470321
310
+
311
+
312
+ #### LaPS
313
+
314
+
315
+ ```python
316
+ ds = load_data('lapsbm_dataset')
317
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
318
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
319
+ print("Laps WER:", wer)
320
+ ```
321
+ Laps WER: 0.1154924242424242
322
+
323
+
324
+ #### MLS
325
+
326
+
327
+ ```python
328
+ ds = load_data('mls_dataset')
329
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
330
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
331
+ print("MLS WER:", wer)
332
+ ```
333
+ MLS WER: 0.4322369152606427
334
+
335
+
336
+ #### SID
337
+
338
+
339
+ ```python
340
+ ds = load_data('sid_dataset')
341
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
342
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
343
+ print("Sid WER:", wer)
344
+ ```
345
+ Sid WER: 0.10080313085145765
346
+
347
+
348
+ #### TEDx
349
+
350
+
351
+ ```python
352
+ ds = load_data('tedx_dataset')
353
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
354
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
355
+ print("TEDx WER:", wer)
356
+ ```
357
+ TEDx WER: 0.7911789829264236
358
+
359
+
360
+ #### VoxForge
361
+
362
+
363
+ ```python
364
+ ds = load_data('voxforge_dataset')
365
+ result = ds.map(stt.batch_predict, batched=True, batch_size=8)
366
+ wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
367
+ print("VoxForge WER:", wer)
368
+ ```
369
+ VoxForge WER: 0.34786255411255407
370
+