lighteternal commited on
Commit
8aa27c7
1 Parent(s): c9229a1

Added inference script

Browse files
.ipynb_checkpoints/ASR_Inference-checkpoint.ipynb ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {
7
+ "ExecuteTime": {
8
+ "end_time": "2021-03-14T09:33:41.892030Z",
9
+ "start_time": "2021-03-14T09:33:40.729163Z"
10
+ }
11
+ },
12
+ "outputs": [
13
+ {
14
+ "name": "stderr",
15
+ "output_type": "stream",
16
+ "text": [
17
+ "/home/earendil/anaconda3/envs/cuda110/lib/python3.8/site-packages/torchaudio/backend/utils.py:53: UserWarning: \"sox\" backend is being deprecated. The default backend will be changed to \"sox_io\" backend in 0.8.0 and \"sox\" backend will be removed in 0.9.0. Please migrate to \"sox_io\" backend. Please refer to https://github.com/pytorch/audio/issues/903 for the detail.\n",
18
+ " warnings.warn(\n"
19
+ ]
20
+ }
21
+ ],
22
+ "source": [
23
+ "from transformers import Wav2Vec2ForCTC\n",
24
+ "from transformers import Wav2Vec2Processor\n",
25
+ "from datasets import load_dataset, load_metric\n",
26
+ "import re\n",
27
+ "import torchaudio\n",
28
+ "import librosa\n",
29
+ "import numpy as np\n",
30
+ "from datasets import load_dataset, load_metric\n",
31
+ "import torch"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": 2,
37
+ "metadata": {
38
+ "ExecuteTime": {
39
+ "end_time": "2021-03-14T09:33:41.909851Z",
40
+ "start_time": "2021-03-14T09:33:41.906327Z"
41
+ }
42
+ },
43
+ "outputs": [],
44
+ "source": [
45
+ "chars_to_ignore_regex = '[\\,\\?\\.\\!\\-\\;\\:\\\"\\“\\%\\‘\\”\\�]'\n",
46
+ "\n",
47
+ "def remove_special_characters(batch):\n",
48
+ " batch[\"text\"] = re.sub(chars_to_ignore_regex, '', batch[\"sentence\"]).lower() + \" \"\n",
49
+ " return batch\n",
50
+ "\n",
51
+ "def speech_file_to_array_fn(batch):\n",
52
+ " speech_array, sampling_rate = torchaudio.load(batch[\"path\"])\n",
53
+ " batch[\"speech\"] = speech_array[0].numpy()\n",
54
+ " batch[\"sampling_rate\"] = sampling_rate\n",
55
+ " batch[\"target_text\"] = batch[\"text\"]\n",
56
+ " return batch\n",
57
+ "\n",
58
+ "def resample(batch):\n",
59
+ " batch[\"speech\"] = librosa.resample(np.asarray(batch[\"speech\"]), 48_000, 16_000)\n",
60
+ " batch[\"sampling_rate\"] = 16_000\n",
61
+ " return batch\n",
62
+ "\n",
63
+ "def prepare_dataset(batch):\n",
64
+ " # check that all files have the correct sampling rate\n",
65
+ " assert (\n",
66
+ " len(set(batch[\"sampling_rate\"])) == 1\n",
67
+ " ), f\"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}.\"\n",
68
+ "\n",
69
+ " batch[\"input_values\"] = processor(batch[\"speech\"], sampling_rate=batch[\"sampling_rate\"][0]).input_values\n",
70
+ " \n",
71
+ " with processor.as_target_processor():\n",
72
+ " batch[\"labels\"] = processor(batch[\"target_text\"]).input_ids\n",
73
+ " return batch"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": 3,
79
+ "metadata": {
80
+ "ExecuteTime": {
81
+ "end_time": "2021-03-14T09:33:49.053762Z",
82
+ "start_time": "2021-03-14T09:33:41.922683Z"
83
+ }
84
+ },
85
+ "outputs": [
86
+ {
87
+ "name": "stderr",
88
+ "output_type": "stream",
89
+ "text": [
90
+ "Special tokens have been added in the vocabulary, make sure the associated word embedding are fine-tuned or trained.\n"
91
+ ]
92
+ }
93
+ ],
94
+ "source": [
95
+ "model = Wav2Vec2ForCTC.from_pretrained(\"wav2vec2-large-xlsr-greek/checkpoint-9200/\").to(\"cuda\")\n",
96
+ "processor = Wav2Vec2Processor.from_pretrained(\"wav2vec2-large-xlsr-greek/\")"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": 4,
102
+ "metadata": {
103
+ "ExecuteTime": {
104
+ "end_time": "2021-03-14T09:33:52.413558Z",
105
+ "start_time": "2021-03-14T09:33:49.078466Z"
106
+ }
107
+ },
108
+ "outputs": [
109
+ {
110
+ "name": "stderr",
111
+ "output_type": "stream",
112
+ "text": [
113
+ "Using custom data configuration el-afd0a157f05ee080\n",
114
+ "Reusing dataset common_voice (/home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564)\n"
115
+ ]
116
+ }
117
+ ],
118
+ "source": [
119
+ "common_voice_test = load_dataset(\"common_voice\", \"el\", data_dir=\"cv-corpus-6.1-2020-12-11\", split=\"test\")"
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "code",
124
+ "execution_count": 5,
125
+ "metadata": {
126
+ "ExecuteTime": {
127
+ "end_time": "2021-03-14T09:33:52.444418Z",
128
+ "start_time": "2021-03-14T09:33:52.441338Z"
129
+ }
130
+ },
131
+ "outputs": [],
132
+ "source": [
133
+ "common_voice_test = common_voice_test.remove_columns([\"accent\", \"age\", \"client_id\", \"down_votes\", \"gender\", \"locale\", \"segment\", \"up_votes\"])"
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "code",
138
+ "execution_count": 6,
139
+ "metadata": {
140
+ "ExecuteTime": {
141
+ "end_time": "2021-03-14T09:33:52.473087Z",
142
+ "start_time": "2021-03-14T09:33:52.468014Z"
143
+ }
144
+ },
145
+ "outputs": [
146
+ {
147
+ "name": "stderr",
148
+ "output_type": "stream",
149
+ "text": [
150
+ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-0ce2ebca66096fff.arrow\n"
151
+ ]
152
+ }
153
+ ],
154
+ "source": [
155
+ "common_voice_test = common_voice_test.map(remove_special_characters, remove_columns=[\"sentence\"])"
156
+ ]
157
+ },
158
+ {
159
+ "cell_type": "code",
160
+ "execution_count": 7,
161
+ "metadata": {
162
+ "ExecuteTime": {
163
+ "end_time": "2021-03-14T09:33:52.510377Z",
164
+ "start_time": "2021-03-14T09:33:52.501677Z"
165
+ }
166
+ },
167
+ "outputs": [
168
+ {
169
+ "name": "stderr",
170
+ "output_type": "stream",
171
+ "text": [
172
+ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-38a09981767eff59.arrow\n"
173
+ ]
174
+ }
175
+ ],
176
+ "source": [
177
+ "common_voice_test = common_voice_test.map(speech_file_to_array_fn, remove_columns=common_voice_test.column_names)"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "code",
182
+ "execution_count": 8,
183
+ "metadata": {
184
+ "ExecuteTime": {
185
+ "end_time": "2021-03-14T09:33:53.321810Z",
186
+ "start_time": "2021-03-14T09:33:52.533233Z"
187
+ }
188
+ },
189
+ "outputs": [
190
+ {
191
+ "name": "stdout",
192
+ "output_type": "stream",
193
+ "text": [
194
+ " "
195
+ ]
196
+ },
197
+ {
198
+ "name": "stderr",
199
+ "output_type": "stream",
200
+ "text": [
201
+ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-ba8c6dd59eb8ccf2.arrow\n",
202
+ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-2e240883a5f827fd.arrow\n"
203
+ ]
204
+ },
205
+ {
206
+ "name": "stdout",
207
+ "output_type": "stream",
208
+ "text": [
209
+ " "
210
+ ]
211
+ },
212
+ {
213
+ "name": "stderr",
214
+ "output_type": "stream",
215
+ "text": [
216
+ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-485c00dc9048ed50.arrow\n",
217
+ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-44bf1791baae8e2e.arrow\n"
218
+ ]
219
+ },
220
+ {
221
+ "name": "stdout",
222
+ "output_type": "stream",
223
+ "text": [
224
+ " "
225
+ ]
226
+ },
227
+ {
228
+ "name": "stderr",
229
+ "output_type": "stream",
230
+ "text": [
231
+ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-ecc0dfac5615a58e.arrow\n"
232
+ ]
233
+ },
234
+ {
235
+ "name": "stdout",
236
+ "output_type": "stream",
237
+ "text": [
238
+ " "
239
+ ]
240
+ },
241
+ {
242
+ "name": "stderr",
243
+ "output_type": "stream",
244
+ "text": [
245
+ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-923d905502a8661d.arrow\n"
246
+ ]
247
+ },
248
+ {
249
+ "name": "stdout",
250
+ "output_type": "stream",
251
+ "text": [
252
+ " "
253
+ ]
254
+ },
255
+ {
256
+ "name": "stderr",
257
+ "output_type": "stream",
258
+ "text": [
259
+ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-062aeafc3b8816c1.arrow\n"
260
+ ]
261
+ },
262
+ {
263
+ "name": "stdout",
264
+ "output_type": "stream",
265
+ "text": [
266
+ " "
267
+ ]
268
+ },
269
+ {
270
+ "name": "stderr",
271
+ "output_type": "stream",
272
+ "text": [
273
+ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-bb54bb00dae79669.arrow\n"
274
+ ]
275
+ }
276
+ ],
277
+ "source": [
278
+ "common_voice_test = common_voice_test.map(resample, num_proc=8)"
279
+ ]
280
+ },
281
+ {
282
+ "cell_type": "code",
283
+ "execution_count": 9,
284
+ "metadata": {
285
+ "ExecuteTime": {
286
+ "end_time": "2021-03-14T09:33:53.611415Z",
287
+ "start_time": "2021-03-14T09:33:53.342487Z"
288
+ }
289
+ },
290
+ "outputs": [
291
+ {
292
+ "name": "stderr",
293
+ "output_type": "stream",
294
+ "text": [
295
+ "/home/earendil/anaconda3/envs/cuda110/lib/python3.8/site-packages/numpy/core/_asarray.py:83: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n",
296
+ " return array(a, dtype, copy=False, order=order)\n"
297
+ ]
298
+ },
299
+ {
300
+ "name": "stdout",
301
+ "output_type": "stream",
302
+ "text": [
303
+ " "
304
+ ]
305
+ },
306
+ {
307
+ "name": "stderr",
308
+ "output_type": "stream",
309
+ "text": [
310
+ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-6dfad29ca815f865.arrow\n"
311
+ ]
312
+ },
313
+ {
314
+ "name": "stdout",
315
+ "output_type": "stream",
316
+ "text": [
317
+ " "
318
+ ]
319
+ },
320
+ {
321
+ "name": "stderr",
322
+ "output_type": "stream",
323
+ "text": [
324
+ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-61e9ae0296df46f8.arrow\n"
325
+ ]
326
+ },
327
+ {
328
+ "name": "stdout",
329
+ "output_type": "stream",
330
+ "text": [
331
+ " "
332
+ ]
333
+ },
334
+ {
335
+ "name": "stderr",
336
+ "output_type": "stream",
337
+ "text": [
338
+ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-7f5aae16804e0788.arrow\n"
339
+ ]
340
+ },
341
+ {
342
+ "name": "stdout",
343
+ "output_type": "stream",
344
+ "text": [
345
+ " "
346
+ ]
347
+ },
348
+ {
349
+ "name": "stderr",
350
+ "output_type": "stream",
351
+ "text": [
352
+ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-b9636a5d30ffb973.arrow\n"
353
+ ]
354
+ },
355
+ {
356
+ "name": "stdout",
357
+ "output_type": "stream",
358
+ "text": [
359
+ " "
360
+ ]
361
+ },
362
+ {
363
+ "name": "stderr",
364
+ "output_type": "stream",
365
+ "text": [
366
+ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-7e60f2d73a65610a.arrow\n"
367
+ ]
368
+ },
369
+ {
370
+ "name": "stdout",
371
+ "output_type": "stream",
372
+ "text": [
373
+ " "
374
+ ]
375
+ },
376
+ {
377
+ "name": "stderr",
378
+ "output_type": "stream",
379
+ "text": [
380
+ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-3c99781789816a60.arrow\n"
381
+ ]
382
+ },
383
+ {
384
+ "name": "stdout",
385
+ "output_type": "stream",
386
+ "text": [
387
+ " "
388
+ ]
389
+ },
390
+ {
391
+ "name": "stderr",
392
+ "output_type": "stream",
393
+ "text": [
394
+ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-bae077f32f9eb290.arrow\n"
395
+ ]
396
+ },
397
+ {
398
+ "name": "stdout",
399
+ "output_type": "stream",
400
+ "text": [
401
+ " "
402
+ ]
403
+ },
404
+ {
405
+ "name": "stderr",
406
+ "output_type": "stream",
407
+ "text": [
408
+ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-4fb6951626f7548e.arrow\n"
409
+ ]
410
+ }
411
+ ],
412
+ "source": [
413
+ "common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names, batch_size=8, num_proc=8, batched=True)"
414
+ ]
415
+ },
416
+ {
417
+ "cell_type": "code",
418
+ "execution_count": 10,
419
+ "metadata": {
420
+ "ExecuteTime": {
421
+ "end_time": "2021-03-14T09:33:56.243678Z",
422
+ "start_time": "2021-03-14T09:33:53.632436Z"
423
+ }
424
+ },
425
+ "outputs": [
426
+ {
427
+ "name": "stderr",
428
+ "output_type": "stream",
429
+ "text": [
430
+ "Using custom data configuration el-ac779bf2c9f7c09b\n",
431
+ "Reusing dataset common_voice (/home/earendil/.cache/huggingface/datasets/common_voice/el-ac779bf2c9f7c09b/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564)\n"
432
+ ]
433
+ }
434
+ ],
435
+ "source": [
436
+ "common_voice_test_transcription = load_dataset(\"common_voice\", \"el\", data_dir=\"./cv-corpus-6.1-2020-12-11\", split=\"test\")"
437
+ ]
438
+ },
439
+ {
440
+ "cell_type": "code",
441
+ "execution_count": 19,
442
+ "metadata": {
443
+ "ExecuteTime": {
444
+ "end_time": "2021-03-14T09:36:50.076837Z",
445
+ "start_time": "2021-03-14T09:36:24.943947Z"
446
+ }
447
+ },
448
+ "outputs": [],
449
+ "source": [
450
+ "# Change this value to try inference on different CommonVoice extracts\n",
451
+ "example = 123\n",
452
+ "\n",
453
+ "input_dict = processor(common_voice_test[\"input_values\"][example], return_tensors=\"pt\", sampling_rate=16_000, padding=True)\n",
454
+ "\n",
455
+ "logits = model(input_dict.input_values.to(\"cuda\")).logits\n",
456
+ "\n",
457
+ "pred_ids = torch.argmax(logits, dim=-1)"
458
+ ]
459
+ },
460
+ {
461
+ "cell_type": "code",
462
+ "execution_count": 20,
463
+ "metadata": {
464
+ "ExecuteTime": {
465
+ "end_time": "2021-03-14T09:36:50.137886Z",
466
+ "start_time": "2021-03-14T09:36:50.134218Z"
467
+ }
468
+ },
469
+ "outputs": [
470
+ {
471
+ "name": "stdout",
472
+ "output_type": "stream",
473
+ "text": [
474
+ "Prediction:\n",
475
+ "καμιά φορά τα έπαιρνε και έπαιζε όταν η δουλειά ήταν πιο χαλαρί\n",
476
+ "\n",
477
+ "Reference:\n",
478
+ "καμιά φορά τα έπαιρνε και έπαιζε όταν η δουλειά ήταν πιο χαλαρή\n"
479
+ ]
480
+ }
481
+ ],
482
+ "source": [
483
+ "print(\"Prediction:\")\n",
484
+ "print(processor.decode(pred_ids[0]))\n",
485
+ "# καμιά φορά τα έπαιρνε και έπαιζε όταν η δουλειά ήταν πιο χαλαρί\n",
486
+ "\n",
487
+ "print(\"\\nReference:\")\n",
488
+ "print(common_voice_test_transcription[\"sentence\"][example].lower())\n",
489
+ "# καμιά φορά τα έπαιρνε και έπαιζε όταν η δουλειά ήταν πιο χαλαρή"
490
+ ]
491
+ },
492
+ {
493
+ "cell_type": "code",
494
+ "execution_count": null,
495
+ "metadata": {},
496
+ "outputs": [],
497
+ "source": []
498
+ }
499
+ ],
500
+ "metadata": {
501
+ "kernelspec": {
502
+ "display_name": "cuda110",
503
+ "language": "python",
504
+ "name": "cuda110"
505
+ },
506
+ "language_info": {
507
+ "codemirror_mode": {
508
+ "name": "ipython",
509
+ "version": 3
510
+ },
511
+ "file_extension": ".py",
512
+ "mimetype": "text/x-python",
513
+ "name": "python",
514
+ "nbconvert_exporter": "python",
515
+ "pygments_lexer": "ipython3",
516
+ "version": "3.8.5"
517
+ },
518
+ "varInspector": {
519
+ "cols": {
520
+ "lenName": 16,
521
+ "lenType": 16,
522
+ "lenVar": 40
523
+ },
524
+ "kernels_config": {
525
+ "python": {
526
+ "delete_cmd_postfix": "",
527
+ "delete_cmd_prefix": "del ",
528
+ "library": "var_list.py",
529
+ "varRefreshCmd": "print(var_dic_list())"
530
+ },
531
+ "r": {
532
+ "delete_cmd_postfix": ") ",
533
+ "delete_cmd_prefix": "rm(",
534
+ "library": "var_list.r",
535
+ "varRefreshCmd": "cat(var_dic_list()) "
536
+ }
537
+ },
538
+ "types_to_exclude": [
539
+ "module",
540
+ "function",
541
+ "builtin_function_or_method",
542
+ "instance",
543
+ "_Feature"
544
+ ],
545
+ "window_display": false
546
+ }
547
+ },
548
+ "nbformat": 4,
549
+ "nbformat_minor": 4
550
+ }
ASR_Inference.ipynb ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {
7
+ "ExecuteTime": {
8
+ "end_time": "2021-03-14T09:33:41.892030Z",
9
+ "start_time": "2021-03-14T09:33:40.729163Z"
10
+ }
11
+ },
12
+ "outputs": [],
13
+ "source": [
14
+ "from transformers import Wav2Vec2ForCTC\n",
15
+ "from transformers import Wav2Vec2Processor\n",
16
+ "from datasets import load_dataset, load_metric\n",
17
+ "import re\n",
18
+ "import torchaudio\n",
19
+ "import librosa\n",
20
+ "import numpy as np\n",
21
+ "from datasets import load_dataset, load_metric\n",
22
+ "import torch"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": null,
28
+ "metadata": {
29
+ "ExecuteTime": {
30
+ "end_time": "2021-03-14T09:33:41.909851Z",
31
+ "start_time": "2021-03-14T09:33:41.906327Z"
32
+ }
33
+ },
34
+ "outputs": [],
35
+ "source": [
36
+ "chars_to_ignore_regex = '[\\,\\?\\.\\!\\-\\;\\:\\\"\\“\\%\\‘\\”\\�]'\n",
37
+ "\n",
38
+ "def remove_special_characters(batch):\n",
39
+ " batch[\"text\"] = re.sub(chars_to_ignore_regex, '', batch[\"sentence\"]).lower() + \" \"\n",
40
+ " return batch\n",
41
+ "\n",
42
+ "def speech_file_to_array_fn(batch):\n",
43
+ " speech_array, sampling_rate = torchaudio.load(batch[\"path\"])\n",
44
+ " batch[\"speech\"] = speech_array[0].numpy()\n",
45
+ " batch[\"sampling_rate\"] = sampling_rate\n",
46
+ " batch[\"target_text\"] = batch[\"text\"]\n",
47
+ " return batch\n",
48
+ "\n",
49
+ "def resample(batch):\n",
50
+ " batch[\"speech\"] = librosa.resample(np.asarray(batch[\"speech\"]), 48_000, 16_000)\n",
51
+ " batch[\"sampling_rate\"] = 16_000\n",
52
+ " return batch\n",
53
+ "\n",
54
+ "def prepare_dataset(batch):\n",
55
+ " # check that all files have the correct sampling rate\n",
56
+ " assert (\n",
57
+ " len(set(batch[\"sampling_rate\"])) == 1\n",
58
+ " ), f\"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}.\"\n",
59
+ "\n",
60
+ " batch[\"input_values\"] = processor(batch[\"speech\"], sampling_rate=batch[\"sampling_rate\"][0]).input_values\n",
61
+ " \n",
62
+ " with processor.as_target_processor():\n",
63
+ " batch[\"labels\"] = processor(batch[\"target_text\"]).input_ids\n",
64
+ " return batch"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "code",
69
+ "execution_count": null,
70
+ "metadata": {
71
+ "ExecuteTime": {
72
+ "end_time": "2021-03-14T09:33:49.053762Z",
73
+ "start_time": "2021-03-14T09:33:41.922683Z"
74
+ }
75
+ },
76
+ "outputs": [],
77
+ "source": [
78
+ "model = Wav2Vec2ForCTC.from_pretrained(\"wav2vec2-large-xlsr-greek/checkpoint-9200/\").to(\"cuda\")\n",
79
+ "processor = Wav2Vec2Processor.from_pretrained(\"wav2vec2-large-xlsr-greek/\")"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "code",
84
+ "execution_count": null,
85
+ "metadata": {
86
+ "ExecuteTime": {
87
+ "end_time": "2021-03-14T09:33:52.413558Z",
88
+ "start_time": "2021-03-14T09:33:49.078466Z"
89
+ }
90
+ },
91
+ "outputs": [],
92
+ "source": [
93
+ "common_voice_test = load_dataset(\"common_voice\", \"el\", data_dir=\"cv-corpus-6.1-2020-12-11\", split=\"test\")"
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "code",
98
+ "execution_count": null,
99
+ "metadata": {
100
+ "ExecuteTime": {
101
+ "end_time": "2021-03-14T09:33:52.444418Z",
102
+ "start_time": "2021-03-14T09:33:52.441338Z"
103
+ }
104
+ },
105
+ "outputs": [],
106
+ "source": [
107
+ "common_voice_test = common_voice_test.remove_columns([\"accent\", \"age\", \"client_id\", \"down_votes\", \"gender\", \"locale\", \"segment\", \"up_votes\"])"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "code",
112
+ "execution_count": null,
113
+ "metadata": {
114
+ "ExecuteTime": {
115
+ "end_time": "2021-03-14T09:33:52.473087Z",
116
+ "start_time": "2021-03-14T09:33:52.468014Z"
117
+ }
118
+ },
119
+ "outputs": [],
120
+ "source": [
121
+ "common_voice_test = common_voice_test.map(remove_special_characters, remove_columns=[\"sentence\"])"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "code",
126
+ "execution_count": null,
127
+ "metadata": {
128
+ "ExecuteTime": {
129
+ "end_time": "2021-03-14T09:33:52.510377Z",
130
+ "start_time": "2021-03-14T09:33:52.501677Z"
131
+ }
132
+ },
133
+ "outputs": [],
134
+ "source": [
135
+ "common_voice_test = common_voice_test.map(speech_file_to_array_fn, remove_columns=common_voice_test.column_names)"
136
+ ]
137
+ },
138
+ {
139
+ "cell_type": "code",
140
+ "execution_count": null,
141
+ "metadata": {
142
+ "ExecuteTime": {
143
+ "end_time": "2021-03-14T09:33:53.321810Z",
144
+ "start_time": "2021-03-14T09:33:52.533233Z"
145
+ }
146
+ },
147
+ "outputs": [],
148
+ "source": [
149
+ "common_voice_test = common_voice_test.map(resample, num_proc=8)"
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "code",
154
+ "execution_count": null,
155
+ "metadata": {
156
+ "ExecuteTime": {
157
+ "end_time": "2021-03-14T09:33:53.611415Z",
158
+ "start_time": "2021-03-14T09:33:53.342487Z"
159
+ }
160
+ },
161
+ "outputs": [],
162
+ "source": [
163
+ "common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names, batch_size=8, num_proc=8, batched=True)"
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "code",
168
+ "execution_count": null,
169
+ "metadata": {
170
+ "ExecuteTime": {
171
+ "end_time": "2021-03-14T09:33:56.243678Z",
172
+ "start_time": "2021-03-14T09:33:53.632436Z"
173
+ }
174
+ },
175
+ "outputs": [],
176
+ "source": [
177
+ "common_voice_test_transcription = load_dataset(\"common_voice\", \"el\", data_dir=\"./cv-corpus-6.1-2020-12-11\", split=\"test\")"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "code",
182
+ "execution_count": null,
183
+ "metadata": {
184
+ "ExecuteTime": {
185
+ "end_time": "2021-03-14T09:36:50.076837Z",
186
+ "start_time": "2021-03-14T09:36:24.943947Z"
187
+ }
188
+ },
189
+ "outputs": [],
190
+ "source": [
191
+ "# Change this value to try inference on different CommonVoice extracts\n",
192
+ "example = 123\n",
193
+ "\n",
194
+ "input_dict = processor(common_voice_test[\"input_values\"][example], return_tensors=\"pt\", sampling_rate=16_000, padding=True)\n",
195
+ "\n",
196
+ "logits = model(input_dict.input_values.to(\"cuda\")).logits\n",
197
+ "\n",
198
+ "pred_ids = torch.argmax(logits, dim=-1)"
199
+ ]
200
+ },
201
+ {
202
+ "cell_type": "code",
203
+ "execution_count": null,
204
+ "metadata": {
205
+ "ExecuteTime": {
206
+ "end_time": "2021-03-14T09:36:50.137886Z",
207
+ "start_time": "2021-03-14T09:36:50.134218Z"
208
+ }
209
+ },
210
+ "outputs": [],
211
+ "source": [
212
+ "print(\"Prediction:\")\n",
213
+ "print(processor.decode(pred_ids[0]))\n",
214
+ "# καμιά φορά τα έπαιρνε και έπαιζε όταν η δουλειά ήταν πιο χαλαρί\n",
215
+ "\n",
216
+ "print(\"\\nReference:\")\n",
217
+ "print(common_voice_test_transcription[\"sentence\"][example].lower())\n",
218
+ "# καμιά φορά τα έπαιρνε και έπαιζε όταν η δουλειά ήταν πιο χαλαρή"
219
+ ]
220
+ },
221
+ {
222
+ "cell_type": "code",
223
+ "execution_count": null,
224
+ "metadata": {},
225
+ "outputs": [],
226
+ "source": []
227
+ }
228
+ ],
229
+ "metadata": {
230
+ "kernelspec": {
231
+ "display_name": "cuda110",
232
+ "language": "python",
233
+ "name": "cuda110"
234
+ },
235
+ "language_info": {
236
+ "codemirror_mode": {
237
+ "name": "ipython",
238
+ "version": 3
239
+ },
240
+ "file_extension": ".py",
241
+ "mimetype": "text/x-python",
242
+ "name": "python",
243
+ "nbconvert_exporter": "python",
244
+ "pygments_lexer": "ipython3",
245
+ "version": "3.8.5"
246
+ },
247
+ "varInspector": {
248
+ "cols": {
249
+ "lenName": 16,
250
+ "lenType": 16,
251
+ "lenVar": 40
252
+ },
253
+ "kernels_config": {
254
+ "python": {
255
+ "delete_cmd_postfix": "",
256
+ "delete_cmd_prefix": "del ",
257
+ "library": "var_list.py",
258
+ "varRefreshCmd": "print(var_dic_list())"
259
+ },
260
+ "r": {
261
+ "delete_cmd_postfix": ") ",
262
+ "delete_cmd_prefix": "rm(",
263
+ "library": "var_list.r",
264
+ "varRefreshCmd": "cat(var_dic_list()) "
265
+ }
266
+ },
267
+ "types_to_exclude": [
268
+ "module",
269
+ "function",
270
+ "builtin_function_or_method",
271
+ "instance",
272
+ "_Feature"
273
+ ],
274
+ "window_display": false
275
+ }
276
+ },
277
+ "nbformat": 4,
278
+ "nbformat_minor": 4
279
+ }
README.md CHANGED
@@ -24,9 +24,107 @@ Wav2Vec2 is a pretrained model for Automatic Speech Recognition (ASR) and was re
24
 
25
  Similar to Wav2Vec2, XLSR-Wav2Vec2 learns powerful speech representations from hundreds of thousands of hours of speech in more than 50 languages of unlabeled speech. Similar, to BERT's masked language modeling, the model learns contextualized speech representations by randomly masking feature vectors before passing them to a transformer network.
26
 
27
- ### How to use
28
 
29
- Instructions to replicate the process are included in the Jupyter notebook.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
 
32
  ## Metrics
@@ -38,6 +136,6 @@ Instructions to replicate the process are included in the Jupyter notebook.
38
  | WER | 0.45049 |
39
 
40
 
41
- ### BibTeX entry and citation info
42
  Based on the tutorial of Patrick von Platen: https://huggingface.co/blog/fine-tune-xlsr-wav2vec2
43
  Original colab notebook here: https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/Fine_Tune_XLSR_Wav2Vec2_on_Turkish_ASR_with_%F0%9F%A4%97_Transformers.ipynb#scrollTo=V7YOT2mnUiea
 
24
 
25
  Similar to Wav2Vec2, XLSR-Wav2Vec2 learns powerful speech representations from hundreds of thousands of hours of speech in more than 50 languages of unlabeled speech. Similar, to BERT's masked language modeling, the model learns contextualized speech representations by randomly masking feature vectors before passing them to a transformer network.
26
 
27
+ ### How to use for inference:
28
 
29
+ Instructions to test on CommonVoice extracts are provided in the ASR_Inference.ipynb. Snippet also available below:
30
+
31
+ ```
32
+ #!/usr/bin/env python
33
+ # coding: utf-8
34
+
35
+ # Loading dependencies and defining preprocessing functions
36
+
37
+ from transformers import Wav2Vec2ForCTC
38
+ from transformers import Wav2Vec2Processor
39
+ from datasets import load_dataset, load_metric
40
+ import re
41
+ import torchaudio
42
+ import librosa
43
+ import numpy as np
44
+ from datasets import load_dataset, load_metric
45
+ import torch
46
+
47
+ chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�]'
48
+
49
+ def remove_special_characters(batch):
50
+ batch["text"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower() + " "
51
+ return batch
52
+
53
+ def speech_file_to_array_fn(batch):
54
+ speech_array, sampling_rate = torchaudio.load(batch["path"])
55
+ batch["speech"] = speech_array[0].numpy()
56
+ batch["sampling_rate"] = sampling_rate
57
+ batch["target_text"] = batch["text"]
58
+ return batch
59
+
60
+ def resample(batch):
61
+ batch["speech"] = librosa.resample(np.asarray(batch["speech"]), 48_000, 16_000)
62
+ batch["sampling_rate"] = 16_000
63
+ return batch
64
+
65
+ def prepare_dataset(batch):
66
+ # check that all files have the correct sampling rate
67
+ assert (
68
+ len(set(batch["sampling_rate"])) == 1
69
+ ), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."
70
+
71
+ batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values
72
+
73
+ with processor.as_target_processor():
74
+ batch["labels"] = processor(batch["target_text"]).input_ids
75
+ return batch
76
+
77
+
78
+ # Loading model and dataset processor
79
+
80
+ model = Wav2Vec2ForCTC.from_pretrained("wav2vec2-large-xlsr-greek/checkpoint-9200/").to("cuda")
81
+ processor = Wav2Vec2Processor.from_pretrained("wav2vec2-large-xlsr-greek/")
82
+
83
+
84
+ # Preparing speech dataset to be suitable for inference
85
+
86
+ common_voice_test = load_dataset("common_voice", "el", data_dir="cv-corpus-6.1-2020-12-11", split="test")
87
+
88
+ common_voice_test = common_voice_test.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])
89
+
90
+ common_voice_test = common_voice_test.map(remove_special_characters, remove_columns=["sentence"])
91
+
92
+ common_voice_test = common_voice_test.map(speech_file_to_array_fn, remove_columns=common_voice_test.column_names)
93
+
94
+ common_voice_test = common_voice_test.map(resample, num_proc=8)
95
+
96
+ common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names, batch_size=8, num_proc=8, batched=True)
97
+
98
+
99
+ # Loading test dataset
100
+
101
+ common_voice_test_transcription = load_dataset("common_voice", "el", data_dir="./cv-corpus-6.1-2020-12-11", split="test")
102
+
103
+
104
+ #Performing inference on a random sample. Change the "example" value to try inference on different CommonVoice extracts
105
+
106
+ example = 123
107
+
108
+ input_dict = processor(common_voice_test["input_values"][example], return_tensors="pt", sampling_rate=16_000, padding=True)
109
+
110
+ logits = model(input_dict.input_values.to("cuda")).logits
111
+
112
+ pred_ids = torch.argmax(logits, dim=-1)
113
+
114
+ print("Prediction:")
115
+ print(processor.decode(pred_ids[0]))
116
+ # καμιά φορά τα έπαιρνε και έπαιζε όταν η δουλειά ήταν πιο χαλαρί
117
+
118
+ print("\nReference:")
119
+ print(common_voice_test_transcription["sentence"][example].lower())
120
+ # καμιά φορά τα έπαιρνε και έπαιζε όταν η δουλειά ήταν πιο χαλαρή
121
+
122
+
123
+ ```
124
+
125
+ ### How to use for training:
126
+
127
+ Instructions and code to replicate the process are provided in the Fine_Tune_XLSR_Wav2Vec2_on_Greek_ASR_with_🤗_Transformers.ipynb notebook.
128
 
129
 
130
  ## Metrics
 
136
  | WER | 0.45049 |
137
 
138
 
139
+ ### Acknowledgment
140
  Based on the tutorial of Patrick von Platen: https://huggingface.co/blog/fine-tune-xlsr-wav2vec2
141
  Original colab notebook here: https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/Fine_Tune_XLSR_Wav2Vec2_on_Turkish_ASR_with_%F0%9F%A4%97_Transformers.ipynb#scrollTo=V7YOT2mnUiea