kimbochen commited on
Commit
9fcc5b4
1 Parent(s): 1925aa2

Training in progress, step 200

Browse files
Files changed (25) hide show
  1. .ipynb_checkpoints/fine-tune-whisper-streaming-checkpoint.ipynb +156 -368
  2. config.json +0 -1
  3. fine-tune-whisper-streaming.ipynb +115 -694
  4. pytorch_model.bin +1 -1
  5. run.sh +38 -0
  6. run_speech_recognition_seq2seq_streaming.py +644 -0
  7. runs/Dec12_18-34-55_129-213-131-105/1670870162.6479564/events.out.tfevents.1670870162.129-213-131-105.68826.1 +3 -0
  8. runs/Dec12_18-34-55_129-213-131-105/events.out.tfevents.1670870162.129-213-131-105.68826.0 +3 -0
  9. runs/Dec12_19-04-31_129-213-131-105/1670871909.2493246/events.out.tfevents.1670871909.129-213-131-105.451160.1 +3 -0
  10. runs/Dec12_19-04-31_129-213-131-105/events.out.tfevents.1670871909.129-213-131-105.451160.0 +3 -0
  11. runs/Dec12_20-09-15_129-213-131-105/1670875765.5760763/events.out.tfevents.1670875765.129-213-131-105.451160.3 +3 -0
  12. runs/Dec12_20-09-15_129-213-131-105/events.out.tfevents.1670875765.129-213-131-105.451160.2 +3 -0
  13. runs/Dec12_20-11-02_129-213-131-105/1670875868.8091414/events.out.tfevents.1670875868.129-213-131-105.451160.5 +3 -0
  14. runs/Dec12_20-11-02_129-213-131-105/events.out.tfevents.1670875868.129-213-131-105.451160.4 +3 -0
  15. runs/Dec12_20-13-20_129-213-131-105/1670876009.054387/events.out.tfevents.1670876009.129-213-131-105.983201.1 +3 -0
  16. runs/Dec12_20-13-20_129-213-131-105/events.out.tfevents.1670876009.129-213-131-105.983201.0 +3 -0
  17. runs/Dec12_21-41-07_129-213-131-105/1670881275.6468236/events.out.tfevents.1670881275.129-213-131-105.1284650.1 +3 -0
  18. runs/Dec12_21-41-07_129-213-131-105/events.out.tfevents.1670881275.129-213-131-105.1284650.0 +3 -0
  19. runs/Dec12_21-43-12_129-213-131-105/1670881400.3312242/events.out.tfevents.1670881400.129-213-131-105.1319036.1 +3 -0
  20. runs/Dec12_21-43-12_129-213-131-105/events.out.tfevents.1670881400.129-213-131-105.1319036.0 +3 -0
  21. runs/Dec12_21-47-11_129-213-131-105/1670881639.3589363/events.out.tfevents.1670881639.129-213-131-105.1364959.1 +3 -0
  22. runs/Dec12_21-47-11_129-213-131-105/events.out.tfevents.1670881639.129-213-131-105.1364959.0 +3 -0
  23. runs/Dec12_21-54-54_129-213-131-105/1670882102.7244208/events.out.tfevents.1670882102.129-213-131-105.1405782.1 +3 -0
  24. runs/Dec12_21-54-54_129-213-131-105/events.out.tfevents.1670882102.129-213-131-105.1405782.0 +3 -0
  25. training_args.bin +1 -1
.ipynb_checkpoints/fine-tune-whisper-streaming-checkpoint.ipynb CHANGED
@@ -19,7 +19,10 @@
19
  {
20
  "cell_type": "markdown",
21
  "id": "afe0d503-ae4e-4aa7-9af4-dbcba52db41e",
22
- "metadata": {},
 
 
 
23
  "source": [
24
  "## Introduction"
25
  ]
@@ -108,10 +111,19 @@
108
  },
109
  {
110
  "cell_type": "code",
111
- "execution_count": 5,
112
  "id": "065a8cf7-e54f-4ac3-900e-609c80714fca",
113
  "metadata": {},
114
- "outputs": [],
 
 
 
 
 
 
 
 
 
115
  "source": [
116
  "from datasets import interleave_datasets, load_dataset\n",
117
  "\n",
@@ -142,7 +154,7 @@
142
  },
143
  {
144
  "cell_type": "code",
145
- "execution_count": 6,
146
  "id": "a2787582-554f-44ce-9f38-4180a5ed6b44",
147
  "metadata": {},
148
  "outputs": [],
@@ -185,7 +197,7 @@
185
  },
186
  {
187
  "cell_type": "code",
188
- "execution_count": 7,
189
  "id": "77d9f0c5-8607-4642-a8ac-c3ab2e223ea6",
190
  "metadata": {},
191
  "outputs": [],
@@ -213,7 +225,7 @@
213
  },
214
  {
215
  "cell_type": "code",
216
- "execution_count": 8,
217
  "id": "ab5a13b4-9bd4-4aa0-aef2-b3de9b762988",
218
  "metadata": {},
219
  "outputs": [
@@ -233,7 +245,7 @@
233
  " 'segment': Value(dtype='string', id=None)}"
234
  ]
235
  },
236
- "execution_count": 8,
237
  "metadata": {},
238
  "output_type": "execute_result"
239
  }
@@ -259,7 +271,7 @@
259
  },
260
  {
261
  "cell_type": "code",
262
- "execution_count": 9,
263
  "id": "3ab6a724-3d1e-478b-a9e9-d2f85feb6c39",
264
  "metadata": {},
265
  "outputs": [],
@@ -279,7 +291,7 @@
279
  },
280
  {
281
  "cell_type": "code",
282
- "execution_count": 10,
283
  "id": "d041650e-1c48-4439-87b3-5b6f4a514107",
284
  "metadata": {},
285
  "outputs": [],
@@ -306,7 +318,29 @@
306
  },
307
  {
308
  "cell_type": "code",
309
- "execution_count": 11,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  "id": "c085911c-a10a-41ef-8874-306e0503e9bb",
311
  "metadata": {},
312
  "outputs": [],
@@ -326,9 +360,12 @@
326
  " transcription = transcription.lower()\n",
327
  " if do_remove_punctuation:\n",
328
  " transcription = normalizer(transcription).strip()\n",
329
- " \n",
 
 
 
330
  " # encode target text to label ids\n",
331
- " batch[\"labels\"] = processor.tokenizer(transcription).input_ids\n",
332
  " return batch"
333
  ]
334
  },
@@ -342,7 +379,7 @@
342
  },
343
  {
344
  "cell_type": "code",
345
- "execution_count": 12,
346
  "id": "a37a7cdb-9013-427f-8de9-6a8d0e9dc684",
347
  "metadata": {},
348
  "outputs": [],
@@ -360,7 +397,7 @@
360
  },
361
  {
362
  "cell_type": "code",
363
- "execution_count": 13,
364
  "id": "1b145699-acfc-4b1d-93a2-a2ad3d62674c",
365
  "metadata": {},
366
  "outputs": [],
@@ -381,7 +418,7 @@
381
  },
382
  {
383
  "cell_type": "code",
384
- "execution_count": 14,
385
  "id": "01cb25ef-4bb0-4325-9461-f59198acadf6",
386
  "metadata": {},
387
  "outputs": [],
@@ -402,7 +439,7 @@
402
  },
403
  {
404
  "cell_type": "code",
405
- "execution_count": 15,
406
  "id": "333f7f6e-6053-4d3b-8924-c733c79b82ac",
407
  "metadata": {},
408
  "outputs": [],
@@ -413,6 +450,63 @@
413
  ")"
414
  ]
415
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
  {
417
  "cell_type": "markdown",
418
  "id": "263a5a58-0239-4a25-b0df-c625fc9c5810",
@@ -550,22 +644,7 @@
550
  "execution_count": 18,
551
  "id": "b22b4011-f31f-4b57-b684-c52332f92890",
552
  "metadata": {},
553
- "outputs": [
554
- {
555
- "data": {
556
- "application/vnd.jupyter.widget-view+json": {
557
- "model_id": "bffdd7b1fed44295954d9eed41a9cfd5",
558
- "version_major": 2,
559
- "version_minor": 0
560
- },
561
- "text/plain": [
562
- "Downloading builder script: 0%| | 0.00/4.49k [00:00<?, ?B/s]"
563
- ]
564
- },
565
- "metadata": {},
566
- "output_type": "display_data"
567
- }
568
- ],
569
  "source": [
570
  "import evaluate\n",
571
  "\n",
@@ -644,36 +723,7 @@
644
  "execution_count": 20,
645
  "id": "5a10cc4b-07ec-4ebd-ac1d-7c601023594f",
646
  "metadata": {},
647
- "outputs": [
648
- {
649
- "data": {
650
- "application/vnd.jupyter.widget-view+json": {
651
- "model_id": "48fee2fd3b2a4a67b3a35666fda4dfe9",
652
- "version_major": 2,
653
- "version_minor": 0
654
- },
655
- "text/plain": [
656
- "Downloading: 0%| | 0.00/1.97k [00:00<?, ?B/s]"
657
- ]
658
- },
659
- "metadata": {},
660
- "output_type": "display_data"
661
- },
662
- {
663
- "data": {
664
- "application/vnd.jupyter.widget-view+json": {
665
- "model_id": "51cdba284e8f44318868fbd013970280",
666
- "version_major": 2,
667
- "version_minor": 0
668
- },
669
- "text/plain": [
670
- "Downloading: 0%| | 0.00/967M [00:00<?, ?B/s]"
671
- ]
672
- },
673
- "metadata": {},
674
- "output_type": "display_data"
675
- }
676
- ],
677
  "source": [
678
  "from transformers import WhisperForConditionalGeneration\n",
679
  "\n",
@@ -708,6 +758,34 @@
708
  "### Define the Training Configuration"
709
  ]
710
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
711
  {
712
  "cell_type": "markdown",
713
  "id": "c21af1e9-0188-4134-ac82-defc7bdcc436",
@@ -718,7 +796,7 @@
718
  },
719
  {
720
  "cell_type": "code",
721
- "execution_count": 22,
722
  "id": "0ae3e9af-97b7-4aa0-ae85-20b23b5bcb3a",
723
  "metadata": {},
724
  "outputs": [],
@@ -730,16 +808,16 @@
730
  " per_device_train_batch_size=64,\n",
731
  " gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size\n",
732
  " learning_rate=1e-5,\n",
733
- " warmup_steps=200,\n",
734
  " max_steps=1000,\n",
735
  " gradient_checkpointing=True,\n",
736
  " fp16=True,\n",
737
  " evaluation_strategy=\"steps\",\n",
738
- " per_device_eval_batch_size=8,\n",
739
  " predict_with_generate=True,\n",
740
  " generation_max_length=225,\n",
741
  " save_steps=200,\n",
742
- " eval_steps=200,\n",
743
  " logging_steps=25,\n",
744
  " report_to=[\"tensorboard\"],\n",
745
  " load_best_model_at_end=True,\n",
@@ -758,34 +836,6 @@
758
  "set `push_to_hub=False`."
759
  ]
760
  },
761
- {
762
- "cell_type": "markdown",
763
- "id": "393c883e-3e50-492c-bd58-f51dbf15ee56",
764
- "metadata": {},
765
- "source": [
766
- "We then define a custom [Callback](https://huggingface.co/docs/transformers/main_classes/callback) that is called by the 🤗 Trainer on the end of each epoch. The Callback reinitialises and reshuffles the streaming dataset at the beginning of each new epoch - this gives different shuffling across our subsets for every epoch."
767
- ]
768
- },
769
- {
770
- "cell_type": "code",
771
- "execution_count": 23,
772
- "id": "3ac16b62-b3c0-4c68-8f3d-9ecf471534b2",
773
- "metadata": {},
774
- "outputs": [],
775
- "source": [
776
- "from transformers import TrainerCallback\n",
777
- "from transformers.trainer_pt_utils import IterableDatasetShard\n",
778
- "from torch.utils.data import IterableDataset\n",
779
- "\n",
780
- "# trainer callback to reinitialise and reshuffle the streamable datasets at the beginning of each epoch\n",
781
- "class ShuffleCallback(TrainerCallback):\n",
782
- " def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):\n",
783
- " if isinstance(train_dataloader.dataset, IterableDatasetShard):\n",
784
- " pass # set_epoch() is handled by the Trainer\n",
785
- " elif isinstance(train_dataloader.dataset, IterableDataset):\n",
786
- " train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)"
787
- ]
788
- },
789
  {
790
  "cell_type": "markdown",
791
  "id": "bac29114-d226-4f54-97cf-8718c9f94e1e",
@@ -884,7 +934,9 @@
884
  "cell_type": "code",
885
  "execution_count": null,
886
  "id": "ee8b7b8e-1c9a-4d77-9137-1778a629e6de",
887
- "metadata": {},
 
 
888
  "outputs": [
889
  {
890
  "name": "stderr",
@@ -900,8 +952,8 @@
900
  " Gradient Accumulation steps = 1\n",
901
  " Total optimization steps = 1000\n",
902
  " Number of trainable parameters = 241734912\n",
903
- "Reading metadata...: 6505it [00:00, 31331.40it/s]\n",
904
- "Reading metadata...: 4485it [00:00, 41376.86it/s]\n",
905
  "The following columns in the training set don't have a corresponding argument in `WhisperForConditionalGeneration.forward` and have been ignored: input_length. If input_length are not expected by `WhisperForConditionalGeneration.forward`, you can safely ignore this message.\n"
906
  ]
907
  },
@@ -911,8 +963,8 @@
911
  "\n",
912
  " <div>\n",
913
  " \n",
914
- " <progress value='201' max='1000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
915
- " [ 201/1000 22:31 < 1:30:27, 0.15 it/s, Epoch 1.06/9223372036854775807]\n",
916
  " </div>\n",
917
  " <table border=\"1\" class=\"dataframe\">\n",
918
  " <thead>\n",
@@ -937,12 +989,10 @@
937
  "name": "stderr",
938
  "output_type": "stream",
939
  "text": [
940
- "Reading metadata...: 6505it [00:00, 64162.65it/s]\n",
941
- "Reading metadata...: 4485it [00:00, 27834.06it/s]\n",
942
  "***** Running Evaluation *****\n",
943
  " Num examples: Unknown\n",
944
- " Batch size = 8\n",
945
- "Reading metadata...: 4604it [00:00, 27155.92it/s]\n",
946
  "The following columns in the evaluation set don't have a corresponding argument in `WhisperForConditionalGeneration.forward` and have been ignored: input_length. If input_length are not expected by `WhisperForConditionalGeneration.forward`, you can safely ignore this message.\n"
947
  ]
948
  }
@@ -973,7 +1023,7 @@
973
  },
974
  {
975
  "cell_type": "code",
976
- "execution_count": 24,
977
  "id": "6dd0e310-9b07-4133-ac14-2ed2d7524e22",
978
  "metadata": {},
979
  "outputs": [],
@@ -999,275 +1049,13 @@
999
  },
1000
  {
1001
  "cell_type": "code",
1002
- "execution_count": 31,
1003
  "id": "95737cda-c5dd-4887-a4d0-dfcb0d61d977",
1004
  "metadata": {},
1005
- "outputs": [
1006
- {
1007
- "name": "stderr",
1008
- "output_type": "stream",
1009
- "text": [
1010
- "Saving model checkpoint to ./\n",
1011
- "Configuration saved in ./config.json\n",
1012
- "Model weights saved in ./pytorch_model.bin\n",
1013
- "Feature extractor saved in ./preprocessor_config.json\n",
1014
- "tokenizer config file saved in ./tokenizer_config.json\n",
1015
- "Special tokens file saved in ./special_tokens_map.json\n",
1016
- "added tokens file saved in ./added_tokens.json\n"
1017
- ]
1018
- },
1019
- {
1020
- "data": {
1021
- "application/vnd.jupyter.widget-view+json": {
1022
- "model_id": "695c170663c94560a567be198b7181ff",
1023
- "version_major": 2,
1024
- "version_minor": 0
1025
- },
1026
- "text/plain": [
1027
- "Upload file runs/Dec10_16-23-25_129-213-27-84/1670689420.7830398/events.out.tfevents.1670689420.129-213-27-84.…"
1028
- ]
1029
- },
1030
- "metadata": {},
1031
- "output_type": "display_data"
1032
- },
1033
- {
1034
- "data": {
1035
- "application/vnd.jupyter.widget-view+json": {
1036
- "model_id": "2318836d6dd3405fabafca4370232e34",
1037
- "version_major": 2,
1038
- "version_minor": 0
1039
- },
1040
- "text/plain": [
1041
- "Upload file training_args.bin: 100%|##########| 3.50k/3.50k [00:00<?, ?B/s]"
1042
- ]
1043
- },
1044
- "metadata": {},
1045
- "output_type": "display_data"
1046
- },
1047
- {
1048
- "data": {
1049
- "application/vnd.jupyter.widget-view+json": {
1050
- "model_id": "9b673eb134984bdda227d23929b66479",
1051
- "version_major": 2,
1052
- "version_minor": 0
1053
- },
1054
- "text/plain": [
1055
- "Upload file runs/Dec10_16-23-25_129-213-27-84/events.out.tfevents.1670689420.129-213-27-84.69598.2: 100%|#####…"
1056
- ]
1057
- },
1058
- "metadata": {},
1059
- "output_type": "display_data"
1060
- },
1061
- {
1062
- "name": "stderr",
1063
- "output_type": "stream",
1064
- "text": [
1065
- "remote: Scanning LFS files for validity, may be slow... \n",
1066
- "remote: LFS file scan complete. \n",
1067
- "To https://huggingface.co/kimbochen/whisper-small-jp\n",
1068
- " 3a44fa5..05da956 main -> main\n",
1069
- "\n",
1070
- "To https://huggingface.co/kimbochen/whisper-small-jp\n",
1071
- " 05da956..30906c5 main -> main\n",
1072
- "\n"
1073
- ]
1074
- },
1075
- {
1076
- "data": {
1077
- "text/plain": [
1078
- "'https://huggingface.co/kimbochen/whisper-small-jp/commit/05da956fdc97e7c01112f45c20e56c8f6a127502'"
1079
- ]
1080
- },
1081
- "execution_count": 31,
1082
- "metadata": {},
1083
- "output_type": "execute_result"
1084
- }
1085
- ],
1086
  "source": [
1087
  "trainer.push_to_hub(**kwargs)"
1088
  ]
1089
- },
1090
- {
1091
- "cell_type": "code",
1092
- "execution_count": 28,
1093
- "id": "4df1603c-ef35-40f1-ae57-3214441073c8",
1094
- "metadata": {},
1095
- "outputs": [
1096
- {
1097
- "name": "stderr",
1098
- "output_type": "stream",
1099
- "text": [
1100
- "PyTorch: setting up devices\n"
1101
- ]
1102
- }
1103
- ],
1104
- "source": [
1105
- "training_args = Seq2SeqTrainingArguments(\n",
1106
- " output_dir=\"./\",\n",
1107
- " per_device_train_batch_size=64,\n",
1108
- " gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size\n",
1109
- " learning_rate=1e-5,\n",
1110
- " max_steps=1000,\n",
1111
- " num_train_epochs=-1,\n",
1112
- " gradient_checkpointing=True,\n",
1113
- " fp16=True,\n",
1114
- " evaluation_strategy=\"steps\",\n",
1115
- " per_device_eval_batch_size=8,\n",
1116
- " predict_with_generate=True,\n",
1117
- " generation_max_length=225,\n",
1118
- " save_steps=1000,\n",
1119
- " eval_steps=1000,\n",
1120
- " logging_steps=25,\n",
1121
- " report_to=[\"tensorboard\"],\n",
1122
- " load_best_model_at_end=True,\n",
1123
- " metric_for_best_model=\"wer\",\n",
1124
- " greater_is_better=False,\n",
1125
- " push_to_hub=True,\n",
1126
- ")"
1127
- ]
1128
- },
1129
- {
1130
- "cell_type": "code",
1131
- "execution_count": 29,
1132
- "id": "afc2b554-7171-48c7-95aa-b7e61b70ab20",
1133
- "metadata": {},
1134
- "outputs": [
1135
- {
1136
- "name": "stderr",
1137
- "output_type": "stream",
1138
- "text": [
1139
- "/home/ubuntu/whisper-small-jp/./ is already a clone of https://huggingface.co/kimbochen/whisper-small-jp. Make sure you pull the latest changes with `repo.git_pull()`.\n",
1140
- "max_steps is given, it will override any value given in num_train_epochs\n",
1141
- "Using cuda_amp half precision backend\n"
1142
- ]
1143
- }
1144
- ],
1145
- "source": [
1146
- "trainer = Seq2SeqTrainer(\n",
1147
- " args=training_args,\n",
1148
- " model=model,\n",
1149
- " train_dataset=vectorized_datasets[\"train\"],\n",
1150
- " eval_dataset=vectorized_datasets[\"test\"],\n",
1151
- " data_collator=data_collator,\n",
1152
- " compute_metrics=compute_metrics,\n",
1153
- " tokenizer=processor,\n",
1154
- " callbacks=[ShuffleCallback()],\n",
1155
- ")"
1156
- ]
1157
- },
1158
- {
1159
- "cell_type": "code",
1160
- "execution_count": 30,
1161
- "id": "b029a1d8-24de-46e7-b067-0f900b1db342",
1162
- "metadata": {},
1163
- "outputs": [
1164
- {
1165
- "name": "stderr",
1166
- "output_type": "stream",
1167
- "text": [
1168
- "Loading model from checkpoint-4000.\n",
1169
- "/home/ubuntu/.venv/lib/python3.8/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
1170
- " warnings.warn(\n",
1171
- "***** Running training *****\n",
1172
- " Num examples = 64000\n",
1173
- " Num Epochs = 9223372036854775807\n",
1174
- " Instantaneous batch size per device = 64\n",
1175
- " Total train batch size (w. parallel, distributed & accumulation) = 64\n",
1176
- " Gradient Accumulation steps = 1\n",
1177
- " Total optimization steps = 1000\n",
1178
- " Number of trainable parameters = 241734912\n",
1179
- " Continuing training from checkpoint, will skip to saved global_step\n",
1180
- " Continuing training from epoch 4\n",
1181
- " Continuing training from global step 4000\n",
1182
- " Will skip the first 4 epochs then the first 0 batches in the first epoch. If this takes a lot of time, you can add the `--ignore_data_skip` flag to your launch command, but you will resume the training on data already seen by your model.\n"
1183
- ]
1184
- },
1185
- {
1186
- "data": {
1187
- "application/vnd.jupyter.widget-view+json": {
1188
- "model_id": "01337298313740d98d3cc75b6d5e3ff7",
1189
- "version_major": 2,
1190
- "version_minor": 0
1191
- },
1192
- "text/plain": [
1193
- "0it [00:00, ?it/s]"
1194
- ]
1195
- },
1196
- "metadata": {},
1197
- "output_type": "display_data"
1198
- },
1199
- {
1200
- "name": "stderr",
1201
- "output_type": "stream",
1202
- "text": [
1203
- "\n",
1204
- "Reading metadata...: 0it [00:00, ?it/s]\u001b[A\n",
1205
- "Reading metadata...: 6505it [00:00, 34246.80it/s]\n",
1206
- "The following columns in the training set don't have a corresponding argument in `WhisperForConditionalGeneration.forward` and have been ignored: input_length. If input_length are not expected by `WhisperForConditionalGeneration.forward`, you can safely ignore this message.\n",
1207
- "\n",
1208
- "Reading metadata...: 6505it [00:00, 84823.64it/s]\n",
1209
- "\n",
1210
- "Reading metadata...: 6505it [00:00, 88617.62it/s]\n",
1211
- "\n",
1212
- "Reading metadata...: 6505it [00:00, 90289.78it/s]\n",
1213
- "\n",
1214
- "Reading metadata...: 6505it [00:00, 91816.92it/s]\n"
1215
- ]
1216
- },
1217
- {
1218
- "data": {
1219
- "text/html": [
1220
- "\n",
1221
- " <div>\n",
1222
- " \n",
1223
- " <progress value='4001' max='1000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1224
- " [1000/1000 00:00, Epoch 4/9223372036854775807]\n",
1225
- " </div>\n",
1226
- " <table border=\"1\" class=\"dataframe\">\n",
1227
- " <thead>\n",
1228
- " <tr style=\"text-align: left;\">\n",
1229
- " <th>Step</th>\n",
1230
- " <th>Training Loss</th>\n",
1231
- " <th>Validation Loss</th>\n",
1232
- " </tr>\n",
1233
- " </thead>\n",
1234
- " <tbody>\n",
1235
- " </tbody>\n",
1236
- "</table><p>"
1237
- ],
1238
- "text/plain": [
1239
- "<IPython.core.display.HTML object>"
1240
- ]
1241
- },
1242
- "metadata": {},
1243
- "output_type": "display_data"
1244
- },
1245
- {
1246
- "name": "stderr",
1247
- "output_type": "stream",
1248
- "text": [
1249
- "\n",
1250
- "\n",
1251
- "Training completed. Do not forget to share your model on huggingface.co/models =)\n",
1252
- "\n",
1253
- "\n",
1254
- "Loading best model from ./checkpoint-4000 (score: 88.31039863810469).\n"
1255
- ]
1256
- },
1257
- {
1258
- "data": {
1259
- "text/plain": [
1260
- "TrainOutput(global_step=4001, training_loss=8.343380785802548e-08, metrics={'train_runtime': 169.0541, 'train_samples_per_second': 378.577, 'train_steps_per_second': 5.915, 'total_flos': 7.363747084345344e+19, 'train_loss': 8.343380785802548e-08, 'epoch': 4.0})"
1261
- ]
1262
- },
1263
- "execution_count": 30,
1264
- "metadata": {},
1265
- "output_type": "execute_result"
1266
- }
1267
- ],
1268
- "source": [
1269
- "trainer.train(\"checkpoint-4000\")"
1270
- ]
1271
  }
1272
  ],
1273
  "metadata": {
 
19
  {
20
  "cell_type": "markdown",
21
  "id": "afe0d503-ae4e-4aa7-9af4-dbcba52db41e",
22
+ "metadata": {
23
+ "jp-MarkdownHeadingCollapsed": true,
24
+ "tags": []
25
+ },
26
  "source": [
27
  "## Introduction"
28
  ]
 
111
  },
112
  {
113
  "cell_type": "code",
114
+ "execution_count": 1,
115
  "id": "065a8cf7-e54f-4ac3-900e-609c80714fca",
116
  "metadata": {},
117
+ "outputs": [
118
+ {
119
+ "name": "stderr",
120
+ "output_type": "stream",
121
+ "text": [
122
+ "/home/ubuntu/.venv/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
123
+ " from .autonotebook import tqdm as notebook_tqdm\n"
124
+ ]
125
+ }
126
+ ],
127
  "source": [
128
  "from datasets import interleave_datasets, load_dataset\n",
129
  "\n",
 
154
  },
155
  {
156
  "cell_type": "code",
157
+ "execution_count": 2,
158
  "id": "a2787582-554f-44ce-9f38-4180a5ed6b44",
159
  "metadata": {},
160
  "outputs": [],
 
197
  },
198
  {
199
  "cell_type": "code",
200
+ "execution_count": 3,
201
  "id": "77d9f0c5-8607-4642-a8ac-c3ab2e223ea6",
202
  "metadata": {},
203
  "outputs": [],
 
225
  },
226
  {
227
  "cell_type": "code",
228
+ "execution_count": 4,
229
  "id": "ab5a13b4-9bd4-4aa0-aef2-b3de9b762988",
230
  "metadata": {},
231
  "outputs": [
 
245
  " 'segment': Value(dtype='string', id=None)}"
246
  ]
247
  },
248
+ "execution_count": 4,
249
  "metadata": {},
250
  "output_type": "execute_result"
251
  }
 
271
  },
272
  {
273
  "cell_type": "code",
274
+ "execution_count": 5,
275
  "id": "3ab6a724-3d1e-478b-a9e9-d2f85feb6c39",
276
  "metadata": {},
277
  "outputs": [],
 
291
  },
292
  {
293
  "cell_type": "code",
294
+ "execution_count": 6,
295
  "id": "d041650e-1c48-4439-87b3-5b6f4a514107",
296
  "metadata": {},
297
  "outputs": [],
 
318
  },
319
  {
320
  "cell_type": "code",
321
+ "execution_count": 7,
322
+ "id": "ce788f4e-1270-424d-b1f3-a10d984ddb31",
323
+ "metadata": {},
324
+ "outputs": [],
325
+ "source": [
326
+ "from fugashi import Tagger\n",
327
+ "tagger = Tagger('-Owakati')"
328
+ ]
329
+ },
330
+ {
331
+ "cell_type": "code",
332
+ "execution_count": 8,
333
+ "id": "c858c814-6d32-472e-afe7-2f7273b244ba",
334
+ "metadata": {},
335
+ "outputs": [],
336
+ "source": [
337
+ "FULL2HALF = dict((i + 0xFEE0, i) for i in range(0x21, 0x7F))\n",
338
+ "FULL2HALF[0x3000] = 0x20"
339
+ ]
340
+ },
341
+ {
342
+ "cell_type": "code",
343
+ "execution_count": 9,
344
  "id": "c085911c-a10a-41ef-8874-306e0503e9bb",
345
  "metadata": {},
346
  "outputs": [],
 
360
  " transcription = transcription.lower()\n",
361
  " if do_remove_punctuation:\n",
362
  " transcription = normalizer(transcription).strip()\n",
363
+ "\n",
364
+ " input_str = transcription.translate(FULL2HALF)\n",
365
+ " input_str = tagger.parse(input_str)\n",
366
+ "\n",
367
  " # encode target text to label ids\n",
368
+ " batch[\"labels\"] = processor.tokenizer(input_str).input_ids\n",
369
  " return batch"
370
  ]
371
  },
 
379
  },
380
  {
381
  "cell_type": "code",
382
+ "execution_count": 10,
383
  "id": "a37a7cdb-9013-427f-8de9-6a8d0e9dc684",
384
  "metadata": {},
385
  "outputs": [],
 
397
  },
398
  {
399
  "cell_type": "code",
400
+ "execution_count": 11,
401
  "id": "1b145699-acfc-4b1d-93a2-a2ad3d62674c",
402
  "metadata": {},
403
  "outputs": [],
 
418
  },
419
  {
420
  "cell_type": "code",
421
+ "execution_count": 12,
422
  "id": "01cb25ef-4bb0-4325-9461-f59198acadf6",
423
  "metadata": {},
424
  "outputs": [],
 
439
  },
440
  {
441
  "cell_type": "code",
442
+ "execution_count": 13,
443
  "id": "333f7f6e-6053-4d3b-8924-c733c79b82ac",
444
  "metadata": {},
445
  "outputs": [],
 
450
  ")"
451
  ]
452
  },
453
+ {
454
+ "cell_type": "code",
455
+ "execution_count": 14,
456
+ "id": "bede1184",
457
+ "metadata": {
458
+ "jupyter": {
459
+ "outputs_hidden": true
460
+ },
461
+ "scrolled": true,
462
+ "tags": []
463
+ },
464
+ "outputs": [
465
+ {
466
+ "name": "stderr",
467
+ "output_type": "stream",
468
+ "text": [
469
+ "Reading metadata...: 6505it [00:00, 79889.28it/s]\n",
470
+ "Reading metadata...: 4485it [00:00, 81713.25it/s]\n"
471
+ ]
472
+ },
473
+ {
474
+ "data": {
475
+ "text/plain": [
476
+ "[50258, 50266, 50359, 50363, 6392, 11046, 26923, 2605, 116, 16746]"
477
+ ]
478
+ },
479
+ "execution_count": 14,
480
+ "metadata": {},
481
+ "output_type": "execute_result"
482
+ }
483
+ ],
484
+ "source": [
485
+ "xb = next(iter(vectorized_datasets['train']))\n",
486
+ "xb['labels'][:10]"
487
+ ]
488
+ },
489
+ {
490
+ "cell_type": "code",
491
+ "execution_count": 15,
492
+ "id": "ac1e8d5b",
493
+ "metadata": {},
494
+ "outputs": [
495
+ {
496
+ "data": {
497
+ "text/plain": [
498
+ "'<|startoftranscript|><|ja|><|transcribe|><|notimestamps|>多から 一 へ と いう の は 、 世界 を 因果 的 に 決定 論 的 に 考える こと で ある 、 過去 から 考える こと で ある 、 機械 的 に 考える こと で ある 。<|endoftext|>'"
499
+ ]
500
+ },
501
+ "execution_count": 15,
502
+ "metadata": {},
503
+ "output_type": "execute_result"
504
+ }
505
+ ],
506
+ "source": [
507
+ "processor.tokenizer.decode(xb['labels'])"
508
+ ]
509
+ },
510
  {
511
  "cell_type": "markdown",
512
  "id": "263a5a58-0239-4a25-b0df-c625fc9c5810",
 
644
  "execution_count": 18,
645
  "id": "b22b4011-f31f-4b57-b684-c52332f92890",
646
  "metadata": {},
647
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
648
  "source": [
649
  "import evaluate\n",
650
  "\n",
 
723
  "execution_count": 20,
724
  "id": "5a10cc4b-07ec-4ebd-ac1d-7c601023594f",
725
  "metadata": {},
726
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
727
  "source": [
728
  "from transformers import WhisperForConditionalGeneration\n",
729
  "\n",
 
758
  "### Define the Training Configuration"
759
  ]
760
  },
761
+ {
762
+ "cell_type": "markdown",
763
+ "id": "393c883e-3e50-492c-bd58-f51dbf15ee56",
764
+ "metadata": {},
765
+ "source": [
766
+ "We then define a custom [Callback](https://huggingface.co/docs/transformers/main_classes/callback) that is called by the 🤗 Trainer on the end of each epoch. The Callback reinitialises and reshuffles the streaming dataset at the beginning of each new epoch - this gives different shuffling across our subsets for every epoch."
767
+ ]
768
+ },
769
+ {
770
+ "cell_type": "code",
771
+ "execution_count": 22,
772
+ "id": "3ac16b62-b3c0-4c68-8f3d-9ecf471534b2",
773
+ "metadata": {},
774
+ "outputs": [],
775
+ "source": [
776
+ "from transformers import TrainerCallback\n",
777
+ "from transformers.trainer_pt_utils import IterableDatasetShard\n",
778
+ "from torch.utils.data import IterableDataset\n",
779
+ "\n",
780
+ "# trainer callback to reinitialise and reshuffle the streamable datasets at the beginning of each epoch\n",
781
+ "class ShuffleCallback(TrainerCallback):\n",
782
+ " def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):\n",
783
+ " if isinstance(train_dataloader.dataset, IterableDatasetShard):\n",
784
+ " pass # set_epoch() is handled by the Trainer\n",
785
+ " elif isinstance(train_dataloader.dataset, IterableDataset):\n",
786
+ " train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)"
787
+ ]
788
+ },
789
  {
790
  "cell_type": "markdown",
791
  "id": "c21af1e9-0188-4134-ac82-defc7bdcc436",
 
796
  },
797
  {
798
  "cell_type": "code",
799
+ "execution_count": 23,
800
  "id": "0ae3e9af-97b7-4aa0-ae85-20b23b5bcb3a",
801
  "metadata": {},
802
  "outputs": [],
 
808
  " per_device_train_batch_size=64,\n",
809
  " gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size\n",
810
  " learning_rate=1e-5,\n",
811
+ " warmup_steps=500,\n",
812
  " max_steps=1000,\n",
813
  " gradient_checkpointing=True,\n",
814
  " fp16=True,\n",
815
  " evaluation_strategy=\"steps\",\n",
816
+ " per_device_eval_batch_size=32,\n",
817
  " predict_with_generate=True,\n",
818
  " generation_max_length=225,\n",
819
  " save_steps=200,\n",
820
+ " eval_steps=100,\n",
821
  " logging_steps=25,\n",
822
  " report_to=[\"tensorboard\"],\n",
823
  " load_best_model_at_end=True,\n",
 
836
  "set `push_to_hub=False`."
837
  ]
838
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
839
  {
840
  "cell_type": "markdown",
841
  "id": "bac29114-d226-4f54-97cf-8718c9f94e1e",
 
934
  "cell_type": "code",
935
  "execution_count": null,
936
  "id": "ee8b7b8e-1c9a-4d77-9137-1778a629e6de",
937
+ "metadata": {
938
+ "scrolled": false
939
+ },
940
  "outputs": [
941
  {
942
  "name": "stderr",
 
952
  " Gradient Accumulation steps = 1\n",
953
  " Total optimization steps = 1000\n",
954
  " Number of trainable parameters = 241734912\n",
955
+ "Reading metadata...: 6505it [00:00, 33561.53it/s]\n",
956
+ "Reading metadata...: 4485it [00:00, 24503.43it/s]\n",
957
  "The following columns in the training set don't have a corresponding argument in `WhisperForConditionalGeneration.forward` and have been ignored: input_length. If input_length are not expected by `WhisperForConditionalGeneration.forward`, you can safely ignore this message.\n"
958
  ]
959
  },
 
963
  "\n",
964
  " <div>\n",
965
  " \n",
966
+ " <progress value='101' max='1000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
967
+ " [ 101/1000 11:45 < 1:46:47, 0.14 it/s, Epoch 0.10/9223372036854775807]\n",
968
  " </div>\n",
969
  " <table border=\"1\" class=\"dataframe\">\n",
970
  " <thead>\n",
 
989
  "name": "stderr",
990
  "output_type": "stream",
991
  "text": [
 
 
992
  "***** Running Evaluation *****\n",
993
  " Num examples: Unknown\n",
994
+ " Batch size = 32\n",
995
+ "Reading metadata...: 4604it [00:00, 79017.99it/s]\n",
996
  "The following columns in the evaluation set don't have a corresponding argument in `WhisperForConditionalGeneration.forward` and have been ignored: input_length. If input_length are not expected by `WhisperForConditionalGeneration.forward`, you can safely ignore this message.\n"
997
  ]
998
  }
 
1023
  },
1024
  {
1025
  "cell_type": "code",
1026
+ "execution_count": null,
1027
  "id": "6dd0e310-9b07-4133-ac14-2ed2d7524e22",
1028
  "metadata": {},
1029
  "outputs": [],
 
1049
  },
1050
  {
1051
  "cell_type": "code",
1052
+ "execution_count": null,
1053
  "id": "95737cda-c5dd-4887-a4d0-dfcb0d61d977",
1054
  "metadata": {},
1055
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1056
  "source": [
1057
  "trainer.push_to_hub(**kwargs)"
1058
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1059
  }
1060
  ],
1061
  "metadata": {
config.json CHANGED
@@ -34,7 +34,6 @@
34
  "num_mel_bins": 80,
35
  "pad_token_id": 50257,
36
  "scale_embedding": false,
37
- "suppress_tokens": [],
38
  "torch_dtype": "float32",
39
  "transformers_version": "4.26.0.dev0",
40
  "use_cache": false,
 
34
  "num_mel_bins": 80,
35
  "pad_token_id": 50257,
36
  "scale_embedding": false,
 
37
  "torch_dtype": "float32",
38
  "transformers_version": "4.26.0.dev0",
39
  "use_cache": false,
fine-tune-whisper-streaming.ipynb CHANGED
@@ -19,7 +19,10 @@
19
  {
20
  "cell_type": "markdown",
21
  "id": "afe0d503-ae4e-4aa7-9af4-dbcba52db41e",
22
- "metadata": {},
 
 
 
23
  "source": [
24
  "## Introduction"
25
  ]
@@ -108,10 +111,19 @@
108
  },
109
  {
110
  "cell_type": "code",
111
- "execution_count": 5,
112
  "id": "065a8cf7-e54f-4ac3-900e-609c80714fca",
113
  "metadata": {},
114
- "outputs": [],
 
 
 
 
 
 
 
 
 
115
  "source": [
116
  "from datasets import interleave_datasets, load_dataset\n",
117
  "\n",
@@ -142,7 +154,7 @@
142
  },
143
  {
144
  "cell_type": "code",
145
- "execution_count": 6,
146
  "id": "a2787582-554f-44ce-9f38-4180a5ed6b44",
147
  "metadata": {},
148
  "outputs": [],
@@ -185,7 +197,7 @@
185
  },
186
  {
187
  "cell_type": "code",
188
- "execution_count": 7,
189
  "id": "77d9f0c5-8607-4642-a8ac-c3ab2e223ea6",
190
  "metadata": {},
191
  "outputs": [],
@@ -213,7 +225,7 @@
213
  },
214
  {
215
  "cell_type": "code",
216
- "execution_count": 8,
217
  "id": "ab5a13b4-9bd4-4aa0-aef2-b3de9b762988",
218
  "metadata": {},
219
  "outputs": [
@@ -233,7 +245,7 @@
233
  " 'segment': Value(dtype='string', id=None)}"
234
  ]
235
  },
236
- "execution_count": 8,
237
  "metadata": {},
238
  "output_type": "execute_result"
239
  }
@@ -259,7 +271,7 @@
259
  },
260
  {
261
  "cell_type": "code",
262
- "execution_count": 9,
263
  "id": "3ab6a724-3d1e-478b-a9e9-d2f85feb6c39",
264
  "metadata": {},
265
  "outputs": [],
@@ -279,7 +291,7 @@
279
  },
280
  {
281
  "cell_type": "code",
282
- "execution_count": 10,
283
  "id": "d041650e-1c48-4439-87b3-5b6f4a514107",
284
  "metadata": {},
285
  "outputs": [],
@@ -306,7 +318,29 @@
306
  },
307
  {
308
  "cell_type": "code",
309
- "execution_count": 44,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  "id": "c085911c-a10a-41ef-8874-306e0503e9bb",
311
  "metadata": {},
312
  "outputs": [],
@@ -326,10 +360,12 @@
326
  " transcription = transcription.lower()\n",
327
  " if do_remove_punctuation:\n",
328
  " transcription = normalizer(transcription).strip()\n",
329
- " \n",
 
 
 
330
  " # encode target text to label ids\n",
331
- "# batch[\"labels\"] = processor.tokenizer(transcription).input_ids\n",
332
- " batch['labels'] = transcription\n",
333
  " return batch"
334
  ]
335
  },
@@ -343,7 +379,7 @@
343
  },
344
  {
345
  "cell_type": "code",
346
- "execution_count": 45,
347
  "id": "a37a7cdb-9013-427f-8de9-6a8d0e9dc684",
348
  "metadata": {},
349
  "outputs": [],
@@ -361,7 +397,7 @@
361
  },
362
  {
363
  "cell_type": "code",
364
- "execution_count": 46,
365
  "id": "1b145699-acfc-4b1d-93a2-a2ad3d62674c",
366
  "metadata": {},
367
  "outputs": [],
@@ -382,7 +418,7 @@
382
  },
383
  {
384
  "cell_type": "code",
385
- "execution_count": 47,
386
  "id": "01cb25ef-4bb0-4325-9461-f59198acadf6",
387
  "metadata": {},
388
  "outputs": [],
@@ -403,7 +439,7 @@
403
  },
404
  {
405
  "cell_type": "code",
406
- "execution_count": 48,
407
  "id": "333f7f6e-6053-4d3b-8924-c733c79b82ac",
408
  "metadata": {},
409
  "outputs": [],
@@ -416,250 +452,59 @@
416
  },
417
  {
418
  "cell_type": "code",
419
- "execution_count": 49,
420
  "id": "bede1184",
421
- "metadata": {},
 
 
 
 
 
 
422
  "outputs": [
423
  {
424
  "name": "stderr",
425
  "output_type": "stream",
426
  "text": [
427
- "Reading metadata...: 6505it [00:00, 35406.66it/s]\n",
428
- "Reading metadata...: 4485it [00:00, 19930.24it/s]\n"
429
  ]
430
  },
431
  {
432
  "data": {
433
  "text/plain": [
434
- "'多から一へというのは、世界を因果的に決定論的に考えることである、過去から考えることである、機械的に考えること��ある。'"
435
  ]
436
  },
437
- "execution_count": 49,
438
  "metadata": {},
439
  "output_type": "execute_result"
440
  }
441
  ],
442
  "source": [
443
  "xb = next(iter(vectorized_datasets['train']))\n",
444
- "xb['labels']"
445
  ]
446
  },
447
  {
448
  "cell_type": "code",
449
- "execution_count": 59,
450
  "id": "ac1e8d5b",
451
  "metadata": {},
452
- "outputs": [
453
- {
454
- "name": "stdout",
455
- "output_type": "stream",
456
- "text": [
457
- "<|startoftranscript|>\n",
458
- "<|ja|>\n",
459
- "<|transcribe|>\n",
460
- "<|notimestamps|>\n",
461
- "多\n",
462
- "から\n",
463
- "一\n",
464
- "へ\n",
465
- "という\n",
466
- "のは\n",
467
- "、\n",
468
- "世界\n",
469
- "を\n",
470
- "因\n",
471
- "果\n",
472
- "的\n",
473
- "に\n",
474
- "決\n",
475
- "定\n",
476
- "論\n",
477
- "的\n",
478
- "に\n",
479
- "考\n",
480
- "える\n",
481
- "こと\n",
482
- "で\n",
483
- "ある\n",
484
- "、\n",
485
- "過去\n",
486
- "から\n",
487
- "考\n",
488
- "える\n",
489
- "こと\n",
490
- "で\n",
491
- "ある\n",
492
- "、\n",
493
- "機\n",
494
- "�\n",
495
- "�\n",
496
- "的\n",
497
- "に\n",
498
- "考\n",
499
- "える\n",
500
- "こと\n",
501
- "で\n",
502
- "ある\n",
503
- "。\n",
504
- "<|endoftext|>\n"
505
- ]
506
- }
507
- ],
508
- "source": [
509
- "idxs = processor.tokenizer(xb['labels']).input_ids\n",
510
- "for idx in idxs:\n",
511
- " print(processor.tokenizer.decode(idx))"
512
- ]
513
- },
514
- {
515
- "cell_type": "code",
516
- "execution_count": 60,
517
- "id": "d33cefc4",
518
- "metadata": {},
519
- "outputs": [
520
- {
521
- "data": {
522
- "text/plain": [
523
- "[多から,\n",
524
- " 一,\n",
525
- " へ,\n",
526
- " と,\n",
527
- " いう,\n",
528
- " の,\n",
529
- " は,\n",
530
- " 、,\n",
531
- " 世界,\n",
532
- " を,\n",
533
- " 因果,\n",
534
- " 的,\n",
535
- " に,\n",
536
- " 決定,\n",
537
- " 論,\n",
538
- " 的,\n",
539
- " に,\n",
540
- " 考える,\n",
541
- " こと,\n",
542
- " で,\n",
543
- " ある,\n",
544
- " 、,\n",
545
- " 過去,\n",
546
- " から,\n",
547
- " 考える,\n",
548
- " こと,\n",
549
- " で,\n",
550
- " ある,\n",
551
- " 、,\n",
552
- " 機械,\n",
553
- " 的,\n",
554
- " に,\n",
555
- " 考える,\n",
556
- " こと,\n",
557
- " で,\n",
558
- " ある,\n",
559
- " 。]"
560
- ]
561
- },
562
- "execution_count": 60,
563
- "metadata": {},
564
- "output_type": "execute_result"
565
- }
566
- ],
567
- "source": [
568
- "tagger(xb['labels'])"
569
- ]
570
- },
571
- {
572
- "cell_type": "code",
573
- "execution_count": 55,
574
- "id": "2cbb82ef",
575
- "metadata": {},
576
- "outputs": [
577
- {
578
- "name": "stdout",
579
- "output_type": "stream",
580
- "text": [
581
- "Help on method decode in module transformers.tokenization_utils_base:\n",
582
- "\n",
583
- "decode(token_ids: Union[int, List[int], ForwardRef('np.ndarray'), ForwardRef('torch.Tensor'), ForwardRef('tf.Tensor')], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True, **kwargs) -> str method of transformers.models.whisper.tokenization_whisper.WhisperTokenizer instance\n",
584
- " Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special\n",
585
- " tokens and clean up tokenization spaces.\n",
586
- " \n",
587
- " Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`.\n",
588
- " \n",
589
- " Args:\n",
590
- " token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):\n",
591
- " List of tokenized input ids. Can be obtained using the `__call__` method.\n",
592
- " skip_special_tokens (`bool`, *optional*, defaults to `False`):\n",
593
- " Whether or not to remove special tokens in the decoding.\n",
594
- " clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):\n",
595
- " Whether or not to clean up the tokenization spaces.\n",
596
- " kwargs (additional keyword arguments, *optional*):\n",
597
- " Will be passed to the underlying model specific decode method.\n",
598
- " \n",
599
- " Returns:\n",
600
- " `str`: The decoded sentence.\n",
601
- "\n"
602
- ]
603
- }
604
- ],
605
- "source": [
606
- "help(processor.tokenizer.decode)"
607
- ]
608
- },
609
- {
610
- "cell_type": "code",
611
- "execution_count": 41,
612
- "id": "b4b9bbfc",
613
- "metadata": {},
614
  "outputs": [
615
  {
616
  "data": {
617
  "text/plain": [
618
- "' 菓子 は 、 主材 日本 菓子 '"
619
  ]
620
  },
621
- "execution_count": 41,
622
  "metadata": {},
623
  "output_type": "execute_result"
624
  }
625
  ],
626
  "source": [
627
- "from fugashi import Tagger\n",
628
- "\n",
629
- "tagger = Tagger('-Owakati')\n",
630
- "text = \"麩菓子は、麩を主材料とした日本の菓子。\"\n",
631
- "tagger.parse(text)"
632
- ]
633
- },
634
- {
635
- "cell_type": "code",
636
- "execution_count": 43,
637
- "id": "833ca62d",
638
- "metadata": {},
639
- "outputs": [
640
- {
641
- "data": {
642
- "text/plain": [
643
- "[麩, 菓子, は, 、, 麩, を, 主材, 料, と, し, た, 日本, の, 菓子, 。]"
644
- ]
645
- },
646
- "execution_count": 43,
647
- "metadata": {},
648
- "output_type": "execute_result"
649
- }
650
- ],
651
- "source": [
652
- "tagger(text)"
653
- ]
654
- },
655
- {
656
- "cell_type": "code",
657
- "execution_count": null,
658
- "id": "7b7854d6",
659
- "metadata": {},
660
- "outputs": [],
661
- "source": [
662
- "raw_datasets['']"
663
  ]
664
  },
665
  {
@@ -799,22 +644,7 @@
799
  "execution_count": 18,
800
  "id": "b22b4011-f31f-4b57-b684-c52332f92890",
801
  "metadata": {},
802
- "outputs": [
803
- {
804
- "data": {
805
- "application/vnd.jupyter.widget-view+json": {
806
- "model_id": "bffdd7b1fed44295954d9eed41a9cfd5",
807
- "version_major": 2,
808
- "version_minor": 0
809
- },
810
- "text/plain": [
811
- "Downloading builder script: 0%| | 0.00/4.49k [00:00<?, ?B/s]"
812
- ]
813
- },
814
- "metadata": {},
815
- "output_type": "display_data"
816
- }
817
- ],
818
  "source": [
819
  "import evaluate\n",
820
  "\n",
@@ -893,36 +723,7 @@
893
  "execution_count": 20,
894
  "id": "5a10cc4b-07ec-4ebd-ac1d-7c601023594f",
895
  "metadata": {},
896
- "outputs": [
897
- {
898
- "data": {
899
- "application/vnd.jupyter.widget-view+json": {
900
- "model_id": "48fee2fd3b2a4a67b3a35666fda4dfe9",
901
- "version_major": 2,
902
- "version_minor": 0
903
- },
904
- "text/plain": [
905
- "Downloading: 0%| | 0.00/1.97k [00:00<?, ?B/s]"
906
- ]
907
- },
908
- "metadata": {},
909
- "output_type": "display_data"
910
- },
911
- {
912
- "data": {
913
- "application/vnd.jupyter.widget-view+json": {
914
- "model_id": "51cdba284e8f44318868fbd013970280",
915
- "version_major": 2,
916
- "version_minor": 0
917
- },
918
- "text/plain": [
919
- "Downloading: 0%| | 0.00/967M [00:00<?, ?B/s]"
920
- ]
921
- },
922
- "metadata": {},
923
- "output_type": "display_data"
924
- }
925
- ],
926
  "source": [
927
  "from transformers import WhisperForConditionalGeneration\n",
928
  "\n",
@@ -957,6 +758,34 @@
957
  "### Define the Training Configuration"
958
  ]
959
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
960
  {
961
  "cell_type": "markdown",
962
  "id": "c21af1e9-0188-4134-ac82-defc7bdcc436",
@@ -967,7 +796,7 @@
967
  },
968
  {
969
  "cell_type": "code",
970
- "execution_count": 22,
971
  "id": "0ae3e9af-97b7-4aa0-ae85-20b23b5bcb3a",
972
  "metadata": {},
973
  "outputs": [],
@@ -979,16 +808,16 @@
979
  " per_device_train_batch_size=64,\n",
980
  " gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size\n",
981
  " learning_rate=1e-5,\n",
982
- " warmup_steps=200,\n",
983
  " max_steps=1000,\n",
984
  " gradient_checkpointing=True,\n",
985
  " fp16=True,\n",
986
  " evaluation_strategy=\"steps\",\n",
987
- " per_device_eval_batch_size=8,\n",
988
  " predict_with_generate=True,\n",
989
  " generation_max_length=225,\n",
990
  " save_steps=200,\n",
991
- " eval_steps=200,\n",
992
  " logging_steps=25,\n",
993
  " report_to=[\"tensorboard\"],\n",
994
  " load_best_model_at_end=True,\n",
@@ -1007,34 +836,6 @@
1007
  "set `push_to_hub=False`."
1008
  ]
1009
  },
1010
- {
1011
- "cell_type": "markdown",
1012
- "id": "393c883e-3e50-492c-bd58-f51dbf15ee56",
1013
- "metadata": {},
1014
- "source": [
1015
- "We then define a custom [Callback](https://huggingface.co/docs/transformers/main_classes/callback) that is called by the 🤗 Trainer on the end of each epoch. The Callback reinitialises and reshuffles the streaming dataset at the beginning of each new epoch - this gives different shuffling across our subsets for every epoch."
1016
- ]
1017
- },
1018
- {
1019
- "cell_type": "code",
1020
- "execution_count": 23,
1021
- "id": "3ac16b62-b3c0-4c68-8f3d-9ecf471534b2",
1022
- "metadata": {},
1023
- "outputs": [],
1024
- "source": [
1025
- "from transformers import TrainerCallback\n",
1026
- "from transformers.trainer_pt_utils import IterableDatasetShard\n",
1027
- "from torch.utils.data import IterableDataset\n",
1028
- "\n",
1029
- "# trainer callback to reinitialise and reshuffle the streamable datasets at the beginning of each epoch\n",
1030
- "class ShuffleCallback(TrainerCallback):\n",
1031
- " def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):\n",
1032
- " if isinstance(train_dataloader.dataset, IterableDatasetShard):\n",
1033
- " pass # set_epoch() is handled by the Trainer\n",
1034
- " elif isinstance(train_dataloader.dataset, IterableDataset):\n",
1035
- " train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)"
1036
- ]
1037
- },
1038
  {
1039
  "cell_type": "markdown",
1040
  "id": "bac29114-d226-4f54-97cf-8718c9f94e1e",
@@ -1131,10 +932,10 @@
1131
  },
1132
  {
1133
  "cell_type": "code",
1134
- "execution_count": 26,
1135
  "id": "ee8b7b8e-1c9a-4d77-9137-1778a629e6de",
1136
  "metadata": {
1137
- "scrolled": true
1138
  },
1139
  "outputs": [
1140
  {
@@ -1151,8 +952,8 @@
1151
  " Gradient Accumulation steps = 1\n",
1152
  " Total optimization steps = 1000\n",
1153
  " Number of trainable parameters = 241734912\n",
1154
- "Reading metadata...: 6505it [00:00, 31331.40it/s]\n",
1155
- "Reading metadata...: 4485it [00:00, 41376.86it/s]\n",
1156
  "The following columns in the training set don't have a corresponding argument in `WhisperForConditionalGeneration.forward` and have been ignored: input_length. If input_length are not expected by `WhisperForConditionalGeneration.forward`, you can safely ignore this message.\n"
1157
  ]
1158
  },
@@ -1162,8 +963,8 @@
1162
  "\n",
1163
  " <div>\n",
1164
  " \n",
1165
- " <progress value='1001' max='1000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1166
- " [1000/1000 3:35:08, Epoch 7.01/9223372036854775807]\n",
1167
  " </div>\n",
1168
  " <table border=\"1\" class=\"dataframe\">\n",
1169
  " <thead>\n",
@@ -1171,34 +972,9 @@
1171
  " <th>Step</th>\n",
1172
  " <th>Training Loss</th>\n",
1173
  " <th>Validation Loss</th>\n",
1174
- " <th>Wer</th>\n",
1175
  " </tr>\n",
1176
  " </thead>\n",
1177
  " <tbody>\n",
1178
- " <tr>\n",
1179
- " <td>200</td>\n",
1180
- " <td>0.220800</td>\n",
1181
- " <td>0.278119</td>\n",
1182
- " <td>81.117889</td>\n",
1183
- " </tr>\n",
1184
- " <tr>\n",
1185
- " <td>400</td>\n",
1186
- " <td>0.136700</td>\n",
1187
- " <td>0.269168</td>\n",
1188
- " <td>73.102568</td>\n",
1189
- " </tr>\n",
1190
- " <tr>\n",
1191
- " <td>600</td>\n",
1192
- " <td>0.033800</td>\n",
1193
- " <td>0.278346</td>\n",
1194
- " <td>70.960420</td>\n",
1195
- " </tr>\n",
1196
- " <tr>\n",
1197
- " <td>800</td>\n",
1198
- " <td>0.026300</td>\n",
1199
- " <td>0.298785</td>\n",
1200
- " <td>74.734005</td>\n",
1201
- " </tr>\n",
1202
  " </tbody>\n",
1203
  "</table><p>"
1204
  ],
@@ -1213,117 +989,12 @@
1213
  "name": "stderr",
1214
  "output_type": "stream",
1215
  "text": [
1216
- "Reading metadata...: 6505it [00:00, 64162.65it/s]\n",
1217
- "Reading metadata...: 4485it [00:00, 27834.06it/s]\n",
1218
  "***** Running Evaluation *****\n",
1219
  " Num examples: Unknown\n",
1220
- " Batch size = 8\n",
1221
- "Reading metadata...: 4604it [00:00, 27155.92it/s]\n",
1222
- "The following columns in the evaluation set don't have a corresponding argument in `WhisperForConditionalGeneration.forward` and have been ignored: input_length. If input_length are not expected by `WhisperForConditionalGeneration.forward`, you can safely ignore this message.\n",
1223
- "Saving model checkpoint to ./checkpoint-200\n",
1224
- "Configuration saved in ./checkpoint-200/config.json\n",
1225
- "Model weights saved in ./checkpoint-200/pytorch_model.bin\n",
1226
- "Feature extractor saved in ./checkpoint-200/preprocessor_config.json\n",
1227
- "tokenizer config file saved in ./checkpoint-200/tokenizer_config.json\n",
1228
- "Special tokens file saved in ./checkpoint-200/special_tokens_map.json\n",
1229
- "added tokens file saved in ./checkpoint-200/added_tokens.json\n",
1230
- "Feature extractor saved in ./preprocessor_config.json\n",
1231
- "tokenizer config file saved in ./tokenizer_config.json\n",
1232
- "Special tokens file saved in ./special_tokens_map.json\n",
1233
- "added tokens file saved in ./added_tokens.json\n",
1234
- "Reading metadata...: 6505it [00:00, 44457.32it/s]\n",
1235
- "Reading metadata...: 4485it [00:00, 29197.09it/s]\n",
1236
- "***** Running Evaluation *****\n",
1237
- " Num examples: Unknown\n",
1238
- " Batch size = 8\n",
1239
- "Reading metadata...: 4604it [00:00, 34447.62it/s]\n",
1240
- "The following columns in the evaluation set don't have a corresponding argument in `WhisperForConditionalGeneration.forward` and have been ignored: input_length. If input_length are not expected by `WhisperForConditionalGeneration.forward`, you can safely ignore this message.\n",
1241
- "Saving model checkpoint to ./checkpoint-400\n",
1242
- "Configuration saved in ./checkpoint-400/config.json\n",
1243
- "Model weights saved in ./checkpoint-400/pytorch_model.bin\n",
1244
- "Feature extractor saved in ./checkpoint-400/preprocessor_config.json\n",
1245
- "tokenizer config file saved in ./checkpoint-400/tokenizer_config.json\n",
1246
- "Special tokens file saved in ./checkpoint-400/special_tokens_map.json\n",
1247
- "added tokens file saved in ./checkpoint-400/added_tokens.json\n",
1248
- "Feature extractor saved in ./preprocessor_config.json\n",
1249
- "tokenizer config file saved in ./tokenizer_config.json\n",
1250
- "Special tokens file saved in ./special_tokens_map.json\n",
1251
- "added tokens file saved in ./added_tokens.json\n",
1252
- "Reading metadata...: 6505it [00:00, 33208.71it/s]\n",
1253
- "Reading metadata...: 4485it [00:00, 23213.70it/s]\n",
1254
- "Reading metadata...: 6505it [00:00, 25768.67it/s]\n",
1255
- "Reading metadata...: 4485it [00:00, 27756.07it/s]\n",
1256
- "***** Running Evaluation *****\n",
1257
- " Num examples: Unknown\n",
1258
- " Batch size = 8\n",
1259
- "Reading metadata...: 4604it [00:00, 28855.43it/s]\n",
1260
- "The following columns in the evaluation set don't have a corresponding argument in `WhisperForConditionalGeneration.forward` and have been ignored: input_length. If input_length are not expected by `WhisperForConditionalGeneration.forward`, you can safely ignore this message.\n",
1261
- "Saving model checkpoint to ./checkpoint-600\n",
1262
- "Configuration saved in ./checkpoint-600/config.json\n",
1263
- "Model weights saved in ./checkpoint-600/pytorch_model.bin\n",
1264
- "Feature extractor saved in ./checkpoint-600/preprocessor_config.json\n",
1265
- "tokenizer config file saved in ./checkpoint-600/tokenizer_config.json\n",
1266
- "Special tokens file saved in ./checkpoint-600/special_tokens_map.json\n",
1267
- "added tokens file saved in ./checkpoint-600/added_tokens.json\n",
1268
- "Feature extractor saved in ./preprocessor_config.json\n",
1269
- "tokenizer config file saved in ./tokenizer_config.json\n",
1270
- "Special tokens file saved in ./special_tokens_map.json\n",
1271
- "added tokens file saved in ./added_tokens.json\n",
1272
- "Reading metadata...: 6505it [00:00, 86030.70it/s]\n",
1273
- "Reading metadata...: 4485it [00:00, 68522.65it/s]\n",
1274
- "***** Running Evaluation *****\n",
1275
- " Num examples: Unknown\n",
1276
- " Batch size = 8\n",
1277
- "Reading metadata...: 4604it [00:00, 30988.60it/s]\n",
1278
- "The following columns in the evaluation set don't have a corresponding argument in `WhisperForConditionalGeneration.forward` and have been ignored: input_length. If input_length are not expected by `WhisperForConditionalGeneration.forward`, you can safely ignore this message.\n",
1279
- "Saving model checkpoint to ./checkpoint-800\n",
1280
- "Configuration saved in ./checkpoint-800/config.json\n",
1281
- "Model weights saved in ./checkpoint-800/pytorch_model.bin\n",
1282
- "Feature extractor saved in ./checkpoint-800/preprocessor_config.json\n",
1283
- "tokenizer config file saved in ./checkpoint-800/tokenizer_config.json\n",
1284
- "Special tokens file saved in ./checkpoint-800/special_tokens_map.json\n",
1285
- "added tokens file saved in ./checkpoint-800/added_tokens.json\n",
1286
- "Feature extractor saved in ./preprocessor_config.json\n",
1287
- "tokenizer config file saved in ./tokenizer_config.json\n",
1288
- "Special tokens file saved in ./special_tokens_map.json\n",
1289
- "added tokens file saved in ./added_tokens.json\n",
1290
- "Reading metadata...: 6505it [00:00, 36357.17it/s]\n",
1291
- "Reading metadata...: 4485it [00:00, 30574.75it/s]\n",
1292
- "Got disconnected from remote data host. Retrying in 5sec [1/20]\n",
1293
- "Got disconnected from remote data host. Retrying in 5sec [2/20]\n",
1294
- "Reading metadata...: 6505it [00:00, 31147.16it/s]\n",
1295
- "Reading metadata...: 4485it [00:00, 22808.34it/s]\n",
1296
- "***** Running Evaluation *****\n",
1297
- " Num examples: Unknown\n",
1298
- " Batch size = 8\n",
1299
- "Reading metadata...: 4604it [00:00, 28132.71it/s]\n",
1300
  "The following columns in the evaluation set don't have a corresponding argument in `WhisperForConditionalGeneration.forward` and have been ignored: input_length. If input_length are not expected by `WhisperForConditionalGeneration.forward`, you can safely ignore this message.\n"
1301
  ]
1302
- },
1303
- {
1304
- "ename": "KeyboardInterrupt",
1305
- "evalue": "",
1306
- "output_type": "error",
1307
- "traceback": [
1308
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1309
- "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
1310
- "Cell \u001b[0;32mIn[26], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
1311
- "File \u001b[0;32m~/.venv/lib/python3.8/site-packages/transformers/trainer.py:1535\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel_wrapped \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\n\u001b[1;32m 1532\u001b[0m inner_training_loop \u001b[38;5;241m=\u001b[39m find_executable_batch_size(\n\u001b[1;32m 1533\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_inner_training_loop, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_train_batch_size, args\u001b[38;5;241m.\u001b[39mauto_find_batch_size\n\u001b[1;32m 1534\u001b[0m )\n\u001b[0;32m-> 1535\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1536\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1537\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1538\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1539\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1540\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
1312
- "File \u001b[0;32m~/.venv/lib/python3.8/site-packages/transformers/trainer.py:1860\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 1857\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mepoch \u001b[38;5;241m=\u001b[39m epoch \u001b[38;5;241m+\u001b[39m (step \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m) \u001b[38;5;241m/\u001b[39m steps_in_epoch\n\u001b[1;32m 1858\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_step_end(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n\u001b[0;32m-> 1860\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_maybe_log_save_evaluate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtr_loss\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepoch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1861\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1862\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_substep_end(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n",
1313
- "File \u001b[0;32m~/.venv/lib/python3.8/site-packages/transformers/trainer.py:2123\u001b[0m, in \u001b[0;36mTrainer._maybe_log_save_evaluate\u001b[0;34m(self, tr_loss, model, trial, epoch, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 2117\u001b[0m metrics \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mevaluate(\n\u001b[1;32m 2118\u001b[0m eval_dataset\u001b[38;5;241m=\u001b[39meval_dataset,\n\u001b[1;32m 2119\u001b[0m ignore_keys\u001b[38;5;241m=\u001b[39mignore_keys_for_eval,\n\u001b[1;32m 2120\u001b[0m metric_key_prefix\u001b[38;5;241m=\u001b[39m\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124meval_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00meval_dataset_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 2121\u001b[0m )\n\u001b[1;32m 2122\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 2123\u001b[0m metrics \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mevaluate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mignore_keys\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2124\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_report_to_hp_search(trial, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step, metrics)\n\u001b[1;32m 2126\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol\u001b[38;5;241m.\u001b[39mshould_save:\n",
1314
- "File \u001b[0;32m~/.venv/lib/python3.8/site-packages/transformers/trainer_seq2seq.py:78\u001b[0m, in \u001b[0;36mSeq2SeqTrainer.evaluate\u001b[0;34m(self, eval_dataset, ignore_keys, metric_key_prefix, **gen_kwargs)\u001b[0m\n\u001b[1;32m 73\u001b[0m gen_kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnum_beams\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 74\u001b[0m gen_kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnum_beams\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m gen_kwargs\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnum_beams\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mgeneration_num_beams\n\u001b[1;32m 75\u001b[0m )\n\u001b[1;32m 76\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gen_kwargs \u001b[38;5;241m=\u001b[39m gen_kwargs\n\u001b[0;32m---> 78\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mevaluate\u001b[49m\u001b[43m(\u001b[49m\u001b[43meval_dataset\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mignore_keys\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmetric_key_prefix\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmetric_key_prefix\u001b[49m\u001b[43m)\u001b[49m\n",
1315
- "File \u001b[0;32m~/.venv/lib/python3.8/site-packages/transformers/trainer.py:2819\u001b[0m, in \u001b[0;36mTrainer.evaluate\u001b[0;34m(self, eval_dataset, ignore_keys, metric_key_prefix)\u001b[0m\n\u001b[1;32m 2816\u001b[0m start_time \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[1;32m 2818\u001b[0m eval_loop \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprediction_loop \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39muse_legacy_prediction_loop \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mevaluation_loop\n\u001b[0;32m-> 2819\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43meval_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2820\u001b[0m \u001b[43m \u001b[49m\u001b[43meval_dataloader\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2821\u001b[0m \u001b[43m \u001b[49m\u001b[43mdescription\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mEvaluation\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2822\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m# No point gathering the predictions if there are no metrics, otherwise we defer to\u001b[39;49;00m\n\u001b[1;32m 2823\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m# self.args.prediction_loss_only\u001b[39;49;00m\n\u001b[1;32m 2824\u001b[0m \u001b[43m \u001b[49m\u001b[43mprediction_loss_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_metrics\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mis\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 2825\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2826\u001b[0m \u001b[43m \u001b[49m\u001b[43mmetric_key_prefix\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmetric_key_prefix\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2827\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2829\u001b[0m total_batch_size \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39meval_batch_size \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mworld_size\n\u001b[1;32m 2830\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmetric_key_prefix\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m_jit_compilation_time\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m output\u001b[38;5;241m.\u001b[39mmetrics:\n",
1316
- "File \u001b[0;32m~/.venv/lib/python3.8/site-packages/transformers/trainer.py:3001\u001b[0m, in \u001b[0;36mTrainer.evaluation_loop\u001b[0;34m(self, dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)\u001b[0m\n\u001b[1;32m 2998\u001b[0m batch_size \u001b[38;5;241m=\u001b[39m observed_batch_size\n\u001b[1;32m 3000\u001b[0m \u001b[38;5;66;03m# Prediction step\u001b[39;00m\n\u001b[0;32m-> 3001\u001b[0m loss, logits, labels \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprediction_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprediction_loss_only\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mignore_keys\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3002\u001b[0m inputs_decode \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_prepare_input(inputs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minput_ids\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;28;01mif\u001b[39;00m args\u001b[38;5;241m.\u001b[39minclude_inputs_for_metrics \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 3004\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_torch_tpu_available():\n",
1317
- "File \u001b[0;32m~/.venv/lib/python3.8/site-packages/transformers/trainer_seq2seq.py:213\u001b[0m, in \u001b[0;36mSeq2SeqTrainer.prediction_step\u001b[0;34m(self, model, inputs, prediction_loss_only, ignore_keys)\u001b[0m\n\u001b[1;32m 211\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_labels:\n\u001b[1;32m 212\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcompute_loss_context_manager():\n\u001b[0;32m--> 213\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 214\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlabel_smoother \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 215\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlabel_smoother(outputs, inputs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabels\u001b[39m\u001b[38;5;124m\"\u001b[39m])\u001b[38;5;241m.\u001b[39mmean()\u001b[38;5;241m.\u001b[39mdetach()\n",
1318
- "File \u001b[0;32m~/.venv/lib/python3.8/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1186\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
1319
- "File \u001b[0;32m~/.venv/lib/python3.8/site-packages/transformers/models/whisper/modeling_whisper.py:1197\u001b[0m, in \u001b[0;36mWhisperForConditionalGeneration.forward\u001b[0;34m(self, input_features, decoder_input_ids, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, past_key_values, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m decoder_input_ids \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m decoder_inputs_embeds \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1193\u001b[0m decoder_input_ids \u001b[38;5;241m=\u001b[39m shift_tokens_right(\n\u001b[1;32m 1194\u001b[0m labels, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mpad_token_id, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mdecoder_start_token_id\n\u001b[1;32m 1195\u001b[0m )\n\u001b[0;32m-> 1197\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1198\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_features\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1199\u001b[0m \u001b[43m \u001b[49m\u001b[43mdecoder_input_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdecoder_input_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1200\u001b[0m \u001b[43m \u001b[49m\u001b[43mencoder_outputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoder_outputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1201\u001b[0m \u001b[43m \u001b[49m\u001b[43mdecoder_attention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdecoder_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1202\u001b[0m \u001b[43m \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1203\u001b[0m \u001b[43m \u001b[49m\u001b[43mdecoder_head_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdecoder_head_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1204\u001b[0m \u001b[43m \u001b[49m\u001b[43mcross_attn_head_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcross_attn_head_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1205\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1206\u001b[0m \u001b[43m \u001b[49m\u001b[43mdecoder_inputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdecoder_inputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1207\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1208\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1209\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1210\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1211\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1212\u001b[0m lm_logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mproj_out(outputs[\u001b[38;5;241m0\u001b[39m])\n\u001b[1;32m 1214\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
1320
- "File \u001b[0;32m~/.venv/lib/python3.8/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1186\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
1321
- "File \u001b[0;32m~/.venv/lib/python3.8/site-packages/transformers/models/whisper/modeling_whisper.py:1066\u001b[0m, in \u001b[0;36mWhisperModel.forward\u001b[0;34m(self, input_features, decoder_input_ids, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, past_key_values, decoder_inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1059\u001b[0m encoder_outputs \u001b[38;5;241m=\u001b[39m BaseModelOutput(\n\u001b[1;32m 1060\u001b[0m last_hidden_state\u001b[38;5;241m=\u001b[39mencoder_outputs[\u001b[38;5;241m0\u001b[39m],\n\u001b[1;32m 1061\u001b[0m hidden_states\u001b[38;5;241m=\u001b[39mencoder_outputs[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(encoder_outputs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1062\u001b[0m attentions\u001b[38;5;241m=\u001b[39mencoder_outputs[\u001b[38;5;241m2\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(encoder_outputs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m2\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1063\u001b[0m )\n\u001b[1;32m 1065\u001b[0m \u001b[38;5;66;03m# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)\u001b[39;00m\n\u001b[0;32m-> 1066\u001b[0m decoder_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecoder\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1067\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdecoder_input_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1068\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdecoder_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1069\u001b[0m \u001b[43m \u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoder_outputs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1070\u001b[0m \u001b[43m \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdecoder_head_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1071\u001b[0m \u001b[43m \u001b[49m\u001b[43mcross_attn_head_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcross_attn_head_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1072\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1073\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdecoder_inputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1074\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1075\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1076\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1077\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1078\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1080\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m return_dict:\n\u001b[1;32m 1081\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m decoder_outputs \u001b[38;5;241m+\u001b[39m encoder_outputs\n",
1322
- "File \u001b[0;32m~/.venv/lib/python3.8/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1186\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
1323
- "File \u001b[0;32m~/.venv/lib/python3.8/site-packages/transformers/models/whisper/modeling_whisper.py:866\u001b[0m, in \u001b[0;36mWhisperDecoder.forward\u001b[0;34m(self, input_ids, attention_mask, encoder_hidden_states, head_mask, cross_attn_head_mask, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 863\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m inputs_embeds \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 864\u001b[0m inputs_embeds \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39membed_tokens(input_ids)\n\u001b[0;32m--> 866\u001b[0m attention_mask \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_prepare_decoder_attention_mask\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 867\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minput_shape\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpast_key_values_length\u001b[49m\n\u001b[1;32m 868\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 870\u001b[0m \u001b[38;5;66;03m# embed positions\u001b[39;00m\n\u001b[1;32m 871\u001b[0m positions \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39membed_positions(input_ids, past_key_values_length\u001b[38;5;241m=\u001b[39mpast_key_values_length)\n",
1324
- "File \u001b[0;32m~/.venv/lib/python3.8/site-packages/transformers/models/whisper/modeling_whisper.py:758\u001b[0m, in \u001b[0;36mWhisperDecoder._prepare_decoder_attention_mask\u001b[0;34m(self, attention_mask, input_shape, inputs_embeds, past_key_values_length)\u001b[0m\n\u001b[1;32m 755\u001b[0m combined_attention_mask \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 757\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m input_shape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[0;32m--> 758\u001b[0m combined_attention_mask \u001b[38;5;241m=\u001b[39m \u001b[43m_make_causal_mask\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 759\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_shape\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpast_key_values_length\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values_length\u001b[49m\n\u001b[1;32m 760\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 762\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m attention_mask \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 763\u001b[0m \u001b[38;5;66;03m# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\u001b[39;00m\n\u001b[1;32m 764\u001b[0m expanded_attn_mask \u001b[38;5;241m=\u001b[39m _expand_mask(attention_mask, inputs_embeds\u001b[38;5;241m.\u001b[39mdtype, tgt_len\u001b[38;5;241m=\u001b[39minput_shape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m])\n",
1325
- "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
1326
- ]
1327
  }
1328
  ],
1329
  "source": [
@@ -1352,7 +1023,7 @@
1352
  },
1353
  {
1354
  "cell_type": "code",
1355
- "execution_count": 27,
1356
  "id": "6dd0e310-9b07-4133-ac14-2ed2d7524e22",
1357
  "metadata": {},
1358
  "outputs": [],
@@ -1378,263 +1049,13 @@
1378
  },
1379
  {
1380
  "cell_type": "code",
1381
- "execution_count": 28,
1382
  "id": "95737cda-c5dd-4887-a4d0-dfcb0d61d977",
1383
  "metadata": {},
1384
- "outputs": [
1385
- {
1386
- "name": "stderr",
1387
- "output_type": "stream",
1388
- "text": [
1389
- "Saving model checkpoint to ./\n",
1390
- "Configuration saved in ./config.json\n",
1391
- "Model weights saved in ./pytorch_model.bin\n",
1392
- "Feature extractor saved in ./preprocessor_config.json\n",
1393
- "tokenizer config file saved in ./tokenizer_config.json\n",
1394
- "Special tokens file saved in ./special_tokens_map.json\n",
1395
- "added tokens file saved in ./added_tokens.json\n"
1396
- ]
1397
- },
1398
- {
1399
- "data": {
1400
- "application/vnd.jupyter.widget-view+json": {
1401
- "model_id": "a47d7e61b9144723a4208cc4cc492eee",
1402
- "version_major": 2,
1403
- "version_minor": 0
1404
- },
1405
- "text/plain": [
1406
- "Upload file pytorch_model.bin: 0%| | 32.0k/922M [00:00<?, ?B/s]"
1407
- ]
1408
- },
1409
- "metadata": {},
1410
- "output_type": "display_data"
1411
- },
1412
- {
1413
- "data": {
1414
- "application/vnd.jupyter.widget-view+json": {
1415
- "model_id": "a7eb0d82c2fd4f978981915aa2314463",
1416
- "version_major": 2,
1417
- "version_minor": 0
1418
- },
1419
- "text/plain": [
1420
- "Upload file runs/Dec12_04-37-47_150-136-44-233/events.out.tfevents.1670819878.150-136-44-233.69039.0: 100%|###…"
1421
- ]
1422
- },
1423
- "metadata": {},
1424
- "output_type": "display_data"
1425
- },
1426
- {
1427
- "name": "stderr",
1428
- "output_type": "stream",
1429
- "text": [
1430
- "remote: Scanning LFS files for validity, may be slow... \n",
1431
- "remote: LFS file scan complete. \n",
1432
- "To https://huggingface.co/kimbochen/whisper-small-jp\n",
1433
- " d83a98f..0ff52f0 main -> main\n",
1434
- "\n",
1435
- "Dropping the following result as it does not have all the necessary fields:\n",
1436
- "{'task': {'name': 'Automatic Speech Recognition', 'type': 'automatic-speech-recognition'}, 'dataset': {'name': 'Common Voice 11.0', 'type': 'mozilla-foundation/common_voice_11_0', 'config': 'ja', 'split': 'test', 'args': 'ja'}}\n",
1437
- "To https://huggingface.co/kimbochen/whisper-small-jp\n",
1438
- " 0ff52f0..22e3a01 main -> main\n",
1439
- "\n"
1440
- ]
1441
- },
1442
- {
1443
- "data": {
1444
- "text/plain": [
1445
- "'https://huggingface.co/kimbochen/whisper-small-jp/commit/0ff52f0f1d63daf816427096a83f7bbf8f3892eb'"
1446
- ]
1447
- },
1448
- "execution_count": 28,
1449
- "metadata": {},
1450
- "output_type": "execute_result"
1451
- }
1452
- ],
1453
  "source": [
1454
  "trainer.push_to_hub(**kwargs)"
1455
  ]
1456
- },
1457
- {
1458
- "cell_type": "code",
1459
- "execution_count": 28,
1460
- "id": "4df1603c-ef35-40f1-ae57-3214441073c8",
1461
- "metadata": {},
1462
- "outputs": [
1463
- {
1464
- "name": "stderr",
1465
- "output_type": "stream",
1466
- "text": [
1467
- "PyTorch: setting up devices\n"
1468
- ]
1469
- }
1470
- ],
1471
- "source": [
1472
- "training_args = Seq2SeqTrainingArguments(\n",
1473
- " output_dir=\"./\",\n",
1474
- " per_device_train_batch_size=64,\n",
1475
- " gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size\n",
1476
- " learning_rate=1e-5,\n",
1477
- " max_steps=1000,\n",
1478
- " num_train_epochs=-1,\n",
1479
- " gradient_checkpointing=True,\n",
1480
- " fp16=True,\n",
1481
- " evaluation_strategy=\"steps\",\n",
1482
- " per_device_eval_batch_size=8,\n",
1483
- " predict_with_generate=True,\n",
1484
- " generation_max_length=225,\n",
1485
- " save_steps=1000,\n",
1486
- " eval_steps=1000,\n",
1487
- " logging_steps=25,\n",
1488
- " report_to=[\"tensorboard\"],\n",
1489
- " load_best_model_at_end=True,\n",
1490
- " metric_for_best_model=\"wer\",\n",
1491
- " greater_is_better=False,\n",
1492
- " push_to_hub=True,\n",
1493
- ")"
1494
- ]
1495
- },
1496
- {
1497
- "cell_type": "code",
1498
- "execution_count": 29,
1499
- "id": "afc2b554-7171-48c7-95aa-b7e61b70ab20",
1500
- "metadata": {},
1501
- "outputs": [
1502
- {
1503
- "name": "stderr",
1504
- "output_type": "stream",
1505
- "text": [
1506
- "/home/ubuntu/whisper-small-jp/./ is already a clone of https://huggingface.co/kimbochen/whisper-small-jp. Make sure you pull the latest changes with `repo.git_pull()`.\n",
1507
- "max_steps is given, it will override any value given in num_train_epochs\n",
1508
- "Using cuda_amp half precision backend\n"
1509
- ]
1510
- }
1511
- ],
1512
- "source": [
1513
- "trainer = Seq2SeqTrainer(\n",
1514
- " args=training_args,\n",
1515
- " model=model,\n",
1516
- " train_dataset=vectorized_datasets[\"train\"],\n",
1517
- " eval_dataset=vectorized_datasets[\"test\"],\n",
1518
- " data_collator=data_collator,\n",
1519
- " compute_metrics=compute_metrics,\n",
1520
- " tokenizer=processor,\n",
1521
- " callbacks=[ShuffleCallback()],\n",
1522
- ")"
1523
- ]
1524
- },
1525
- {
1526
- "cell_type": "code",
1527
- "execution_count": 30,
1528
- "id": "b029a1d8-24de-46e7-b067-0f900b1db342",
1529
- "metadata": {},
1530
- "outputs": [
1531
- {
1532
- "name": "stderr",
1533
- "output_type": "stream",
1534
- "text": [
1535
- "Loading model from checkpoint-4000.\n",
1536
- "/home/ubuntu/.venv/lib/python3.8/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
1537
- " warnings.warn(\n",
1538
- "***** Running training *****\n",
1539
- " Num examples = 64000\n",
1540
- " Num Epochs = 9223372036854775807\n",
1541
- " Instantaneous batch size per device = 64\n",
1542
- " Total train batch size (w. parallel, distributed & accumulation) = 64\n",
1543
- " Gradient Accumulation steps = 1\n",
1544
- " Total optimization steps = 1000\n",
1545
- " Number of trainable parameters = 241734912\n",
1546
- " Continuing training from checkpoint, will skip to saved global_step\n",
1547
- " Continuing training from epoch 4\n",
1548
- " Continuing training from global step 4000\n",
1549
- " Will skip the first 4 epochs then the first 0 batches in the first epoch. If this takes a lot of time, you can add the `--ignore_data_skip` flag to your launch command, but you will resume the training on data already seen by your model.\n"
1550
- ]
1551
- },
1552
- {
1553
- "data": {
1554
- "application/vnd.jupyter.widget-view+json": {
1555
- "model_id": "01337298313740d98d3cc75b6d5e3ff7",
1556
- "version_major": 2,
1557
- "version_minor": 0
1558
- },
1559
- "text/plain": [
1560
- "0it [00:00, ?it/s]"
1561
- ]
1562
- },
1563
- "metadata": {},
1564
- "output_type": "display_data"
1565
- },
1566
- {
1567
- "name": "stderr",
1568
- "output_type": "stream",
1569
- "text": [
1570
- "\n",
1571
- "Reading metadata...: 0it [00:00, ?it/s]\u001b[A\n",
1572
- "Reading metadata...: 6505it [00:00, 34246.80it/s]\n",
1573
- "The following columns in the training set don't have a corresponding argument in `WhisperForConditionalGeneration.forward` and have been ignored: input_length. If input_length are not expected by `WhisperForConditionalGeneration.forward`, you can safely ignore this message.\n",
1574
- "\n",
1575
- "Reading metadata...: 6505it [00:00, 84823.64it/s]\n",
1576
- "\n",
1577
- "Reading metadata...: 6505it [00:00, 88617.62it/s]\n",
1578
- "\n",
1579
- "Reading metadata...: 6505it [00:00, 90289.78it/s]\n",
1580
- "\n",
1581
- "Reading metadata...: 6505it [00:00, 91816.92it/s]\n"
1582
- ]
1583
- },
1584
- {
1585
- "data": {
1586
- "text/html": [
1587
- "\n",
1588
- " <div>\n",
1589
- " \n",
1590
- " <progress value='4001' max='1000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1591
- " [1000/1000 00:00, Epoch 4/9223372036854775807]\n",
1592
- " </div>\n",
1593
- " <table border=\"1\" class=\"dataframe\">\n",
1594
- " <thead>\n",
1595
- " <tr style=\"text-align: left;\">\n",
1596
- " <th>Step</th>\n",
1597
- " <th>Training Loss</th>\n",
1598
- " <th>Validation Loss</th>\n",
1599
- " </tr>\n",
1600
- " </thead>\n",
1601
- " <tbody>\n",
1602
- " </tbody>\n",
1603
- "</table><p>"
1604
- ],
1605
- "text/plain": [
1606
- "<IPython.core.display.HTML object>"
1607
- ]
1608
- },
1609
- "metadata": {},
1610
- "output_type": "display_data"
1611
- },
1612
- {
1613
- "name": "stderr",
1614
- "output_type": "stream",
1615
- "text": [
1616
- "\n",
1617
- "\n",
1618
- "Training completed. Do not forget to share your model on huggingface.co/models =)\n",
1619
- "\n",
1620
- "\n",
1621
- "Loading best model from ./checkpoint-4000 (score: 88.31039863810469).\n"
1622
- ]
1623
- },
1624
- {
1625
- "data": {
1626
- "text/plain": [
1627
- "TrainOutput(global_step=4001, training_loss=8.343380785802548e-08, metrics={'train_runtime': 169.0541, 'train_samples_per_second': 378.577, 'train_steps_per_second': 5.915, 'total_flos': 7.363747084345344e+19, 'train_loss': 8.343380785802548e-08, 'epoch': 4.0})"
1628
- ]
1629
- },
1630
- "execution_count": 30,
1631
- "metadata": {},
1632
- "output_type": "execute_result"
1633
- }
1634
- ],
1635
- "source": [
1636
- "trainer.train(\"checkpoint-4000\")"
1637
- ]
1638
  }
1639
  ],
1640
  "metadata": {
 
19
  {
20
  "cell_type": "markdown",
21
  "id": "afe0d503-ae4e-4aa7-9af4-dbcba52db41e",
22
+ "metadata": {
23
+ "jp-MarkdownHeadingCollapsed": true,
24
+ "tags": []
25
+ },
26
  "source": [
27
  "## Introduction"
28
  ]
 
111
  },
112
  {
113
  "cell_type": "code",
114
+ "execution_count": 1,
115
  "id": "065a8cf7-e54f-4ac3-900e-609c80714fca",
116
  "metadata": {},
117
+ "outputs": [
118
+ {
119
+ "name": "stderr",
120
+ "output_type": "stream",
121
+ "text": [
122
+ "/home/ubuntu/.venv/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
123
+ " from .autonotebook import tqdm as notebook_tqdm\n"
124
+ ]
125
+ }
126
+ ],
127
  "source": [
128
  "from datasets import interleave_datasets, load_dataset\n",
129
  "\n",
 
154
  },
155
  {
156
  "cell_type": "code",
157
+ "execution_count": 2,
158
  "id": "a2787582-554f-44ce-9f38-4180a5ed6b44",
159
  "metadata": {},
160
  "outputs": [],
 
197
  },
198
  {
199
  "cell_type": "code",
200
+ "execution_count": 3,
201
  "id": "77d9f0c5-8607-4642-a8ac-c3ab2e223ea6",
202
  "metadata": {},
203
  "outputs": [],
 
225
  },
226
  {
227
  "cell_type": "code",
228
+ "execution_count": 4,
229
  "id": "ab5a13b4-9bd4-4aa0-aef2-b3de9b762988",
230
  "metadata": {},
231
  "outputs": [
 
245
  " 'segment': Value(dtype='string', id=None)}"
246
  ]
247
  },
248
+ "execution_count": 4,
249
  "metadata": {},
250
  "output_type": "execute_result"
251
  }
 
271
  },
272
  {
273
  "cell_type": "code",
274
+ "execution_count": 5,
275
  "id": "3ab6a724-3d1e-478b-a9e9-d2f85feb6c39",
276
  "metadata": {},
277
  "outputs": [],
 
291
  },
292
  {
293
  "cell_type": "code",
294
+ "execution_count": 6,
295
  "id": "d041650e-1c48-4439-87b3-5b6f4a514107",
296
  "metadata": {},
297
  "outputs": [],
 
318
  },
319
  {
320
  "cell_type": "code",
321
+ "execution_count": 7,
322
+ "id": "ce788f4e-1270-424d-b1f3-a10d984ddb31",
323
+ "metadata": {},
324
+ "outputs": [],
325
+ "source": [
326
+ "from fugashi import Tagger\n",
327
+ "tagger = Tagger('-Owakati')"
328
+ ]
329
+ },
330
+ {
331
+ "cell_type": "code",
332
+ "execution_count": 8,
333
+ "id": "c858c814-6d32-472e-afe7-2f7273b244ba",
334
+ "metadata": {},
335
+ "outputs": [],
336
+ "source": [
337
+ "FULL2HALF = dict((i + 0xFEE0, i) for i in range(0x21, 0x7F))\n",
338
+ "FULL2HALF[0x3000] = 0x20"
339
+ ]
340
+ },
341
+ {
342
+ "cell_type": "code",
343
+ "execution_count": 9,
344
  "id": "c085911c-a10a-41ef-8874-306e0503e9bb",
345
  "metadata": {},
346
  "outputs": [],
 
360
  " transcription = transcription.lower()\n",
361
  " if do_remove_punctuation:\n",
362
  " transcription = normalizer(transcription).strip()\n",
363
+ "\n",
364
+ " input_str = transcription.translate(FULL2HALF)\n",
365
+ " input_str = tagger.parse(input_str)\n",
366
+ "\n",
367
  " # encode target text to label ids\n",
368
+ " batch[\"labels\"] = processor.tokenizer(input_str).input_ids\n",
 
369
  " return batch"
370
  ]
371
  },
 
379
  },
380
  {
381
  "cell_type": "code",
382
+ "execution_count": 10,
383
  "id": "a37a7cdb-9013-427f-8de9-6a8d0e9dc684",
384
  "metadata": {},
385
  "outputs": [],
 
397
  },
398
  {
399
  "cell_type": "code",
400
+ "execution_count": 11,
401
  "id": "1b145699-acfc-4b1d-93a2-a2ad3d62674c",
402
  "metadata": {},
403
  "outputs": [],
 
418
  },
419
  {
420
  "cell_type": "code",
421
+ "execution_count": 12,
422
  "id": "01cb25ef-4bb0-4325-9461-f59198acadf6",
423
  "metadata": {},
424
  "outputs": [],
 
439
  },
440
  {
441
  "cell_type": "code",
442
+ "execution_count": 13,
443
  "id": "333f7f6e-6053-4d3b-8924-c733c79b82ac",
444
  "metadata": {},
445
  "outputs": [],
 
452
  },
453
  {
454
  "cell_type": "code",
455
+ "execution_count": 14,
456
  "id": "bede1184",
457
+ "metadata": {
458
+ "jupyter": {
459
+ "outputs_hidden": true
460
+ },
461
+ "scrolled": true,
462
+ "tags": []
463
+ },
464
  "outputs": [
465
  {
466
  "name": "stderr",
467
  "output_type": "stream",
468
  "text": [
469
+ "Reading metadata...: 6505it [00:00, 79889.28it/s]\n",
470
+ "Reading metadata...: 4485it [00:00, 81713.25it/s]\n"
471
  ]
472
  },
473
  {
474
  "data": {
475
  "text/plain": [
476
+ "[50258, 50266, 50359, 50363, 6392, 11046, 26923, 2605, 116, 16746]"
477
  ]
478
  },
479
+ "execution_count": 14,
480
  "metadata": {},
481
  "output_type": "execute_result"
482
  }
483
  ],
484
  "source": [
485
  "xb = next(iter(vectorized_datasets['train']))\n",
486
+ "xb['labels'][:10]"
487
  ]
488
  },
489
  {
490
  "cell_type": "code",
491
+ "execution_count": 15,
492
  "id": "ac1e8d5b",
493
  "metadata": {},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
494
  "outputs": [
495
  {
496
  "data": {
497
  "text/plain": [
498
+ "'<|startoftranscript|><|ja|><|transcribe|><|notimestamps|>多から へ と いう の は 、 世界因果 決定 考える こと で ある 、 過去 から 考える こと で ある 、 機械 的 に 考える こと で ある 。<|endoftext|>'"
499
  ]
500
  },
501
+ "execution_count": 15,
502
  "metadata": {},
503
  "output_type": "execute_result"
504
  }
505
  ],
506
  "source": [
507
+ "processor.tokenizer.decode(xb['labels'])"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
  ]
509
  },
