sanchit-gandhi HF staff commited on
Commit
87bc511
1 Parent(s): 8211932

17l2s3v2: saving weights and logs of step 10k

Browse files
.gitattributes CHANGED
@@ -30,3 +30,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
30
  *.zip filter=lfs diff=lfs merge=lfs -text
31
  *.zst filter=lfs diff=lfs merge=lfs -text
32
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
30
  *.zip filter=lfs diff=lfs merge=lfs -text
31
  *.zst filter=lfs diff=lfs merge=lfs -text
32
  *tfevents* filter=lfs diff=lfs merge=lfs -text
33
+ *.wandb filter=lfs diff=lfs merge=lfs -text
config.json ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "sanchit-gandhi/flax-wav2vec2-2-bart-large-scan",
3
+ "architectures": [
4
+ "SpeechEncoderDecoderModel"
5
+ ],
6
+ "decoder": {
7
+ "_name_or_path": "",
8
+ "activation_dropout": 0.2,
9
+ "activation_function": "gelu",
10
+ "add_bias_logits": false,
11
+ "add_cross_attention": true,
12
+ "add_final_layer_norm": false,
13
+ "architectures": [
14
+ "BartModel"
15
+ ],
16
+ "attention_dropout": 0.1,
17
+ "bad_words_ids": null,
18
+ "bos_token_id": 0,
19
+ "chunk_size_feed_forward": 0,
20
+ "classif_dropout": 0.1,
21
+ "classifier_dropout": 0.0,
22
+ "cross_attention_hidden_size": null,
23
+ "d_model": 1024,
24
+ "decoder_attention_heads": 16,
25
+ "decoder_ffn_dim": 4096,
26
+ "decoder_layerdrop": 0.0,
27
+ "decoder_layers": 12,
28
+ "decoder_start_token_id": 2,
29
+ "diversity_penalty": 0.0,
30
+ "do_sample": false,
31
+ "dropout": 0.2,
32
+ "early_stopping": true,
33
+ "encoder_attention_heads": 16,
34
+ "encoder_ffn_dim": 4096,
35
+ "encoder_layerdrop": 0.0,
36
+ "encoder_layers": 12,
37
+ "encoder_no_repeat_ngram_size": 0,
38
+ "eos_token_id": 2,
39
+ "exponential_decay_length_penalty": null,
40
+ "finetuning_task": null,
41
+ "forced_bos_token_id": 0,
42
+ "forced_eos_token_id": 2,
43
+ "fuse_matmuls": false,
44
+ "gradient_checkpointing": true,
45
+ "id2label": {
46
+ "0": "LABEL_0",
47
+ "1": "LABEL_1",
48
+ "2": "LABEL_2"
49
+ },
50
+ "init_std": 0.02,
51
+ "is_decoder": true,
52
+ "is_encoder_decoder": false,
53
+ "label2id": {
54
+ "LABEL_0": 0,
55
+ "LABEL_1": 1,
56
+ "LABEL_2": 2
57
+ },
58
+ "length_penalty": 1.0,
59
+ "max_length": 20,
60
+ "max_position_embeddings": 1024,
61
+ "min_length": 0,
62
+ "model_type": "bart",
63
+ "no_repeat_ngram_size": 3,
64
+ "normalize_before": false,
65
+ "num_beam_groups": 1,
66
+ "num_beams": 4,
67
+ "num_hidden_layers": 12,
68
+ "num_return_sequences": 1,
69
+ "output_attentions": false,
70
+ "output_hidden_states": false,
71
+ "output_scores": false,
72
+ "pad_token_id": 1,
73
+ "prefix": null,
74
+ "problem_type": null,
75
+ "pruned_heads": {},
76
+ "remove_invalid_values": false,
77
+ "repetition_penalty": 1.0,
78
+ "return_dict": true,
79
+ "return_dict_in_generate": false,
80
+ "scale_embedding": false,
81
+ "sep_token_id": null,
82
+ "task_specific_params": {
83
+ "summarization": {
84
+ "length_penalty": 1.0,
85
+ "max_length": 128,
86
+ "min_length": 12,
87
+ "num_beams": 4
88
+ },
89
+ "summarization_cnn": {
90
+ "length_penalty": 2.0,
91
+ "max_length": 142,
92
+ "min_length": 56,
93
+ "num_beams": 4
94
+ },
95
+ "summarization_xsum": {
96
+ "length_penalty": 1.0,
97
+ "max_length": 62,
98
+ "min_length": 11,
99
+ "num_beams": 6
100
+ }
101
+ },
102
+ "temperature": 1.0,
103
+ "tf_legacy_loss": false,
104
+ "tie_encoder_decoder": false,
105
+ "tie_word_embeddings": true,
106
+ "tokenizer_class": null,
107
+ "top_k": 50,
108
+ "top_p": 1.0,
109
+ "torch_dtype": "float32",
110
+ "torchscript": false,
111
+ "transformers_version": "4.21.0.dev0",
112
+ "typical_p": 1.0,
113
+ "use_bfloat16": false,
114
+ "use_cache": true,
115
+ "use_scan": true,
116
+ "vocab_size": 50265
117
+ },
118
+ "decoder_start_token_id": 0,
119
+ "encoder": {
120
+ "_name_or_path": "",
121
+ "activation_dropout": 0.2,
122
+ "adapter_kernel_size": 3,
123
+ "adapter_stride": 2,
124
+ "add_adapter": true,
125
+ "add_cross_attention": false,
126
+ "apply_spec_augment": true,
127
+ "architectures": [
128
+ "Wav2Vec2ForPreTraining"
129
+ ],
130
+ "attention_dropout": 0.1,
131
+ "bad_words_ids": null,
132
+ "bos_token_id": 1,
133
+ "chunk_size_feed_forward": 0,
134
+ "classifier_proj_size": 256,
135
+ "codevector_dim": 768,
136
+ "contrastive_logits_temperature": 0.1,
137
+ "conv_bias": true,
138
+ "conv_dim": [
139
+ 512,
140
+ 512,
141
+ 512,
142
+ 512,
143
+ 512,
144
+ 512,
145
+ 512
146
+ ],
147
+ "conv_kernel": [
148
+ 10,
149
+ 3,
150
+ 3,
151
+ 3,
152
+ 3,
153
+ 2,
154
+ 2
155
+ ],
156
+ "conv_stride": [
157
+ 5,
158
+ 2,
159
+ 2,
160
+ 2,
161
+ 2,
162
+ 2,
163
+ 2
164
+ ],
165
+ "cross_attention_hidden_size": null,
166
+ "ctc_loss_reduction": "sum",
167
+ "ctc_zero_infinity": false,
168
+ "decoder_start_token_id": null,
169
+ "diversity_loss_weight": 0.1,
170
+ "diversity_penalty": 0.0,
171
+ "do_sample": false,
172
+ "do_stable_layer_norm": true,
173
+ "early_stopping": false,
174
+ "encoder_no_repeat_ngram_size": 0,
175
+ "eos_token_id": 2,
176
+ "exponential_decay_length_penalty": null,
177
+ "feat_extract_activation": "gelu",
178
+ "feat_extract_dropout": 0.0,
179
+ "feat_extract_norm": "layer",
180
+ "feat_proj_dropout": 0.2,
181
+ "feat_quantizer_dropout": 0.0,
182
+ "final_dropout": 0.0,
183
+ "finetuning_task": null,
184
+ "forced_bos_token_id": null,
185
+ "forced_eos_token_id": null,
186
+ "fuse_matmuls": false,
187
+ "gradient_checkpointing": true,
188
+ "hidden_act": "gelu",
189
+ "hidden_dropout": 0.2,
190
+ "hidden_dropout_prob": 0.1,
191
+ "hidden_size": 1024,
192
+ "id2label": {
193
+ "0": "LABEL_0",
194
+ "1": "LABEL_1"
195
+ },
196
+ "initializer_range": 0.02,
197
+ "intermediate_size": 4096,
198
+ "is_decoder": false,
199
+ "is_encoder_decoder": false,
200
+ "label2id": {
201
+ "LABEL_0": 0,
202
+ "LABEL_1": 1
203
+ },
204
+ "layer_norm_eps": 1e-05,
205
+ "layerdrop": 0.0,
206
+ "length_penalty": 1.0,
207
+ "mask_feature_length": 10,
208
+ "mask_feature_min_masks": 0,
209
+ "mask_feature_prob": 0.0,
210
+ "mask_time_length": 10,
211
+ "mask_time_min_masks": 2,
212
+ "mask_time_prob": 0.1,
213
+ "max_length": 20,
214
+ "min_length": 0,
215
+ "model_type": "wav2vec2",
216
+ "no_repeat_ngram_size": 0,
217
+ "num_adapter_layers": 3,
218
+ "num_attention_heads": 16,
219
+ "num_beam_groups": 1,
220
+ "num_beams": 1,
221
+ "num_codevector_groups": 2,
222
+ "num_codevectors_per_group": 320,
223
+ "num_conv_pos_embedding_groups": 16,
224
+ "num_conv_pos_embeddings": 128,
225
+ "num_feat_extract_layers": 7,
226
+ "num_hidden_layers": 24,
227
+ "num_negatives": 100,
228
+ "num_return_sequences": 1,
229
+ "output_attentions": false,
230
+ "output_hidden_size": 1024,
231
+ "output_hidden_states": false,
232
+ "output_scores": false,
233
+ "pad_token_id": 0,
234
+ "prefix": null,
235
+ "problem_type": null,
236
+ "proj_codevector_dim": 768,
237
+ "pruned_heads": {},
238
+ "remove_invalid_values": false,
239
+ "repetition_penalty": 1.0,
240
+ "return_dict": true,
241
+ "return_dict_in_generate": false,
242
+ "sep_token_id": null,
243
+ "task_specific_params": null,
244
+ "tdnn_dilation": [
245
+ 1,
246
+ 2,
247
+ 3,
248
+ 1,
249
+ 1
250
+ ],
251
+ "tdnn_dim": [
252
+ 512,
253
+ 512,
254
+ 512,
255
+ 512,
256
+ 1500
257
+ ],
258
+ "tdnn_kernel": [
259
+ 5,
260
+ 3,
261
+ 3,
262
+ 1,
263
+ 1
264
+ ],
265
+ "temperature": 1.0,
266
+ "tf_legacy_loss": false,
267
+ "tie_encoder_decoder": false,
268
+ "tie_word_embeddings": true,
269
+ "tokenizer_class": null,
270
+ "top_k": 50,
271
+ "top_p": 1.0,
272
+ "torch_dtype": null,
273
+ "torchscript": false,
274
+ "transformers_version": "4.21.0.dev0",
275
+ "typical_p": 1.0,
276
+ "use_bfloat16": false,
277
+ "use_scan": true,
278
+ "use_weighted_layer_sum": false,
279
+ "vocab_size": 32,
280
+ "xvector_output_dim": 512
281
+ },
282
+ "eos_token_id": 2,
283
+ "is_encoder_decoder": true,
284
+ "max_length": 40,
285
+ "model_type": "speech-encoder-decoder",
286
+ "pad_token_id": 1,
287
+ "processor_class": "Wav2Vec2Processor",
288
+ "tie_word_embeddings": false,
289
+ "transformers_version": null,
290
+ "use_cache": false
291
+ }
flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb2940f024c5e43c87e730444eadbefe3090b57f6f0db63d22deae94fe2101f0
3
+ size 2353616717
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
preprocessor_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "feature_extractor_type": "Wav2Vec2FeatureExtractor",
4
+ "feature_size": 1,
5
+ "padding_side": "right",
6
+ "padding_value": 0.0,
7
+ "return_attention_mask": true,
8
+ "sampling_rate": 16000
9
+ }
run_earnings22.sh ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ python run_flax_speech_recognition_seq2seq.py \
3
+ --dataset_name="sanchit-gandhi/earnings22" \
4
+ --model_name_or_path="sanchit-gandhi/flax-wav2vec2-2-bart-large-scan" \
5
+ --dataset_config_name="all" \
6
+ --train_split_name="train" \
7
+ --eval_split_name="validation" \
8
+ --test_split_name="test" \
9
+ --text_column_name="sentence" \
10
+ --id_column_name="id" \
11
+ --output_dir="./flax-wav2vec2-2-bart-large-earnings22-black-box" \
12
+ --wandb_project="earnings22" \
13
+ --wandb_name="flax-wav2vec2-2-bart-large-earnings22-black-box" \
14
+ --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \
15
+ --per_device_train_batch_size="8" \
16
+ --per_device_eval_batch_size="4" \
17
+ --logging_steps="25" \
18
+ --max_steps="50000" \
19
+ --eval_steps="10000" \
20
+ --save_steps="10000" \
21
+ --generation_max_length="40" \
22
+ --generation_num_beams="1" \
23
+ --generation_length_penalty="1.2" \
24
+ --final_generation_max_length="200" \
25
+ --final_generation_num_beams="5" \
26
+ --learning_rate="1e-4" \
27
+ --warmup_steps="500" \
28
+ --do_lower_case="False" \
29
+ --hidden_dropout="0.2" \
30
+ --activation_dropout="0.2" \
31
+ --feat_proj_dropout="0.2" \
32
+ --overwrite_output_dir \
33
+ --gradient_checkpointing \
34
+ --freeze_feature_encoder \
35
+ --predict_with_generate \
36
+ --do_eval \
37
+ --do_train \
38
+ --do_predict \
39
+ --push_to_hub \
40
+ --use_auth_token
run_flax_speech_recognition_seq2seq.py ADDED
@@ -0,0 +1,1646 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2022 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the Flax library models for sequence to sequence speech recognition.
18
+ """
19
+ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
+
21
+ import logging
22
+ import math
23
+ import os
24
+ import re
25
+ import sys
26
+ import time
27
+ from dataclasses import dataclass, field
28
+ from pathlib import Path
29
+ from typing import Any, Callable, Dict, List, Optional, Union
30
+
31
+ import datasets
32
+ import numpy as np
33
+ from datasets import DatasetDict, load_dataset, load_metric
34
+ from tqdm import tqdm
35
+
36
+ import flax
37
+ import jax
38
+ import jax.numpy as jnp
39
+ import optax
40
+ import transformers
41
+ import wandb as wandb
42
+ from flax import core, jax_utils, struct, traverse_util
43
+ from flax.jax_utils import pad_shard_unpad, unreplicate
44
+ from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
45
+ from huggingface_hub import Repository
46
+ from models import FlaxSpeechEncoderDecoderModel
47
+ from optax._src import linear_algebra
48
+ from transformers import (
49
+ AutoConfig,
50
+ AutoFeatureExtractor,
51
+ AutoProcessor,
52
+ AutoTokenizer,
53
+ HfArgumentParser,
54
+ Seq2SeqTrainingArguments,
55
+ is_tensorboard_available,
56
+ )
57
+ from transformers.file_utils import get_full_repo_name
58
+ from transformers.trainer_utils import get_last_checkpoint
59
+ from transformers.utils import check_min_version
60
+ from transformers.utils.versions import require_version
61
+
62
+
63
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
64
+ check_min_version("4.17.0.dev0")
65
+
66
+ require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
67
+
68
+ logger = logging.getLogger(__name__)
69
+
70
+
71
+ @flax.struct.dataclass
72
+ class ModelArguments:
73
+ """
74
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
75
+ """
76
+
77
+ model_name_or_path: str = field(
78
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
79
+ )
80
+ config_name: Optional[str] = field(
81
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
82
+ )
83
+ tokenizer_name: Optional[str] = field(
84
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
85
+ )
86
+ feature_extractor_name: Optional[str] = field(
87
+ default=None, metadata={"help": "feature extractor name or path if not the same as model_name"}
88
+ )
89
+ cache_dir: Optional[str] = field(
90
+ default=None,
91
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
92
+ )
93
+ use_fast_tokenizer: bool = field(
94
+ default=True,
95
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
96
+ )
97
+ model_revision: str = field(
98
+ default="main",
99
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
100
+ )
101
+ use_auth_token: bool = field(
102
+ default=False,
103
+ metadata={
104
+ "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
105
+ "with private models)."
106
+ },
107
+ )
108
+ freeze_feature_encoder: bool = field(
109
+ default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
110
+ )
111
+ activation_dropout: float = field(
112
+ default=0.1,
113
+ metadata={
114
+ "help": "The hidden activation dropout probability in the embeddings, encoder, and pooler."
115
+ },
116
+ )
117
+ hidden_dropout: float = field(
118
+ default=0.1,
119
+ metadata={
120
+ "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler."
121
+ },
122
+ )
123
+ feat_proj_dropout: float = field(
124
+ default=0.0,
125
+ metadata={
126
+ "help": "The feat proj dropout probability for feature encoder representations."
127
+ },
128
+ )
129
+ mask_time_prob: float = field(
130
+ default=0.1,
131
+ metadata={
132
+ "help": "The spec aug dropout probability for feature encoder representations."
133
+ },
134
+ )
135
+ encoder_add_adapter: bool = field(
136
+ default=True, metadata={"help": "Whether to add an adapter layer between the encoder and decoder."}
137
+ )
138
+
139
+
140
+ @flax.struct.dataclass
141
+ class DataTrainingArguments:
142
+ """
143
+ Arguments pertaining to what data we are going to input our model for training and eval.
144
+ """
145
+
146
+ dataset_name: str = field(
147
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
148
+ )
149
+ dataset_config_name: Optional[str] = field(
150
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
151
+ )
152
+ text_column: Optional[str] = field(
153
+ default=None,
154
+ metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
155
+ )
156
+ dataset_cache_dir: Optional[str] = field(
157
+ default=None, metadata={"help": "Path to cache directory for saving and loading datasets"}
158
+ )
159
+ overwrite_cache: bool = field(
160
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
161
+ )
162
+ preprocessing_num_workers: Optional[int] = field(
163
+ default=None,
164
+ metadata={"help": "The number of processes to use for the preprocessing."},
165
+ )
166
+ max_train_samples: Optional[int] = field(
167
+ default=None,
168
+ metadata={
169
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
170
+ "value if set."
171
+ },
172
+ )
173
+ max_eval_samples: Optional[int] = field(
174
+ default=None,
175
+ metadata={
176
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
177
+ "value if set."
178
+ },
179
+ )
180
+ max_test_samples: Optional[int] = field(
181
+ default=None,
182
+ metadata={
183
+ "help": "For debugging purposes or quicker training, truncate the number of test examples to this "
184
+ "value if set."
185
+ },
186
+ )
187
+ audio_column_name: str = field(
188
+ default="audio",
189
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
190
+ )
191
+ text_column_name: str = field(
192
+ default="text",
193
+ metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
194
+ )
195
+ id_column_name: str = field(
196
+ default="id",
197
+ metadata={"help": "The name of the dataset column containing the id data. Defaults to 'id'"},
198
+ )
199
+ max_duration_in_seconds: float = field(
200
+ default=20.0,
201
+ metadata={
202
+ "help": "Filter audio files in the training set that are longer than `max_duration_in_seconds` seconds"
203
+ },
204
+ )
205
+ min_duration_in_seconds: float = field(
206
+ default=0.0, metadata={"help": "Filter audio files in the training set that are shorter than `min_duration_in_seconds` seconds"}
207
+ )
208
+ max_eval_duration_in_seconds: float = field(
209
+ default=None,
210
+ metadata={
211
+ "help": "Filter audio files in the eval/test set that are longer than `max_duration_in_seconds` seconds"
212
+ },
213
+ )
214
+ max_target_length: Optional[int] = field(
215
+ default=128,
216
+ metadata={
217
+ "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
218
+ "than this will be truncated, sequences shorter will be padded."
219
+ },
220
+ )
221
+ min_target_length: Optional[int] = field(
222
+ default=0,
223
+ metadata={
224
+ "help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
225
+ "than this will be filtered."
226
+ },
227
+ )
228
+ pad_input_to_multiple_of: Optional[int] = field(
229
+ default=24000,
230
+ metadata={
231
+ "help": "If set will pad the input sequence to a multiple of the provided value. "
232
+ "This is important to avoid triggering recompilations on TPU."
233
+ },
234
+ )
235
+ pad_target_to_multiple_of: Optional[int] = field(
236
+ default=None,
237
+ metadata={
238
+ "help": "If set will pad the target sequence to a multiple of the provided value. "
239
+ "This is important to avoid triggering recompilations on TPU. If unspecified, will default to `max_target_length`, "
240
+ " the equivalent of padding the targets to max length."
241
+ },
242
+ )
243
+ preprocessing_only: bool = field(
244
+ default=False,
245
+ metadata={
246
+ "help": "Whether to only do data preprocessing and skip training. "
247
+ "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
248
+ "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
249
+ "so that the cached datasets can consequently be loaded in distributed training"
250
+ },
251
+ )
252
+ train_split_name: str = field(
253
+ default="train",
254
+ metadata={
255
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
256
+ },
257
+ )
258
+ eval_split_name: str = field(
259
+ default="validation",
260
+ metadata={
261
+ "help": "The name of the evaluation data set split to use (via the datasets library). Defaults to 'validation'"
262
+ },
263
+ )
264
+ test_split_name: str = field(
265
+ default="test",
266
+ metadata={"help": "The name of the test data set split to use (via the datasets library). Defaults to 'test'"},
267
+ )
268
+ do_lower_case: bool = field(
269
+ default=True,
270
+ metadata={"help": "Whether the target text should be lower cased."},
271
+ )
272
+ wandb_project: str = field(
273
+ default="flax-speech-recognition-seq2seq",
274
+ metadata={"help": "The name of the wandb project."},
275
+ )
276
+ wandb_name: str = field(
277
+ default=None,
278
+ metadata={"help": "The name of the wandb run."},
279
+ )
280
+ wandb_job_type: str = field(
281
+ default="Seq2Seq",
282
+ metadata={"help": "The name of the wandb job type."},
283
+ )
284
+ log_first_ids: bool = field(
285
+ default=True,
286
+ metadata={
287
+ "help": "Whether to log the first id's from the dataset. Defaults to `True`. If `False`, will log the first id's returned by the grouped length sampler."
288
+ },
289
+ )
290
+
291
+
292
+ # @flax.struct.dataclass
293
+ @dataclass
294
+ class FlaxSeq2SeqTrainingArguments(Seq2SeqTrainingArguments):
295
+ precision: str = field(
296
+ default="full",
297
+ metadata={
298
+ "help": "Whether to enable mixed-precision training. If true, the optimizer is stored in half-precision (bfloat16) and computations are executed in half-precision"
299
+ "**Note that this only specifies the dtype of the computation and optimizer state. It does not influence the dtype of model parameters.**"
300
+ },
301
+ )
302
+ matmul_precision: str = field(
303
+ default="default",
304
+ metadata={
305
+ "help": "Default floating-point precision of internal computations used in TPU matrix multiplications and convolutions. "
306
+ "This configuration option controls the default precision for JAX operations that take an optional precision argument (e.g. `lax.conv_general_dilated` and `lax.dot`). "
307
+ "This configuration option does not change the behaviours of such calls with explicit precision arguments; "
308
+ "it only changes the behaviors of calls with no such argument provided. "
309
+ "One of `['highest', 'float32', 'high', 'bfloat16_3x', 'default', 'bfloat16', 'fastest', None]`."
310
+ },
311
+ )
312
+ generation_length_penalty: float = field(
313
+ default=1,
314
+ metadata={
315
+ "help": "Exponential penalty to the length. 1.0 (default) means no penalty. Set to values < 1.0 in order to encourage the model"
316
+ "to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer sequences."
317
+ },
318
+ )
319
+ final_generation_max_length: int = field(
320
+ default=None,
321
+ metadata={
322
+ "help": "The `max_length` to use on each evaluation loop when `predict_with_generate=True`. If unspecified, will default "
323
+ "to the `max_length` value of the model configuration."
324
+ },
325
+ )
326
+ final_generation_num_beams: int = field(
327
+ default=None,
328
+ metadata={
329
+ "help": "The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. If unspecified, will default "
330
+ "to the `num_beams` value of the model configuration."
331
+ },
332
+ )
333
+
334
+ def __post_init__(self):
335
+ if self.final_generation_max_length is None:
336
+ self.final_generation_max_length = self.generation_max_length
337
+ if self.final_generation_num_beams is None:
338
+ self.final_generation_num_beams = self.generation_num_beams
339
+
340
+
341
+ def to_fp32(t):
342
+ return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
343
+
344
+
345
+ def to_bf16(t):
346
+ return jax.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t)
347
+
348
+
349
+ class MixedPrecisionTrainState(struct.PyTreeNode):
350
+ """Train state for use with a single Optax optimizer.
351
+ Adapted from flax train_state https://github.com/google/flax/blob/main/flax/training/train_state.py
352
+
353
+ Synopsis::
354
+
355
+ state = TrainState.create(
356
+ apply_fn=model.apply,
357
+ params=variables['params'],
358
+ tx=tx)
359
+ grad_fn = jax.grad(make_loss_fn(state.apply_fn))
360
+ for batch in data:
361
+ grads = grad_fn(state.params, batch)
362
+ state = state.apply_gradients(grads=grads)
363
+
364
+ Args:
365
+ step: Counter starts at 0 and is incremented by every call to
366
+ `.apply_gradients()`.
367
+ apply_fn: Usually set to `model.apply()`. Kept in this dataclass for
368
+ convenience to have a shorter params list for the `train_step()` function
369
+ in your training loop.
370
+ params: The parameters to be updated by `tx` and used by `apply_fn`.
371
+ tx: An Optax gradient transformation.
372
+ opt_state: The state for `tx`.
373
+ dropout_rng: PRNG key for stochastic operations.
374
+ bf16: Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training.
375
+ """
376
+
377
+ step: int
378
+ apply_fn: Callable = struct.field(pytree_node=False)
379
+ params: core.FrozenDict[str, Any]
380
+ tx: optax.GradientTransformation = struct.field(pytree_node=False)
381
+ opt_state: optax.OptState
382
+ dropout_rng: jnp.ndarray
383
+ max_grad_norm: Optional[float] = 1.0
384
+
385
+ def apply_gradients(self, *, grads, to_dtype, **kwargs):
386
+ """Updates `step`, `params`, `opt_state` and `**kwargs` in return value.
387
+
388
+ Note that internally this function calls `.tx.update()` followed by a call
389
+ to `optax.apply_updates()` to update `params` and `opt_state`.
390
+
391
+ Args:
392
+ grads: Gradients that have the same pytree structure as `.params`.
393
+ **kwargs: Additional dataclass attributes that should be `.replace()`-ed.
394
+
395
+ Returns:
396
+ An updated instance of `self` with `step` incremented by one, `params`
397
+ and `opt_state` updated by applying `grads`, and additional attributes
398
+ replaced as specified by `kwargs`.
399
+ """
400
+
401
+ # clip gradients by global l2 norm
402
+ casted_max_grad_norm = to_dtype(self.max_grad_norm)
403
+ g_norm = linear_algebra.global_norm(grads)
404
+ g_norm = jnp.maximum(casted_max_grad_norm, g_norm)
405
+ grads = jax.tree_map(lambda t: (t / g_norm) * casted_max_grad_norm, grads)
406
+
407
+ # perform update step in fp32 and subsequently downcast optimizer states if mixed precision training
408
+ # grads and opt_state in bf16 (need to upcast), params in fp32 (leave as is)
409
+ updates, new_opt_state = self.tx.update(to_fp32(grads), to_fp32(self.opt_state), self.params)
410
+
411
+ new_params = optax.apply_updates(self.params, updates)
412
+ return self.replace(
413
+ step=self.step + 1,
414
+ params=new_params,
415
+ opt_state=to_dtype(new_opt_state),
416
+ **kwargs,
417
+ )
418
+
419
+ @classmethod
420
+ def create(cls, *, apply_fn, params, tx, to_dtype, **kwargs):
421
+ """Creates a new instance with `step=0` and initialized `opt_state`."""
422
+ # downcast optimizer state to bf16 if mixed-precision training
423
+ opt_state = tx.init(to_dtype(params)) if tx is not None else None
424
+ return cls(
425
+ step=0,
426
+ apply_fn=apply_fn,
427
+ params=params,
428
+ tx=tx,
429
+ opt_state=opt_state,
430
+ **kwargs,
431
+ )
432
+
433
+ def replicate(self):
434
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
435
+
436
+
437
+ def pad_to_max_length(data, tokenizer):
438
+ # Get lengths of each row of data
439
+ lens = np.array([len(i) for i in data])
440
+
441
+ # Mask of valid places in each row
442
+ mask = np.arange(lens.max()) < lens[:, None]
443
+
444
+ # Setup output array and put elements from data into masked positions
445
+ out = np.ones_like(mask, dtype=data.dtype) * tokenizer.pad_token_id
446
+ out[mask] = np.concatenate(data)
447
+ return out
448
+
449
+
450
+ def shift_tokens_right(label_ids: np.array, decoder_start_token_id: int) -> np.ndarray:
451
+ """
452
+ Shift label ids one token to the right.
453
+ """
454
+ shifted_label_ids = np.zeros_like(label_ids)
455
+ shifted_label_ids[:, 1:] = label_ids[:, :-1]
456
+ shifted_label_ids[:, 0] = decoder_start_token_id
457
+
458
+ return shifted_label_ids
459
+
460
+
461
+ @flax.struct.dataclass
462
+ class FlaxDataCollatorSpeechSeq2SeqWithPadding:
463
+ """
464
+ Data collator that will dynamically pad the inputs received.
465
+ Args:
466
+ processor ([`Wav2Vec2Processor`])
467
+ The processor used for proccessing the data.
468
+ decoder_start_token_id (:obj: `int`)
469
+ The begin-of-sentence of the decoder.
470
+ input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
471
+ Select a strategy to pad the returned input sequences (according to the model's padding side and padding index)
472
+ among:
473
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
474
+ sequence if provided).
475
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
476
+ maximum acceptable input length for the model if that argument is not provided.
477
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
478
+ different lengths).
479
+ target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
480
+ Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
481
+ See above for details.
482
+ max_input_length (:obj:`float`, `optional`):
483
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
484
+ max_target_length (:obj:`int`, `optional`):
485
+ Maximum length of the ``labels`` of the returned list and optionally padding length (see above).
486
+ pad_input_to_multiple_of (:obj:`int`, `optional`):
487
+ If set will pad the input sequence to a multiple of the provided value.
488
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
489
+ 7.5 (Volta).
490
+ pad_target_to_multiple_of (:obj:`int`, `optional`):
491
+ If set will pad the target sequence to a multiple of the provided value.
492
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
493
+ 7.5 (Volta).
494
+ """
495
+
496
+ processor: Any
497
+ decoder_start_token_id: int
498
+ input_padding: Union[bool, str] = "longest"
499
+ target_padding: Union[bool, str] = "max_length"
500
+ max_input_length: Optional[float] = None
501
+ max_target_length: Optional[int] = None
502
+ pad_input_to_multiple_of: Optional[int] = None
503
+ pad_target_to_multiple_of: Optional[int] = None
504
+
505
+ def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
506
+ # split inputs and labels since they have to be of different lengths and need
507
+ # different padding methods
508
+ input_features = [{"input_values": feature["input_values"]} for feature in features]
509
+ input_ids = [feature["input_id"] for feature in features]
510
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
511
+
512
+ # reformat list to dict and set to pytorch format
513
+ batch = self.processor.feature_extractor.pad(
514
+ input_features,
515
+ max_length=self.max_input_length,
516
+ padding=self.input_padding,
517
+ pad_to_multiple_of=self.pad_input_to_multiple_of,
518
+ return_tensors="np",
519
+ )
520
+
521
+ labels_batch = self.processor.tokenizer.pad(
522
+ label_features,
523
+ max_length=self.max_target_length,
524
+ padding=self.target_padding,
525
+ pad_to_multiple_of=self.pad_target_to_multiple_of,
526
+ return_tensors="np",
527
+ )
528
+
529
+ # if bos token is appended in previous tokenization step,
530
+ # cut bos token here as it's append later anyways
531
+ labels = labels_batch["input_ids"]
532
+ if (labels[:, 0] == self.decoder_start_token_id).all().item():
533
+ labels = labels[:, 1:]
534
+ labels_batch.attention_mask = labels_batch.attention_mask[:, 1:]
535
+
536
+ decoder_input_ids = shift_tokens_right(labels, self.decoder_start_token_id)
537
+
538
+ # replace padding with -100 to ignore correctly when computing the loss
539
+ labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1))
540
+ labels = labels.filled(fill_value=-100)
541
+
542
+ batch["inputs"] = batch.pop("input_values")
543
+ batch["input_ids"] = input_ids
544
+ batch["labels"] = labels
545
+ batch["decoder_input_ids"] = decoder_input_ids
546
+
547
+ return batch
548
+
549
+
550
+ def get_grouped_indices(
551
+ dataset, batch_size: int, rng: Optional[List[int]] = None, mega_batch_mult: Optional[int] = None
552
+ ) -> np.array:
553
+ """
554
+ Adapted from the `get_length_grouped_indices` function in the PyTorch Trainer utils file (https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_pt_utils.py#L486)
555
+ Function that returns a list of indices in which each slice of `batch_size` consecutive indices correspond to elements of similar
556
+ lengths. To do this, the indices are:
557
+
558
+ - randomly permuted (if a JAX rng is specified)
559
+ - grouped in mega-batches of size `mega_batch_mult * batch_size`
560
+ - sorted by length in each mega-batch
561
+
562
+ The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
563
+ maximum length placed first, so that an OOM happens sooner rather than later.
564
+ """
565
+ lengths = dataset["input_length"]
566
+
567
+ # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
568
+ if mega_batch_mult is None:
569
+ mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
570
+ # Just in case, for tiny datasets
571
+ if mega_batch_mult == 0:
572
+ mega_batch_mult = 1
573
+
574
+ # We need to use JAX for the random permutation as the PRNG key will be set based on the seed outside of the sampler.
575
+ num_samples = len(lengths)
576
+ indices = jax.random.permutation(rng, np.arange(num_samples)) if rng is not None else np.arange(num_samples)
577
+ indices = np.asarray(indices)
578
+
579
+ megabatch_size = mega_batch_mult * batch_size
580
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
581
+ megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches]
582
+
583
+ # The rest is to get the biggest batch first.
584
+ # Since each megabatch is sorted by descending length, the longest element is the first
585
+ megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
586
+ max_idx = np.argmax(megabatch_maximums).item()
587
+ # Switch to put the longest batch in first position
588
+ # (note that this is different to the PT grouped sampler in which we only put the longest element in the first position, and not its batch)
589
+ megabatches[0], megabatches[max_idx] = megabatches[max_idx], megabatches[0]
590
+
591
+ megabatches = np.array([i for megabatch in megabatches for i in megabatch])
592
+
593
+ return megabatches
594
+
595
+
596
+ def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last_batch=True) -> np.ndarray:
597
+ """Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
598
+ the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned."""
599
+ num_samples = len(samples_idx)
600
+ if drop_last_batch:
601
+ samples_to_remove = num_samples % batch_size
602
+ if samples_to_remove != 0:
603
+ samples_idx = samples_idx[:-samples_to_remove]
604
+ sections_split = num_samples // batch_size
605
+ samples_idx = samples_idx.reshape((sections_split, batch_size))
606
+ else:
607
+ sections_split = math.ceil(num_samples / batch_size)
608
+ samples_idx = np.array_split(samples_idx, sections_split)
609
+ return samples_idx
610
+
611
+
612
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
613
+ summary_writer.scalar("train_time", train_time, step)
614
+
615
+ train_metrics = get_metrics(train_metrics)
616
+ for key, vals in train_metrics.items():
617
+ tag = f"train_{key}"
618
+ for i, val in enumerate(vals):
619
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
620
+
621
+
622
+ def write_eval_metric(summary_writer, eval_metrics, step, pred_str=None):
623
+ for metric_name, value in eval_metrics.items():
624
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
625
+
626
+ if pred_str is not None:
627
+ # write output actual predictions for debugging
628
+ summary_writer.text("eval_predictions", "\n".join(pred_str), step)
629
+
630
+
631
+ def write_wandb_log(metrics, step, prefix=None):
632
+ if jax.process_index() == 0:
633
+ log_metrics = {}
634
+ for k, v in metrics.items():
635
+ if "layer" in k:
636
+ log_metrics[f"{k}/"] = v
637
+ elif prefix is not None:
638
+ log_metrics[f"{prefix}/{k}"] = v
639
+ else:
640
+ log_metrics[k] = v
641
+ wandb.log(log_metrics, step)
642
+
643
+
644
+ def write_wandb_pred(pred_str, label_str, eval_ids, step, prefix="eval", top_ids=None, final_step=True):
645
+ if jax.process_index() == 0:
646
+ top_ids = top_ids if top_ids else eval_ids
647
+ num_beams = len(pred_str)
648
+ # convert str data to a wandb compatible format
649
+ str_data = []
650
+ for id in top_ids:
651
+ if id in eval_ids:
652
+ idx = eval_ids.index(id)
653
+ str_data.append([eval_ids[idx], label_str[idx]] + [pred_str[beam][idx] for beam in range(num_beams)])
654
+ columns = ["id", "label_str"] + [f"beam_{i + 1}" for i in range(num_beams)]
655
+ wandb.log(
656
+ {f"{prefix}/step_{int(step / 1000)}k": wandb.Table(columns=columns, data=str_data[:50])},
657
+ step,
658
+ )
659
+ if final_step:
660
+ str_data = np.array(str_data)
661
+ wandb.log(
662
+ {f"{prefix}/step_{int(step / 1000)}k_all": wandb.Table(columns=columns, data=str_data[:200000])},
663
+ step,
664
+ )
665
+ str_data = str_data[str_data[:, 1] != str_data[:, 2]]
666
+ wandb.log(
667
+ {f"{prefix}/step_{int(step / 1000)}k_incorrect": wandb.Table(columns=columns, data=str_data[:200000])},
668
+ step,
669
+ )
670
+
671
+
672
+ def create_learning_rate_fn(
673
+ num_train_steps: int, num_warmup_steps: int, learning_rate: float
674
+ ) -> Callable[[int], jnp.array]:
675
+ """Returns a linear warmup, linear_decay learning rate function."""
676
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
677
+ decay_fn = optax.linear_schedule(
678
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
679
+ )
680
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
681
+ return schedule_fn
682
+
683
+
684
+ def main():
685
+ # 1. Parse input arguments
686
+ # See all possible arguments in src/transformers/training_args.py
687
+ # or by passing the --help flag to this script.
688
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
689
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FlaxSeq2SeqTrainingArguments))
690
+
691
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
692
+ # If we pass only one argument to the script and it's the path to a json file,
693
+ # let's parse it to get our arguments.
694
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
695
+ else:
696
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
697
+
698
+ # 2. Setup logging
699
+ # Make one log on every process with the configuration for debugging.
700
+ logging.basicConfig(
701
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
702
+ datefmt="%m/%d/%Y %H:%M:%S",
703
+ handlers=[logging.StreamHandler(sys.stdout)],
704
+ )
705
+ # Set the verbosity to info of the Transformers logger.
706
+ # We only want one process per machine to log things on the screen.
707
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
708
+ if jax.process_index() == 0:
709
+ datasets.utils.logging.set_verbosity_warning()
710
+ transformers.utils.logging.set_verbosity_info()
711
+ else:
712
+ datasets.utils.logging.set_verbosity_error()
713
+ transformers.utils.logging.set_verbosity_error()
714
+
715
+ # Set up wandb run
716
+ if jax.process_index() == 0:
717
+ wandb.init(project=data_args.wandb_project, name=data_args.wandb_name, job_type=data_args.wandb_job_type)
718
+
719
+ logger.info("Training/evaluation parameters %s", training_args)
720
+
721
+ # Set the default TPU matmul precision and display the number of devices
722
+ jax.config.update("jax_default_matmul_precision", training_args.matmul_precision)
723
+ logger.info(f"JAX devices: {jax.device_count()}, matmul precision: {training_args.matmul_precision}")
724
+
725
+ # TODO: 3. Detecting last checkpoint and eventually continue from last checkpoint
726
+ last_checkpoint = None
727
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
728
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
729
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
730
+ raise ValueError(
731
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
732
+ "Use --overwrite_output_dir to overcome."
733
+ )
734
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
735
+ logger.info(
736
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
737
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
738
+ )
739
+
740
+ # 4. Load dataset
741
+ raw_datasets = DatasetDict()
742
+
743
+ if training_args.do_train:
744
+ raw_datasets["train"] = load_dataset(
745
+ data_args.dataset_name,
746
+ data_args.dataset_config_name,
747
+ split=data_args.train_split_name,
748
+ cache_dir=data_args.dataset_cache_dir,
749
+ use_auth_token=True if model_args.use_auth_token else None,
750
+ )
751
+
752
+ if training_args.do_eval:
753
+ raw_datasets["eval"] = load_dataset(
754
+ data_args.dataset_name,
755
+ data_args.dataset_config_name,
756
+ split=data_args.eval_split_name,
757
+ cache_dir=data_args.dataset_cache_dir,
758
+ use_auth_token=True if model_args.use_auth_token else None,
759
+ )
760
+
761
+ if training_args.do_predict:
762
+ test_split = data_args.test_split_name.split("+")
763
+ for split in test_split:
764
+ raw_datasets[split] = load_dataset(
765
+ data_args.dataset_name,
766
+ data_args.dataset_config_name,
767
+ split=split,
768
+ cache_dir=data_args.dataset_cache_dir,
769
+ use_auth_token=True if model_args.use_auth_token else None,
770
+ )
771
+
772
+ if not training_args.do_train and not training_args.do_eval and not training_args.do_predict:
773
+ raise ValueError(
774
+ "Cannot not train, not do evaluation and not do prediction. At least one of "
775
+ "training, evaluation or prediction has to be done."
776
+ )
777
+
778
+ # if not training, there is no need to run multiple epochs
779
+ if not training_args.do_train:
780
+ training_args.num_train_epochs = 1
781
+
782
+ if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
783
+ raise ValueError(
784
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
785
+ "Make sure to set `--audio_column_name` to the correct audio column - one of "
786
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
787
+ )
788
+
789
+ if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names:
790
+ raise ValueError(
791
+ f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
792
+ "Make sure to set `--text_column_name` to the correct text column - one of "
793
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
794
+ )
795
+
796
+ if data_args.log_first_ids and data_args.id_column_name not in next(iter(raw_datasets.values())).column_names:
797
+ raise ValueError(
798
+ f"--id_column_name {data_args.id_column_name} not found in dataset '{data_args.dataset_name}'. "
799
+ "Make sure to set `--id_column_name` to the correct id column - one of "
800
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
801
+ )
802
+
803
+ # 5. Load pretrained model, tokenizer, and feature extractor
804
+ #
805
+ # Distributed training:
806
+ # The .from_pretrained methods guarantee that only one local process can concurrently
807
+ config = AutoConfig.from_pretrained(
808
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
809
+ cache_dir=model_args.cache_dir,
810
+ revision=model_args.model_revision,
811
+ use_auth_token=True if model_args.use_auth_token else None,
812
+ )
813
+
814
+ # update config according to training and model args
815
+ config.encoder.update(
816
+ {
817
+ "gradient_checkpointing": training_args.gradient_checkpointing,
818
+ "hidden_dropout": model_args.hidden_dropout,
819
+ "activation_dropout": model_args.activation_dropout,
820
+ "feat_proj_dropout": model_args.feat_proj_dropout,
821
+ "mask_time_prob": model_args.mask_time_prob,
822
+ "add_adapter": model_args.encoder_add_adapter,
823
+ }
824
+ )
825
+ config.decoder.update(
826
+ {
827
+ "gradient_checkpointing": training_args.gradient_checkpointing,
828
+ "dropout": model_args.hidden_dropout,
829
+ "activation_dropout": model_args.activation_dropout,
830
+ }
831
+ )
832
+
833
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
834
+ model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
835
+ cache_dir=model_args.cache_dir,
836
+ revision=model_args.model_revision,
837
+ use_auth_token=True if model_args.use_auth_token else None,
838
+ )
839
+ tokenizer = AutoTokenizer.from_pretrained(
840
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
841
+ cache_dir=model_args.cache_dir,
842
+ use_fast=model_args.use_fast_tokenizer,
843
+ revision=model_args.model_revision,
844
+ use_auth_token=True if model_args.use_auth_token else None,
845
+ )
846
+
847
+ if training_args.precision == "full_mixed":
848
+ dtype = jnp.bfloat16
849
+ training_args.mixed_precision = True
850
+ elif training_args.precision == "half_mixed":
851
+ dtype = jnp.bfloat16
852
+ training_args.mixed_precision = False
853
+ else:
854
+ dtype = jnp.float32
855
+ training_args.mixed_precision = False
856
+
857
+ model = FlaxSpeechEncoderDecoderModel.from_pretrained(
858
+ model_args.model_name_or_path,
859
+ config=config,
860
+ dtype=dtype,
861
+ cache_dir=model_args.cache_dir,
862
+ revision=model_args.model_revision,
863
+ use_auth_token=True if model_args.use_auth_token else None,
864
+ )
865
+
866
+ if model.config.decoder_start_token_id is None:
867
+ raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
868
+
869
+ # 6. Resample speech dataset ALWAYS
870
+ raw_datasets = raw_datasets.cast_column(
871
+ data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
872
+ )
873
+
874
+ # 7. Preprocessing the datasets.
875
+ # We need to read the audio files as arrays and tokenize the targets.
876
+ max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
877
+ min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate)
878
+ max_eval_input_length = int(data_args.max_eval_duration_in_seconds * feature_extractor.sampling_rate) if data_args.max_eval_duration_in_seconds else None
879
+ max_target_length = data_args.max_target_length
880
+ min_target_length = data_args.min_target_length
881
+ pad_input_to_multiple_of = data_args.pad_input_to_multiple_of
882
+ pad_target_to_multiple_of = data_args.pad_target_to_multiple_of
883
+ audio_column_name = data_args.audio_column_name
884
+ num_workers = data_args.preprocessing_num_workers
885
+ text_column_name = data_args.text_column_name
886
+ id_column_name = data_args.id_column_name
887
+ model_input_name = feature_extractor.model_input_names[0]
888
+ do_lower_case = data_args.do_lower_case
889
+ log_first_ids = data_args.log_first_ids
890
+ dataset_name = data_args.dataset_name
891
+ tedlium_contractions = [" 's", " 't", " 're", " 've", " 'm", " 'll", " 'd", " 'clock", " 'all"]
892
+ gigaspeech_punctuation = {" <comma>": ",", " <period>": ".", " <questionmark>": "?", " <exclamationpoint>": "!"}
893
+ gigaspeech_disfluencies = ["<other>", "<sil>"]
894
+ swb_disfluencies = ["[noise]", "[laughter]", "[silence]", "[vocalized-noise]", "<a_aside>", "<b_aside>", "<e_aside>",
895
+ "[laughter-", "_1", "[laugh]", "[sigh]", "[cough]", "[mn]", "[breath]", "[lipsmack]",
896
+ "[sneeze]", "[skip]", "[pause]", "(%hesitation)", "(%HESITATION)"]
897
+ swb_punctuations = ["{", "}", "[", "]-", "]", "((", "))", "(", ")"]
898
+ earnings_disfluencies = ["<noise>", "<crosstalk>", "<affirmative>", "<inaudible>", "inaudible", "<laugh>"]
899
+ ignore_segments = ["ignore_time_segment_in_scoring", "<noise>", "<music>", "[noise]", "[laughter]", "[silence]",
900
+ "[vocalized-noise]", "<crosstalk>", "<affirmative>", "<inaudible>", "<laugh>", "<other>", "<sil>", ""]
901
+ ignore_segments += swb_disfluencies
902
+
903
+ if training_args.do_train and data_args.max_train_samples is not None:
904
+ raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
905
+
906
+ if training_args.do_eval and data_args.max_eval_samples is not None:
907
+ raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
908
+
909
+ if training_args.do_predict and data_args.max_test_samples is not None:
910
+ for split in test_split:
911
+ raw_datasets[split] = raw_datasets[split].select(range(data_args.max_eval_samples))
912
+
913
+ # filter data where the targets are ignored in scoring
914
+ def is_target_labels(input_str):
915
+ return input_str.lower() not in ignore_segments
916
+
917
+ raw_datasets = raw_datasets.filter(
918
+ is_target_labels,
919
+ num_proc=num_workers,
920
+ input_columns=[text_column_name],
921
+ desc="filtering data where the targets are ignored in scoring",
922
+ )
923
+
924
+ def prepare_dataset(batch):
925
+ # Pre-process audio
926
+ try:
927
+ sample = batch[audio_column_name]
928
+ except ValueError:
929
+ # E22: some samples are empty (no audio). Reading the empty audio array will trigger
930
+ # a soundfile ValueError. For now, we'll manually set these arrays to a zero array.
931
+ # They will be filtered in the subsequent filtering stage and so are
932
+ # explicitly ignored during training.
933
+ sample = {"array": np.array([0.]), "sampling_rate": feature_extractor.sampling_rate}
934
+
935
+ # normalise audio (mean, std) to (0, 1)
936
+ inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
937
+ # process audio length
938
+ batch[model_input_name] = inputs.input_values[0]
939
+ batch["input_length"] = len(batch["input_values"])
940
+ batch["input_id"] = batch[id_column_name] if log_first_ids else None
941
+
942
+ # 'Error correction' of targets
943
+ input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
944
+
945
+ # LibriSpeech ASR
946
+ if dataset_name == "librispeech_asr":
947
+ pass # no error correction necessary
948
+
949
+ # VoxPopuli
950
+ if dataset_name == "google/xtreme_s":
951
+ pass # no error correction necessary
952
+
953
+ # Common Voice 9
954
+ if dataset_name == "mozilla-foundation/common_voice_9_0":
955
+ if input_str.startswith('"') and input_str.endswith('"'):
956
+ # we can remove trailing quotation marks as they do not affect the transcription
957
+ input_str = input_str[1:-1]
958
+ # replace double quotation marks with single
959
+ input_str = input_str.replace('""', '"')
960
+
961
+ # TED-LIUM (Release 3)
962
+ if dataset_name == "LIUM/tedlium":
963
+ # delete the <unk> token from the text
964
+ input_str = input_str.replace("<unk>", "")
965
+ # replace spaced apostrophes with un-spaced (it 's -> it's)
966
+ for contraction in tedlium_contractions:
967
+ input_str = input_str.replace(contraction, contraction[1:])
968
+
969
+ # GigaSpeech
970
+ if dataset_name == "speechcolab/gigaspeech":
971
+ for disfluency in gigaspeech_disfluencies:
972
+ input_str = input_str.replace(disfluency, "")
973
+ # convert spelled out punctuation to symbolic form
974
+ for punctuation, replacement in gigaspeech_punctuation.items():
975
+ input_str = input_str.replace(punctuation, replacement)
976
+
977
+ # SWB: hide the path to the private HF dataset
978
+ if "switchboard" in dataset_name:
979
+ # In one conversation people speak some German phrases that are tagged as
980
+ # <german (( ja wohl )) > -- we remove these
981
+ input_str = re.sub("<[^>]*>", "", input_str)
982
+
983
+ # Remove junk tokens
984
+ for disfluency in swb_disfluencies:
985
+ input_str = input_str.replace(disfluency, "")
986
+
987
+ # Replace partially pronounced words (square brackets + hyphen): westmin[ster]- to westmin- or -[go]ing to -ing
988
+ # Replace anomalous words (square brackets + backslack): [lemguini/linguini] to linguini
989
+ # Replace the combo of the two: [lem[guini]-/linguini] to lem-
990
+ # Example: we [ah/are] -[go]ing to westmin[ster]- for [lem[guini]-/linguini]
991
+ # Target: we ah -ing to westmin- for lem-
992
+ # Treat anomalous words first then destroy the content of all square brackets (partially pronounced words)
993
+
994
+ # First treat partially pronounced anomalous words by removing correct word: [lem[guini]-/linguini] to [lem[guini]-
995
+ input_str = re.sub(r"\-\/.*?\]", "-", input_str)
996
+
997
+ # Now replace anomalous words with their correct transcriptions: [lemguini/linguini] to linguini
998
+ split_str = input_str.split("/")
999
+ if len(split_str) > 1:
1000
+ input_str = " ".join(
1001
+ [" ".join([" ".join(i.split(" ")[:-1]) for i in split_str])] + [split_str[-1].split(" ")[-1]])
1002
+
1003
+ # Remove the trailing brackets on the start/end of words
1004
+ processed_str = []
1005
+ for word in input_str.split():
1006
+ if word[0] == "[":
1007
+ processed_str.append(word[1:])
1008
+ elif word[-1] == "]":
1009
+ processed_str.append(word[:-1])
1010
+ else:
1011
+ processed_str.append(word)
1012
+
1013
+ # Stick the processed words back together
1014
+ input_str = " ".join(processed_str)
1015
+
1016
+ # Now we can remove all words in square brackets: -[go]ing to -ing
1017
+ input_str = re.sub(r"\-\[(.*?)\]", "-", input_str)
1018
+
1019
+ # westmin[ster]- to westmin-
1020
+ input_str = re.sub(r"\[(.*?)\]\-", "-", input_str)
1021
+
1022
+ # tech[n]ology to tech-ology
1023
+ input_str = re.sub(r"\[(.*?)\]", "-", input_str)
1024
+
1025
+ # partially pronounced words are now done!
1026
+ # remove erroneous punctuations (curly braces, trailing square brackets, etc.)
1027
+ for punctuation in swb_punctuations:
1028
+ input_str = input_str.replace(punctuation, "")
1029
+
1030
+ # Earnings 22: still figuring out best segmenting method. Thus, dataset name subject to change
1031
+ if "earnings22" in dataset_name:
1032
+ # Remove the 100ms offset at the end of the sample
1033
+ sampling_rate = sample["sampling_rate"]
1034
+ offset = int(100 * (10 ** -3) * sampling_rate)
1035
+ batch["input_ids"] = sample["array"][:-offset]
1036
+ batch["input_lengths"] = len(batch["input_ids"])
1037
+ # Remove junk tokens
1038
+ for disfluency in earnings_disfluencies:
1039
+ input_str = input_str.replace(disfluency, "")
1040
+
1041
+ # SPGISpeech
1042
+ if dataset_name == "kensho/spgispeech":
1043
+ pass # no error correction necessary
1044
+
1045
+ # JIWER compliance (for WER/CER calc.)
1046
+ # remove multiple spaces
1047
+ input_str = re.sub(r"\s\s+", " ", input_str)
1048
+ # strip trailing spaces
1049
+ input_str = input_str.strip()
1050
+
1051
+ # Finally, we tokenize the processed text
1052
+ batch["labels"] = tokenizer(input_str).input_ids
1053
+ batch["labels_length"] = len(batch["labels"])
1054
+ return batch
1055
+
1056
+ vectorized_datasets = raw_datasets.map(
1057
+ prepare_dataset,
1058
+ remove_columns=next(iter(raw_datasets.values())).column_names,
1059
+ num_proc=num_workers,
1060
+ desc="preprocess train dataset",
1061
+ )
1062
+
1063
+ # filter training data with inputs longer than max_input_length
1064
+ def is_audio_in_length_range(length):
1065
+ return min_input_length < length < max_input_length
1066
+
1067
+ if training_args.do_train:
1068
+ vectorized_datasets["train"] = vectorized_datasets["train"].filter(
1069
+ is_audio_in_length_range,
1070
+ num_proc=num_workers,
1071
+ input_columns=["input_length"],
1072
+ )
1073
+
1074
+ if max_eval_input_length is not None:
1075
+ # filter training data with inputs longer than max_input_length
1076
+ def is_eval_audio_in_length_range(length):
1077
+ return min_input_length < length < max_eval_input_length
1078
+
1079
+ if training_args.do_eval:
1080
+ vectorized_datasets["eval"] = vectorized_datasets["eval"].filter(
1081
+ is_eval_audio_in_length_range,
1082
+ num_proc=num_workers,
1083
+ input_columns=["input_length"],
1084
+ )
1085
+
1086
+ if training_args.do_test:
1087
+ for split in test_split:
1088
+ vectorized_datasets[split] = vectorized_datasets[split].filter(
1089
+ is_eval_audio_in_length_range,
1090
+ num_proc=num_workers,
1091
+ input_columns=["input_length"],
1092
+ )
1093
+
1094
+ # filter data with targets shorter than min_target_length or longer than max_target_length
1095
+ def is_labels_in_length_range(length):
1096
+ return min_target_length < length < max_target_length
1097
+
1098
+ if training_args.do_train:
1099
+ vectorized_datasets["train"] = vectorized_datasets["train"].filter(
1100
+ is_labels_in_length_range,
1101
+ num_proc=num_workers,
1102
+ input_columns=["labels_length"],
1103
+ )
1104
+
1105
+ # filter data with targets shorter than 2 tokens: <s></s> -> empty sentences
1106
+ def is_labels_greater_than_min(length):
1107
+ return length > 2
1108
+
1109
+ vectorized_datasets = vectorized_datasets.filter(
1110
+ is_labels_greater_than_min,
1111
+ num_proc=num_workers,
1112
+ input_columns=["labels_length"],
1113
+ )
1114
+
1115
+ # for large datasets it is advised to run the preprocessing on a
1116
+ # single machine first with `args.preprocessing_only` since there will mostly likely
1117
+ # be a timeout when running the script in distributed mode.
1118
+ # In a second step `args.preprocessing_only` can then be set to `False` to load the
1119
+ # cached dataset
1120
+ if data_args.preprocessing_only:
1121
+ cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
1122
+ logger.info(f"Data preprocessing finished. Files cached at {cache}.")
1123
+ return
1124
+
1125
+ # 8. Load Metrics
1126
+ wer_metric = load_metric("wer")
1127
+ cer_metric = load_metric("cer")
1128
+
1129
+ def compute_metrics(pred_ids: List[List[int]], label_ids: List[List[int]]):
1130
+ label_ids = (
1131
+ pad_to_max_length(np.array(label_ids, dtype="object"), tokenizer)
1132
+ if pad_target_to_multiple_of
1133
+ else label_ids
1134
+ )
1135
+
1136
+ padded_ids = np.where(np.asarray(label_ids) == -100, tokenizer.pad_token_id, np.asarray(label_ids))
1137
+ # we do not want to group tokens when computing the metrics
1138
+ label_str = tokenizer.batch_decode(padded_ids, skip_special_tokens=True)
1139
+
1140
+ pred_ids = np.array(pred_ids)
1141
+ num_beams = pred_ids.shape[1]
1142
+ # decode on a beam-by-beam basis
1143
+ pred_str = [
1144
+ tokenizer.batch_decode(pred_ids[:, beam, :], skip_special_tokens=True)
1145
+ for beam in reversed(range(num_beams))
1146
+ ]
1147
+ # compute word/character error rate for top beam
1148
+ wer = wer_metric.compute(predictions=pred_str[0], references=label_str)
1149
+ cer = cer_metric.compute(predictions=pred_str[0], references=label_str)
1150
+
1151
+ return {"wer": wer, "cer": cer}, pred_str, label_str
1152
+
1153
+ # 9. Save feature extractor, tokenizer and config
1154
+ feature_extractor.save_pretrained(training_args.output_dir)
1155
+ tokenizer.save_pretrained(training_args.output_dir)
1156
+ config.save_pretrained(training_args.output_dir)
1157
+
1158
+ processor = AutoProcessor.from_pretrained(training_args.output_dir)
1159
+
1160
+ data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding(
1161
+ processor=processor,
1162
+ decoder_start_token_id=model.config.decoder_start_token_id,
1163
+ input_padding="longest",
1164
+ target_padding="longest",
1165
+ max_target_length=max_target_length,
1166
+ pad_input_to_multiple_of=pad_input_to_multiple_of,
1167
+ pad_target_to_multiple_of=pad_target_to_multiple_of if pad_target_to_multiple_of else max_target_length,
1168
+ )
1169
+
1170
+ # Enable tensorboard only on the master node
1171
+ has_tensorboard = is_tensorboard_available()
1172
+ if has_tensorboard and jax.process_index() == 0:
1173
+ try:
1174
+ from flax.metrics.tensorboard import SummaryWriter
1175
+
1176
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
1177
+ except ImportError as ie:
1178
+ has_tensorboard = False
1179
+ logger.warning(
1180
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
1181
+ )
1182
+ else:
1183
+ logger.warning(
1184
+ "Unable to display metrics through TensorBoard because the package is not installed: "
1185
+ "Please run `pip install tensorboard` to enable."
1186
+ )
1187
+
1188
+ # 10. Handle the repository creation
1189
+ if training_args.push_to_hub:
1190
+ with open(os.path.join(training_args.output_dir, ".gitattributes"), "r+") as f:
1191
+ git_lfs_extensions = f.read()
1192
+ if "*.wandb" not in git_lfs_extensions:
1193
+ f.write("*.wandb filter=lfs diff=lfs merge=lfs -text")
1194
+ if training_args.hub_model_id is None:
1195
+ repo_name = get_full_repo_name(
1196
+ Path(training_args.output_dir).absolute().name, token=training_args.hub_token
1197
+ )
1198
+ else:
1199
+ repo_name = training_args.hub_model_id
1200
+ repo = Repository(training_args.output_dir, clone_from=repo_name)
1201
+
1202
+ # 11. Initialize our training
1203
+ rng = jax.random.PRNGKey(training_args.seed)
1204
+ rng, dropout_rng = jax.random.split(rng)
1205
+
1206
+ # Store some constants
1207
+ max_steps = int(training_args.max_steps)
1208
+ gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
1209
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
1210
+ batch_size_per_update = train_batch_size * gradient_accumulation_steps
1211
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
1212
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
1213
+ to_dtype = to_bf16 if training_args.mixed_precision else to_fp32
1214
+
1215
+ if training_args.do_train:
1216
+ num_train_samples = len(vectorized_datasets["train"])
1217
+ steps_per_epoch = num_train_samples // batch_size_per_update
1218
+ if max_steps > 0:
1219
+ num_epochs = -(training_args.max_steps // -steps_per_epoch)
1220
+ total_train_steps = max_steps
1221
+ else:
1222
+ num_epochs = int(training_args.num_train_epochs)
1223
+ total_train_steps = steps_per_epoch * num_epochs
1224
+
1225
+ # Create learning rate schedule
1226
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
1227
+ total_train_steps,
1228
+ training_args.warmup_steps,
1229
+ training_args.learning_rate,
1230
+ )
1231
+
1232
+ # We use Optax's "masking" functionality to not apply weight decay
1233
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
1234
+ # mask boolean with the same structure as the parameters.
1235
+ # The mask is True for parameters that should be decayed.
1236
+ # Note that this mask is specifically adapted for FlaxWav2Vec2 and FlaxBart.
1237
+ # For FlaxT5, one should correct the layer norm parameter naming
1238
+ # accordingly - see `run_t5_mlm_flax.py` e.g.
1239
+ def decay_mask_fn(params):
1240
+ flat_params = traverse_util.flatten_dict(params)
1241
+ layer_norm_params = [
1242
+ (name, "scale")
1243
+ for name in ["layer_norm", "self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
1244
+ ]
1245
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
1246
+ return traverse_util.unflatten_dict(flat_mask)
1247
+
1248
+ if training_args.adafactor:
1249
+ # Create Adafactor optimizer
1250
+ optim = optax.adafactor(
1251
+ learning_rate=linear_decay_lr_schedule_fn,
1252
+ dtype_momentum=jnp.bfloat16 if training_args.mixed_precision else jnp.float32,
1253
+ weight_decay_rate=training_args.weight_decay,
1254
+ weight_decay_mask=decay_mask_fn,
1255
+ )
1256
+ else:
1257
+ # Create AdamW optimizer
1258
+ optim = optax.adamw(
1259
+ learning_rate=linear_decay_lr_schedule_fn,
1260
+ b1=training_args.adam_beta1,
1261
+ b2=training_args.adam_beta2,
1262
+ eps=training_args.adam_epsilon,
1263
+ weight_decay=training_args.weight_decay,
1264
+ mask=decay_mask_fn,
1265
+ )
1266
+ else:
1267
+ num_epochs = 0
1268
+ total_train_steps = 0
1269
+ num_train_samples = 0
1270
+ optim = None
1271
+
1272
+ # Setup train state
1273
+ state = MixedPrecisionTrainState.create(
1274
+ apply_fn=model.__call__,
1275
+ params=model.params,
1276
+ tx=optim,
1277
+ to_dtype=to_dtype,
1278
+ dropout_rng=dropout_rng,
1279
+ max_grad_norm=training_args.max_grad_norm,
1280
+ )
1281
+
1282
+ # Cross entropy loss
1283
+ def loss_fn(logits, labels):
1284
+ vocab_size = logits.shape[-1]
1285
+ # optax onehot always returns a float32 device array, need to downcast if performing mixed precision training
1286
+ onehot_targets = to_dtype(onehot(labels, vocab_size))
1287
+ loss = optax.softmax_cross_entropy(logits, onehot_targets)
1288
+ # ignore padded tokens from loss, i.e. where labels are not set to -100
1289
+ padding = labels >= 0
1290
+ loss = loss * padding
1291
+ loss = loss.sum()
1292
+ num_labels = padding.sum()
1293
+ return loss, num_labels
1294
+
1295
+ # Define gradient update step fn
1296
+ def train_step(state, batch):
1297
+ # only one single rng per grad step, with or without accumulation, as the graph should be identical over one effective training batch
1298
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
1299
+
1300
+ def compute_loss(params, minibatch):
1301
+ labels = minibatch.pop("labels")
1302
+ logits = state.apply_fn(
1303
+ **minibatch,
1304
+ params=params,
1305
+ dropout_rng=dropout_rng,
1306
+ freeze_feature_encoder=model_args.freeze_feature_encoder,
1307
+ train=True,
1308
+ )[0]
1309
+ loss, num_labels = loss_fn(logits, labels)
1310
+ return loss, num_labels
1311
+
1312
+ grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
1313
+
1314
+ if gradient_accumulation_steps == 1:
1315
+ (loss, num_labels), grad = grad_fn(to_dtype(state.params), batch)
1316
+
1317
+ # Custom gradient accumulation
1318
+ else:
1319
+ # add a first dimension over gradient_accumulation_steps for minibatch slices
1320
+ batch = jax.tree_map(
1321
+ lambda x: x.reshape(
1322
+ gradient_accumulation_steps, training_args.per_device_train_batch_size, *x.shape[1::]
1323
+ ),
1324
+ batch,
1325
+ )
1326
+
1327
+ def accum_minibatch_step(accum_grad, minibatch):
1328
+ # compute loss, num labels and grad over minibatch and accumulate
1329
+ (loss, num_labels), grad = grad_fn(to_dtype(state.params), minibatch)
1330
+ return jax.tree_map(jnp.add, accum_grad, grad), (loss, num_labels)
1331
+
1332
+ # create an initial state for accumulating losses, num labels and gradients
1333
+ init_grad = jax.tree_map(jnp.zeros_like, to_dtype(state.params))
1334
+ # loop accum minibatch step over the number of gradient accumulation steps
1335
+ grad, (loss, num_labels) = jax.lax.scan(accum_minibatch_step, init_grad, batch)
1336
+
1337
+ grad = jax.lax.psum(grad, "batch")
1338
+ loss = jax.lax.psum(loss.sum(), "batch")
1339
+ total_samples = jax.lax.psum(num_labels.sum(), "batch")
1340
+ grad = jax.tree_map(lambda g: g / total_samples, grad)
1341
+ loss = jax.tree_map(lambda l: l / total_samples, loss)
1342
+
1343
+ # update state
1344
+ new_state = state.apply_gradients(
1345
+ grads=grad,
1346
+ dropout_rng=new_dropout_rng,
1347
+ to_dtype=to_dtype,
1348
+ )
1349
+
1350
+ # compute gradient norms over all layers, total encoder, total decoder and global for detailed monitoring
1351
+ layer_grad_norm = jax.tree_map(jnp.linalg.norm, grad)
1352
+ logs = {
1353
+ "layer_grad_norm": layer_grad_norm,
1354
+ "encoder_grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm["encoder"])),
1355
+ "decoder_grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm["decoder"])),
1356
+ }
1357
+ logs["grad_norm"] = jnp.linalg.norm([logs["encoder_grad_norm"], logs["decoder_grad_norm"]])
1358
+
1359
+ # compute parameter norms over all layers, total encoder, total decoder and global for detailed monitoring
1360
+ layer_param_norm = jax.tree_map(jnp.linalg.norm, new_state.params)
1361
+ logs["layer_param_norm"] = layer_param_norm
1362
+ logs["encoder_param_norm"] = jnp.linalg.norm(jax.tree_util.tree_leaves(layer_param_norm["encoder"]))
1363
+ logs["decoder_param_norm"] = jnp.linalg.norm(jax.tree_util.tree_leaves(layer_param_norm["decoder"]))
1364
+ logs["param_norm"] = jnp.linalg.norm([logs["encoder_param_norm"], logs["decoder_param_norm"]])
1365
+
1366
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
1367
+ metrics.update(logs)
1368
+
1369
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1370
+ # metrics = to_fp32(metrics)
1371
+
1372
+ return new_state, metrics
1373
+
1374
+ # Define eval fn
1375
+ def eval_step(params, batch):
1376
+ labels = batch.pop("labels")
1377
+ logits = model(**batch, params=params, train=False)[0]
1378
+ loss, num_labels = loss_fn(logits, labels)
1379
+
1380
+ total_samples = jax.lax.psum(num_labels, "batch")
1381
+ loss = jax.lax.psum(loss, "batch")
1382
+ loss = jax.tree_map(lambda l: l / total_samples, loss)
1383
+
1384
+ # summarize metrics
1385
+ metrics = {"loss": loss}
1386
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1387
+ # metrics = to_fp32(metrics)
1388
+ return metrics
1389
+
1390
+ # Define generation function
1391
+ gen_kwargs = {
1392
+ "max_length": training_args.generation_max_length,
1393
+ "num_beams": training_args.generation_num_beams,
1394
+ "length_penalty": training_args.generation_length_penalty,
1395
+ }
1396
+ final_gen_kwargs = {
1397
+ "max_length": training_args.final_generation_max_length,
1398
+ "num_beams": training_args.final_generation_num_beams,
1399
+ "length_penalty": training_args.generation_length_penalty,
1400
+ }
1401
+
1402
+ def generate_step(params, batch):
1403
+ model.params = params
1404
+ output_ids = model.generate(batch["inputs"], **gen_kwargs)
1405
+ return output_ids.sequences
1406
+
1407
+ def final_generate_step(params, batch):
1408
+ model.params = params
1409
+ output_ids = model.generate(batch["inputs"], **final_gen_kwargs)
1410
+ return output_ids.sequences
1411
+
1412
+ # Create parallel version of the train and eval step
1413
+ if training_args.do_train:
1414
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
1415
+
1416
+ if training_args.do_eval or training_args.do_predict:
1417
+ p_eval_step = jax.pmap(eval_step, "batch")
1418
+
1419
+ if training_args.predict_with_generate:
1420
+ p_generate_step = jax.pmap(generate_step, "batch")
1421
+ p_final_generate_step = jax.pmap(final_generate_step, "batch")
1422
+
1423
+ def run_evaluation(step, final_step=False):
1424
+ if training_args.do_eval:
1425
+ # ======================== Evaluating ==============================
1426
+ eval_metrics = []
1427
+ eval_preds = []
1428
+ eval_ids = []
1429
+ eval_labels = []
1430
+
1431
+ # Generate eval set by sequentially sampling indices from the eval dataset and grouping by length
1432
+ eval_samples_idx = get_grouped_indices(vectorized_datasets["eval"], eval_batch_size)
1433
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last_batch=False)
1434
+
1435
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
1436
+ samples = [vectorized_datasets["eval"][int(idx)] for idx in batch_idx]
1437
+ batch = data_collator(samples)
1438
+ eval_ids.extend(batch.pop("input_ids"))
1439
+ labels = batch["labels"]
1440
+
1441
+ metrics = pad_shard_unpad(p_eval_step, static_return=True)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1442
+ eval_metrics.append(metrics)
1443
+
1444
+ # generation
1445
+ if training_args.predict_with_generate:
1446
+ if not final_step:
1447
+ generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1448
+ eval_preds.extend(
1449
+ jax.device_get(
1450
+ generated_ids.reshape(-1, gen_kwargs["num_beams"], gen_kwargs["max_length"])
1451
+ )
1452
+ )
1453
+ else:
1454
+ generated_ids = pad_shard_unpad(p_final_generate_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1455
+ eval_preds.extend(
1456
+ jax.device_get(
1457
+ generated_ids.reshape(
1458
+ -1, final_gen_kwargs["num_beams"], final_gen_kwargs["max_length"]
1459
+ )
1460
+ )
1461
+ )
1462
+ eval_labels.extend(labels)
1463
+
1464
+ # normalize eval metrics
1465
+ eval_metrics = get_metrics(eval_metrics)
1466
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
1467
+ eval_metrics = to_fp32(eval_metrics)
1468
+
1469
+ # compute error rate metric and get predicted string (for debugging)
1470
+ error_rate_desc = ""
1471
+ pred_str = []
1472
+ label_str = []
1473
+ if training_args.predict_with_generate:
1474
+ error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels)
1475
+ eval_metrics.update(error_rate_metric)
1476
+ error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()])
1477
+
1478
+ # Print metrics and update progress bar
1479
+ desc = f"Step... ({step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})"
1480
+ epochs.write(desc)
1481
+ epochs.desc = desc
1482
+
1483
+ # Save metrics
1484
+ write_wandb_log(eval_metrics, step, prefix="eval")
1485
+ write_wandb_pred(
1486
+ pred_str,
1487
+ label_str,
1488
+ eval_ids,
1489
+ step,
1490
+ top_ids=vectorized_datasets["eval"]["input_id"] if data_args.log_first_ids else None,
1491
+ final_step=final_step,
1492
+ )
1493
+ # if has_tensorboard and jax.process_index() == 0:
1494
+ # write_eval_metric(summary_writer, eval_metrics, step, pred_str=pred_str)
1495
+
1496
+ def save_checkpoint(step):
1497
+ # save and push checkpoint to the hub
1498
+ if jax.process_index() == 0:
1499
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
1500
+ model.save_pretrained(training_args.output_dir, params=params)
1501
+ tokenizer.save_pretrained(training_args.output_dir)
1502
+ if training_args.push_to_hub:
1503
+ repo.push_to_hub(commit_message=f"{wandb.run.id}: saving weights and logs of step {int(step / 1000)}k", blocking=False)
1504
+
1505
+ # Replicate the train state on each device
1506
+ state = state.replicate()
1507
+
1508
+ logger.info("***** Running training *****")
1509
+ logger.info(f" Num examples = {num_train_samples}")
1510
+ logger.info(f" Num Epochs = {num_epochs}")
1511
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
1512
+ logger.info(f" Num gradient accumulation steps = {gradient_accumulation_steps}")
1513
+ logger.info(f" Total train batch size (w. parallel & distributed) = {batch_size_per_update}")
1514
+ logger.info(f" Total optimization steps = {total_train_steps}")
1515
+ logger.info(f" Gradient checkpointing: {config.encoder.gradient_checkpointing}")
1516
+ logger.info(f" Use scan: {config.encoder.use_scan}")
1517
+ logger.info(f" Fuse matmuls: {config.encoder.fuse_matmuls}")
1518
+
1519
+ train_time = cur_step = 0
1520
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
1521
+ for epoch in epochs:
1522
+ if training_args.do_train:
1523
+ # ======================== Training ================================
1524
+ train_start = time.time()
1525
+
1526
+ # Create sampling rng
1527
+ rng, input_rng = jax.random.split(rng)
1528
+
1529
+ # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length
1530
+ train_samples_idx = get_grouped_indices(vectorized_datasets["train"], batch_size_per_update, input_rng)
1531
+ train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update, drop_last_batch=True)
1532
+
1533
+ # Gather the indices for creating the batch and do a training step
1534
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1), 1):
1535
+ samples = [vectorized_datasets["train"][int(idx)] for idx in batch_idx]
1536
+ batch = data_collator(samples)
1537
+ batch.pop("input_ids")
1538
+ batch = shard(batch.data)
1539
+ state, train_metric = p_train_step(state, batch)
1540
+
1541
+ cur_step = epoch * (num_train_samples // batch_size_per_update) + step
1542
+
1543
+ if cur_step % training_args.logging_steps == 0:
1544
+ # Save metrics
1545
+ train_metric = unreplicate(train_metric)
1546
+ train_time += time.time() - train_start
1547
+ # need to upcast all device arrays to fp32 for wandb logging (jnp.bfloat16 not supported) -> do this here OR in train_step
1548
+ write_wandb_log(to_fp32(train_metric), cur_step, prefix="train")
1549
+ # we won't log to tensorboard for now (it is fiddly logging param and grad norms on a layer-by-layer basis)
1550
+ # if has_tensorboard and jax.process_index() == 0:
1551
+ # write_train_metric(summary_writer, train_metrics, train_time, cur_step)
1552
+
1553
+ epochs.write(
1554
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']}, Gradient Norm: {train_metric['grad_norm']})"
1555
+ )
1556
+
1557
+ if cur_step % total_train_steps == 0:
1558
+ break
1559
+
1560
+ if training_args.eval_steps and cur_step % training_args.eval_steps == 0:
1561
+ # run beam search at each eval step
1562
+ run_evaluation(cur_step, final_step=False)
1563
+
1564
+ if cur_step % training_args.save_steps == 0:
1565
+ save_checkpoint(cur_step)
1566
+
1567
+ if training_args.eval_steps == 0 and (epoch + 1) != num_epochs:
1568
+ # run evaluation at the end of the epoch if eval steps are not specified
1569
+ run_evaluation(cur_step, final_step=False)
1570
+ save_checkpoint(cur_step)
1571
+
1572
+ if training_args.do_train:
1573
+ save_checkpoint(cur_step)
1574
+
1575
+ cur_step = max_steps if max_steps > 0 else cur_step # set step to max steps so that eval happens in alignment with training
1576
+
1577
+ if training_args.do_eval:
1578
+ run_evaluation(cur_step, final_step=True)
1579
+
1580
+ # TODO: collapse 'do_predict' into the run_evaluation function
1581
+ if training_args.do_predict:
1582
+ # ======================== Prediction ==============================
1583
+ for split in test_split:
1584
+ pred_metrics = []
1585
+ pred_generations = []
1586
+ pred_ids = []
1587
+ pred_labels = []
1588
+
1589
+ # Generate eval set by sequentially sampling indices from the eval dataset and grouping by length
1590
+ pred_samples_idx = get_grouped_indices(vectorized_datasets[split], eval_batch_size)
1591
+ pred_batch_idx = generate_batch_splits(pred_samples_idx, eval_batch_size, drop_last_batch=False)
1592
+
1593
+ for i, batch_idx in enumerate(tqdm(pred_batch_idx, desc=f"Predicting {split}...", position=2)):
1594
+ samples = [vectorized_datasets[split][int(idx)] for idx in batch_idx]
1595
+ batch = data_collator(samples)
1596
+ pred_ids.extend(batch.pop("input_ids"))
1597
+ labels = batch["labels"]
1598
+
1599
+ metrics = pad_shard_unpad(p_eval_step, static_return=True)(state.params, batch.data,
1600
+ min_device_batch=per_device_eval_batch_size)
1601
+ pred_metrics.append(metrics)
1602
+
1603
+ # generation
1604
+ if training_args.predict_with_generate:
1605
+ generated_ids = pad_shard_unpad(p_final_generate_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1606
+ pred_generations.extend(
1607
+ jax.device_get(
1608
+ generated_ids.reshape(-1, final_gen_kwargs["num_beams"], final_gen_kwargs["max_length"])
1609
+ )
1610
+ )
1611
+ pred_labels.extend(labels)
1612
+
1613
+ # normalize eval metrics
1614
+ pred_metrics = get_metrics(pred_metrics)
1615
+ pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
1616
+ pred_metrics = to_fp32(pred_metrics)
1617
+
1618
+ # compute error rate metric and get predicted string (for debugging)
1619
+ error_rate_desc = ""
1620
+ pred_str = []
1621
+ label_str = []
1622
+ if training_args.predict_with_generate:
1623
+ error_rate_metric, pred_str, label_str = compute_metrics(pred_generations, pred_labels)
1624
+ pred_metrics.update(error_rate_metric)
1625
+ error_rate_desc = " ".join([f"{split} {key}: {value} |" for key, value in error_rate_metric.items()])
1626
+
1627
+ # Print metrics and update progress bar
1628
+ desc = f"Step... ({cur_step}/{total_train_steps} | {split} Loss: {pred_metrics['loss']} | {error_rate_desc})"
1629
+ epochs.write(desc)
1630
+ epochs.desc = desc
1631
+
1632
+ # Save metrics
1633
+ write_wandb_log(pred_metrics, cur_step, prefix=split)
1634
+ write_wandb_pred(
1635
+ pred_str,
1636
+ label_str,
1637
+ pred_ids,
1638
+ cur_step,
1639
+ prefix=split,
1640
+ top_ids=vectorized_datasets[split]["input_id"] if data_args.log_first_ids else None,
1641
+ final_step=True,
1642
+ )
1643
+
1644
+
1645
+ if __name__ == "__main__":
1646
+ main()
special_tokens_map.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "cls_token": "<s>",
4
+ "eos_token": "</s>",
5
+ "mask_token": {
6
+ "content": "<mask>",
7
+ "lstrip": true,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "pad_token": "<pad>",
13
+ "sep_token": "</s>",
14
+ "unk_token": "<unk>"
15
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": "<s>",
4
+ "cls_token": "<s>",
5
+ "eos_token": "</s>",
6
+ "errors": "replace",
7
+ "mask_token": "<mask>",
8
+ "model_max_length": 1024,
9
+ "name_or_path": "sanchit-gandhi/flax-wav2vec2-2-bart-large-scan",
10
+ "pad_token": "<pad>",
11
+ "sep_token": "</s>",
12
+ "special_tokens_map_file": null,
13
+ "tokenizer_class": "BartTokenizer",
14
+ "trim_offsets": true,
15
+ "unk_token": "<unk>"
16
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff