emilios commited on
Commit
60b3f40
1 Parent(s): baaee71

Training in progress, step 6000

Browse files
config.json CHANGED
@@ -34,6 +34,7 @@
34
  "num_mel_bins": 80,
35
  "pad_token_id": 50257,
36
  "scale_embedding": false,
 
37
  "torch_dtype": "float32",
38
  "transformers_version": "4.26.0.dev0",
39
  "use_cache": false,
 
34
  "num_mel_bins": 80,
35
  "pad_token_id": 50257,
36
  "scale_embedding": false,
37
+ "suppress_tokens": [],
38
  "torch_dtype": "float32",
39
  "transformers_version": "4.26.0.dev0",
40
  "use_cache": false,
e5_interleaving-cl.ipynb CHANGED
@@ -92,8 +92,8 @@
92
  "Building dependency tree \n",
93
  "Reading state information... Done\n",
94
  "git-lfs is already the newest version (2.9.2-1).\n",
95
- "0 upgraded, 0 newly installed, 0 to remove and 154 not upgraded.\n",
96
- "Error: Failed to call git rev-parse --git-dir: exit status 128 \n",
97
  "Git LFS initialized.\n"
98
  ]
99
  }
@@ -341,7 +341,7 @@
341
  },
342
  {
343
  "cell_type": "code",
344
- "execution_count": 3,
345
  "id": "d74b38c5-a1fb-4214-b4f4-b5bf0869f169",
346
  "metadata": {
347
  "colab": {
@@ -356,7 +356,7 @@
356
  "name": "stdout",
357
  "output_type": "stream",
358
  "text": [
359
- "Sun Dec 11 12:49:44 2022 \n",
360
  "+-----------------------------------------------------------------------------+\n",
361
  "| NVIDIA-SMI 515.65.01 Driver Version: 515.65.01 CUDA Version: 11.7 |\n",
362
  "|-------------------------------+----------------------+----------------------+\n",
@@ -365,7 +365,7 @@
365
  "| | | MIG M. |\n",
366
  "|===============================+======================+======================|\n",
367
  "| 0 NVIDIA A100-SXM... On | 00000000:06:00.0 Off | 0 |\n",
368
- "| N/A 32C P0 44W / 400W | 0MiB / 40960MiB | 0% Default |\n",
369
  "| | | Disabled |\n",
370
  "+-------------------------------+----------------------+----------------------+\n",
371
  " \n",
@@ -374,7 +374,7 @@
374
  "| GPU GI CI PID Type Process name GPU Memory |\n",
375
  "| ID ID Usage |\n",
376
  "|=============================================================================|\n",
377
- "| No running processes found |\n",
378
  "+-----------------------------------------------------------------------------+\n"
379
  ]
380
  }
@@ -735,7 +735,7 @@
735
  },
736
  {
737
  "cell_type": "code",
738
- "execution_count": 6,
739
  "id": "dff27c76-575c-432b-8916-b1b810efef4a",
740
  "metadata": {
741
  "colab": {
@@ -768,7 +768,7 @@
768
  {
769
  "data": {
770
  "application/vnd.jupyter.widget-view+json": {
771
- "model_id": "172290fa003f462aa40a9b03e56e91b1",
772
  "version_major": 2,
773
  "version_minor": 0
774
  },
@@ -850,7 +850,7 @@
850
  },
851
  {
852
  "cell_type": "code",
853
- "execution_count": 7,
854
  "id": "065a8cf7-e54f-4ac3-900e-609c80714fca",
855
  "metadata": {
856
  "id": "065a8cf7-e54f-4ac3-900e-609c80714fca"
@@ -898,7 +898,7 @@
898
  },
899
  {
900
  "cell_type": "code",
901
- "execution_count": 8,
902
  "id": "a2787582-554f-44ce-9f38-4180a5ed6b44",
903
  "metadata": {
904
  "id": "a2787582-554f-44ce-9f38-4180a5ed6b44",
@@ -960,7 +960,7 @@
960
  },
961
  {
962
  "cell_type": "code",
963
- "execution_count": 9,
964
  "id": "qOwlctMhNmCG",
965
  "metadata": {
966
  "id": "qOwlctMhNmCG",
@@ -984,7 +984,7 @@
984
  },
985
  {
986
  "cell_type": "code",
987
- "execution_count": 10,
988
  "id": "imRHJOpm4V_j",
989
  "metadata": {
990
  "id": "imRHJOpm4V_j"
@@ -1031,7 +1031,7 @@
1031
  },
1032
  {
1033
  "cell_type": "code",
1034
- "execution_count": 11,
1035
  "id": "77d9f0c5-8607-4642-a8ac-c3ab2e223ea6",
1036
  "metadata": {
1037
  "id": "77d9f0c5-8607-4642-a8ac-c3ab2e223ea6"
@@ -1066,7 +1066,7 @@
1066
  },
1067
  {
1068
  "cell_type": "code",
1069
- "execution_count": 12,
1070
  "id": "ab5a13b4-9bd4-4aa0-aef2-b3de9b762988",
1071
  "metadata": {
1072
  "colab": {
@@ -1083,7 +1083,7 @@
1083
  " 'sentence': Value(dtype='string', id=None)}"
1084
  ]
1085
  },
1086
- "execution_count": 12,
1087
  "metadata": {},
1088
  "output_type": "execute_result"
1089
  }
@@ -1111,7 +1111,7 @@
1111
  },
1112
  {
1113
  "cell_type": "code",
1114
- "execution_count": 13,
1115
  "id": "3ab6a724-3d1e-478b-a9e9-d2f85feb6c39",
1116
  "metadata": {
1117
  "id": "3ab6a724-3d1e-478b-a9e9-d2f85feb6c39"
@@ -1135,7 +1135,7 @@
1135
  },
1136
  {
1137
  "cell_type": "code",
1138
- "execution_count": 14,
1139
  "id": "d041650e-1c48-4439-87b3-5b6f4a514107",
1140
  "metadata": {
1141
  "id": "d041650e-1c48-4439-87b3-5b6f4a514107"
@@ -1145,7 +1145,7 @@
1145
  "import string\n",
1146
  "import re\n",
1147
  "\n",
1148
- "do_lower_case = True\n",
1149
  "do_remove_punctuation = False\n",
1150
  "\n",
1151
  "punctuation_to_remove = string.punctuation.replace(\"'\", \"\") # don't remove apostrophes\n",
@@ -1171,7 +1171,7 @@
1171
  },
1172
  {
1173
  "cell_type": "code",
1174
- "execution_count": 15,
1175
  "id": "c085911c-a10a-41ef-8874-306e0503e9bb",
1176
  "metadata": {
1177
  "id": "c085911c-a10a-41ef-8874-306e0503e9bb"
@@ -1212,7 +1212,7 @@
1212
  },
1213
  {
1214
  "cell_type": "code",
1215
- "execution_count": 16,
1216
  "id": "a37a7cdb-9013-427f-8de9-6a8d0e9dc684",
1217
  "metadata": {
1218
  "id": "a37a7cdb-9013-427f-8de9-6a8d0e9dc684"
@@ -1234,7 +1234,7 @@
1234
  },
1235
  {
1236
  "cell_type": "code",
1237
- "execution_count": 17,
1238
  "id": "1b145699-acfc-4b1d-93a2-a2ad3d62674c",
1239
  "metadata": {
1240
  "id": "1b145699-acfc-4b1d-93a2-a2ad3d62674c"
@@ -1259,7 +1259,7 @@
1259
  },
1260
  {
1261
  "cell_type": "code",
1262
- "execution_count": 18,
1263
  "id": "01cb25ef-4bb0-4325-9461-f59198acadf6",
1264
  "metadata": {
1265
  "id": "01cb25ef-4bb0-4325-9461-f59198acadf6"
@@ -1284,7 +1284,7 @@
1284
  },
1285
  {
1286
  "cell_type": "code",
1287
- "execution_count": 19,
1288
  "id": "333f7f6e-6053-4d3b-8924-c733c79b82ac",
1289
  "metadata": {
1290
  "id": "333f7f6e-6053-4d3b-8924-c733c79b82ac"
@@ -1364,7 +1364,7 @@
1364
  },
1365
  {
1366
  "cell_type": "code",
1367
- "execution_count": 20,
1368
  "id": "8326221e-ec13-4731-bb4e-51e5fc1486c5",
1369
  "metadata": {
1370
  "id": "8326221e-ec13-4731-bb4e-51e5fc1486c5"
@@ -1416,7 +1416,7 @@
1416
  },
1417
  {
1418
  "cell_type": "code",
1419
- "execution_count": 21,
1420
  "id": "fc834702-c0d3-4a96-b101-7b87be32bf42",
1421
  "metadata": {
1422
  "id": "fc834702-c0d3-4a96-b101-7b87be32bf42"
@@ -1449,7 +1449,7 @@
1449
  },
1450
  {
1451
  "cell_type": "code",
1452
- "execution_count": 22,
1453
  "id": "b22b4011-f31f-4b57-b684-c52332f92890",
1454
  "metadata": {
1455
  "colab": {
@@ -1500,7 +1500,7 @@
1500
  },
1501
  {
1502
  "cell_type": "code",
1503
- "execution_count": 23,
1504
  "id": "a11d1bfc-9e28-460f-a287-72d8f7bc1acb",
1505
  "metadata": {
1506
  "id": "a11d1bfc-9e28-460f-a287-72d8f7bc1acb"
@@ -1549,7 +1549,7 @@
1549
  },
1550
  {
1551
  "cell_type": "code",
1552
- "execution_count": 24,
1553
  "id": "5a10cc4b-07ec-4ebd-ac1d-7c601023594f",
1554
  "metadata": {
1555
  "colab": {
@@ -1605,7 +1605,7 @@
1605
  },
1606
  {
1607
  "cell_type": "code",
1608
- "execution_count": 25,
1609
  "id": "62038ba3-88ed-4fce-84db-338f50dcd04f",
1610
  "metadata": {
1611
  "id": "62038ba3-88ed-4fce-84db-338f50dcd04f"
@@ -1614,7 +1614,10 @@
1614
  "source": [
1615
  "model.config.forced_decoder_ids = None\n",
1616
  "model.config.suppress_tokens = []\n",
1617
- "model.config.use_cache = False"
 
 
 
1618
  ]
1619
  },
1620
  {
@@ -1639,7 +1642,7 @@
1639
  },
1640
  {
1641
  "cell_type": "code",
1642
- "execution_count": 26,
1643
  "id": "0ae3e9af-97b7-4aa0-ae85-20b23b5bcb3a",
1644
  "metadata": {
1645
  "colab": {
@@ -1659,8 +1662,11 @@
1659
  " gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size\n",
1660
  " learning_rate=1e-5,\n",
1661
  " warmup_steps=500,\n",
1662
- " max_steps=5000,\n",
1663
  " gradient_checkpointing=True,\n",
 
 
 
1664
  " fp16=True,\n",
1665
  " evaluation_strategy=\"steps\",\n",
1666
  " per_device_eval_batch_size=16,\n",
@@ -1680,7 +1686,7 @@
1680
  },
1681
  {
1682
  "cell_type": "code",
1683
- "execution_count": 27,
1684
  "id": "o72eOpGzD_sK",
1685
  "metadata": {
1686
  "colab": {
@@ -1694,7 +1700,7 @@
1694
  "name": "stdout",
1695
  "output_type": "stream",
1696
  "text": [
1697
- "Sun Dec 11 12:50:47 2022 \n",
1698
  "+-----------------------------------------------------------------------------+\n",
1699
  "| NVIDIA-SMI 515.65.01 Driver Version: 515.65.01 CUDA Version: 11.7 |\n",
1700
  "|-------------------------------+----------------------+----------------------+\n",
@@ -1744,7 +1750,7 @@
1744
  },
1745
  {
1746
  "cell_type": "code",
1747
- "execution_count": 28,
1748
  "id": "3ac16b62-b3c0-4c68-8f3d-9ecf471534b2",
1749
  "metadata": {
1750
  "id": "3ac16b62-b3c0-4c68-8f3d-9ecf471534b2"
@@ -1777,7 +1783,7 @@
1777
  },
1778
  {
1779
  "cell_type": "code",
1780
- "execution_count": 29,
1781
  "id": "d546d7fe-0543-479a-b708-2ebabec19493",
1782
  "metadata": {
1783
  "colab": {
@@ -2355,7 +2361,7 @@
2355
  "name": "stderr",
2356
  "output_type": "stream",
2357
  "text": [
2358
- "/home/ubuntu/./whisper-medium-el is already a clone of https://huggingface.co/emilios/whisper-medium-el. Make sure you pull the latest changes with `repo.git_pull()`.\n",
2359
  "max_steps is given, it will override any value given in num_train_epochs\n",
2360
  "Using cuda_amp half precision backend\n"
2361
  ]
@@ -2391,7 +2397,7 @@
2391
  },
2392
  {
2393
  "cell_type": "code",
2394
- "execution_count": 30,
2395
  "id": "a1ccb9ed-cbc8-4419-91c0-651e9424b672",
2396
  "metadata": {
2397
  "id": "a1ccb9ed-cbc8-4419-91c0-651e9424b672"
@@ -2401,12 +2407,7 @@
2401
  "name": "stderr",
2402
  "output_type": "stream",
2403
  "text": [
2404
- "Configuration saved in ./whisper-medium-el/config.json\n",
2405
- "Model weights saved in ./whisper-medium-el/pytorch_model.bin\n",
2406
- "Feature extractor saved in ./whisper-medium-el/preprocessor_config.json\n",
2407
- "tokenizer config file saved in ./whisper-medium-el/tokenizer_config.json\n",
2408
- "Special tokens file saved in ./whisper-medium-el/special_tokens_map.json\n",
2409
- "added tokens file saved in ./whisper-medium-el/added_tokens.json\n"
2410
  ]
2411
  }
2412
  ],
@@ -2479,67 +2480,10 @@
2479
  "metadata": {
2480
  "id": "ee8b7b8e-1c9a-4d77-9137-1778a629e6de"
2481
  },
2482
- "outputs": [
2483
- {
2484
- "name": "stderr",
2485
- "output_type": "stream",
2486
- "text": [
2487
- "/home/ubuntu/.local/lib/python3.8/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
2488
- " warnings.warn(\n",
2489
- "***** Running training *****\n",
2490
- " Num examples = 160000\n",
2491
- " Num Epochs = 9223372036854775807\n",
2492
- " Instantaneous batch size per device = 32\n",
2493
- " Total train batch size (w. parallel, distributed & accumulation) = 32\n",
2494
- " Gradient Accumulation steps = 1\n",
2495
- " Total optimization steps = 5000\n",
2496
- " Number of trainable parameters = 763857920\n",
2497
- "Reading metadata...: 1914it [00:00, 58899.17it/s]\n",
2498
- "The following columns in the training set don't have a corresponding argument in `WhisperForConditionalGeneration.forward` and have been ignored: input_length. If input_length are not expected by `WhisperForConditionalGeneration.forward`, you can safely ignore this message.\n"
2499
- ]
2500
- },
2501
- {
2502
- "data": {
2503
- "text/html": [
2504
- "\n",
2505
- " <div>\n",
2506
- " \n",
2507
- " <progress value='21' max='5000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
2508
- " [ 21/5000 02:01 < 8:48:39, 0.16 it/s, Epoch 0.00/9223372036854775807]\n",
2509
- " </div>\n",
2510
- " <table border=\"1\" class=\"dataframe\">\n",
2511
- " <thead>\n",
2512
- " <tr style=\"text-align: left;\">\n",
2513
- " <th>Step</th>\n",
2514
- " <th>Training Loss</th>\n",
2515
- " <th>Validation Loss</th>\n",
2516
- " </tr>\n",
2517
- " </thead>\n",
2518
- " <tbody>\n",
2519
- " </tbody>\n",
2520
- "</table><p>"
2521
- ],
2522
- "text/plain": [
2523
- "<IPython.core.display.HTML object>"
2524
- ]
2525
- },
2526
- "metadata": {},
2527
- "output_type": "display_data"
2528
- },
2529
- {
2530
- "name": "stderr",
2531
- "output_type": "stream",
2532
- "text": [
2533
- "***** Running Evaluation *****\n",
2534
- " Num examples: Unknown\n",
2535
- " Batch size = 16\n",
2536
- "Reading metadata...: 1696it [00:00, 51485.08it/s]\n",
2537
- "The following columns in the evaluation set don't have a corresponding argument in `WhisperForConditionalGeneration.forward` and have been ignored: path, accent, down_votes, age, gender, segment, up_votes, input_length, locale, client_id. If path, accent, down_votes, age, gender, segment, up_votes, input_length, locale, client_id are not expected by `WhisperForConditionalGeneration.forward`, you can safely ignore this message.\n"
2538
- ]
2539
- }
2540
- ],
2541
  "source": [
2542
- "trainer.train()"
 
2543
  ]
2544
  },
2545
  {
@@ -2580,10 +2524,11 @@
2580
  " \"dataset\": \"Common Voice 11.0\", # a 'pretty' name for the training dataset\n",
2581
  " #\"dataset\": \"Google FLEURS\", # a 'pretty' name for the training dataset\n",
2582
  " \"language\": \"el\",\n",
2583
- " \"model_name\": \"Whisper Medium El Greco Greek\", # a 'pretty' name for your model\n",
2584
  " \"finetuned_from\": \"openai/whisper-medium\",\n",
2585
  " \"tasks\": \"automatic-speech-recognition\",\n",
2586
- " \"tags\": \"hf-asr-leaderboard, whisper-medium, mozilla-foundation/common_voice_11_0, greek, whisper-event\",\n",
 
2587
  "}"
2588
  ]
2589
  },
@@ -2608,6 +2553,14 @@
2608
  "source": [
2609
  "trainer.push_to_hub(**kwargs)"
2610
  ]
 
 
 
 
 
 
 
 
2611
  }
2612
  ],
2613
  "metadata": {
 
92
  "Building dependency tree \n",
93
  "Reading state information... Done\n",
94
  "git-lfs is already the newest version (2.9.2-1).\n",
95
+ "0 upgraded, 0 newly installed, 0 to remove and 155 not upgraded.\n",
96
+ "Updated git hooks.\n",
97
  "Git LFS initialized.\n"
98
  ]
99
  }
 
341
  },
342
  {
343
  "cell_type": "code",
344
+ "execution_count": 2,
345
  "id": "d74b38c5-a1fb-4214-b4f4-b5bf0869f169",
346
  "metadata": {
347
  "colab": {
 
356
  "name": "stdout",
357
  "output_type": "stream",
358
  "text": [
359
+ "Mon Dec 12 18:05:43 2022 \n",
360
  "+-----------------------------------------------------------------------------+\n",
361
  "| NVIDIA-SMI 515.65.01 Driver Version: 515.65.01 CUDA Version: 11.7 |\n",
362
  "|-------------------------------+----------------------+----------------------+\n",
 
365
  "| | | MIG M. |\n",
366
  "|===============================+======================+======================|\n",
367
  "| 0 NVIDIA A100-SXM... On | 00000000:06:00.0 Off | 0 |\n",
368
+ "| N/A 74C P0 350W / 400W | 36965MiB / 40960MiB | 100% Default |\n",
369
  "| | | Disabled |\n",
370
  "+-------------------------------+----------------------+----------------------+\n",
371
  " \n",
 
374
  "| GPU GI CI PID Type Process name GPU Memory |\n",
375
  "| ID ID Usage |\n",
376
  "|=============================================================================|\n",
377
+ "| 0 N/A N/A 3651714 C python 36963MiB |\n",
378
  "+-----------------------------------------------------------------------------+\n"
379
  ]
380
  }
 
735
  },
736
  {
737
  "cell_type": "code",
738
+ "execution_count": 1,
739
  "id": "dff27c76-575c-432b-8916-b1b810efef4a",
740
  "metadata": {
741
  "colab": {
 
768
  {
769
  "data": {
770
  "application/vnd.jupyter.widget-view+json": {
771
+ "model_id": "6a2bfa275b6f4cdba66b6abdab11e859",
772
  "version_major": 2,
773
  "version_minor": 0
774
  },
 
850
  },
851
  {
852
  "cell_type": "code",
853
+ "execution_count": 1,
854
  "id": "065a8cf7-e54f-4ac3-900e-609c80714fca",
855
  "metadata": {
856
  "id": "065a8cf7-e54f-4ac3-900e-609c80714fca"
 
898
  },
899
  {
900
  "cell_type": "code",
901
+ "execution_count": 2,
902
  "id": "a2787582-554f-44ce-9f38-4180a5ed6b44",
903
  "metadata": {
904
  "id": "a2787582-554f-44ce-9f38-4180a5ed6b44",
 
960
  },
961
  {
962
  "cell_type": "code",
963
+ "execution_count": 3,
964
  "id": "qOwlctMhNmCG",
965
  "metadata": {
966
  "id": "qOwlctMhNmCG",
 
984
  },
985
  {
986
  "cell_type": "code",
987
+ "execution_count": 4,
988
  "id": "imRHJOpm4V_j",
989
  "metadata": {
990
  "id": "imRHJOpm4V_j"
 
1031
  },
1032
  {
1033
  "cell_type": "code",
1034
+ "execution_count": 5,
1035
  "id": "77d9f0c5-8607-4642-a8ac-c3ab2e223ea6",
1036
  "metadata": {
1037
  "id": "77d9f0c5-8607-4642-a8ac-c3ab2e223ea6"
 
1066
  },
1067
  {
1068
  "cell_type": "code",
1069
+ "execution_count": 6,
1070
  "id": "ab5a13b4-9bd4-4aa0-aef2-b3de9b762988",
1071
  "metadata": {
1072
  "colab": {
 
1083
  " 'sentence': Value(dtype='string', id=None)}"
1084
  ]
1085
  },
1086
+ "execution_count": 6,
1087
  "metadata": {},
1088
  "output_type": "execute_result"
1089
  }
 
1111
  },
1112
  {
1113
  "cell_type": "code",
1114
+ "execution_count": 7,
1115
  "id": "3ab6a724-3d1e-478b-a9e9-d2f85feb6c39",
1116
  "metadata": {
1117
  "id": "3ab6a724-3d1e-478b-a9e9-d2f85feb6c39"
 
1135
  },
1136
  {
1137
  "cell_type": "code",
1138
+ "execution_count": 8,
1139
  "id": "d041650e-1c48-4439-87b3-5b6f4a514107",
1140
  "metadata": {
1141
  "id": "d041650e-1c48-4439-87b3-5b6f4a514107"
 
1145
  "import string\n",
1146
  "import re\n",
1147
  "\n",
1148
+ "do_lower_case = False\n",
1149
  "do_remove_punctuation = False\n",
1150
  "\n",
1151
  "punctuation_to_remove = string.punctuation.replace(\"'\", \"\") # don't remove apostrophes\n",
 
1171
  },
1172
  {
1173
  "cell_type": "code",
1174
+ "execution_count": 9,
1175
  "id": "c085911c-a10a-41ef-8874-306e0503e9bb",
1176
  "metadata": {
1177
  "id": "c085911c-a10a-41ef-8874-306e0503e9bb"
 
1212
  },
1213
  {
1214
  "cell_type": "code",
1215
+ "execution_count": 10,
1216
  "id": "a37a7cdb-9013-427f-8de9-6a8d0e9dc684",
1217
  "metadata": {
1218
  "id": "a37a7cdb-9013-427f-8de9-6a8d0e9dc684"
 
1234
  },
1235
  {
1236
  "cell_type": "code",
1237
+ "execution_count": 11,
1238
  "id": "1b145699-acfc-4b1d-93a2-a2ad3d62674c",
1239
  "metadata": {
1240
  "id": "1b145699-acfc-4b1d-93a2-a2ad3d62674c"
 
1259
  },
1260
  {
1261
  "cell_type": "code",
1262
+ "execution_count": 12,
1263
  "id": "01cb25ef-4bb0-4325-9461-f59198acadf6",
1264
  "metadata": {
1265
  "id": "01cb25ef-4bb0-4325-9461-f59198acadf6"
 
1284
  },
1285
  {
1286
  "cell_type": "code",
1287
+ "execution_count": 13,
1288
  "id": "333f7f6e-6053-4d3b-8924-c733c79b82ac",
1289
  "metadata": {
1290
  "id": "333f7f6e-6053-4d3b-8924-c733c79b82ac"
 
1364
  },
1365
  {
1366
  "cell_type": "code",
1367
+ "execution_count": 14,
1368
  "id": "8326221e-ec13-4731-bb4e-51e5fc1486c5",
1369
  "metadata": {
1370
  "id": "8326221e-ec13-4731-bb4e-51e5fc1486c5"
 
1416
  },
1417
  {
1418
  "cell_type": "code",
1419
+ "execution_count": 15,
1420
  "id": "fc834702-c0d3-4a96-b101-7b87be32bf42",
1421
  "metadata": {
1422
  "id": "fc834702-c0d3-4a96-b101-7b87be32bf42"
 
1449
  },
1450
  {
1451
  "cell_type": "code",
1452
+ "execution_count": 16,
1453
  "id": "b22b4011-f31f-4b57-b684-c52332f92890",
1454
  "metadata": {
1455
  "colab": {
 
1500
  },
1501
  {
1502
  "cell_type": "code",
1503
+ "execution_count": 17,
1504
  "id": "a11d1bfc-9e28-460f-a287-72d8f7bc1acb",
1505
  "metadata": {
1506
  "id": "a11d1bfc-9e28-460f-a287-72d8f7bc1acb"
 
1549
  },
1550
  {
1551
  "cell_type": "code",
1552
+ "execution_count": 18,
1553
  "id": "5a10cc4b-07ec-4ebd-ac1d-7c601023594f",
1554
  "metadata": {
1555
  "colab": {
 
1605
  },
1606
  {
1607
  "cell_type": "code",
1608
+ "execution_count": 19,
1609
  "id": "62038ba3-88ed-4fce-84db-338f50dcd04f",
1610
  "metadata": {
1611
  "id": "62038ba3-88ed-4fce-84db-338f50dcd04f"
 
1614
  "source": [
1615
  "model.config.forced_decoder_ids = None\n",
1616
  "model.config.suppress_tokens = []\n",
1617
+ "model.config.use_cache = False\n",
1618
+ "model.config.dropout=0.1\n",
1619
+ "#model.config.dropout=0.05\n",
1620
+ "#model.config.dropout = model.config.attention_dropout = 0.05"
1621
  ]
1622
  },
1623
  {
 
1642
  },
1643
  {
1644
  "cell_type": "code",
1645
+ "execution_count": 20,
1646
  "id": "0ae3e9af-97b7-4aa0-ae85-20b23b5bcb3a",
1647
  "metadata": {
1648
  "colab": {
 
1662
  " gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size\n",
1663
  " learning_rate=1e-5,\n",
1664
  " warmup_steps=500,\n",
1665
+ " max_steps=7000,\n",
1666
  " gradient_checkpointing=True,\n",
1667
+ " resume_from_checkpoint=5000,\n",
1668
+ " #resume_from_checkpoint=\"checkpoint-4000\"\n",
1669
+ " ignore_data_skip=True,\n",
1670
  " fp16=True,\n",
1671
  " evaluation_strategy=\"steps\",\n",
1672
  " per_device_eval_batch_size=16,\n",
 
1686
  },
1687
  {
1688
  "cell_type": "code",
1689
+ "execution_count": 21,
1690
  "id": "o72eOpGzD_sK",
1691
  "metadata": {
1692
  "colab": {
 
1700
  "name": "stdout",
1701
  "output_type": "stream",
1702
  "text": [
1703
+ "Tue Dec 13 00:36:23 2022 \n",
1704
  "+-----------------------------------------------------------------------------+\n",
1705
  "| NVIDIA-SMI 515.65.01 Driver Version: 515.65.01 CUDA Version: 11.7 |\n",
1706
  "|-------------------------------+----------------------+----------------------+\n",
 
1750
  },
1751
  {
1752
  "cell_type": "code",
1753
+ "execution_count": 22,
1754
  "id": "3ac16b62-b3c0-4c68-8f3d-9ecf471534b2",
1755
  "metadata": {
1756
  "id": "3ac16b62-b3c0-4c68-8f3d-9ecf471534b2"
 
1783
  },
1784
  {
1785
  "cell_type": "code",
1786
+ "execution_count": 23,
1787
  "id": "d546d7fe-0543-479a-b708-2ebabec19493",
1788
  "metadata": {
1789
  "colab": {
 
2361
  "name": "stderr",
2362
  "output_type": "stream",
2363
  "text": [
2364
+ "/home/ubuntu/whisper-medium-el/./whisper-medium-el is already a clone of https://huggingface.co/emilios/whisper-medium-el. Make sure you pull the latest changes with `repo.git_pull()`.\n",
2365
  "max_steps is given, it will override any value given in num_train_epochs\n",
2366
  "Using cuda_amp half precision backend\n"
2367
  ]
 
2397
  },
2398
  {
2399
  "cell_type": "code",
2400
+ "execution_count": null,
2401
  "id": "a1ccb9ed-cbc8-4419-91c0-651e9424b672",
2402
  "metadata": {
2403
  "id": "a1ccb9ed-cbc8-4419-91c0-651e9424b672"
 
2407
  "name": "stderr",
2408
  "output_type": "stream",
2409
  "text": [
2410
+ "Configuration saved in ./whisper-medium-el/config.json\n"
 
 
 
 
 
2411
  ]
2412
  }
2413
  ],
 
2480
  "metadata": {
2481
  "id": "ee8b7b8e-1c9a-4d77-9137-1778a629e6de"
2482
  },
2483
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2484
  "source": [
2485
+ "#trainer.train()\n",
2486
+ "trainer.train(resume_from_checkpoint = True)\n"
2487
  ]
2488
  },
2489
  {
 
2524
  " \"dataset\": \"Common Voice 11.0\", # a 'pretty' name for the training dataset\n",
2525
  " #\"dataset\": \"Google FLEURS\", # a 'pretty' name for the training dataset\n",
2526
  " \"language\": \"el\",\n",
2527
+ " \"model_name\": \"Whisper Medium El Greco\", # a 'pretty' name for your model\n",
2528
  " \"finetuned_from\": \"openai/whisper-medium\",\n",
2529
  " \"tasks\": \"automatic-speech-recognition\",\n",
2530
+ " \"tags\": \"hf-asr-leaderboard\",\n",
2531
+ " #\"tags\": \"hf-asr-leaderboard, whisper-medium, mozilla-foundation/common_voice_11_0, greek, whisper-event\",\n",
2532
  "}"
2533
  ]
2534
  },
 
2553
  "source": [
2554
  "trainer.push_to_hub(**kwargs)"
2555
  ]
2556
+ },
2557
+ {
2558
+ "cell_type": "code",
2559
+ "execution_count": null,
2560
+ "id": "afa8a29a-59ab-463f-b2bf-2fbf2bae9b82",
2561
+ "metadata": {},
2562
+ "outputs": [],
2563
+ "source": []
2564
  }
2565
  ],
2566
  "metadata": {
push.py ADDED
@@ -0,0 +1,758 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for sequence to sequence speech recognition
18
+ with 🤗 Datasets' streaming mode.
19
+ """
20
+ # This progam was modified by Michael Kamfonas (mkamfonas@infokarta.com) on Dec 11 2022
21
+ # - added options for drpout, gradient_checkpointing, use_cache, stopping_strategy and streaming
22
+ # - restructured it to enable both streaming and non-streaming modes
23
+ # - allows concatenation of mutiple datasets (single-string comma-separated) for interleaving
24
+ # The following params must have the same number of comma-separated (,) elements:
25
+ # dataset_name,
26
+ # dataset_config_name,
27
+ # train_split_name and eval_split_name (each element plus-separated (+) for multiple splits),
28
+ # text_column_name and audio_column_name
29
+
30
+
31
+ import logging
32
+ import os
33
+ import sys
34
+ from dataclasses import dataclass, field
35
+ from typing import Any, Dict, List, Optional, Union
36
+
37
+ import datasets
38
+ import torch
39
+ from datasets import Audio, DatasetDict, IterableDatasetDict, interleave_datasets, load_dataset
40
+ from torch.utils.data import IterableDataset
41
+
42
+ import evaluate
43
+ import transformers
44
+ from transformers import (
45
+ AutoConfig,
46
+ AutoFeatureExtractor,
47
+ AutoModelForSpeechSeq2Seq,
48
+ AutoProcessor,
49
+ AutoTokenizer,
50
+ HfArgumentParser,
51
+ Seq2SeqTrainer,
52
+ Seq2SeqTrainingArguments,
53
+ TrainerCallback,
54
+ set_seed,
55
+ )
56
+ from transformers.trainer_pt_utils import IterableDatasetShard
57
+ from transformers.trainer_utils import get_last_checkpoint, is_main_process
58
+ from transformers.utils import check_min_version, send_example_telemetry
59
+ from transformers.utils.versions import require_version
60
+ from transformers.models.whisper.english_normalizer import BasicTextNormalizer
61
+
62
+ #TEXT_COL_NAME="text"
63
+ TEXT_COL_NAME="sentence,transcription"
64
+ AUDIO_COL_NAME="audio"
65
+
66
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
67
+ check_min_version("4.25.0.dev0")
68
+
69
+ require_version("datasets>=1.18.2", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
70
+
71
+ logger = logging.getLogger(__name__)
72
+
73
+
74
+ @dataclass
75
+ class ModelArguments:
76
+ """
77
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
78
+ """
79
+
80
+ model_name_or_path: str = field(
81
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
82
+ )
83
+ config_name: Optional[str] = field(
84
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
85
+ )
86
+ tokenizer_name: Optional[str] = field(
87
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
88
+ )
89
+ feature_extractor_name: Optional[str] = field(
90
+ default=None, metadata={"help": "feature extractor name or path if not the same as model_name"}
91
+ )
92
+ cache_dir: Optional[str] = field(
93
+ default=None,
94
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
95
+ )
96
+ use_fast_tokenizer: bool = field(
97
+ default=True,
98
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
99
+ )
100
+ model_revision: str = field(
101
+ default="main",
102
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
103
+ )
104
+ use_auth_token: bool = field(
105
+ default=False,
106
+ metadata={
107
+ "help": (
108
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
109
+ "with private models)."
110
+ )
111
+ },
112
+ )
113
+ freeze_feature_encoder: bool = field(
114
+ default=True, metadata={"help": "Deprecated - Whether to freeze the feature encoder layers of the model."}
115
+ )
116
+ freeze_encoder: bool = field(
117
+ default=False, metadata={"help": "Whether to freeze the entire encoder of the seq2seq model."}
118
+ )
119
+ forced_decoder_ids: List[List[int]] = field(
120
+ default=None,
121
+ metadata={
122
+ "help": (
123
+ "A list of pairs of integers which indicates a mapping from generation indices to token indices "
124
+ "that will be forced before sampling. For example, [[0, 123]] means the first generated token "
125
+ "will always be a token of index 123."
126
+ )
127
+ },
128
+ )
129
+ suppress_tokens: List[int] = field(
130
+ default=None, metadata={"help": "A list of tokens that will be suppressed at generation."}
131
+ )
132
+ model_index_name: str = field(default=None, metadata={"help": "Pretty name for the model card."})
133
+
134
+ ## added by Michael Kamfonas
135
+ use_cache: bool = field(
136
+ default=False, metadata={"help": "Whether to use cache."}
137
+ )
138
+
139
+ dropout: float = field(
140
+ default = 0.0, metadata = {"help": "dropout probability."}
141
+ )
142
+
143
+
144
+ @dataclass
145
+ class DataTrainingArguments:
146
+ """
147
+ Arguments pertaining to what data we are going to input our model for training and eval.
148
+ """
149
+
150
+ dataset_name: str = field(
151
+ default=None,
152
+ metadata={"help": "The name of the dataset to use (via the datasets library)."}
153
+ )
154
+ dataset_config_name: Optional[str] = field(
155
+ default=None,
156
+ metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
157
+ )
158
+ text_column: Optional[str] = field(
159
+ default=None,
160
+ metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
161
+ )
162
+ max_train_samples: Optional[int] = field(
163
+ default=None,
164
+ metadata={
165
+ "help": (
166
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
167
+ "value if set."
168
+ )
169
+ },
170
+ )
171
+ max_eval_samples: Optional[int] = field(
172
+ default=None,
173
+ metadata={
174
+ "help": (
175
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
176
+ "value if set."
177
+ )
178
+ },
179
+ )
180
+ audio_column_name: str = field(
181
+ default="audio",
182
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
183
+ )
184
+ text_column_name: str = field(
185
+ default="text",
186
+ metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
187
+ )
188
+ max_duration_in_seconds: float = field(
189
+ default=20.0,
190
+ metadata={
191
+ "help": (
192
+ "Truncate audio files that are longer than `max_duration_in_seconds` seconds to"
193
+ " 'max_duration_in_seconds`"
194
+ )
195
+ },
196
+ )
197
+ min_duration_in_seconds: float = field(
198
+ default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
199
+ )
200
+ train_split_name: str = field(
201
+ default="train",
202
+ metadata={
203
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
204
+ },
205
+ )
206
+ eval_split_name: str = field(
207
+ default="test",
208
+ metadata={
209
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
210
+ },
211
+ )
212
+ do_lower_case: bool = field(
213
+ default=False,
214
+ metadata={"help": "Whether the target text should be lower cased."},
215
+ )
216
+ do_remove_punctuation: bool = field(
217
+ default=False,
218
+ metadata={"help": "Whether the target text should be striped of punctuation."},
219
+ )
220
+ do_normalize_eval: bool = field(
221
+ default=True,
222
+ metadata={"help": "Whether to normalise the references and predictions in the eval WER calculation."},
223
+ )
224
+ language: str = field(
225
+ default=None,
226
+ metadata={
227
+ "help": (
228
+ "Language for multilingual fine-tuning. This argument should be set for multilingual fine-tuning "
229
+ "only. For English speech recognition, it should be set to `None`."
230
+ )
231
+ },
232
+ )
233
+ task: str = field(
234
+ default="transcribe",
235
+ metadata={"help": "Task, either `transcribe` for speech recognition or `translate` for speech translation."},
236
+ )
237
+ shuffle_buffer_size: Optional[int] = field(
238
+ default=500,
239
+ metadata={
240
+ "help": (
241
+ "The number of streamed examples to download before shuffling them. The large the buffer, "
242
+ "the closer it is to real offline shuffling."
243
+ )
244
+ },
245
+ )
246
+ stopping_strategy: Optional[str] = field(
247
+ default="all_exhausted",
248
+ metadata={
249
+ "help": "Strategy used to consume interleaved data. Default = 'all_exhausted'"
250
+ }
251
+ )
252
+ streaming: bool = field(
253
+ default=True,
254
+ metadata={"help": "Whether to use streaming mode to load and pre-process the data."},
255
+ )
256
+
257
+ @dataclass
258
+ class DataCollatorSpeechSeq2SeqWithPadding:
259
+ """
260
+ Data collator that will dynamically pad the inputs received.
261
+ Args:
262
+ processor ([`WhisperProcessor`])
263
+ The processor used for processing the data.
264
+ decoder_start_token_id (`int`)
265
+ The begin-of-sentence of the decoder.
266
+ """
267
+
268
+ processor: Any
269
+ decoder_start_token_id: int
270
+
271
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
272
+ # split inputs and labels since they have to be of different lengths and need
273
+ # different padding methods
274
+ model_input_name = self.processor.model_input_names[0]
275
+ input_features = [{model_input_name: feature[model_input_name]} for feature in features]
276
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
277
+
278
+ batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
279
+
280
+ labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
281
+
282
+ # replace padding with -100 to ignore loss correctly
283
+ labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
284
+
285
+ # if bos token is appended in previous tokenization step,
286
+ # cut bos token here as it's append later anyways
287
+ if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
288
+ labels = labels[:, 1:]
289
+
290
+ batch["labels"] = labels
291
+
292
+ return batch
293
+
294
+
295
+ def load_streaming_dataset(dataset_name, dataset_config_name, split="train", **kwargs):
296
+ """
297
+ Utility function to load a dataset in streaming mode. For datasets with multiple splits,
298
+ each split is loaded individually and then splits combined by taking alternating examples from
299
+ each (interleaving).
300
+ """
301
+ if "+" in split:
302
+ # load multiple splits separated by the `+` symbol with streaming mode
303
+ dataset_splits = [
304
+ load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=True, **kwargs)
305
+ for split_name in split.split("+")
306
+ ]
307
+ # interleave multiple splits to form one dataset
308
+ interleaved_dataset = interleave_datasets(dataset_splits)
309
+ return interleaved_dataset
310
+ else:
311
+ # load a single split *with* streaming mode
312
+ dataset = load_dataset(dataset_name, dataset_config_name, split=split, streaming=True, **kwargs)
313
+ return dataset
314
+
315
+ def load_multiple_streaming_datasets(
316
+ dataset_names: List,
317
+ dataset_config_names: List,
318
+ splits: Optional[List] = None,
319
+ text_column_names: Optional[List] = None,
320
+ audio_column_names: Optional[List] = None,
321
+ sampling_rate: Optional[int] = 16000,
322
+ stopping_strategy: Optional[str] = "all_exhausted",
323
+ streaming = True,
324
+ **kwargs
325
+ ):
326
+
327
+ if len(dataset_names) != len(dataset_config_names):
328
+ raise ValueError(
329
+ f"Ensure one config is passed for each dataset, got {len(dataset_names)} datasets and"
330
+ f" {len(dataset_config_names)} configs."
331
+ )
332
+
333
+ if splits is not None and len(splits) != len(dataset_names):
334
+ raise ValueError(
335
+ f"Ensure one train_split is passed for each dataset, got {len(dataset_names)} datasets and {len(splits)} splits."
336
+ )
337
+
338
+ if text_column_names is not None and len(text_column_names) != len(dataset_names):
339
+ raise ValueError(
340
+ f"Ensure one text column name is passed for each dataset, got {len(dataset_names)} datasets and"
341
+ f" {len(text_column_names)} text column names."
342
+ )
343
+
344
+ if audio_column_names is not None and len(audio_column_names) != len(dataset_names):
345
+ raise ValueError(
346
+ f"Ensure one text column name is passed for each dataset, got {len(dataset_names)} datasets and"
347
+ f" {len(audio_column_names)} text column names."
348
+ )
349
+
350
+ splits = splits if splits is not None \
351
+ else ["train" for i in range(len(dataset_names))]
352
+
353
+ text_column_names = (
354
+ text_column_names if text_column_names is not None \
355
+ else [TEXT_COL_NAME for i in range(len(dataset_names))]
356
+ )
357
+
358
+ audio_column_names = (
359
+ audio_column_names if audio_column_names is not None \
360
+ else [AUDIO_COL_NAME for i in range(len(dataset_names))]
361
+ )
362
+
363
+ all_data_splits = []
364
+ # iterate over the datasets we want to interleave
365
+ for dset, cfgNm, splt, txtColNm, audColNm in zip(dataset_names,dataset_config_names,\
366
+ splits,text_column_names, audio_column_names):
367
+
368
+ dset_splits = [load_dataset(dset, cfgNm, split=c, streaming=streaming, **kwargs) \
369
+ for c in splt.split('+') if c != '-']
370
+
371
+ if streaming:
372
+ dset_splits = [ds if TEXT_COL_NAME in ds.features else ds.rename_column(txtColNm, TEXT_COL_NAME) \
373
+ for ds in dset_splits ]
374
+ dset_splits = [ds if AUDIO_COL_NAME in ds.features else ds.rename_column(audColNm, AUDIO_COL_NAME) \
375
+ for ds in dset_splits]
376
+
377
+ if len(dset_splits)>0 and sampling_rate != next(iter(dset_splits[0]))[AUDIO_COL_NAME]['sampling_rate']:
378
+ dset_splits = [ds.cast_column(AUDIO_COL_NAME, Audio(sampling_rate)) for ds in dset_splits]
379
+ else:
380
+
381
+ dset_splits = [ds if TEXT_COL_NAME in ds.column_names else ds.rename_column(txtColNm, TEXT_COL_NAME) \
382
+ for ds in dset_splits ]
383
+ dset_splits = [ds if AUDIO_COL_NAME in ds.column_names else ds.rename_column(audColNm, AUDIO_COL_NAME) \
384
+ for ds in dset_splits]
385
+
386
+ if len(dset_splits)>0 and sampling_rate != next(iter(dset_splits[0]))[AUDIO_COL_NAME]['sampling_rate']:
387
+ dset_splits = [ds.cast_column(AUDIO_COL_NAME, Audio(sampling_rate)) for ds in dset_splits]
388
+
389
+ cols2keep = set([AUDIO_COL_NAME, TEXT_COL_NAME])
390
+
391
+ dset_splits = [ds.remove_columns(set(ds.features.keys()) - cols2keep) for ds in dset_splits]
392
+
393
+ all_data_splits += dset_splits
394
+
395
+ return interleave_datasets(all_data_splits, stopping_strategy=stopping_strategy)
396
+
397
+ def main():
398
+ # 1. Parse input arguments
399
+ # See all possible arguments in src/transformers/training_args.py
400
+ # or by passing the --help flag to this script.
401
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
402
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
403
+
404
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
405
+ # If we pass only one argument to the script and it's the path to a json file,
406
+ # let's parse it to get our arguments.
407
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
408
+ else:
409
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
410
+
411
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
412
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
413
+ send_example_telemetry("run_speech_recognition_seq2seq_streaming", model_args, data_args)
414
+
415
+ # 2. Setup logging
416
+ logging.basicConfig(
417
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
418
+ datefmt="%m/%d/%Y %H:%M:%S",
419
+ handlers=[logging.StreamHandler(sys.stdout)],
420
+ )
421
+ log_level = training_args.get_process_log_level()
422
+ logger.setLevel(log_level)
423
+ datasets.utils.logging.set_verbosity(log_level)
424
+ transformers.utils.logging.set_verbosity(log_level)
425
+ transformers.utils.logging.enable_default_handler()
426
+ transformers.utils.logging.enable_explicit_format()
427
+
428
+ logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
429
+
430
+ # Log on each process the small summary:
431
+ logger.warning(
432
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
433
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
434
+ )
435
+ logger.info(f"Training/evaluation parameters {training_args}")
436
+
437
+ # Set the verbosity to info of the Transformers logger (on main process only):
438
+ if is_main_process(training_args.local_rank):
439
+ transformers.utils.logging.set_verbosity_info()
440
+ logger.info("Training/evaluation parameters %s", training_args)
441
+
442
+ # 3. Detecting last checkpoint and eventually continue from last checkpoint
443
+ last_checkpoint = None
444
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
445
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
446
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
447
+ raise ValueError(
448
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
449
+ "Use --overwrite_output_dir to overcome."
450
+ )
451
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
452
+ logger.info(
453
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
454
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
455
+ )
456
+
457
+ # Set seed before initializing model.
458
+ set_seed(training_args.seed)
459
+
460
+ # 5. Load pretrained model, tokenizer, and feature extractor
461
+ #
462
+ # Distributed training:
463
+ # The .from_pretrained methods guarantee that only one local process can concurrently
464
+ config = AutoConfig.from_pretrained(
465
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
466
+ cache_dir=model_args.cache_dir,
467
+ revision=model_args.model_revision,
468
+ use_auth_token=True if model_args.use_auth_token else None,
469
+ )
470
+
471
+ config.update({ "forced_decoder_ids": model_args.forced_decoder_ids,
472
+ "suppress_tokens": model_args.suppress_tokens})
473
+
474
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
475
+ model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
476
+ cache_dir=model_args.cache_dir,
477
+ revision=model_args.model_revision,
478
+ use_auth_token=True if model_args.use_auth_token else None,
479
+ )
480
+ tokenizer = AutoTokenizer.from_pretrained(
481
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
482
+ cache_dir=model_args.cache_dir,
483
+ use_fast=model_args.use_fast_tokenizer,
484
+ revision=model_args.model_revision,
485
+ use_auth_token=True if model_args.use_auth_token else None,
486
+ )
487
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
488
+ model_args.model_name_or_path,
489
+ config=config,
490
+ cache_dir=model_args.cache_dir,
491
+ revision=model_args.model_revision,
492
+ use_auth_token=True if model_args.use_auth_token else None,
493
+ )
494
+
495
+ model.config.use_cache = model_args.use_cache
496
+ model.config.dropout = model_args.dropout
497
+ if training_args.gradient_checkpointing:
498
+ model.gradient_checkpointing_enable()
499
+
500
+ if model.config.decoder_start_token_id is None:
501
+ raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
502
+
503
+ # deprecated
504
+ #if model_args.freeze_feature_encoder:
505
+ # model.freeze_feature_encoder()
506
+
507
+ if model_args.freeze_encoder:
508
+ model.freeze_encoder()
509
+ model.model.encoder.gradient_checkpointing = False
510
+
511
+ if data_args.language is not None:
512
+ # We only need to set the task id when the language is specified (i.e. in a multilingual setting)
513
+ tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
514
+
515
+
516
+ # 4. Load dataset
517
+ raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
518
+
519
+ # if training_args.do_train:
520
+ # raw_datasets["train"] = load_streaming_dataset(
521
+ # data_args.dataset_name,
522
+ # data_args.dataset_config_name,
523
+ # split=data_args.train_split_name,
524
+ # use_auth_token=True if model_args.use_auth_token else None,
525
+ # )
526
+
527
+ # if training_args.do_eval:
528
+ # raw_datasets["eval"] = load_streaming_dataset(
529
+ # data_args.dataset_name,
530
+ # data_args.dataset_config_name,
531
+ # split=data_args.eval_split_name,
532
+ # use_auth_token=True if model_args.use_auth_token else None,
533
+ # )
534
+
535
+ if training_args.do_train:
536
+ raw_datasets["train"] = load_multiple_streaming_datasets(
537
+ dataset_names=data_args.dataset_name.split(","),
538
+ dataset_config_names=data_args.dataset_config_name.split(","),
539
+ splits = data_args.train_split_name.split(","),
540
+ text_column_names = data_args.text_column_name.split(","),
541
+ sampling_rate = feature_extractor.sampling_rate,
542
+ streaming=data_args.streaming,
543
+ use_auth_token=True if model_args.use_auth_token else None,
544
+ )
545
+
546
+ if training_args.do_eval:
547
+ raw_datasets["eval"] = load_multiple_streaming_datasets(
548
+ dataset_names=data_args.dataset_name.split(","),
549
+ dataset_config_names=data_args.dataset_config_name.split(","),
550
+ splits = data_args.eval_split_name.split(","),
551
+ text_column_names = data_args.text_column_name.split(","),
552
+ sampling_rate = feature_extractor.sampling_rate,
553
+ streaming=False,
554
+ use_auth_token=True if model_args.use_auth_token else None,
555
+ )
556
+
557
+ raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
558
+
559
+ if AUDIO_COL_NAME not in raw_datasets_features:
560
+ raise ValueError(
561
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
562
+ "Make sure to set `--audio_column_name` to the correct audio column - one of "
563
+ f"{', '.join(raw_datasets_features)}."
564
+ )
565
+
566
+ if TEXT_COL_NAME not in raw_datasets_features:
567
+ raise ValueError(
568
+ f"--text_column_name {TEXT_COL_NAME} not found in dataset. "
569
+ "Make sure to set `--text_column_name` to the the respective correct text columns."
570
+ )
571
+
572
+
573
+ # 6. Resample speech dataset if necessary
574
+ #dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
575
+ #if dataset_sampling_rate != feature_extractor.sampling_rate:
576
+ # raw_datasets = raw_datasets.cast_column(
577
+ # data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
578
+ # )
579
+
580
+ # 7. Preprocessing the datasets.
581
+ # We need to read the audio files as arrays and tokenize the targets.
582
+ max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
583
+ min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
584
+ audio_column_name = AUDIO_COL_NAME
585
+ text_column_name = TEXT_COL_NAME
586
+ model_input_name = feature_extractor.model_input_names[0]
587
+ do_lower_case = data_args.do_lower_case
588
+ do_remove_punctuation = data_args.do_remove_punctuation
589
+ normalizer = BasicTextNormalizer() # 'official' text normalizer from OpenAI
590
+
591
+ if data_args.max_train_samples is not None:
592
+ raw_datasets["train"] = (
593
+ raw_datasets["train"].take(data_args.max_train_samples)
594
+ if data_args.streaming
595
+ else raw_datasets["train"].select(range(data_args.max_train_samples))
596
+ )
597
+
598
+ if data_args.max_eval_samples is not None:
599
+ raw_datasets["eval"] = (
600
+ raw_datasets["eval"].take(data_args.max_eval_samples)
601
+ if data_args.streaming
602
+ else raw_datasets["eval"].select(range(data_args.max_eval_samples))
603
+ )
604
+
605
+ def prepare_dataset(batch):
606
+ # process audio
607
+ sample = batch[audio_column_name]
608
+ inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
609
+ # process audio length
610
+ batch[model_input_name] = inputs.get(model_input_name)[0]
611
+ batch["input_length"] = len(sample["array"])
612
+
613
+ # process targets
614
+ input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
615
+ if do_remove_punctuation:
616
+ input_str = normalizer(input_str).strip()
617
+ batch["labels"] = tokenizer(input_str).input_ids
618
+ return batch
619
+
620
+ with training_args.main_process_first(desc="dataset map pre-processing"):
621
+ vectorized_datasets = raw_datasets.map(
622
+ prepare_dataset,
623
+ remove_columns=raw_datasets_features
624
+ ).with_format("torch")
625
+
626
+ if training_args.do_train and data_args.streaming:
627
+ # manually shuffle if streaming (done by the trainer for non-streaming)
628
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(
629
+ buffer_size=data_args.shuffle_buffer_size,
630
+ seed=training_args.seed,
631
+ )
632
+
633
+ # filter training data that is shorter than min_input_length or longer than
634
+ # max_input_length
635
+ def is_audio_in_length_range(length):
636
+ return min_input_length < length < max_input_length
637
+
638
+ if training_args.do_train:
639
+ vectorized_datasets["train"] = vectorized_datasets["train"].filter(
640
+ is_audio_in_length_range,
641
+ input_columns=["input_length"],
642
+ )
643
+
644
+ # 8. Load Metric
645
+ metric = evaluate.load("wer")
646
+ do_normalize_eval = data_args.do_normalize_eval
647
+
648
+ def compute_metrics(pred):
649
+ pred_ids = pred.predictions
650
+
651
+ pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id
652
+
653
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
654
+ # we do not want to group tokens when computing the metrics
655
+ label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)
656
+
657
+ if do_normalize_eval:
658
+ pred_str = [normalizer(pred) for pred in pred_str]
659
+ label_str = [normalizer(label) for label in label_str]
660
+ # filtering step to only evaluate the samples that correspond to non-zero references:
661
+ pred_str = [pred_str[i] for i in range(len(pred_str)) if len(label_str[i]) > 0]
662
+ label_str = [label_str[i] for i in range(len(label_str)) if len(label_str[i]) > 0]
663
+
664
+ wer = 100 * metric.compute(predictions=pred_str, references=label_str)
665
+
666
+ return {"wer": wer}
667
+
668
+ # 9. Create a single speech processor
669
+ if is_main_process(training_args.local_rank):
670
+ # save feature extractor, tokenizer and config
671
+ feature_extractor.save_pretrained(training_args.output_dir)
672
+ tokenizer.save_pretrained(training_args.output_dir)
673
+ config.save_pretrained(training_args.output_dir)
674
+
675
+ processor = AutoProcessor.from_pretrained(training_args.output_dir)
676
+
677
+ # 10. Define data collator
678
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(
679
+ processor=processor,
680
+ decoder_start_token_id=model.config.decoder_start_token_id,
681
+ )
682
+
683
+ # 11. Configure Trainer
684
+ # Trainer callback to reinitialise and reshuffle the streamable datasets at the beginning of each epoch
685
+ # Only required for streaming: Trainer automatically shuffles non-streaming datasets
686
+ class ShuffleCallback(TrainerCallback):
687
+ def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):
688
+ if isinstance(train_dataloader.dataset, IterableDatasetShard):
689
+ pass # set_epoch() is handled by the Trainer
690
+ elif isinstance(train_dataloader.dataset, IterableDataset):
691
+ train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)
692
+
693
+ # Initialize Trainer
694
+ trainer = Seq2SeqTrainer(
695
+ model=model,
696
+ args=training_args,
697
+ train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
698
+ eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
699
+ tokenizer=feature_extractor,
700
+ data_collator=data_collator,
701
+ compute_metrics=compute_metrics if training_args.predict_with_generate else None,
702
+ callbacks=[ShuffleCallback()] if data_args.streaming else None,
703
+ )
704
+
705
+ # 12. Training
706
+ if training_args.do_train:
707
+ checkpoint = None
708
+ if training_args.resume_from_checkpoint is not None:
709
+ checkpoint = training_args.resume_from_checkpoint
710
+ elif last_checkpoint is not None:
711
+ checkpoint = last_checkpoint
712
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
713
+ trainer.save_model() # Saves the feature extractor too for easy upload
714
+
715
+ metrics = train_result.metrics
716
+ if data_args.max_train_samples:
717
+ metrics["train_samples"] = data_args.max_train_samples
718
+ trainer.log_metrics("train", metrics)
719
+ trainer.save_metrics("train", metrics)
720
+ trainer.save_state()
721
+
722
+ # 13. Evaluation
723
+ results = {}
724
+ if training_args.do_eval:
725
+ logger.info("*** Evaluate ***")
726
+ metrics = trainer.evaluate(
727
+ metric_key_prefix="eval",
728
+ max_length=training_args.generation_max_length,
729
+ num_beams=training_args.generation_num_beams,
730
+ )
731
+ if data_args.max_eval_samples:
732
+ metrics["eval_samples"] = data_args.max_eval_samples
733
+
734
+ trainer.log_metrics("eval", metrics)
735
+ trainer.save_metrics("eval", metrics)
736
+
737
+ # 14. Write Training Stats
738
+ kwargs = {
739
+ "finetuned_from": model_args.model_name_or_path,
740
+ "tasks": "automatic-speech-recognition",
741
+ "tags": "whisper-event",
742
+ }
743
+ if data_args.dataset_name is not None:
744
+ kwargs["dataset_tags"] = data_args.dataset_name
745
+ if data_args.dataset_config_name is not None:
746
+ kwargs["dataset"] = f"{data_args.dataset_name} el"
747
+ else:
748
+ kwargs["dataset"] = data_args.dataset_name
749
+ if "common_voice" in data_args.dataset_name:
750
+ kwargs["language"] = data_args.dataset_config_name[:2]
751
+ if model_args.model_index_name is not None:
752
+ kwargs["model_name"] = model_args.model_index_name
753
+
754
+ trainer.push_to_hub(**kwargs)
755
+
756
+
757
+ if __name__ == "__main__":
758
+ main()
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:aa9c1cc306eaa444f99ef35b5e3ffb717f457ea531fd3a8d4727ffdbd4b4cccd
3
  size 3055754841
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b29d215205456ef40d6ac76daabab944f90a03dcb2153898b46883c2733cf26
3
  size 3055754841
run.sh ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python run_speech_recognition_seq2seq_streaming.py \
2
+ --model_name_or_path="emilios/whisper-medium-el" \
3
+ --dataset_name="mozilla-foundation/common_voice_11_0" \
4
+ --dataset_name="mozilla-foundation/common_voice_11_0,google/fleurs" \
5
+ --dataset_config_name="el,el_gr" \
6
+ --language="greek" \
7
+ --train_split_name="train+validation,train+validation" \
8
+ --eval_split_name="test,-" \
9
+ --model_index_name="Whisper Medium El Greco" \
10
+ --text_column_name="sentence,transcription" \
11
+ --audio_column_name="audio,audio" \
12
+ --resume_from_checkpoint="4000" \
13
+ --streaming="False" \
14
+ --max_steps="5000" \
15
+ --max_steps="9000" \
16
+ --output_dir="./" \
17
+ --per_device_train_batch_size="32" \
18
+ --per_device_eval_batch_size="16" \
19
+ --logging_steps="25" \
20
+ --learning_rate="1e-5" \
21
+ --warmup_steps="500" \
22
+ --evaluation_strategy="steps" \
23
+ --eval_steps="1000" \
24
+ --save_strategy="steps" \
25
+ --save_steps="1000" \
26
+ --generation_max_length="225" \
27
+ --length_column_name="input_length" \
28
+ --max_duration_in_seconds="30" \
29
+ --freeze_feature_encoder="False" \
30
+ --report_to="tensorboard" \
31
+ --gradient_checkpointing \
32
+ --fp16 \
33
+ --overwrite_output_dir \
34
+ --do_train="True" \
35
+ --do_eval="True" \
36
+ --predict_with_generate \
37
+ --do_normalize_eval \
38
+ --use_auth_token \
39
+ --push_to_hub
run_eval_whisper_streaming.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from transformers import pipeline
4
+ from transformers.models.whisper.english_normalizer import BasicTextNormalizer
5
+ from datasets import load_dataset, Audio
6
+ import evaluate
7
+
8
+ wer_metric = evaluate.load("wer")
9
+
10
+
11
+ def is_target_text_in_range(ref):
12
+ if ref.strip() == "ignore time segment in scoring":
13
+ return False
14
+ else:
15
+ return ref.strip() != ""
16
+
17
+
18
+ def get_text(sample):
19
+ if "text" in sample:
20
+ return sample["text"]
21
+ elif "sentence" in sample:
22
+ return sample["sentence"]
23
+ elif "normalized_text" in sample:
24
+ return sample["normalized_text"]
25
+ elif "transcript" in sample:
26
+ return sample["transcript"]
27
+ elif "transcription" in sample:
28
+ return sample["transcription"]
29
+ else:
30
+ raise ValueError(
31
+ f"Expected transcript column of either 'text', 'sentence', 'normalized_text' or 'transcript'. Got sample of "
32
+ ".join{sample.keys()}. Ensure a text column name is present in the dataset."
33
+ )
34
+
35
+
36
+ whisper_norm = BasicTextNormalizer()
37
+
38
+
39
+ def normalise(batch):
40
+ batch["norm_text"] = whisper_norm(get_text(batch))
41
+ return batch
42
+
43
+
44
+ def data(dataset):
45
+ for i, item in enumerate(dataset):
46
+ yield {**item["audio"], "reference": item["norm_text"]}
47
+
48
+
49
+ def main(args):
50
+ batch_size = args.batch_size
51
+ whisper_asr = pipeline(
52
+ "automatic-speech-recognition", model=args.model_id, device=args.device
53
+ )
54
+
55
+ whisper_asr.model.config.forced_decoder_ids = (
56
+ whisper_asr.tokenizer.get_decoder_prompt_ids(
57
+ language=args.language, task="transcribe"
58
+ )
59
+ )
60
+
61
+ dataset = load_dataset(
62
+ args.dataset,
63
+ args.config,
64
+ split=args.split,
65
+ streaming=args.streaming,
66
+ use_auth_token=True,
67
+ )
68
+
69
+ # Only uncomment for debugging
70
+ dataset = dataset.take(args.max_eval_samples)
71
+
72
+ dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
73
+ dataset = dataset.map(normalise)
74
+ dataset = dataset.filter(is_target_text_in_range, input_columns=["norm_text"])
75
+
76
+ predictions = []
77
+ references = []
78
+
79
+ # run streamed inference
80
+ for out in whisper_asr(data(dataset), batch_size=batch_size):
81
+ predictions.append(whisper_norm(out["text"]))
82
+ references.append(out["reference"][0])
83
+
84
+ wer = wer_metric.compute(references=references, predictions=predictions)
85
+ wer = round(100 * wer, 2)
86
+
87
+ print("WER:", wer)
88
+
89
+
90
+ if __name__ == "__main__":
91
+ parser = argparse.ArgumentParser()
92
+
93
+ parser.add_argument(
94
+ "--model_id",
95
+ type=str,
96
+ required=True,
97
+ help="Model identifier. Should be loadable with 🤗 Transformers",
98
+ )
99
+ parser.add_argument(
100
+ "--dataset",
101
+ type=str,
102
+ default="mozilla-foundation/common_voice_11_0",
103
+ help="Dataset name to evaluate the `model_id`. Should be loadable with 🤗 Datasets",
104
+ )
105
+ parser.add_argument(
106
+ "--config",
107
+ type=str,
108
+ required=True,
109
+ help="Config of the dataset. *E.g.* `'en'` for the English split of Common Voice",
110
+ )
111
+ parser.add_argument(
112
+ "--split",
113
+ type=str,
114
+ default="test",
115
+ help="Split of the dataset. *E.g.* `'test'`",
116
+ )
117
+
118
+ parser.add_argument(
119
+ "--device",
120
+ type=int,
121
+ default=-1,
122
+ help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
123
+ )
124
+ parser.add_argument(
125
+ "--batch_size",
126
+ type=int,
127
+ default=16,
128
+ help="Number of samples to go through each streamed batch.",
129
+ )
130
+ parser.add_argument(
131
+ "--max_eval_samples",
132
+ type=int,
133
+ default=None,
134
+ help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.",
135
+ )
136
+ parser.add_argument(
137
+ "--streaming",
138
+ type=bool,
139
+ default=True,
140
+ help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.",
141
+ )
142
+ parser.add_argument(
143
+ "--language",
144
+ type=str,
145
+ required=True,
146
+ help="Two letter language code for the transcription language, e.g. use 'en' for English.",
147
+ )
148
+ args = parser.parse_args()
149
+
150
+ main(args)
run_interleave.py CHANGED
@@ -61,7 +61,7 @@ from transformers.models.whisper.english_normalizer import BasicTextNormalizer
61
 
62
  #TEXT_COL_NAME="text"
63
  TEXT_COL_NAME="sentence,transcription"
64
- AUDIO_COL_NAME="audio,audio"
65
 
66
  # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
67
  check_min_version("4.25.0.dev0")
 
61
 
62
  #TEXT_COL_NAME="text"
63
  TEXT_COL_NAME="sentence,transcription"
64
+ AUDIO_COL_NAME="audio"
65
 
66
  # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
67
  check_min_version("4.25.0.dev0")
run_speech_recognition_seq2seq_streaming.py ADDED
@@ -0,0 +1,629 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for sequence to sequence speech recognition
18
+ with 🤗 Datasets' streaming mode.
19
+ """
20
+ # You can also adapt this script for your own sequence to sequence speech
21
+ # recognition task. Pointers for this are left as comments.
22
+
23
+ import logging
24
+ import os
25
+ import sys
26
+ from dataclasses import dataclass, field
27
+ from typing import Any, Dict, List, Optional, Union
28
+
29
+ import datasets
30
+ import torch
31
+ from datasets import DatasetDict, IterableDatasetDict, interleave_datasets, load_dataset
32
+ from torch.utils.data import IterableDataset
33
+
34
+ import evaluate
35
+ import transformers
36
+ from transformers import (
37
+ AutoConfig,
38
+ AutoFeatureExtractor,
39
+ AutoModelForSpeechSeq2Seq,
40
+ AutoProcessor,
41
+ AutoTokenizer,
42
+ HfArgumentParser,
43
+ Seq2SeqTrainer,
44
+ Seq2SeqTrainingArguments,
45
+ TrainerCallback,
46
+ set_seed,
47
+ )
48
+ from transformers.models.whisper.english_normalizer import BasicTextNormalizer
49
+ from transformers.trainer_pt_utils import IterableDatasetShard
50
+ from transformers.trainer_utils import get_last_checkpoint, is_main_process
51
+ from transformers.utils import check_min_version, send_example_telemetry
52
+ from transformers.utils.versions import require_version
53
+
54
+
55
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
56
+ check_min_version("4.25.0.dev0")
57
+
58
+ require_version("datasets>=1.18.2", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
59
+
60
+ logger = logging.getLogger(__name__)
61
+
62
+
63
+ @dataclass
64
+ class ModelArguments:
65
+ """
66
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
67
+ """
68
+
69
+ model_name_or_path: str = field(
70
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
71
+ )
72
+ config_name: Optional[str] = field(
73
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
74
+ )
75
+ tokenizer_name: Optional[str] = field(
76
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
77
+ )
78
+ feature_extractor_name: Optional[str] = field(
79
+ default=None, metadata={"help": "feature extractor name or path if not the same as model_name"}
80
+ )
81
+ cache_dir: Optional[str] = field(
82
+ default=None,
83
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
84
+ )
85
+ use_fast_tokenizer: bool = field(
86
+ default=True,
87
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
88
+ )
89
+ model_revision: str = field(
90
+ default="main",
91
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
92
+ )
93
+ use_auth_token: bool = field(
94
+ default=False,
95
+ metadata={
96
+ "help": (
97
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
98
+ "with private models)."
99
+ )
100
+ },
101
+ )
102
+ freeze_feature_encoder: bool = field(
103
+ default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
104
+ )
105
+ freeze_encoder: bool = field(
106
+ default=False, metadata={"help": "Whether to freeze the entire encoder of the seq2seq model."}
107
+ )
108
+ forced_decoder_ids: List[List[int]] = field(
109
+ default=None,
110
+ metadata={
111
+ "help": (
112
+ "A list of pairs of integers which indicates a mapping from generation indices to token indices "
113
+ "that will be forced before sampling. For example, [[0, 123]] means the first generated token "
114
+ "will always be a token of index 123."
115
+ )
116
+ },
117
+ )
118
+ suppress_tokens: List[int] = field(
119
+ default=None, metadata={"help": "A list of tokens that will be suppressed at generation."}
120
+ )
121
+ model_index_name: str = field(default=None, metadata={"help": "Pretty name for the model card."})
122
+
123
+
124
+ @dataclass
125
+ class DataTrainingArguments:
126
+ """
127
+ Arguments pertaining to what data we are going to input our model for training and eval.
128
+ """
129
+
130
+ dataset_name: str = field(
131
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
132
+ )
133
+ dataset_config_name: Optional[str] = field(
134
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
135
+ )
136
+ text_column: Optional[str] = field(
137
+ default=None,
138
+ metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
139
+ )
140
+ max_train_samples: Optional[int] = field(
141
+ default=None,
142
+ metadata={
143
+ "help": (
144
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
145
+ "value if set."
146
+ )
147
+ },
148
+ )
149
+ max_eval_samples: Optional[int] = field(
150
+ default=None,
151
+ metadata={
152
+ "help": (
153
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
154
+ "value if set."
155
+ )
156
+ },
157
+ )
158
+ audio_column_name: str = field(
159
+ default="audio",
160
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
161
+ )
162
+ text_column_name: str = field(
163
+ default="text",
164
+ metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
165
+ )
166
+ max_duration_in_seconds: float = field(
167
+ default=20.0,
168
+ metadata={
169
+ "help": (
170
+ "Truncate audio files that are longer than `max_duration_in_seconds` seconds to"
171
+ " 'max_duration_in_seconds`"
172
+ )
173
+ },
174
+ )
175
+ min_duration_in_seconds: float = field(
176
+ default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
177
+ )
178
+ train_split_name: str = field(
179
+ default="train",
180
+ metadata={
181
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
182
+ },
183
+ )
184
+ eval_split_name: str = field(
185
+ default="test",
186
+ metadata={
187
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
188
+ },
189
+ )
190
+ do_lower_case: bool = field(
191
+ default=False,
192
+ metadata={"help": "Whether the target text should be lower cased."},
193
+ )
194
+ do_remove_punctuation: bool = field(
195
+ default=False,
196
+ metadata={"help": "Whether the target text should be striped of punctuation."},
197
+ )
198
+ do_normalize_eval: bool = field(
199
+ default=True,
200
+ metadata={"help": "Whether to normalise the references and predictions in the eval WER calculation."},
201
+ )
202
+ language: str = field(
203
+ default=None,
204
+ metadata={
205
+ "help": (
206
+ "Language for multilingual fine-tuning. This argument should be set for multilingual fine-tuning "
207
+ "only. For English speech recognition, it should be set to `None`."
208
+ )
209
+ },
210
+ )
211
+ task: str = field(
212
+ default="transcribe",
213
+ metadata={"help": "Task, either `transcribe` for speech recognition or `translate` for speech translation."},
214
+ )
215
+ shuffle_buffer_size: Optional[int] = field(
216
+ default=500,
217
+ metadata={
218
+ "help": (
219
+ "The number of streamed examples to download before shuffling them. The large the buffer, "
220
+ "the closer it is to real offline shuffling."
221
+ )
222
+ },
223
+ )
224
+ streaming: bool = field(
225
+ default=True,
226
+ metadata={"help": "Whether to use streaming mode to load and pre-process the data."},
227
+ )
228
+
229
+
230
+ @dataclass
231
+ class DataCollatorSpeechSeq2SeqWithPadding:
232
+ """
233
+ Data collator that will dynamically pad the inputs received.
234
+ Args:
235
+ processor ([`WhisperProcessor`])
236
+ The processor used for processing the data.
237
+ decoder_start_token_id (`int`)
238
+ The begin-of-sentence of the decoder.
239
+ """
240
+
241
+ processor: Any
242
+ decoder_start_token_id: int
243
+
244
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
245
+ # split inputs and labels since they have to be of different lengths and need
246
+ # different padding methods
247
+ model_input_name = self.processor.model_input_names[0]
248
+ input_features = [{model_input_name: feature[model_input_name]} for feature in features]
249
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
250
+
251
+ batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
252
+
253
+ labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
254
+
255
+ # replace padding with -100 to ignore loss correctly
256
+ labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
257
+
258
+ # if bos token is appended in previous tokenization step,
259
+ # cut bos token here as it's append later anyways
260
+ if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
261
+ labels = labels[:, 1:]
262
+
263
+ batch["labels"] = labels
264
+
265
+ return batch
266
+
267
+
268
+ def load_maybe_streaming_dataset(dataset_name, dataset_config_name, split="train", streaming=True, **kwargs):
269
+ """
270
+ Utility function to load a dataset in streaming mode. For datasets with multiple splits,
271
+ each split is loaded individually and then splits combined by taking alternating examples from
272
+ each (interleaving).
273
+ """
274
+ if "+" in split:
275
+ # load multiple splits separated by the `+` symbol with streaming mode
276
+ dataset_splits = [
277
+ load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=streaming, **kwargs)
278
+ for split_name in split.split("+")
279
+ ]
280
+ # interleave multiple splits to form one dataset
281
+ interleaved_dataset = interleave_datasets(dataset_splits)
282
+ return interleaved_dataset
283
+ else:
284
+ # load a single split *with* streaming mode
285
+ dataset = load_dataset(dataset_name, dataset_config_name, split=split, streaming=streaming, **kwargs)
286
+ return dataset
287
+
288
+
289
+ def main():
290
+ # 1. Parse input arguments
291
+ # See all possible arguments in src/transformers/training_args.py
292
+ # or by passing the --help flag to this script.
293
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
294
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
295
+
296
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
297
+ # If we pass only one argument to the script and it's the path to a json file,
298
+ # let's parse it to get our arguments.
299
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
300
+ else:
301
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
302
+
303
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
304
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
305
+ send_example_telemetry("run_speech_recognition_seq2seq_streaming", model_args, data_args)
306
+
307
+ # 2. Setup logging
308
+ logging.basicConfig(
309
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
310
+ datefmt="%m/%d/%Y %H:%M:%S",
311
+ handlers=[logging.StreamHandler(sys.stdout)],
312
+ )
313
+ log_level = training_args.get_process_log_level()
314
+ logger.setLevel(log_level)
315
+ datasets.utils.logging.set_verbosity(log_level)
316
+ transformers.utils.logging.set_verbosity(log_level)
317
+ transformers.utils.logging.enable_default_handler()
318
+ transformers.utils.logging.enable_explicit_format()
319
+
320
+ logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
321
+
322
+ # Log on each process the small summary:
323
+ logger.warning(
324
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
325
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
326
+ )
327
+ logger.info(f"Training/evaluation parameters {training_args}")
328
+
329
+ # Set the verbosity to info of the Transformers logger (on main process only):
330
+ if is_main_process(training_args.local_rank):
331
+ transformers.utils.logging.set_verbosity_info()
332
+ logger.info("Training/evaluation parameters %s", training_args)
333
+
334
+ # 3. Detecting last checkpoint and eventually continue from last checkpoint
335
+ last_checkpoint = None
336
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
337
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
338
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
339
+ raise ValueError(
340
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
341
+ "Use --overwrite_output_dir to overcome."
342
+ )
343
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
344
+ logger.info(
345
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
346
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
347
+ )
348
+
349
+ # Set seed before initializing model.
350
+ set_seed(training_args.seed)
351
+
352
+ # 4. Load dataset
353
+ raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
354
+
355
+ if training_args.do_train:
356
+ raw_datasets["train"] = load_maybe_streaming_dataset(
357
+ data_args.dataset_name,
358
+ data_args.dataset_config_name,
359
+ split=data_args.train_split_name,
360
+ use_auth_token=True if model_args.use_auth_token else None,
361
+ streaming=data_args.streaming,
362
+ )
363
+
364
+ if training_args.do_eval:
365
+ raw_datasets["eval"] = load_maybe_streaming_dataset(
366
+ data_args.dataset_name,
367
+ data_args.dataset_config_name,
368
+ split=data_args.eval_split_name,
369
+ use_auth_token=True if model_args.use_auth_token else None,
370
+ streaming=data_args.streaming,
371
+ )
372
+
373
+ raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
374
+
375
+ if data_args.audio_column_name not in raw_datasets_features:
376
+ raise ValueError(
377
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
378
+ "Make sure to set `--audio_column_name` to the correct audio column - one of "
379
+ f"{', '.join(raw_datasets_features)}."
380
+ )
381
+
382
+ if data_args.text_column_name not in raw_datasets_features:
383
+ raise ValueError(
384
+ f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
385
+ "Make sure to set `--text_column_name` to the correct text column - one of "
386
+ f"{', '.join(raw_datasets_features)}."
387
+ )
388
+
389
+ # 5. Load pretrained model, tokenizer, and feature extractor
390
+ #
391
+ # Distributed training:
392
+ # The .from_pretrained methods guarantee that only one local process can concurrently
393
+ config = AutoConfig.from_pretrained(
394
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
395
+ cache_dir=model_args.cache_dir,
396
+ revision=model_args.model_revision,
397
+ use_auth_token=True if model_args.use_auth_token else None,
398
+ )
399
+
400
+ config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens})
401
+
402
+ if training_args.gradient_checkpointing:
403
+ config.update({"use_cache": False})
404
+
405
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
406
+ model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
407
+ cache_dir=model_args.cache_dir,
408
+ revision=model_args.model_revision,
409
+ use_auth_token=True if model_args.use_auth_token else None,
410
+ )
411
+ tokenizer = AutoTokenizer.from_pretrained(
412
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
413
+ cache_dir=model_args.cache_dir,
414
+ use_fast=model_args.use_fast_tokenizer,
415
+ revision=model_args.model_revision,
416
+ use_auth_token=True if model_args.use_auth_token else None,
417
+ )
418
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
419
+ model_args.model_name_or_path,
420
+ config=config,
421
+ cache_dir=model_args.cache_dir,
422
+ revision=model_args.model_revision,
423
+ use_auth_token=True if model_args.use_auth_token else None,
424
+ )
425
+
426
+ if model.config.decoder_start_token_id is None:
427
+ raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
428
+
429
+ if model_args.freeze_feature_encoder:
430
+ model.freeze_feature_encoder()
431
+
432
+ if model_args.freeze_encoder:
433
+ model.freeze_encoder()
434
+
435
+ if data_args.language is not None:
436
+ # We only need to set the task id when the language is specified (i.e. in a multilingual setting)
437
+ tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
438
+
439
+ # 6. Resample speech dataset if necessary
440
+ dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
441
+ if dataset_sampling_rate != feature_extractor.sampling_rate:
442
+ raw_datasets = raw_datasets.cast_column(
443
+ data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
444
+ )
445
+
446
+ # 7. Preprocessing the datasets.
447
+ # We need to read the audio files as arrays and tokenize the targets.
448
+ max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
449
+ min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
450
+ audio_column_name = data_args.audio_column_name
451
+ text_column_name = data_args.text_column_name
452
+ model_input_name = feature_extractor.model_input_names[0]
453
+ do_lower_case = data_args.do_lower_case
454
+ do_remove_punctuation = data_args.do_remove_punctuation
455
+ normalizer = BasicTextNormalizer() # 'official' text normalizer from OpenAI
456
+
457
+ if data_args.max_train_samples is not None:
458
+ raw_datasets["train"] = (
459
+ raw_datasets["train"].take(data_args.max_train_samples)
460
+ if data_args.streaming
461
+ else raw_datasets["train"].select(range(data_args.max_train_samples))
462
+ )
463
+
464
+ if data_args.max_eval_samples is not None:
465
+ raw_datasets["eval"] = (
466
+ raw_datasets["eval"].take(data_args.max_eval_samples)
467
+ if data_args.streaming
468
+ else raw_datasets["eval"].select(range(data_args.max_eval_samples))
469
+ )
470
+
471
+ def prepare_dataset(batch):
472
+ # process audio
473
+ sample = batch[audio_column_name]
474
+ inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
475
+ # process audio length
476
+ batch[model_input_name] = inputs.get(model_input_name)[0]
477
+ batch["input_length"] = len(sample["array"])
478
+
479
+ # process targets
480
+ input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
481
+ if do_remove_punctuation:
482
+ input_str = normalizer(input_str).strip()
483
+ batch["labels"] = tokenizer(input_str).input_ids
484
+ return batch
485
+
486
+ with training_args.main_process_first(desc="dataset map pre-processing"):
487
+ vectorized_datasets = raw_datasets.map(
488
+ prepare_dataset,
489
+ remove_columns=raw_datasets_features,
490
+ ).with_format("torch")
491
+
492
+ if training_args.do_train and data_args.streaming:
493
+ # manually shuffle if streaming (done by the trainer for non-streaming)
494
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(
495
+ buffer_size=data_args.shuffle_buffer_size,
496
+ seed=training_args.seed,
497
+ )
498
+
499
+ # filter training data that is shorter than min_input_length or longer than
500
+ # max_input_length
501
+ def is_audio_in_length_range(length):
502
+ return min_input_length < length < max_input_length
503
+
504
+ if training_args.do_train:
505
+ vectorized_datasets["train"] = vectorized_datasets["train"].filter(
506
+ is_audio_in_length_range,
507
+ input_columns=["input_length"],
508
+ )
509
+
510
+ # 8. Load Metric
511
+ metric = evaluate.load("wer")
512
+ do_normalize_eval = data_args.do_normalize_eval
513
+
514
+ def compute_metrics(pred):
515
+ pred_ids = pred.predictions
516
+
517
+ pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id
518
+
519
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
520
+ # we do not want to group tokens when computing the metrics
521
+ label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)
522
+
523
+ if do_normalize_eval:
524
+ pred_str = [normalizer(pred) for pred in pred_str]
525
+ label_str = [normalizer(label) for label in label_str]
526
+ # filtering step to only evaluate the samples that correspond to non-zero references:
527
+ pred_str = [pred_str[i] for i in range(len(pred_str)) if len(label_str[i]) > 0]
528
+ label_str = [label_str[i] for i in range(len(label_str)) if len(label_str[i]) > 0]
529
+
530
+ wer = 100 * metric.compute(predictions=pred_str, references=label_str)
531
+
532
+ return {"wer": wer}
533
+
534
+ # 9. Create a single speech processor
535
+ if is_main_process(training_args.local_rank):
536
+ # save feature extractor, tokenizer and config
537
+ feature_extractor.save_pretrained(training_args.output_dir)
538
+ tokenizer.save_pretrained(training_args.output_dir)
539
+ config.save_pretrained(training_args.output_dir)
540
+
541
+ processor = AutoProcessor.from_pretrained(training_args.output_dir)
542
+
543
+ # 10. Define data collator
544
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(
545
+ processor=processor,
546
+ decoder_start_token_id=model.config.decoder_start_token_id,
547
+ )
548
+
549
+ # 11. Configure Trainer
550
+ # Trainer callback to reinitialise and reshuffle the streamable datasets at the beginning of each epoch
551
+ # Only required for streaming: Trainer automatically shuffles non-streaming datasets
552
+ class ShuffleCallback(TrainerCallback):
553
+ def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):
554
+ if isinstance(train_dataloader.dataset, IterableDatasetShard):
555
+ pass # set_epoch() is handled by the Trainer
556
+ elif isinstance(train_dataloader.dataset, IterableDataset):
557
+ train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)
558
+
559
+ # Initialize Trainer
560
+ trainer = Seq2SeqTrainer(
561
+ model=model,
562
+ args=training_args,
563
+ train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
564
+ eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
565
+ tokenizer=feature_extractor,
566
+ data_collator=data_collator,
567
+ compute_metrics=compute_metrics if training_args.predict_with_generate else None,
568
+ callbacks=[ShuffleCallback()] if data_args.streaming else None,
569
+ )
570
+
571
+ # 12. Training
572
+ if training_args.do_train:
573
+ checkpoint = None
574
+ if training_args.resume_from_checkpoint is not None:
575
+ checkpoint = training_args.resume_from_checkpoint
576
+ elif last_checkpoint is not None:
577
+ checkpoint = last_checkpoint
578
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
579
+ trainer.save_model() # Saves the feature extractor too for easy upload
580
+
581
+ metrics = train_result.metrics
582
+ if data_args.max_train_samples:
583
+ metrics["train_samples"] = data_args.max_train_samples
584
+ trainer.log_metrics("train", metrics)
585
+ trainer.save_metrics("train", metrics)
586
+ trainer.save_state()
587
+
588
+ # 13. Evaluation
589
+ results = {}
590
+ if training_args.do_eval:
591
+ logger.info("*** Evaluate ***")
592
+ metrics = trainer.evaluate(
593
+ metric_key_prefix="eval",
594
+ max_length=training_args.generation_max_length,
595
+ num_beams=training_args.generation_num_beams,
596
+ )
597
+ if data_args.max_eval_samples:
598
+ metrics["eval_samples"] = data_args.max_eval_samples
599
+
600
+ trainer.log_metrics("eval", metrics)
601
+ trainer.save_metrics("eval", metrics)
602
+
603
+ # 14. Write Training Stats
604
+ kwargs = {
605
+ "finetuned_from": model_args.model_name_or_path,
606
+ "tasks": "automatic-speech-recognition",
607
+ "tags": "whisper-event",
608
+ }
609
+ if data_args.dataset_name is not None:
610
+ kwargs["dataset_tags"] = data_args.dataset_name
611
+ if data_args.dataset_config_name is not None:
612
+ kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
613
+ else:
614
+ kwargs["dataset"] = data_args.dataset_name
615
+ if "common_voice" in data_args.dataset_name:
616
+ kwargs["language"] = data_args.dataset_config_name[:2]
617
+ if model_args.model_index_name is not None:
618
+ kwargs["model_name"] = model_args.model_index_name
619
+
620
+ if training_args.push_to_hub:
621
+ trainer.push_to_hub(**kwargs)
622
+ else:
623
+ trainer.create_model_card(**kwargs)
624
+
625
+ return results
626
+
627
+
628
+ if __name__ == "__main__":
629
+ main()
run_whisper-md-el-intlv-xs.sh CHANGED
@@ -29,6 +29,7 @@ python run_interleave.py \
29
  --dropout 0.1 \
30
  --warmup_steps 500 \
31
  --max_steps 5000 \
 
32
  --eval_steps 1000 \
33
  --gradient_checkpointing True \
34
  --cache_dir '~/.cache' \
@@ -42,8 +43,7 @@ python run_interleave.py \
42
  --load_best_model_at_end True \
43
  --metric_for_best_model wer \
44
  --greater_is_better False \
45
- --push_to_hub False
46
 
47
 
48
- #
49
 
 
29
  --dropout 0.1 \
30
  --warmup_steps 500 \
31
  --max_steps 5000 \
32
+ --resume_from_checkpoint="4000" \
33
  --eval_steps 1000 \
34
  --gradient_checkpointing True \
35
  --cache_dir '~/.cache' \
 
43
  --load_best_model_at_end True \
44
  --metric_for_best_model wer \
45
  --greater_is_better False \
46
+ --push_to_hub True
47
 
48
 
 
49
 
runs/Dec13_00-20-31_150-136-33-0/1670891059.409924/events.out.tfevents.1670891059.150-136-33-0.3897722.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd75ccf3b1bd98778e44cae4793e895aeec95a0d406fd4842436a2557e542e15
3
+ size 5917
runs/Dec13_00-20-31_150-136-33-0/events.out.tfevents.1670891059.150-136-33-0.3897722.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c9c9398f6ed427bd41ff5741f352098952e1b913961b916e2d365d493d692c78
3
+ size 4344
runs/Dec13_00-26-09_150-136-33-0/1670891185.5348835/events.out.tfevents.1670891185.150-136-33-0.3897722.3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:35265950f470f0bb0bc82f0f6bb1ea4a83bc795bf5d255645f4f880d140d506a
3
+ size 5917
runs/Dec13_00-26-09_150-136-33-0/events.out.tfevents.1670891185.150-136-33-0.3897722.2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94d957e6f71dfd826ff305f81cafa3d73556930b286f440116360b825b11332f
3
+ size 4343
runs/Dec13_00-29-21_150-136-33-0/1670891404.5181074/events.out.tfevents.1670891404.150-136-33-0.127865.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2de5049f3bf41fae21055257bab18c7324d889c6f89810468c3ba3c622d9eb3b
3
+ size 5917
runs/Dec13_00-29-21_150-136-33-0/1670891530.7157552/events.out.tfevents.1670891530.150-136-33-0.127865.2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:367be3341ff0617872ec5a6d9835404033aeb617bd43a6c309b3eaf37b96a8c1
3
+ size 5917
runs/Dec13_00-29-21_150-136-33-0/events.out.tfevents.1670891404.150-136-33-0.127865.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:49d3593a4263cbf140645c09337c65104f6b6257df0e3fdff3f10a4027c867c4
3
+ size 9354
runs/Dec13_00-33-53_150-136-33-0/1670891644.5885968/events.out.tfevents.1670891644.150-136-33-0.127865.4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33cf05b8a4dd04ead085260d1bbb46d5e480d28366bcf7be1970bc9b5e4eced2
3
+ size 5917
runs/Dec13_00-33-53_150-136-33-0/events.out.tfevents.1670891644.150-136-33-0.127865.3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0417f0824bb1e400e36093b415ea327c75a275f99909bd52a68f2639b4a2d2db
3
+ size 4343
runs/Dec13_00-36-22_150-136-33-0/1670891797.5398705/events.out.tfevents.1670891797.150-136-33-0.152213.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:707fb35872449414d5446b923f39ed46d3efaf89754df4c9cdff9d2ad7e27dc3
3
+ size 5917
runs/Dec13_00-36-22_150-136-33-0/events.out.tfevents.1670891797.150-136-33-0.152213.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f972921d72e7ec59f8f92edf0ffa2b0b0a0b3e0f1000b294806893018aba7201
3
+ size 10941
training_args.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a3ebd5ec6b9b2945948080568a87ff00951d03abd4956eab0351149cc37e37fe
3
- size 3579
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dccf5c02020cb4db0b8e5e421cc7fbb47dfb469165fb362c0937177d17244b57
3
+ size 3643