510
  {
 
644
  "execution_count": 18,
645
  "id": "b22b4011-f31f-4b57-b684-c52332f92890",
646
  "metadata": {},
647
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
648
  "source": [
649
  "import evaluate\n",
650
  "\n",
 
723
  "execution_count": 20,
724
  "id": "5a10cc4b-07ec-4ebd-ac1d-7c601023594f",
725
  "metadata": {},
726
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
727
  "source": [
728
  "from transformers import WhisperForConditionalGeneration\n",
729
  "\n",
 
758
  "### Define the Training Configuration"
759
  ]
760
  },
761
+ {
762
+ "cell_type": "markdown",
763
+ "id": "393c883e-3e50-492c-bd58-f51dbf15ee56",
764
+ "metadata": {},
765
+ "source": [
766
+ "We then define a custom [Callback](https://huggingface.co/docs/transformers/main_classes/callback) that is called by the 🤗 Trainer on the end of each epoch. The Callback reinitialises and reshuffles the streaming dataset at the beginning of each new epoch - this gives different shuffling across our subsets for every epoch."
767
+ ]
768
+ },
769
+ {
770
+ "cell_type": "code",
771
+ "execution_count": 22,
772
+ "id": "3ac16b62-b3c0-4c68-8f3d-9ecf471534b2",
773
+ "metadata": {},
774
+ "outputs": [],
775
+ "source": [
776
+ "from transformers import TrainerCallback\n",
777
+ "from transformers.trainer_pt_utils import IterableDatasetShard\n",
778
+ "from torch.utils.data import IterableDataset\n",
779
+ "\n",
780
+ "# trainer callback to reinitialise and reshuffle the streamable datasets at the beginning of each epoch\n",
781
+ "class ShuffleCallback(TrainerCallback):\n",
782
+ " def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):\n",
783
+ " if isinstance(train_dataloader.dataset, IterableDatasetShard):\n",
784
+ " pass # set_epoch() is handled by the Trainer\n",
785
+ " elif isinstance(train_dataloader.dataset, IterableDataset):\n",
786
+ " train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)"
787
+ ]
788
+ },
789
  {
790
  "cell_type": "markdown",
791
  "id": "c21af1e9-0188-4134-ac82-defc7bdcc436",
 
796
  },
797
  {
798
  "cell_type": "code",
799
+ "execution_count": 23,
800
  "id": "0ae3e9af-97b7-4aa0-ae85-20b23b5bcb3a",
801
  "metadata": {},
802
  "outputs": [],
 
808
  " per_device_train_batch_size=64,\n",
809
  " gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size\n",
810
  " learning_rate=1e-5,\n",
811
+ " warmup_steps=500,\n",
812
  " max_steps=1000,\n",
813
  " gradient_checkpointing=True,\n",
814
  " fp16=True,\n",
815
  " evaluation_strategy=\"steps\",\n",
816
+ " per_device_eval_batch_size=32,\n",
817
  " predict_with_generate=True,\n",
818
  " generation_max_length=225,\n",
819
  " save_steps=200,\n",
820
+ " eval_steps=100,\n",
821
  " logging_steps=25,\n",
822
  " report_to=[\"tensorboard\"],\n",
823
  " load_best_model_at_end=True,\n",
 
836
  "set `push_to_hub=False`."
837
  ]
838
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
839
  {
840
  "cell_type": "markdown",
841
  "id": "bac29114-d226-4f54-97cf-8718c9f94e1e",
 
932
  },
933
  {
934
  "cell_type": "code",
935
+ "execution_count": null,
936
  "id": "ee8b7b8e-1c9a-4d77-9137-1778a629e6de",
937
  "metadata": {
938
+ "scrolled": false
939
  },
940
  "outputs": [
941
  {
 
952
  " Gradient Accumulation steps = 1\n",
953
  " Total optimization steps = 1000\n",
954
  " Number of trainable parameters = 241734912\n",
955
+ "Reading metadata...: 6505it [00:00, 33561.53it/s]\n",
956
+ "Reading metadata...: 4485it [00:00, 24503.43it/s]\n",
957
  "The following columns in the training set don't have a corresponding argument in `WhisperForConditionalGeneration.forward` and have been ignored: input_length. If input_length are not expected by `WhisperForConditionalGeneration.forward`, you can safely ignore this message.\n"
958
  ]
959
  },
 
963
  "\n",
964
  " <div>\n",
965
  " \n",
966
+ " <progress value='101' max='1000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
967
+ " [ 101/1000 11:45 < 1:46:47, 0.14 it/s, Epoch 0.10/9223372036854775807]\n",
968
  " </div>\n",
969
  " <table border=\"1\" class=\"dataframe\">\n",
970
  " <thead>\n",
 
972
  " <th>Step</th>\n",
973
  " <th>Training Loss</th>\n",
974
  " <th>Validation Loss</th>\n",
 
975
  " </tr>\n",
976
  " </thead>\n",
977
  " <tbody>\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
978
  " </tbody>\n",
979
  "</table><p>"
980
  ],
 
989
  "name": "stderr",
990
  "output_type": "stream",
991
  "text": [
 
 
992
  "***** Running Evaluation *****\n",
993
  " Num examples: Unknown\n",
994
+ " Batch size = 32\n",
995
+ "Reading metadata...: 4604it [00:00, 79017.99it/s]\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
996
  "The following columns in the evaluation set don't have a corresponding argument in `WhisperForConditionalGeneration.forward` and have been ignored: input_length. If input_length are not expected by `WhisperForConditionalGeneration.forward`, you can safely ignore this message.\n"
997
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
998
  }
999
  ],
1000
  "source": [
 
1023
  },
1024
  {
1025
  "cell_type": "code",
1026
+ "execution_count": null,
1027
  "id": "6dd0e310-9b07-4133-ac14-2ed2d7524e22",
1028
  "metadata": {},
1029
  "outputs": [],
 
1049
  },
1050
  {
1051
  "cell_type": "code",
1052
+ "execution_count": null,
1053
  "id": "95737cda-c5dd-4887-a4d0-dfcb0d61d977",
1054
  "metadata": {},
1055
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1056
  "source": [
1057
  "trainer.push_to_hub(**kwargs)"
1058
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1059
  }
1060
  ],
1061
  "metadata": {
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f9e06d3e42c138244efef713a71edd16e7d80d9f5c735a8f6d28405049e4324d
3
  size 967102601
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a52725380668355f0448b3218516e7537f6c679336c6c10c74fef50e1b86494e
3
  size 967102601
run.sh ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python run_speech_recognition_seq2seq_streaming.py \
2
+ --output_dir="./" \
3
+ --model_name_or_path="openai/whisper-small" \
4
+ --model_index_name="Whisper Small Japanese" \
5
+ --dataset_name="mozilla-foundation/common_voice_11_0" \
6
+ --dataset_config_name="ja" \
7
+ --language="japanese" \
8
+ --train_split_name="train+validation" \
9
+ --eval_split_name="test" \
10
+ --learning_rate="1e-5" \
11
+ --per_device_train_batch_size="64" \
12
+ --per_device_eval_batch_size="32" \
13
+ --max_steps="1000" \
14
+ --warmup_steps="500" \
15
+ --logging_steps="25" \
16
+ --evaluation_strategy="steps" \
17
+ --eval_steps="200" \
18
+ --save_strategy="steps" \
19
+ --save_steps="200" \
20
+ --generation_max_length="225" \
21
+ --length_column_name="input_length" \
22
+ --max_duration_in_seconds="30" \
23
+ --text_column_name="sentence" \
24
+ --freeze_feature_encoder="False" \
25
+ --report_to="tensorboard" \
26
+ --metric_for_best_model="wer" \
27
+ --greater_is_better="False" \
28
+ --load_best_model_at_end \
29
+ --gradient_checkpointing \
30
+ --fp16 \
31
+ --overwrite_output_dir \
32
+ --do_train \
33
+ --do_eval \
34
+ --predict_with_generate \
35
+ --do_normalize_eval \
36
+ --streaming \
37
+ --use_auth_token \
38
+ --push_to_hub
run_speech_recognition_seq2seq_streaming.py ADDED
@@ -0,0 +1,644 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for sequence to sequence speech recognition
18
+ with 🤗 Datasets' streaming mode.
19
+ """
20
+ # You can also adapt this script for your own sequence to sequence speech
21
+ # recognition task. Pointers for this are left as comments.
22
+
23
+ import logging
24
+ import os
25
+ import sys
26
+ from dataclasses import dataclass, field
27
+ from typing import Any, Dict, List, Optional, Union
28
+
29
+ import datasets
30
+ import torch
31
+ from datasets import DatasetDict, IterableDatasetDict, interleave_datasets, load_dataset
32
+ from torch.utils.data import IterableDataset
33
+
34
+ import evaluate
35
+ import transformers
36
+ from transformers import (
37
+ AutoConfig,
38
+ AutoFeatureExtractor,
39
+ AutoModelForSpeechSeq2Seq,
40
+ AutoProcessor,
41
+ AutoTokenizer,
42
+ HfArgumentParser,
43
+ Seq2SeqTrainer,
44
+ Seq2SeqTrainingArguments,
45
+ TrainerCallback,
46
+ set_seed,
47
+ )
48
+ from transformers.models.whisper.english_normalizer import BasicTextNormalizer
49
+ from transformers.trainer_pt_utils import IterableDatasetShard
50
+ from transformers.trainer_utils import get_last_checkpoint, is_main_process
51
+ from transformers.utils import check_min_version, send_example_telemetry
52
+ from transformers.utils.versions import require_version
53
+
54
+ from fugashi import Tagger
55
+ import warnings
56
+ import logging
57
+
58
+ warnings.simplefilter(action='ignore', category=FutureWarning)
59
+ logging.basicConfig(level=logging.ERROR)
60
+
61
+
62
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
63
+ check_min_version("4.25.0.dev0")
64
+
65
+ require_version("datasets>=1.18.2", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
66
+
67
+ logger = logging.getLogger(__name__)
68
+
69
+
70
+ @dataclass
71
+ class ModelArguments:
72
+ """
73
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
74
+ """
75
+
76
+ model_name_or_path: str = field(
77
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
78
+ )
79
+ config_name: Optional[str] = field(
80
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
81
+ )
82
+ tokenizer_name: Optional[str] = field(
83
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
84
+ )
85
+ feature_extractor_name: Optional[str] = field(
86
+ default=None, metadata={"help": "feature extractor name or path if not the same as model_name"}
87
+ )
88
+ cache_dir: Optional[str] = field(
89
+ default=None,
90
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
91
+ )
92
+ use_fast_tokenizer: bool = field(
93
+ default=True,
94
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
95
+ )
96
+ model_revision: str = field(
97
+ default="main",
98
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
99
+ )
100
+ use_auth_token: bool = field(
101
+ default=False,
102
+ metadata={
103
+ "help": (
104
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
105
+ "with private models)."
106
+ )
107
+ },
108
+ )
109
+ freeze_feature_encoder: bool = field(
110
+ default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
111
+ )
112
+ freeze_encoder: bool = field(
113
+ default=False, metadata={"help": "Whether to freeze the entire encoder of the seq2seq model."}
114
+ )
115
+ forced_decoder_ids: List[List[int]] = field(
116
+ default=None,
117
+ metadata={
118
+ "help": (
119
+ "A list of pairs of integers which indicates a mapping from generation indices to token indices "
120
+ "that will be forced before sampling. For example, [[0, 123]] means the first generated token "
121
+ "will always be a token of index 123."
122
+ )
123
+ },
124
+ )
125
+ suppress_tokens: List[int] = field(
126
+ default=None, metadata={"help": "A list of tokens that will be suppressed at generation."}
127
+ )
128
+ model_index_name: str = field(default=None, metadata={"help": "Pretty name for the model card."})
129
+
130
+
131
+ @dataclass
132
+ class DataTrainingArguments:
133
+ """
134
+ Arguments pertaining to what data we are going to input our model for training and eval.
135
+ """
136
+
137
+ dataset_name: str = field(
138
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
139
+ )
140
+ dataset_config_name: Optional[str] = field(
141
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
142
+ )
143
+ text_column: Optional[str] = field(
144
+ default=None,
145
+ metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
146
+ )
147
+ max_train_samples: Optional[int] = field(
148
+ default=None,
149
+ metadata={
150
+ "help": (
151
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
152
+ "value if set."
153
+ )
154
+ },
155
+ )
156
+ max_eval_samples: Optional[int] = field(
157
+ default=None,
158
+ metadata={
159
+ "help": (
160
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
161
+ "value if set."
162
+ )
163
+ },
164
+ )
165
+ audio_column_name: str = field(
166
+ default="audio",
167
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
168
+ )
169
+ text_column_name: str = field(
170
+ default="text",
171
+ metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
172
+ )
173
+ max_duration_in_seconds: float = field(
174
+ default=20.0,
175
+ metadata={
176
+ "help": (
177
+ "Truncate audio files that are longer than `max_duration_in_seconds` seconds to"
178
+ " 'max_duration_in_seconds`"
179
+ )
180
+ },
181
+ )
182
+ min_duration_in_seconds: float = field(
183
+ default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
184
+ )
185
+ train_split_name: str = field(
186
+ default="train",
187
+ metadata={
188
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
189
+ },
190
+ )
191
+ eval_split_name: str = field(
192
+ default="test",
193
+ metadata={
194
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
195
+ },
196
+ )
197
+ do_lower_case: bool = field(
198
+ default=False,
199
+ metadata={"help": "Whether the target text should be lower cased."},
200
+ )
201
+ do_remove_punctuation: bool = field(
202
+ default=False,
203
+ metadata={"help": "Whether the target text should be striped of punctuation."},
204
+ )
205
+ do_normalize_eval: bool = field(
206
+ default=True,
207
+ metadata={"help": "Whether to normalise the references and predictions in the eval WER calculation."},
208
+ )
209
+ language: str = field(
210
+ default=None,
211
+ metadata={
212
+ "help": (
213
+ "Language for multilingual fine-tuning. This argument should be set for multilingual fine-tuning "
214
+ "only. For English speech recognition, it should be set to `None`."
215
+ )
216
+ },
217
+ )
218
+ task: str = field(
219
+ default="transcribe",
220
+ metadata={"help": "Task, either `transcribe` for speech recognition or `translate` for speech translation."},
221
+ )
222
+ shuffle_buffer_size: Optional[int] = field(
223
+ default=500,
224
+ metadata={
225
+ "help": (
226
+ "The number of streamed examples to download before shuffling them. The large the buffer, "
227
+ "the closer it is to real offline shuffling."
228
+ )
229
+ },
230
+ )
231
+ streaming: bool = field(
232
+ default=True,
233
+ metadata={"help": "Whether to use streaming mode to load and pre-process the data."},
234
+ )
235
+
236
+
237
+ @dataclass
238
+ class DataCollatorSpeechSeq2SeqWithPadding:
239
+ """
240
+ Data collator that will dynamically pad the inputs received.
241
+ Args:
242
+ processor ([`WhisperProcessor`])
243
+ The processor used for processing the data.
244
+ decoder_start_token_id (`int`)
245
+ The begin-of-sentence of the decoder.
246
+ """
247
+
248
+ processor: Any
249
+ decoder_start_token_id: int
250
+
251
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
252
+ # split inputs and labels since they have to be of different lengths and need
253
+ # different padding methods
254
+ model_input_name = self.processor.model_input_names[0]
255
+ input_features = [{model_input_name: feature[model_input_name]} for feature in features]
256
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
257
+
258
+ batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
259
+
260
+ labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
261
+
262
+ # replace padding with -100 to ignore loss correctly
263
+ labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
264
+
265
+ # if bos token is appended in previous tokenization step,
266
+ # cut bos token here as it's append later anyways
267
+ if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
268
+ labels = labels[:, 1:]
269
+
270
+ batch["labels"] = labels
271
+
272
+ return batch
273
+
274
+
275
+ def load_maybe_streaming_dataset(dataset_name, dataset_config_name, split="train", streaming=True, **kwargs):
276
+ """
277
+ Utility function to load a dataset in streaming mode. For datasets with multiple splits,
278
+ each split is loaded individually and then splits combined by taking alternating examples from
279
+ each (interleaving).
280
+ """
281
+ if "+" in split:
282
+ # load multiple splits separated by the `+` symbol with streaming mode
283
+ dataset_splits = [
284
+ load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=streaming, **kwargs)
285
+ for split_name in split.split("+")
286
+ ]
287
+ # interleave multiple splits to form one dataset
288
+ interleaved_dataset = interleave_datasets(dataset_splits)
289
+ return interleaved_dataset
290
+ else:
291
+ # load a single split *with* streaming mode
292
+ dataset = load_dataset(dataset_name, dataset_config_name, split=split, streaming=streaming, **kwargs)
293
+ return dataset
294
+
295
+
296
+ def main():
297
+ # 1. Parse input arguments
298
+ # See all possible arguments in src/transformers/training_args.py
299
+ # or by passing the --help flag to this script.
300
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
301
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
302
+
303
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
304
+ # If we pass only one argument to the script and it's the path to a json file,
305
+ # let's parse it to get our arguments.
306
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
307
+ else:
308
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
309
+
310
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
311
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
312
+ send_example_telemetry("run_speech_recognition_seq2seq_streaming", model_args, data_args)
313
+
314
+ # 2. Setup logging
315
+ logging.basicConfig(
316
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
317
+ datefmt="%m/%d/%Y %H:%M:%S",
318
+ handlers=[logging.StreamHandler(sys.stdout)],
319
+ )
320
+ log_level = training_args.get_process_log_level()
321
+ logger.setLevel(log_level)
322
+ datasets.utils.logging.set_verbosity(log_level)
323
+ transformers.utils.logging.set_verbosity(log_level)
324
+ transformers.utils.logging.enable_default_handler()
325
+ transformers.utils.logging.enable_explicit_format()
326
+
327
+ logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
328
+
329
+ # Log on each process the small summary:
330
+ logger.warning(
331
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
332
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
333
+ )
334
+ logger.info(f"Training/evaluation parameters {training_args}")
335
+
336
+ # Set the verbosity to info of the Transformers logger (on main process only):
337
+ if is_main_process(training_args.local_rank):
338
+ transformers.utils.logging.set_verbosity_info()
339
+ logger.info("Training/evaluation parameters %s", training_args)
340
+
341
+ # 3. Detecting last checkpoint and eventually continue from last checkpoint
342
+ last_checkpoint = None
343
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
344
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
345
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
346
+ raise ValueError(
347
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
348
+ "Use --overwrite_output_dir to overcome."
349
+ )
350
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
351
+ logger.info(
352
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
353
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
354
+ )
355
+
356
+ # Set seed before initializing model.
357
+ set_seed(training_args.seed)
358
+
359
+ # 4. Load dataset
360
+ raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
361
+
362
+ if training_args.do_train:
363
+ raw_datasets["train"] = load_maybe_streaming_dataset(
364
+ data_args.dataset_name,
365
+ data_args.dataset_config_name,
366
+ split=data_args.train_split_name,
367
+ use_auth_token=True if model_args.use_auth_token else None,
368
+ streaming=data_args.streaming,
369
+ )
370
+
371
+ if training_args.do_eval:
372
+ raw_datasets["eval"] = load_maybe_streaming_dataset(
373
+ data_args.dataset_name,
374
+ data_args.dataset_config_name,
375
+ split=data_args.eval_split_name,
376
+ use_auth_token=True if model_args.use_auth_token else None,
377
+ streaming=data_args.streaming,
378
+ )
379
+
380
+ raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
381
+
382
+ if data_args.audio_column_name not in raw_datasets_features:
383
+ raise ValueError(
384
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
385
+ "Make sure to set `--audio_column_name` to the correct audio column - one of "
386
+ f"{', '.join(raw_datasets_features)}."
387
+ )
388
+
389
+ if data_args.text_column_name not in raw_datasets_features:
390
+ raise ValueError(
391
+ f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
392
+ "Make sure to set `--text_column_name` to the correct text column - one of "
393
+ f"{', '.join(raw_datasets_features)}."
394
+ )
395
+
396
+ # 5. Load pretrained model, tokenizer, and feature extractor
397
+ #
398
+ # Distributed training:
399
+ # The .from_pretrained methods guarantee that only one local process can concurrently
400
+ config = AutoConfig.from_pretrained(
401
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
402
+ cache_dir=model_args.cache_dir,
403
+ revision=model_args.model_revision,
404
+ use_auth_token=True if model_args.use_auth_token else None,
405
+ )
406
+
407
+ config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens})
408
+
409
+ if training_args.gradient_checkpointing:
410
+ config.update({"use_cache": False})
411
+
412
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
413
+ model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
414
+ cache_dir=model_args.cache_dir,
415
+ revision=model_args.model_revision,
416
+ use_auth_token=True if model_args.use_auth_token else None,
417
+ )
418
+ tokenizer = AutoTokenizer.from_pretrained(
419
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
420
+ cache_dir=model_args.cache_dir,
421
+ use_fast=model_args.use_fast_tokenizer,
422
+ revision=model_args.model_revision,
423
+ use_auth_token=True if model_args.use_auth_token else None,
424
+ )
425
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
426
+ model_args.model_name_or_path,
427
+ config=config,
428
+ cache_dir=model_args.cache_dir,
429
+ revision=model_args.model_revision,
430
+ use_auth_token=True if model_args.use_auth_token else None,
431
+ )
432
+
433
+ if model.config.decoder_start_token_id is None:
434
+ raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
435
+
436
+ if model_args.freeze_feature_encoder:
437
+ model.freeze_feature_encoder()
438
+
439
+ if model_args.freeze_encoder:
440
+ model.freeze_encoder()
441
+
442
+ if data_args.language is not None:
443
+ # We only need to set the task id when the language is specified (i.e. in a multilingual setting)
444
+ tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
445
+
446
+ # 6. Resample speech dataset if necessary
447
+ dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
448
+ if dataset_sampling_rate != feature_extractor.sampling_rate:
449
+ raw_datasets = raw_datasets.cast_column(
450
+ data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
451
+ )
452
+
453
+ # 7. Preprocessing the datasets.
454
+ # We need to read the audio files as arrays and tokenize the targets.
455
+ max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
456
+ min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
457
+ audio_column_name = data_args.audio_column_name
458
+ text_column_name = data_args.text_column_name
459
+ model_input_name = feature_extractor.model_input_names[0]
460
+ do_lower_case = data_args.do_lower_case
461
+ do_remove_punctuation = data_args.do_remove_punctuation
462
+ normalizer = BasicTextNormalizer() # 'official' text normalizer from OpenAI
463
+
464
+ if data_args.max_train_samples is not None:
465
+ raw_datasets["train"] = (
466
+ raw_datasets["train"].take(data_args.max_train_samples)
467
+ if data_args.streaming
468
+ else raw_datasets["train"].select(range(data_args.max_train_samples))
469
+ )
470
+
471
+ if data_args.max_eval_samples is not None:
472
+ raw_datasets["eval"] = (
473
+ raw_datasets["eval"].take(data_args.max_eval_samples)
474
+ if data_args.streaming
475
+ else raw_datasets["eval"].select(range(data_args.max_eval_samples))
476
+ )
477
+
478
+ tagger = Tagger('-Owakati')
479
+ FULL2HALF = dict((i + 0xFEE0, i) for i in range(0x21, 0x7F))
480
+ FULL2HALF[0x3000] = 0x20
481
+
482
+ def prepare_dataset(batch):
483
+ # process audio
484
+ sample = batch[audio_column_name]
485
+ inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
486
+ # process audio length
487
+ batch[model_input_name] = inputs.get(model_input_name)[0]
488
+ batch["input_length"] = len(sample["array"])
489
+
490
+ # process targets
491
+ input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
492
+ if do_remove_punctuation:
493
+ input_str = normalizer(input_str).strip()
494
+
495
+ input_str = input_str.translate(FULL2HALF)
496
+ input_str = tagger.parse(input_str)
497
+
498
+ batch["labels"] = tokenizer(input_str).input_ids
499
+ return batch
500
+
501
+ with training_args.main_process_first(desc="dataset map pre-processing"):
502
+ vectorized_datasets = raw_datasets.map(
503
+ prepare_dataset,
504
+ remove_columns=raw_datasets_features,
505
+ ).with_format("torch")
506
+
507
+ if training_args.do_train and data_args.streaming:
508
+ # manually shuffle if streaming (done by the trainer for non-streaming)
509
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(
510
+ buffer_size=data_args.shuffle_buffer_size,
511
+ seed=training_args.seed,
512
+ )
513
+
514
+ # filter training data that is shorter than min_input_length or longer than
515
+ # max_input_length
516
+ def is_audio_in_length_range(length):
517
+ return min_input_length < length < max_input_length
518
+
519
+ if training_args.do_train:
520
+ vectorized_datasets["train"] = vectorized_datasets["train"].filter(
521
+ is_audio_in_length_range,
522
+ input_columns=["input_length"],
523
+ )
524
+
525
+ # 8. Load Metric
526
+ metric = evaluate.load("wer")
527
+ do_normalize_eval = data_args.do_normalize_eval
528
+
529
+ def compute_metrics(pred):
530
+ pred_ids = pred.predictions
531
+
532
+ pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id
533
+
534
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
535
+ # we do not want to group tokens when computing the metrics
536
+ label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)
537
+
538
+ if do_normalize_eval:
539
+ pred_str = [normalizer(pred) for pred in pred_str]
540
+ label_str = [normalizer(label) for label in label_str]
541
+ # filtering step to only evaluate the samples that correspond to non-zero references:
542
+ pred_str = [pred_str[i] for i in range(len(pred_str)) if len(label_str[i]) > 0]
543
+ label_str = [label_str[i] for i in range(len(label_str)) if len(label_str[i]) > 0]
544
+
545
+ wer = 100 * metric.compute(predictions=pred_str, references=label_str)
546
+
547
+ return {"wer": wer}
548
+
549
+ # 9. Create a single speech processor
550
+ if is_main_process(training_args.local_rank):
551
+ # save feature extractor, tokenizer and config
552
+ feature_extractor.save_pretrained(training_args.output_dir)
553
+ tokenizer.save_pretrained(training_args.output_dir)
554
+ config.save_pretrained(training_args.output_dir)
555
+
556
+ processor = AutoProcessor.from_pretrained(training_args.output_dir)
557
+
558
+ # 10. Define data collator
559
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(
560
+ processor=processor,
561
+ decoder_start_token_id=model.config.decoder_start_token_id,
562
+ )
563
+
564
+ # 11. Configure Trainer
565
+ # Trainer callback to reinitialise and reshuffle the streamable datasets at the beginning of each epoch
566
+ # Only required for streaming: Trainer automatically shuffles non-streaming datasets
567
+ class ShuffleCallback(TrainerCallback):
568
+ def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):
569
+ if isinstance(train_dataloader.dataset, IterableDatasetShard):
570
+ pass # set_epoch() is handled by the Trainer
571
+ elif isinstance(train_dataloader.dataset, IterableDataset):
572
+ train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)
573
+
574
+ # Initialize Trainer
575
+ trainer = Seq2SeqTrainer(
576
+ model=model,
577
+ args=training_args,
578
+ train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
579
+ eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
580
+ tokenizer=feature_extractor,
581
+ data_collator=data_collator,
582
+ compute_metrics=compute_metrics if training_args.predict_with_generate else None,
583
+ callbacks=[ShuffleCallback()] if data_args.streaming else None,
584
+ )
585
+
586
+ # 12. Training
587
+ if training_args.do_train:
588
+ checkpoint = None
589
+ if training_args.resume_from_checkpoint is not None:
590
+ checkpoint = training_args.resume_from_checkpoint
591
+ elif last_checkpoint is not None:
592
+ checkpoint = last_checkpoint
593
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
594
+ trainer.save_model() # Saves the feature extractor too for easy upload
595
+
596
+ metrics = train_result.metrics
597
+ if data_args.max_train_samples:
598
+ metrics["train_samples"] = data_args.max_train_samples
599
+ trainer.log_metrics("train", metrics)
600
+ trainer.save_metrics("train", metrics)
601
+ trainer.save_state()
602
+
603
+ # 13. Evaluation
604
+ results = {}
605
+ if training_args.do_eval:
606
+ logger.info("*** Evaluate ***")
607
+ metrics = trainer.evaluate(
608
+ metric_key_prefix="eval",
609
+ max_length=training_args.generation_max_length,
610
+ num_beams=training_args.generation_num_beams,
611
+ )
612
+ if data_args.max_eval_samples:
613
+ metrics["eval_samples"] = data_args.max_eval_samples
614
+
615
+ trainer.log_metrics("eval", metrics)
616
+ trainer.save_metrics("eval", metrics)
617
+
618
+ # 14. Write Training Stats
619
+ kwargs = {
620
+ "finetuned_from": model_args.model_name_or_path,
621
+ "tasks": "automatic-speech-recognition",
622
+ "tags": "whisper-event",
623
+ }
624
+ if data_args.dataset_name is not None:
625
+ kwargs["dataset_tags"] = data_args.dataset_name
626
+ if data_args.dataset_config_name is not None:
627
+ kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
628
+ else:
629
+ kwargs["dataset"] = data_args.dataset_name
630
+ if "common_voice" in data_args.dataset_name:
631
+ kwargs["language"] = data_args.dataset_config_name[:2]
632
+ if model_args.model_index_name is not None:
633
+ kwargs["model_name"] = model_args.model_index_name
634
+
635
+ if training_args.push_to_hub:
636
+ trainer.push_to_hub(**kwargs)
637
+ else:
638
+ trainer.create_model_card(**kwargs)
639
+
640
+ return results
641
+
642
+
643
+ if __name__ == "__main__":
644
+ main()
runs/Dec12_18-34-55_129-213-131-105/1670870162.6479564/events.out.tfevents.1670870162.129-213-131-105.68826.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ece26b007dd4241195f82f895387be76fc7cc1c070b15d6f83ae2cc25b41f8c6
3
+ size 5865
runs/Dec12_18-34-55_129-213-131-105/events.out.tfevents.1670870162.129-213-131-105.68826.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:704f9f0cca7b13ef1f58bb28efd54ba803cdba17d54dd80adcf64079cc38ed7f
3
+ size 5370
runs/Dec12_19-04-31_129-213-131-105/1670871909.2493246/events.out.tfevents.1670871909.129-213-131-105.451160.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a4b59cfc6e1d7df33976f6ce86c76510a7f41fead860dec2c69fe7cd2bf5e167
3
+ size 5865
runs/Dec12_19-04-31_129-213-131-105/events.out.tfevents.1670871909.129-213-131-105.451160.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b3452ff74759aee1840801aed3d7356082411fd624c88f84ddd9ceb9bae889f6
3
+ size 5527
runs/Dec12_20-09-15_129-213-131-105/1670875765.5760763/events.out.tfevents.1670875765.129-213-131-105.451160.3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b008ba42cd993b9bd02a9144eff9cf9407ecbbeb529d55c4dc9a1ec94b44df3
3
+ size 5865
runs/Dec12_20-09-15_129-213-131-105/events.out.tfevents.1670875765.129-213-131-105.451160.2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fbdab32f753d53633663898b271753c8b3b5ea9e96877ae405591009e23d7a94
3
+ size 4287
runs/Dec12_20-11-02_129-213-131-105/1670875868.8091414/events.out.tfevents.1670875868.129-213-131-105.451160.5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:23ebc2f17552904c895f3c671b4094d006a214efda5d17991beffa79c33e8d6d
3
+ size 5865
runs/Dec12_20-11-02_129-213-131-105/events.out.tfevents.1670875868.129-213-131-105.451160.4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54ef5fa9c89efaa03ab70ca5b5048b390e7834395f47878d50d4b6c88c5963a5
3
+ size 4287
runs/Dec12_20-13-20_129-213-131-105/1670876009.054387/events.out.tfevents.1670876009.129-213-131-105.983201.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a9d78faf345393965670df0c13ba215d9d2f3388530185c1388dda5a8fa750dc
3
+ size 5865
runs/Dec12_20-13-20_129-213-131-105/events.out.tfevents.1670876009.129-213-131-105.983201.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b488dc85e8c956dad5b3e36797db99372014a9dc6d5b6a04273fa276526b714b
3
+ size 4903
runs/Dec12_21-41-07_129-213-131-105/1670881275.6468236/events.out.tfevents.1670881275.129-213-131-105.1284650.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bfec8cffe4dc0c9cab5530627cb8e095f271d05e9e92a21b424b32ae423e8f49
3
+ size 5871
runs/Dec12_21-41-07_129-213-131-105/events.out.tfevents.1670881275.129-213-131-105.1284650.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a93dc6d2dfd1ac5560c91f552c89c96897229798b978c8df009d41553c206e9
3
+ size 4266
runs/Dec12_21-43-12_129-213-131-105/1670881400.3312242/events.out.tfevents.1670881400.129-213-131-105.1319036.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50a8322f4fb98acacd534d77bca4396a48dc39c5c52ff7aca9cf34b75fe67bf8
3
+ size 5871
runs/Dec12_21-43-12_129-213-131-105/events.out.tfevents.1670881400.129-213-131-105.1319036.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b878099c11902d0b4199ee621f3a5da722a287c0f380503b446684af2e9dc233
3
+ size 4265
runs/Dec12_21-47-11_129-213-131-105/1670881639.3589363/events.out.tfevents.1670881639.129-213-131-105.1364959.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aeab408d92a5d13a0ba12c052c4f63d7fd4000ddd0547a93a98e06e0ded2b8c9
3
+ size 5871
runs/Dec12_21-47-11_129-213-131-105/events.out.tfevents.1670881639.129-213-131-105.1364959.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:32bc02edeb3a6602e541887d81fa5dcd7ffa46c2e2fb356d90f01b73977763eb
3
+ size 4577
runs/Dec12_21-54-54_129-213-131-105/1670882102.7244208/events.out.tfevents.1670882102.129-213-131-105.1405782.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:abcce00ed71b81568f3e4404a3337617a043299eeab15de4aed07163ec387333
3
+ size 5871
runs/Dec12_21-54-54_129-213-131-105/events.out.tfevents.1670882102.129-213-131-105.1405782.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9553563c509a988e7f213ee237136fa3c489c62602abaaa1cacafe7558aff9b
3
+ size 5825
training_args.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:728d6cd7b154a86029fc38c737217977eb35dd910ed073d6628129742d876d7e
3
  size 3579
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f9ea5f3e1c3a7983069a437425a4e13e3bbae038786ac8163dfd02cc6f10148
3
  size 3579