sanchit-gandhi HF staff commited on
Commit
1ec36d9
1 Parent(s): f324b0e

Saving weights and logs of step 10k

Browse files
config.json ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "sanchit-gandhi/flax-wav2vec2-2-bart-large-scan",
3
+ "architectures": [
4
+ "SpeechEncoderDecoderModel"
5
+ ],
6
+ "decoder": {
7
+ "_name_or_path": "",
8
+ "activation_dropout": 0.1,
9
+ "activation_function": "gelu",
10
+ "add_bias_logits": false,
11
+ "add_cross_attention": true,
12
+ "add_final_layer_norm": false,
13
+ "architectures": [
14
+ "BartModel"
15
+ ],
16
+ "attention_dropout": 0.1,
17
+ "bad_words_ids": null,
18
+ "bos_token_id": 0,
19
+ "chunk_size_feed_forward": 0,
20
+ "classif_dropout": 0.1,
21
+ "classifier_dropout": 0.0,
22
+ "cross_attention_hidden_size": null,
23
+ "d_model": 1024,
24
+ "decoder_attention_heads": 16,
25
+ "decoder_ffn_dim": 4096,
26
+ "decoder_layerdrop": 0.0,
27
+ "decoder_layers": 12,
28
+ "decoder_start_token_id": 2,
29
+ "diversity_penalty": 0.0,
30
+ "do_sample": false,
31
+ "dropout": 0.1,
32
+ "early_stopping": true,
33
+ "encoder_attention_heads": 16,
34
+ "encoder_ffn_dim": 4096,
35
+ "encoder_layerdrop": 0.0,
36
+ "encoder_layers": 12,
37
+ "encoder_no_repeat_ngram_size": 0,
38
+ "eos_token_id": 2,
39
+ "exponential_decay_length_penalty": null,
40
+ "finetuning_task": null,
41
+ "forced_bos_token_id": 0,
42
+ "forced_eos_token_id": 2,
43
+ "fuse_matmuls": false,
44
+ "gradient_checkpointing": true,
45
+ "id2label": {
46
+ "0": "LABEL_0",
47
+ "1": "LABEL_1",
48
+ "2": "LABEL_2"
49
+ },
50
+ "init_std": 0.02,
51
+ "is_decoder": true,
52
+ "is_encoder_decoder": false,
53
+ "label2id": {
54
+ "LABEL_0": 0,
55
+ "LABEL_1": 1,
56
+ "LABEL_2": 2
57
+ },
58
+ "length_penalty": 1.0,
59
+ "max_length": 20,
60
+ "max_position_embeddings": 1024,
61
+ "min_length": 0,
62
+ "model_type": "bart",
63
+ "no_repeat_ngram_size": 3,
64
+ "normalize_before": false,
65
+ "num_beam_groups": 1,
66
+ "num_beams": 4,
67
+ "num_hidden_layers": 12,
68
+ "num_return_sequences": 1,
69
+ "output_attentions": false,
70
+ "output_hidden_states": false,
71
+ "output_scores": false,
72
+ "pad_token_id": 1,
73
+ "prefix": null,
74
+ "problem_type": null,
75
+ "pruned_heads": {},
76
+ "remove_invalid_values": false,
77
+ "repetition_penalty": 1.0,
78
+ "return_dict": true,
79
+ "return_dict_in_generate": false,
80
+ "scale_embedding": false,
81
+ "sep_token_id": null,
82
+ "task_specific_params": {
83
+ "summarization": {
84
+ "length_penalty": 1.0,
85
+ "max_length": 128,
86
+ "min_length": 12,
87
+ "num_beams": 4
88
+ },
89
+ "summarization_cnn": {
90
+ "length_penalty": 2.0,
91
+ "max_length": 142,
92
+ "min_length": 56,
93
+ "num_beams": 4
94
+ },
95
+ "summarization_xsum": {
96
+ "length_penalty": 1.0,
97
+ "max_length": 62,
98
+ "min_length": 11,
99
+ "num_beams": 6
100
+ }
101
+ },
102
+ "temperature": 1.0,
103
+ "tie_encoder_decoder": false,
104
+ "tie_word_embeddings": true,
105
+ "tokenizer_class": null,
106
+ "top_k": 50,
107
+ "top_p": 1.0,
108
+ "torch_dtype": "float32",
109
+ "torchscript": false,
110
+ "transformers_version": "4.18.0.dev0",
111
+ "typical_p": 1.0,
112
+ "use_bfloat16": false,
113
+ "use_cache": true,
114
+ "use_scan": true,
115
+ "vocab_size": 50265
116
+ },
117
+ "decoder_start_token_id": 0,
118
+ "encoder": {
119
+ "_name_or_path": "",
120
+ "activation_dropout": 0.1,
121
+ "adapter_kernel_size": 3,
122
+ "adapter_stride": 2,
123
+ "add_adapter": true,
124
+ "add_cross_attention": false,
125
+ "apply_spec_augment": true,
126
+ "architectures": [
127
+ "Wav2Vec2ForPreTraining"
128
+ ],
129
+ "attention_dropout": 0.1,
130
+ "bad_words_ids": null,
131
+ "bos_token_id": 1,
132
+ "chunk_size_feed_forward": 0,
133
+ "classifier_proj_size": 256,
134
+ "codevector_dim": 768,
135
+ "contrastive_logits_temperature": 0.1,
136
+ "conv_bias": true,
137
+ "conv_dim": [
138
+ 512,
139
+ 512,
140
+ 512,
141
+ 512,
142
+ 512,
143
+ 512,
144
+ 512
145
+ ],
146
+ "conv_kernel": [
147
+ 10,
148
+ 3,
149
+ 3,
150
+ 3,
151
+ 3,
152
+ 2,
153
+ 2
154
+ ],
155
+ "conv_stride": [
156
+ 5,
157
+ 2,
158
+ 2,
159
+ 2,
160
+ 2,
161
+ 2,
162
+ 2
163
+ ],
164
+ "cross_attention_hidden_size": null,
165
+ "ctc_loss_reduction": "sum",
166
+ "ctc_zero_infinity": false,
167
+ "decoder_start_token_id": null,
168
+ "diversity_loss_weight": 0.1,
169
+ "diversity_penalty": 0.0,
170
+ "do_sample": false,
171
+ "do_stable_layer_norm": true,
172
+ "early_stopping": false,
173
+ "encoder_no_repeat_ngram_size": 0,
174
+ "eos_token_id": 2,
175
+ "exponential_decay_length_penalty": null,
176
+ "feat_extract_activation": "gelu",
177
+ "feat_extract_dropout": 0.0,
178
+ "feat_extract_norm": "layer",
179
+ "feat_proj_dropout": 0.0,
180
+ "feat_quantizer_dropout": 0.0,
181
+ "final_dropout": 0.0,
182
+ "finetuning_task": null,
183
+ "forced_bos_token_id": null,
184
+ "forced_eos_token_id": null,
185
+ "fuse_matmuls": false,
186
+ "gradient_checkpointing": true,
187
+ "hidden_act": "gelu",
188
+ "hidden_dropout": 0.1,
189
+ "hidden_dropout_prob": 0.1,
190
+ "hidden_size": 1024,
191
+ "id2label": {
192
+ "0": "LABEL_0",
193
+ "1": "LABEL_1"
194
+ },
195
+ "initializer_range": 0.02,
196
+ "intermediate_size": 4096,
197
+ "is_decoder": false,
198
+ "is_encoder_decoder": false,
199
+ "label2id": {
200
+ "LABEL_0": 0,
201
+ "LABEL_1": 1
202
+ },
203
+ "layer_norm_eps": 1e-05,
204
+ "layerdrop": 0.0,
205
+ "length_penalty": 1.0,
206
+ "mask_feature_length": 10,
207
+ "mask_feature_min_masks": 0,
208
+ "mask_feature_prob": 0.0,
209
+ "mask_time_length": 10,
210
+ "mask_time_min_masks": 2,
211
+ "mask_time_prob": 0.1,
212
+ "max_length": 20,
213
+ "min_length": 0,
214
+ "model_type": "wav2vec2",
215
+ "no_repeat_ngram_size": 0,
216
+ "num_adapter_layers": 3,
217
+ "num_attention_heads": 16,
218
+ "num_beam_groups": 1,
219
+ "num_beams": 1,
220
+ "num_codevector_groups": 2,
221
+ "num_codevectors_per_group": 320,
222
+ "num_conv_pos_embedding_groups": 16,
223
+ "num_conv_pos_embeddings": 128,
224
+ "num_feat_extract_layers": 7,
225
+ "num_hidden_layers": 24,
226
+ "num_negatives": 100,
227
+ "num_return_sequences": 1,
228
+ "output_attentions": false,
229
+ "output_hidden_size": 1024,
230
+ "output_hidden_states": false,
231
+ "output_scores": false,
232
+ "pad_token_id": 0,
233
+ "prefix": null,
234
+ "problem_type": null,
235
+ "proj_codevector_dim": 768,
236
+ "pruned_heads": {},
237
+ "remove_invalid_values": false,
238
+ "repetition_penalty": 1.0,
239
+ "return_dict": true,
240
+ "return_dict_in_generate": false,
241
+ "sep_token_id": null,
242
+ "task_specific_params": null,
243
+ "tdnn_dilation": [
244
+ 1,
245
+ 2,
246
+ 3,
247
+ 1,
248
+ 1
249
+ ],
250
+ "tdnn_dim": [
251
+ 512,
252
+ 512,
253
+ 512,
254
+ 512,
255
+ 1500
256
+ ],
257
+ "tdnn_kernel": [
258
+ 5,
259
+ 3,
260
+ 3,
261
+ 1,
262
+ 1
263
+ ],
264
+ "temperature": 1.0,
265
+ "tie_encoder_decoder": false,
266
+ "tie_word_embeddings": true,
267
+ "tokenizer_class": null,
268
+ "top_k": 50,
269
+ "top_p": 1.0,
270
+ "torch_dtype": null,
271
+ "torchscript": false,
272
+ "transformers_version": "4.18.0.dev0",
273
+ "typical_p": 1.0,
274
+ "use_bfloat16": false,
275
+ "use_scan": true,
276
+ "use_weighted_layer_sum": false,
277
+ "vocab_size": 32,
278
+ "xvector_output_dim": 512
279
+ },
280
+ "eos_token_id": 2,
281
+ "is_encoder_decoder": true,
282
+ "max_length": 40,
283
+ "model_type": "speech-encoder-decoder",
284
+ "pad_token_id": 1,
285
+ "processor_class": "Wav2Vec2Processor",
286
+ "tie_word_embeddings": false,
287
+ "transformers_version": null,
288
+ "use_cache": false
289
+ }
flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dbca86a64a213cedfad4d5e3f4588587cb7509eda7944cfca791ae02675e2be6
3
+ size 2353616717
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
preprocessor_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "feature_extractor_type": "Wav2Vec2FeatureExtractor",
4
+ "feature_size": 1,
5
+ "padding_side": "right",
6
+ "padding_value": 0.0,
7
+ "return_attention_mask": true,
8
+ "sampling_rate": 16000
9
+ }
run_cv9.sh ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python run_flax_speech_recognition_seq2seq.py \
2
+ --dataset_name="mozilla-foundation/common_voice_9_0" \
3
+ --model_name_or_path="sanchit-gandhi/flax-wav2vec2-2-bart-large-scan" \
4
+ --dataset_config_name="en" \
5
+ --train_split_name="train" \
6
+ --eval_split_name="validation" \
7
+ --test_split_name="test" \
8
+ --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \
9
+ --output_dir="./flax-wav2vec2-2-bart-large-cv9-baseline" \
10
+ --preprocessing_num_workers="1" \
11
+ --id_column_name="client_id" \
12
+ --length_column_name="input_length" \
13
+ --text_column_name="sentence" \
14
+ --overwrite_output_dir \
15
+ --per_device_train_batch_size="8" \
16
+ --per_device_eval_batch_size="4" \
17
+ --logging_steps="25" \
18
+ --max_steps="50000" \
19
+ --eval_steps="10000" \
20
+ --save_steps="10000" \
21
+ --gradient_checkpointing \
22
+ --max_duration_in_seconds="20" \
23
+ --max_target_length="128" \
24
+ --generation_max_length="40" \
25
+ --generation_num_beams="1" \
26
+ --generation_length_penalty="1.2" \
27
+ --final_generation_max_length="200" \
28
+ --final_generation_num_beams="5" \
29
+ --learning_rate="1e-4" \
30
+ --warmup_steps="500" \
31
+ --save_total_limit="1" \
32
+ --freeze_feature_encoder \
33
+ --predict_with_generate \
34
+ --do_lower_case \
35
+ --do_eval \
36
+ --do_train \
37
+ --do_predict \
38
+ --push_to_hub \
39
+ --use_auth_token \
40
+ --wandb_project="commonvoice_9_0" \
41
+ --wandb_name="flax-wav2vec2-2-bart-large-cv9-baseline"
run_flax_speech_recognition_seq2seq.py ADDED
@@ -0,0 +1,1443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2022 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the Flax library models for sequence to sequence speech recognition.
18
+ """
19
+ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
+
21
+ import logging
22
+ import os
23
+ import sys
24
+ import time
25
+ from dataclasses import dataclass, field
26
+ from pathlib import Path
27
+ from typing import Any, Callable, Dict, List, Optional, Union
28
+
29
+ import datasets
30
+ import numpy as np
31
+ from datasets import DatasetDict, load_dataset, load_metric
32
+ from tqdm import tqdm
33
+
34
+ import flax
35
+ import jax
36
+ import jax.numpy as jnp
37
+ import optax
38
+ import transformers
39
+ import wandb as wandb
40
+ from flax import core, jax_utils, struct, traverse_util
41
+ from flax.jax_utils import unreplicate
42
+ from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
43
+ from huggingface_hub import Repository
44
+ from models.modeling_flax_speech_encoder_decoder import FlaxSpeechEncoderDecoderModel
45
+ from optax._src import linear_algebra
46
+ from transformers import (
47
+ AutoConfig,
48
+ AutoFeatureExtractor,
49
+ AutoProcessor,
50
+ AutoTokenizer,
51
+ HfArgumentParser,
52
+ Seq2SeqTrainingArguments,
53
+ is_tensorboard_available,
54
+ )
55
+ from transformers.file_utils import get_full_repo_name
56
+ from transformers.trainer_utils import get_last_checkpoint, is_main_process
57
+ from transformers.utils import check_min_version
58
+ from transformers.utils.versions import require_version
59
+
60
+
61
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
62
+ check_min_version("4.17.0.dev0")
63
+
64
+ require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
65
+
66
+ logger = logging.getLogger(__name__)
67
+
68
+
69
+ @flax.struct.dataclass
70
+ class ModelArguments:
71
+ """
72
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
73
+ """
74
+
75
+ model_name_or_path: str = field(
76
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
77
+ )
78
+ config_name: Optional[str] = field(
79
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
80
+ )
81
+ tokenizer_name: Optional[str] = field(
82
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
83
+ )
84
+ feature_extractor_name: Optional[str] = field(
85
+ default=None, metadata={"help": "feature extractor name or path if not the same as model_name"}
86
+ )
87
+ cache_dir: Optional[str] = field(
88
+ default=None,
89
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
90
+ )
91
+ use_fast_tokenizer: bool = field(
92
+ default=True,
93
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
94
+ )
95
+ model_revision: str = field(
96
+ default="main",
97
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
98
+ )
99
+ use_auth_token: bool = field(
100
+ default=False,
101
+ metadata={
102
+ "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
103
+ "with private models)."
104
+ },
105
+ )
106
+ freeze_feature_encoder: bool = field(
107
+ default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
108
+ )
109
+ hidden_dropout: float = field(
110
+ default=0.1,
111
+ metadata={
112
+ "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler."
113
+ },
114
+ )
115
+ encoder_add_adapter: bool = field(
116
+ default=True, metadata={"help": "Whether to add an adapter layer between the encoder and decoder."}
117
+ )
118
+
119
+
120
+ @flax.struct.dataclass
121
+ class DataTrainingArguments:
122
+ """
123
+ Arguments pertaining to what data we are going to input our model for training and eval.
124
+ """
125
+
126
+ dataset_name: str = field(
127
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
128
+ )
129
+ dataset_config_name: Optional[str] = field(
130
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
131
+ )
132
+ text_column: Optional[str] = field(
133
+ default=None,
134
+ metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
135
+ )
136
+ dataset_cache_dir: Optional[str] = field(
137
+ default=None, metadata={"help": "Path to cache directory for saving and loading datasets"}
138
+ )
139
+ overwrite_cache: bool = field(
140
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
141
+ )
142
+ preprocessing_num_workers: Optional[int] = field(
143
+ default=None,
144
+ metadata={"help": "The number of processes to use for the preprocessing."},
145
+ )
146
+ max_train_samples: Optional[int] = field(
147
+ default=None,
148
+ metadata={
149
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
150
+ "value if set."
151
+ },
152
+ )
153
+ max_eval_samples: Optional[int] = field(
154
+ default=None,
155
+ metadata={
156
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
157
+ "value if set."
158
+ },
159
+ )
160
+ max_test_samples: Optional[int] = field(
161
+ default=None,
162
+ metadata={
163
+ "help": "For debugging purposes or quicker training, truncate the number of test examples to this "
164
+ "value if set."
165
+ },
166
+ )
167
+ audio_column_name: str = field(
168
+ default="audio",
169
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
170
+ )
171
+ text_column_name: str = field(
172
+ default="text",
173
+ metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
174
+ )
175
+ id_column_name: str = field(
176
+ default="id",
177
+ metadata={"help": "The name of the dataset column containing the id data. Defaults to 'id'"},
178
+ )
179
+ max_duration_in_seconds: float = field(
180
+ default=20.0,
181
+ metadata={
182
+ "help": "Truncate audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`"
183
+ },
184
+ )
185
+ min_duration_in_seconds: float = field(
186
+ default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
187
+ )
188
+ max_target_length: Optional[int] = field(
189
+ default=128,
190
+ metadata={
191
+ "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
192
+ "than this will be truncated, sequences shorter will be padded."
193
+ },
194
+ )
195
+ min_target_length: Optional[int] = field(
196
+ default=0,
197
+ metadata={
198
+ "help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
199
+ "than this will be filtered."
200
+ },
201
+ )
202
+ pad_input_to_multiple_of: Optional[int] = field(
203
+ default=24000,
204
+ metadata={
205
+ "help": "If set will pad the input sequence to a multiple of the provided value. "
206
+ "This is important to avoid triggering recompilations on TPU."
207
+ },
208
+ )
209
+ pad_target_to_multiple_of: Optional[int] = field(
210
+ default=None,
211
+ metadata={
212
+ "help": "If set will pad the target sequence to a multiple of the provided value. "
213
+ "This is important to avoid triggering recompilations on TPU. If unspecified, will default to `max_target_length`, "
214
+ " the equivalent of padding the targets to max length."
215
+ },
216
+ )
217
+ preprocessing_only: bool = field(
218
+ default=False,
219
+ metadata={
220
+ "help": "Whether to only do data preprocessing and skip training. "
221
+ "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
222
+ "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
223
+ "so that the cached datasets can consequently be loaded in distributed training"
224
+ },
225
+ )
226
+ train_split_name: str = field(
227
+ default="train",
228
+ metadata={
229
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
230
+ },
231
+ )
232
+ eval_split_name: str = field(
233
+ default="test",
234
+ metadata={
235
+ "help": "The name of the evaluation data set split to use (via the datasets library). Defaults to 'validation'"
236
+ },
237
+ )
238
+ test_split_name: str = field(
239
+ default="test",
240
+ metadata={"help": "The name of the test data set split to use (via the datasets library). Defaults to 'test'"},
241
+ )
242
+ do_lower_case: bool = field(
243
+ default=True,
244
+ metadata={"help": "Whether the target text should be lower cased."},
245
+ )
246
+ wandb_project: str = field(
247
+ default="flax-speech-recognition-seq2seq",
248
+ metadata={"help": "The name of the wandb project."},
249
+ )
250
+ wandb_name: str = field(
251
+ default=None,
252
+ metadata={"help": "The name of the wandb run."},
253
+ )
254
+ wandb_job_type: str = field(
255
+ default="Seq2Seq",
256
+ metadata={"help": "The name of the wandb job type."},
257
+ )
258
+ log_first_ids: bool = field(
259
+ default=True,
260
+ metadata={
261
+ "help": "Whether to log the first id's from the dataset. Defaults to `True`. If `False`, will log the first id's returned by the grouped length sampler."
262
+ },
263
+ )
264
+
265
+
266
+ # @flax.struct.dataclass
267
+ @dataclass
268
+ class FlaxSeq2SeqTrainingArguments(Seq2SeqTrainingArguments):
269
+ precision: str = field(
270
+ default="full",
271
+ metadata={
272
+ "help": "Whether to enable mixed-precision training. If true, the optimizer is stored in half-precision (bfloat16) and computations are executed in half-precision"
273
+ "**Note that this only specifies the dtype of the computation and optimizer state. It does not influence the dtype of model parameters.**"
274
+ },
275
+ )
276
+ matmul_precision: str = field(
277
+ default="default",
278
+ metadata={
279
+ "help": "Default floating-point precision of internal computations used in TPU matrix multiplications and convolutions. "
280
+ "This configuration option controls the default precision for JAX operations that take an optional precision argument (e.g. `lax.conv_general_dilated` and `lax.dot`). "
281
+ "This configuration option does not change the behaviours of such calls with explicit precision arguments; "
282
+ "it only changes the behaviors of calls with no such argument provided. "
283
+ "One of `['highest', 'float32', 'high', 'bfloat16_3x', 'default', 'bfloat16', 'fastest', None]`."
284
+ },
285
+ )
286
+ generation_length_penalty: float = field(
287
+ default=1,
288
+ metadata={
289
+ "help": "Exponential penalty to the length. 1.0 (default) means no penalty. Set to values < 1.0 in order to encourage the model"
290
+ "to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer sequences."
291
+ },
292
+ )
293
+ final_generation_max_length: int = field(
294
+ default=None,
295
+ metadata={
296
+ "help": "The `max_length` to use on each evaluation loop when `predict_with_generate=True`. If unspecified, will default "
297
+ "to the `max_length` value of the model configuration."
298
+ },
299
+ )
300
+ final_generation_num_beams: int = field(
301
+ default=None,
302
+ metadata={
303
+ "help": "The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. If unspecified, will default "
304
+ "to the `num_beams` value of the model configuration."
305
+ },
306
+ )
307
+
308
+ def __post_init__(self):
309
+ if self.final_generation_max_length is None:
310
+ self.final_generation_max_length = self.generation_max_length
311
+ if self.final_generation_num_beams is None:
312
+ self.final_generation_num_beams = self.generation_num_beams
313
+
314
+
315
+ def to_fp32(t):
316
+ return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
317
+
318
+
319
+ def to_bf16(t):
320
+ return jax.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t)
321
+
322
+
323
+ class MixedPrecisionTrainState(struct.PyTreeNode):
324
+ """Train state for use with a single Optax optimizer.
325
+ Adapted from flax train_state https://github.com/google/flax/blob/main/flax/training/train_state.py
326
+
327
+ Synopsis::
328
+
329
+ state = TrainState.create(
330
+ apply_fn=model.apply,
331
+ params=variables['params'],
332
+ tx=tx)
333
+ grad_fn = jax.grad(make_loss_fn(state.apply_fn))
334
+ for batch in data:
335
+ grads = grad_fn(state.params, batch)
336
+ state = state.apply_gradients(grads=grads)
337
+
338
+ Args:
339
+ step: Counter starts at 0 and is incremented by every call to
340
+ `.apply_gradients()`.
341
+ apply_fn: Usually set to `model.apply()`. Kept in this dataclass for
342
+ convenience to have a shorter params list for the `train_step()` function
343
+ in your training loop.
344
+ params: The parameters to be updated by `tx` and used by `apply_fn`.
345
+ tx: An Optax gradient transformation.
346
+ opt_state: The state for `tx`.
347
+ dropout_rng: PRNG key for stochastic operations.
348
+ bf16: Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training.
349
+ """
350
+
351
+ step: int
352
+ apply_fn: Callable = struct.field(pytree_node=False)
353
+ params: core.FrozenDict[str, Any]
354
+ tx: optax.GradientTransformation = struct.field(pytree_node=False)
355
+ opt_state: optax.OptState
356
+ dropout_rng: jnp.ndarray
357
+ max_grad_norm: Optional[float] = 1.0
358
+
359
+ def apply_gradients(self, *, grads, to_dtype, **kwargs):
360
+ """Updates `step`, `params`, `opt_state` and `**kwargs` in return value.
361
+
362
+ Note that internally this function calls `.tx.update()` followed by a call
363
+ to `optax.apply_updates()` to update `params` and `opt_state`.
364
+
365
+ Args:
366
+ grads: Gradients that have the same pytree structure as `.params`.
367
+ **kwargs: Additional dataclass attributes that should be `.replace()`-ed.
368
+
369
+ Returns:
370
+ An updated instance of `self` with `step` incremented by one, `params`
371
+ and `opt_state` updated by applying `grads`, and additional attributes
372
+ replaced as specified by `kwargs`.
373
+ """
374
+
375
+ # clip gradients by global l2 norm
376
+ casted_max_grad_norm = to_dtype(self.max_grad_norm)
377
+ g_norm = linear_algebra.global_norm(grads)
378
+ g_norm = jnp.maximum(casted_max_grad_norm, g_norm)
379
+ grads = jax.tree_map(lambda t: (t / g_norm) * casted_max_grad_norm, grads)
380
+
381
+ # perform update step in fp32 and subsequently downcast optimizer states if mixed precision training
382
+ # grads and opt_state in bf16 (need to upcast), params in fp32 (leave as is)
383
+ updates, new_opt_state = self.tx.update(to_fp32(grads), to_fp32(self.opt_state), self.params)
384
+
385
+ new_params = optax.apply_updates(self.params, updates)
386
+ return self.replace(
387
+ step=self.step + 1,
388
+ params=new_params,
389
+ opt_state=to_dtype(new_opt_state),
390
+ **kwargs,
391
+ )
392
+
393
+ @classmethod
394
+ def create(cls, *, apply_fn, params, tx, to_dtype, **kwargs):
395
+ """Creates a new instance with `step=0` and initialized `opt_state`."""
396
+ # downcast optimizer state to bf16 if mixed-precision training
397
+ opt_state = tx.init(to_dtype(params)) if tx is not None else None
398
+ return cls(
399
+ step=0,
400
+ apply_fn=apply_fn,
401
+ params=params,
402
+ tx=tx,
403
+ opt_state=opt_state,
404
+ **kwargs,
405
+ )
406
+
407
+ def replicate(self):
408
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
409
+
410
+
411
+
412
+ def pad_to_max_length(data, tokenizer):
413
+ # Get lengths of each row of data
414
+ lens = np.array([len(i) for i in data])
415
+
416
+ # Mask of valid places in each row
417
+ mask = np.arange(lens.max()) < lens[:, None]
418
+
419
+ # Setup output array and put elements from data into masked positions
420
+ out = np.ones_like(mask, dtype=data.dtype) * tokenizer.pad_token_id
421
+ out[mask] = np.concatenate(data)
422
+ return out
423
+
424
+
425
+ def shift_tokens_right(label_ids: np.array, decoder_start_token_id: int) -> np.ndarray:
426
+ """
427
+ Shift label ids one token to the right.
428
+ """
429
+ shifted_label_ids = np.zeros_like(label_ids)
430
+ shifted_label_ids[:, 1:] = label_ids[:, :-1]
431
+ shifted_label_ids[:, 0] = decoder_start_token_id
432
+
433
+ return shifted_label_ids
434
+
435
+
436
+ @flax.struct.dataclass
437
+ class FlaxDataCollatorSpeechSeq2SeqWithPadding:
438
+ """
439
+ Data collator that will dynamically pad the inputs received.
440
+ Args:
441
+ processor ([`Wav2Vec2Processor`])
442
+ The processor used for proccessing the data.
443
+ decoder_start_token_id (:obj: `int`)
444
+ The begin-of-sentence of the decoder.
445
+ input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
446
+ Select a strategy to pad the returned input sequences (according to the model's padding side and padding index)
447
+ among:
448
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
449
+ sequence if provided).
450
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
451
+ maximum acceptable input length for the model if that argument is not provided.
452
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
453
+ different lengths).
454
+ target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
455
+ Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
456
+ See above for details.
457
+ max_input_length (:obj:`float`, `optional`):
458
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
459
+ max_target_length (:obj:`int`, `optional`):
460
+ Maximum length of the ``labels`` of the returned list and optionally padding length (see above).
461
+ pad_input_to_multiple_of (:obj:`int`, `optional`):
462
+ If set will pad the input sequence to a multiple of the provided value.
463
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
464
+ 7.5 (Volta).
465
+ pad_target_to_multiple_of (:obj:`int`, `optional`):
466
+ If set will pad the target sequence to a multiple of the provided value.
467
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
468
+ 7.5 (Volta).
469
+ """
470
+
471
+ processor: Any
472
+ decoder_start_token_id: int
473
+ input_padding: Union[bool, str] = "longest"
474
+ target_padding: Union[bool, str] = "max_length"
475
+ max_input_length: Optional[float] = None
476
+ max_target_length: Optional[int] = None
477
+ pad_input_to_multiple_of: Optional[int] = None
478
+ pad_target_to_multiple_of: Optional[int] = None
479
+
480
+ def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
481
+ # split inputs and labels since they have to be of different lengths and need
482
+ # different padding methods
483
+ input_features = [{"input_values": feature["input_values"]} for feature in features]
484
+ input_ids = [feature["input_id"] for feature in features]
485
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
486
+
487
+ # reformat list to dict and set to pytorch format
488
+ batch = self.processor.feature_extractor.pad(
489
+ input_features,
490
+ max_length=self.max_input_length,
491
+ padding=self.input_padding,
492
+ pad_to_multiple_of=self.pad_input_to_multiple_of,
493
+ return_tensors="np",
494
+ )
495
+
496
+ labels_batch = self.processor.tokenizer.pad(
497
+ label_features,
498
+ max_length=self.max_target_length,
499
+ padding=self.target_padding,
500
+ pad_to_multiple_of=self.pad_target_to_multiple_of,
501
+ return_tensors="np",
502
+ )
503
+
504
+ # if bos token is appended in previous tokenization step,
505
+ # cut bos token here as it's append later anyways
506
+ labels = labels_batch["input_ids"]
507
+ if (labels[:, 0] == self.decoder_start_token_id).all().item():
508
+ labels = labels[:, 1:]
509
+ labels_batch.attention_mask = labels_batch.attention_mask[:, 1:]
510
+
511
+ decoder_input_ids = shift_tokens_right(labels, self.decoder_start_token_id)
512
+
513
+ # replace padding with -100 to ignore correctly when computing the loss
514
+ labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1))
515
+ labels = labels.filled(fill_value=-100)
516
+
517
+ batch["inputs"] = batch.pop("input_values")
518
+ batch["input_ids"] = input_ids
519
+ batch["labels"] = labels
520
+ batch["decoder_input_ids"] = decoder_input_ids
521
+
522
+ return batch
523
+
524
+
525
+ def get_grouped_indices(
526
+ dataset, batch_size: int, rng: Optional[List[int]] = None, mega_batch_mult: Optional[int] = None
527
+ ) -> np.array:
528
+ """
529
+ Adapted from the `get_length_grouped_indices` function in the PyTorch Trainer utils file (https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_pt_utils.py#L486)
530
+ Function that returns a list of indices in which each slice of `batch_size` consecutive indices correspond to elements of similar
531
+ lengths. To do this, the indices are:
532
+
533
+ - randomly permuted (if a JAX rng is specified)
534
+ - grouped in mega-batches of size `mega_batch_mult * batch_size`
535
+ - sorted by length in each mega-batch
536
+
537
+ The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
538
+ maximum length placed first, so that an OOM happens sooner rather than later.
539
+ """
540
+ lengths = dataset["input_length"]
541
+
542
+ # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
543
+ if mega_batch_mult is None:
544
+ mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
545
+ # Just in case, for tiny datasets
546
+ if mega_batch_mult == 0:
547
+ mega_batch_mult = 1
548
+
549
+ # We need to use JAX for the random permutation as the PRNG key will be set based on the seed outside of the sampler.
550
+ num_samples = len(lengths)
551
+ indices = jax.random.permutation(rng, np.arange(num_samples)) if rng is not None else np.arange(num_samples)
552
+
553
+ megabatch_size = mega_batch_mult * batch_size
554
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
555
+ megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches]
556
+
557
+ # The rest is to get the biggest batch first.
558
+ # Since each megabatch is sorted by descending length, the longest element is the first
559
+ megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
560
+ max_idx = np.argmax(megabatch_maximums).item()
561
+ # Switch to put the longest batch in first position
562
+ # (note that this is different to the PT grouped sampler in which we only put the longest element in the first position, and not its batch)
563
+ megabatches[0], megabatches[max_idx] = megabatches[max_idx], megabatches[0]
564
+
565
+ megabatches = np.array([i for megabatch in megabatches for i in megabatch])
566
+
567
+ return megabatches
568
+
569
+
570
+ def generate_batch_splits(samples_idx: np.ndarray, batch_size: int) -> np.ndarray:
571
+ """Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
572
+ the batch size, the last incomplete batch is dropped."""
573
+ num_samples = len(samples_idx)
574
+ samples_to_remove = num_samples % batch_size
575
+
576
+ if samples_to_remove != 0:
577
+ samples_idx = samples_idx[:-samples_to_remove]
578
+ sections_split = num_samples // batch_size
579
+ return samples_idx.reshape((sections_split, batch_size))
580
+
581
+
582
+ def data_loader(dataset, batch_size, rng, sampler, collator):
583
+ samples_idx = sampler(dataset, batch_size, rng)
584
+
585
+ num_samples = len(samples_idx)
586
+ samples_to_remove = num_samples % batch_size
587
+
588
+ if samples_to_remove != 0:
589
+ samples_idx = samples_idx[:-samples_to_remove]
590
+ sections_split = num_samples // batch_size
591
+
592
+ batch_idx = np.split(samples_idx, sections_split)
593
+
594
+ for idx in batch_idx:
595
+ samples = dataset[idx]
596
+ batch = collator(samples)
597
+ batch = shard(batch.data)
598
+ yield batch
599
+
600
+
601
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
602
+ summary_writer.scalar("train_time", train_time, step)
603
+
604
+ train_metrics = get_metrics(train_metrics)
605
+ for key, vals in train_metrics.items():
606
+ tag = f"train_{key}"
607
+ for i, val in enumerate(vals):
608
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
609
+
610
+
611
+ def write_eval_metric(summary_writer, eval_metrics, step, pred_str=None):
612
+ for metric_name, value in eval_metrics.items():
613
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
614
+
615
+ if pred_str is not None:
616
+ # write output actual predictions for debugging
617
+ summary_writer.text("eval_predictions", "\n".join(pred_str), step)
618
+
619
+
620
+ def write_wandb_log(metrics, step, prefix=None):
621
+ if jax.process_index() == 0:
622
+ log_metrics = {}
623
+ for k, v in metrics.items():
624
+ if "layer" in k:
625
+ log_metrics[f"{k}/"] = v
626
+ elif prefix is not None:
627
+ log_metrics[f"{prefix}/{k}"] = v
628
+ else:
629
+ log_metrics[k] = v
630
+ wandb.log(log_metrics, step)
631
+
632
+
633
+ def write_wandb_pred(pred_str, label_str, eval_ids, step, prefix="eval", top_ids=None, final_step=True):
634
+ if jax.process_index() == 0:
635
+ top_ids = top_ids if top_ids else eval_ids
636
+ num_beams = len(pred_str)
637
+ # convert str data to a wandb compatible format
638
+ str_data = []
639
+ for id in top_ids:
640
+ if id in eval_ids:
641
+ idx = eval_ids.index(id)
642
+ str_data.append([eval_ids[idx], label_str[idx]] + [pred_str[beam][idx] for beam in range(num_beams)])
643
+ columns = ["id", "label_str"] + [f"beam_{i + 1}" for i in range(num_beams)]
644
+ wandb.log(
645
+ {f"{prefix}/step_{int(step / 1000)}k": wandb.Table(columns=columns, data=str_data[:50])},
646
+ step,
647
+ )
648
+ if final_step:
649
+ str_data = np.array(str_data)
650
+ str_data = str_data[str_data[:, 1] != str_data[:, 2]]
651
+ wandb.log(
652
+ {f"{prefix}/step_{int(step / 1000)}k_incorrect": wandb.Table(columns=columns, data=str_data[:200000])}, step
653
+ )
654
+
655
+
656
+ def create_learning_rate_fn(
657
+ num_train_steps: int, num_warmup_steps: int, learning_rate: float
658
+ ) -> Callable[[int], jnp.array]:
659
+ """Returns a linear warmup, linear_decay learning rate function."""
660
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
661
+ decay_fn = optax.linear_schedule(
662
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
663
+ )
664
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
665
+ return schedule_fn
666
+
667
+
668
+ def main():
669
+ # 1. Parse input arguments
670
+ # See all possible arguments in src/transformers/training_args.py
671
+ # or by passing the --help flag to this script.
672
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
673
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FlaxSeq2SeqTrainingArguments))
674
+
675
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
676
+ # If we pass only one argument to the script and it's the path to a json file,
677
+ # let's parse it to get our arguments.
678
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
679
+ else:
680
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
681
+
682
+ # 2. Setup logging
683
+ # Make one log on every process with the configuration for debugging.
684
+ logging.basicConfig(
685
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
686
+ datefmt="%m/%d/%Y %H:%M:%S",
687
+ handlers=[logging.StreamHandler(sys.stdout)],
688
+ )
689
+ # Set the verbosity to info of the Transformers logger.
690
+ # We only want one process per machine to log things on the screen.
691
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
692
+ if jax.process_index() == 0:
693
+ datasets.utils.logging.set_verbosity_warning()
694
+ transformers.utils.logging.set_verbosity_info()
695
+ else:
696
+ datasets.utils.logging.set_verbosity_error()
697
+ transformers.utils.logging.set_verbosity_error()
698
+
699
+ # Set up wandb run
700
+ if jax.process_index() == 0:
701
+ wandb.init(project=data_args.wandb_project, name=data_args.wandb_name, job_type=data_args.wandb_job_type)
702
+
703
+ logger.info("Training/evaluation parameters %s", training_args)
704
+
705
+ # Set the default TPU matmul precision and display the number of devices
706
+ jax.config.update("jax_default_matmul_precision", training_args.matmul_precision)
707
+ logger.info(f"JAX devices: {jax.device_count()}, matmul precision: {training_args.matmul_precision}")
708
+
709
+ # TODO: 3. Detecting last checkpoint and eventually continue from last checkpoint
710
+ last_checkpoint = None
711
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
712
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
713
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
714
+ raise ValueError(
715
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
716
+ "Use --overwrite_output_dir to overcome."
717
+ )
718
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
719
+ logger.info(
720
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
721
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
722
+ )
723
+
724
+ # 4. Load dataset
725
+ raw_datasets = DatasetDict()
726
+
727
+ if training_args.do_train:
728
+ raw_datasets["train"] = load_dataset(
729
+ data_args.dataset_name,
730
+ data_args.dataset_config_name,
731
+ split=data_args.train_split_name,
732
+ cache_dir=data_args.dataset_cache_dir,
733
+ use_auth_token=True if model_args.use_auth_token else None,
734
+ )
735
+
736
+ if training_args.do_eval:
737
+ raw_datasets["eval"] = load_dataset(
738
+ data_args.dataset_name,
739
+ data_args.dataset_config_name,
740
+ split=data_args.eval_split_name,
741
+ cache_dir=data_args.dataset_cache_dir,
742
+ use_auth_token=True if model_args.use_auth_token else None,
743
+ )
744
+
745
+ if training_args.do_predict:
746
+ test_split = data_args.test_split_name.split("+")
747
+ for split in test_split:
748
+ raw_datasets[split] = load_dataset(
749
+ data_args.dataset_name,
750
+ data_args.dataset_config_name,
751
+ split=split,
752
+ cache_dir=data_args.dataset_cache_dir,
753
+ use_auth_token=True if model_args.use_auth_token else None,
754
+ )
755
+
756
+ if not training_args.do_train and not training_args.do_eval and not training_args.do_predict:
757
+ raise ValueError(
758
+ "Cannot not train, not do evaluation and not do prediction. At least one of "
759
+ "training, evaluation or prediction has to be done."
760
+ )
761
+
762
+ # if not training, there is no need to run multiple epochs
763
+ if not training_args.do_train:
764
+ training_args.num_train_epochs = 1
765
+
766
+ if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
767
+ raise ValueError(
768
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
769
+ "Make sure to set `--audio_column_name` to the correct audio column - one of "
770
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
771
+ )
772
+
773
+ if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names:
774
+ raise ValueError(
775
+ f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
776
+ "Make sure to set `--text_column_name` to the correct text column - one of "
777
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
778
+ )
779
+
780
+ if data_args.id_column_name not in next(iter(raw_datasets.values())).column_names:
781
+ raise ValueError(
782
+ f"--id_column_name {data_args.id_column_name} not found in dataset '{data_args.dataset_name}'. "
783
+ "Make sure to set `--id_column_name` to the correct id column - one of "
784
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
785
+ )
786
+
787
+ # 5. Load pretrained model, tokenizer, and feature extractor
788
+ #
789
+ # Distributed training:
790
+ # The .from_pretrained methods guarantee that only one local process can concurrently
791
+ config = AutoConfig.from_pretrained(
792
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
793
+ cache_dir=model_args.cache_dir,
794
+ revision=model_args.model_revision,
795
+ use_auth_token=True if model_args.use_auth_token else None,
796
+ )
797
+
798
+ # update config according to training and model args
799
+ config.encoder.update(
800
+ {
801
+ "gradient_checkpointing": training_args.gradient_checkpointing,
802
+ "hidden_dropout": model_args.hidden_dropout,
803
+ "add_adapter": model_args.encoder_add_adapter,
804
+ }
805
+ )
806
+ config.decoder.update({"gradient_checkpointing": training_args.gradient_checkpointing})
807
+
808
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
809
+ model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
810
+ cache_dir=model_args.cache_dir,
811
+ revision=model_args.model_revision,
812
+ use_auth_token=True if model_args.use_auth_token else None,
813
+ )
814
+ tokenizer = AutoTokenizer.from_pretrained(
815
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
816
+ cache_dir=model_args.cache_dir,
817
+ use_fast=model_args.use_fast_tokenizer,
818
+ revision=model_args.model_revision,
819
+ use_auth_token=True if model_args.use_auth_token else None,
820
+ )
821
+
822
+ if training_args.precision == "full_mixed":
823
+ dtype = jnp.bfloat16
824
+ training_args.mixed_precision = True
825
+ elif training_args.precision == "half_mixed":
826
+ dtype = jnp.bfloat16
827
+ training_args.mixed_precision = False
828
+ else:
829
+ dtype = jnp.float32
830
+ training_args.mixed_precision = False
831
+
832
+ model = FlaxSpeechEncoderDecoderModel.from_pretrained(
833
+ model_args.model_name_or_path,
834
+ config=config,
835
+ dtype=dtype,
836
+ cache_dir=model_args.cache_dir,
837
+ revision=model_args.model_revision,
838
+ use_auth_token=True if model_args.use_auth_token else None,
839
+ )
840
+
841
+ if model.config.decoder_start_token_id is None:
842
+ raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
843
+
844
+ # 6. Resample speech dataset ALWAYS
845
+ raw_datasets = raw_datasets.cast_column(
846
+ data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
847
+ )
848
+
849
+ # 7. Preprocessing the datasets.
850
+ # We need to read the audio files as arrays and tokenize the targets.
851
+ max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
852
+ min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate)
853
+ max_target_length = data_args.max_target_length
854
+ min_target_length = data_args.min_target_length
855
+ pad_input_to_multiple_of = data_args.pad_input_to_multiple_of
856
+ pad_target_to_multiple_of = data_args.pad_target_to_multiple_of
857
+ audio_column_name = data_args.audio_column_name
858
+ num_workers = data_args.preprocessing_num_workers
859
+ text_column_name = data_args.text_column_name
860
+ id_column_name = data_args.id_column_name
861
+ model_input_name = feature_extractor.model_input_names[0]
862
+ do_lower_case = data_args.do_lower_case
863
+ speech_disfluencies = {"{cough}": "", "{breath}": "", "{noise}": "", "{smack}": "", "{uh}": "", "{um}": "", "<sil>": "", "(1)": "", "(2)": "", "(3)": "", "(4)": "", "(5)": "", "(6)": "", " ": " ", " ": " ", " ": " ", "<s> ": "<s>", " </s>": "</s>"}
864
+
865
+ if training_args.do_train and data_args.max_train_samples is not None:
866
+ raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
867
+
868
+ if training_args.do_eval and data_args.max_eval_samples is not None:
869
+ raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
870
+
871
+ if training_args.do_predict and data_args.max_test_samples is not None:
872
+ for split in test_split:
873
+ raw_datasets[split] = raw_datasets[split].select(range(data_args.max_eval_samples))
874
+
875
+ # filter data where the targets are ignored in scoring
876
+ def is_target_labels(input_str):
877
+ return input_str != "ignore_time_segment_in_scoring"
878
+
879
+ if data_args.dataset_name == "LIUM/tedlium":
880
+ raw_datasets = raw_datasets.filter(
881
+ is_target_labels,
882
+ num_proc=num_workers,
883
+ input_columns=[text_column_name],
884
+ )
885
+
886
+ def prepare_dataset(batch):
887
+ # process audio
888
+ sample = batch[audio_column_name]
889
+ inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
890
+ # process audio length
891
+ batch[model_input_name] = inputs.input_values[0]
892
+ batch["input_length"] = len(batch["input_values"])
893
+ batch["input_id"] = batch[id_column_name]
894
+
895
+ # process targets
896
+ input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
897
+ if input_str.startswith('"') and input_str.endswith('"'):
898
+ input_str = input_str[1:-1]
899
+ input_str = input_str.replace('""', '"')
900
+ if data_args.dataset_name == "mozilla-foundation/common_voice_9_0":
901
+ if input_str[-1] not in ['.', '?', '!']:
902
+ input_str = input_str + '.'
903
+ batch["labels"] = tokenizer(input_str).input_ids
904
+ batch["labels_length"] = len(batch["labels"])
905
+ return batch
906
+
907
+ vectorized_datasets = raw_datasets.map(
908
+ prepare_dataset,
909
+ remove_columns=next(iter(raw_datasets.values())).column_names,
910
+ num_proc=data_args.preprocessing_num_workers,
911
+ desc="preprocess train dataset",
912
+ )
913
+
914
+ # filter data with inputs shorter than min_input_length or longer than max_input_length
915
+ def is_audio_in_length_range(length):
916
+ return length > min_input_length and length < max_input_length
917
+
918
+ vectorized_datasets = vectorized_datasets.filter(
919
+ is_audio_in_length_range,
920
+ num_proc=num_workers,
921
+ input_columns=["input_length"],
922
+ )
923
+
924
+ # filter data with targets shorter than min_target_length or longer than max_target_length
925
+ def is_labels_in_length_range(length):
926
+ return length > min_target_length and length < max_target_length
927
+
928
+ vectorized_datasets = vectorized_datasets.filter(
929
+ is_labels_in_length_range,
930
+ num_proc=num_workers,
931
+ input_columns=["labels_length"],
932
+ )
933
+
934
+ # for large datasets it is advised to run the preprocessing on a
935
+ # single machine first with `args.preprocessing_only` since there will mostly likely
936
+ # be a timeout when running the script in distributed mode.
937
+ # In a second step `args.preprocessing_only` can then be set to `False` to load the
938
+ # cached dataset
939
+ if data_args.preprocessing_only:
940
+ cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
941
+ logger.info(f"Data preprocessing finished. Files cached at {cache}.")
942
+ return
943
+
944
+ # 8. Load Metric
945
+ metric = load_metric("wer")
946
+
947
+ def compute_metrics(pred_ids: List[List[int]], label_ids: List[List[int]]):
948
+ label_ids = pad_to_max_length(np.array(label_ids, dtype='object'), tokenizer) if pad_target_to_multiple_of else label_ids
949
+
950
+ padded_ids = np.where(np.asarray(label_ids) == -100, tokenizer.pad_token_id, np.asarray(label_ids))
951
+ # we do not want to group tokens when computing the metrics
952
+ label_str = tokenizer.batch_decode(padded_ids, skip_special_tokens=True)
953
+
954
+ pred_ids = np.array(pred_ids)
955
+ num_beams = pred_ids.shape[1]
956
+ # decode on a beam-by-beam basis
957
+ pred_str = [
958
+ tokenizer.batch_decode(pred_ids[:, beam, :], skip_special_tokens=True)
959
+ for beam in reversed(range(num_beams))
960
+ ]
961
+ # compute wer for top beam
962
+ wer = metric.compute(predictions=pred_str[0], references=label_str)
963
+
964
+ return {"wer": wer}, pred_str, label_str
965
+
966
+ # 9. Save feature extractor, tokenizer and config
967
+ feature_extractor.save_pretrained(training_args.output_dir)
968
+ tokenizer.save_pretrained(training_args.output_dir)
969
+ config.save_pretrained(training_args.output_dir)
970
+
971
+ processor = AutoProcessor.from_pretrained(training_args.output_dir)
972
+
973
+ data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding(
974
+ processor=processor,
975
+ decoder_start_token_id=model.config.decoder_start_token_id,
976
+ input_padding="longest",
977
+ target_padding="longest",
978
+ max_target_length=max_target_length,
979
+ pad_input_to_multiple_of=pad_input_to_multiple_of,
980
+ pad_target_to_multiple_of=pad_target_to_multiple_of if pad_target_to_multiple_of else max_target_length
981
+ )
982
+
983
+ # Enable tensorboard only on the master node
984
+ has_tensorboard = is_tensorboard_available()
985
+ if has_tensorboard and jax.process_index() == 0:
986
+ try:
987
+ from flax.metrics.tensorboard import SummaryWriter
988
+
989
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
990
+ except ImportError as ie:
991
+ has_tensorboard = False
992
+ logger.warning(
993
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
994
+ )
995
+ else:
996
+ logger.warning(
997
+ "Unable to display metrics through TensorBoard because the package is not installed: "
998
+ "Please run `pip install tensorboard` to enable."
999
+ )
1000
+
1001
+ # 10. Handle the repository creation
1002
+ if training_args.push_to_hub:
1003
+ if training_args.hub_model_id is None:
1004
+ repo_name = get_full_repo_name(
1005
+ Path(training_args.output_dir).absolute().name, token=training_args.hub_token
1006
+ )
1007
+ else:
1008
+ repo_name = training_args.hub_model_id
1009
+ repo = Repository(training_args.output_dir, clone_from=repo_name)
1010
+
1011
+ # 11. Initialize our training
1012
+ rng = jax.random.PRNGKey(training_args.seed)
1013
+ rng, dropout_rng = jax.random.split(rng)
1014
+
1015
+ # Store some constants
1016
+ max_steps = int(training_args.max_steps)
1017
+ gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
1018
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
1019
+ batch_size_per_update = train_batch_size * gradient_accumulation_steps
1020
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
1021
+ to_dtype = to_bf16 if training_args.mixed_precision else to_fp32
1022
+
1023
+ if training_args.do_train:
1024
+ num_train_samples = len(vectorized_datasets["train"])
1025
+ steps_per_epoch = num_train_samples // batch_size_per_update
1026
+ if max_steps > 0:
1027
+ num_epochs = - (training_args.max_steps // - steps_per_epoch)
1028
+ total_train_steps = max_steps
1029
+ else:
1030
+ num_epochs = int(training_args.num_train_epochs)
1031
+ total_train_steps = steps_per_epoch * num_epochs
1032
+
1033
+ # Create learning rate schedule
1034
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
1035
+ total_train_steps,
1036
+ training_args.warmup_steps,
1037
+ training_args.learning_rate,
1038
+ )
1039
+
1040
+ # We use Optax's "masking" functionality to not apply weight decay
1041
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
1042
+ # mask boolean with the same structure as the parameters.
1043
+ # The mask is True for parameters that should be decayed.
1044
+ # Note that this mask is specifically adapted for FlaxWav2Vec2 and FlaxBart.
1045
+ # For FlaxT5, one should correct the layer norm parameter naming
1046
+ # accordingly - see `run_t5_mlm_flax.py` e.g.
1047
+ def decay_mask_fn(params):
1048
+ flat_params = traverse_util.flatten_dict(params)
1049
+ layer_norm_params = [
1050
+ (name, "scale")
1051
+ for name in ["layer_norm", "self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
1052
+ ]
1053
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
1054
+ return traverse_util.unflatten_dict(flat_mask)
1055
+
1056
+ if training_args.adafactor:
1057
+ # Create Adafactor optimizer
1058
+ optim = optax.adafactor(
1059
+ learning_rate=linear_decay_lr_schedule_fn,
1060
+ dtype_momentum=jnp.bfloat16 if training_args.mixed_precision else jnp.float32,
1061
+ weight_decay_rate=training_args.weight_decay,
1062
+ weight_decay_mask=decay_mask_fn,
1063
+ )
1064
+ else:
1065
+ # Create AdamW optimizer
1066
+ optim = optax.adamw(
1067
+ learning_rate=linear_decay_lr_schedule_fn,
1068
+ b1=training_args.adam_beta1,
1069
+ b2=training_args.adam_beta2,
1070
+ eps=training_args.adam_epsilon,
1071
+ weight_decay=training_args.weight_decay,
1072
+ mask=decay_mask_fn,
1073
+ )
1074
+ else:
1075
+ num_epochs = 0
1076
+ total_train_steps = 0
1077
+ num_train_samples = 0
1078
+ optim = None
1079
+
1080
+ # Setup train state
1081
+ state = MixedPrecisionTrainState.create(
1082
+ apply_fn=model.__call__,
1083
+ params=model.params,
1084
+ tx=optim,
1085
+ to_dtype=to_dtype,
1086
+ dropout_rng=dropout_rng,
1087
+ max_grad_norm=training_args.max_grad_norm,
1088
+ )
1089
+
1090
+ # Cross entropy loss
1091
+ def loss_fn(logits, labels):
1092
+ vocab_size = logits.shape[-1]
1093
+ # optax onehot always returns a float32 device array, need to downcast if performing mixed precision training
1094
+ onehot_targets = to_dtype(onehot(labels, vocab_size))
1095
+ loss = optax.softmax_cross_entropy(logits, onehot_targets)
1096
+ # ignore padded tokens from loss, i.e. where labels are not set to -100
1097
+ padding = labels >= 0
1098
+ loss = loss * padding
1099
+ loss = loss.sum()
1100
+ num_labels = padding.sum()
1101
+ return loss, num_labels
1102
+
1103
+ # Define gradient update step fn
1104
+ def train_step(state, batch):
1105
+ # only one single rng per grad step, with or without accumulation, as the graph should be identical over one effective training batch
1106
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
1107
+
1108
+ def compute_loss(params, minibatch):
1109
+ labels = minibatch.pop("labels")
1110
+ logits = state.apply_fn(
1111
+ **minibatch,
1112
+ params=params,
1113
+ dropout_rng=dropout_rng,
1114
+ freeze_feature_encoder=model_args.freeze_feature_encoder,
1115
+ train=True,
1116
+ )[0]
1117
+ loss, num_labels = loss_fn(logits, labels)
1118
+ return loss, num_labels
1119
+
1120
+ grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
1121
+
1122
+ if gradient_accumulation_steps == 1 or training_args.multisteps:
1123
+ (loss, num_labels), grad = grad_fn(to_dtype(state.params), batch)
1124
+
1125
+ # Custom gradient accumulation
1126
+ else:
1127
+ # add a first dimension over gradient_accumulation_steps for minibatch slices
1128
+ batch = jax.tree_map(
1129
+ lambda x: x.reshape(
1130
+ gradient_accumulation_steps, training_args.per_device_train_batch_size, *x.shape[1::]
1131
+ ),
1132
+ batch,
1133
+ )
1134
+
1135
+ def accum_minibatch_step(accum_grad, minibatch):
1136
+ # compute loss, num labels and grad over minibatch and accumulate
1137
+ (loss, num_labels), grad = grad_fn(to_dtype(state.params), minibatch)
1138
+ return jax.tree_map(jnp.add, accum_grad, grad), (loss, num_labels)
1139
+
1140
+ # create an initial state for accumulating losses, num labels and gradients
1141
+ init_grad = jax.tree_map(jnp.zeros_like, to_dtype(state.params))
1142
+ # loop accum minibatch step over the number of gradient accumulation steps
1143
+ grad, (loss, num_labels) = jax.lax.scan(accum_minibatch_step, init_grad, batch)
1144
+
1145
+ grad = jax.lax.psum(grad, "batch")
1146
+ loss = jax.lax.psum(loss.sum(), "batch")
1147
+ total_samples = jax.lax.psum(num_labels.sum(), "batch")
1148
+ grad = jax.tree_map(lambda g: g / total_samples, grad)
1149
+ loss = jax.tree_map(lambda l: l / total_samples, loss)
1150
+
1151
+ # update state
1152
+ new_state = state.apply_gradients(
1153
+ grads=grad,
1154
+ dropout_rng=new_dropout_rng,
1155
+ to_dtype=to_dtype,
1156
+ )
1157
+
1158
+ # compute gradient norms over all layers, total encoder, total decoder and global for detailed monitoring
1159
+ layer_grad_norm = jax.tree_map(jnp.linalg.norm, grad)
1160
+ logs = {
1161
+ "layer_grad_norm": layer_grad_norm,
1162
+ "encoder_grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm["encoder"])),
1163
+ "decoder_grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm["decoder"])),
1164
+ }
1165
+ logs["grad_norm"] = jnp.linalg.norm([logs["encoder_grad_norm"], logs["decoder_grad_norm"]])
1166
+
1167
+ # compute parameter norms over all layers, total encoder, total decoder and global for detailed monitoring
1168
+ layer_param_norm = jax.tree_map(jnp.linalg.norm, new_state.params)
1169
+ logs["layer_param_norm"] = layer_param_norm
1170
+ logs["encoder_param_norm"] = jnp.linalg.norm(jax.tree_util.tree_leaves(layer_param_norm["encoder"]))
1171
+ logs["decoder_param_norm"] = jnp.linalg.norm(jax.tree_util.tree_leaves(layer_param_norm["decoder"]))
1172
+ logs["param_norm"] = jnp.linalg.norm([logs["encoder_param_norm"], logs["decoder_param_norm"]])
1173
+
1174
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
1175
+ metrics.update(logs)
1176
+
1177
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1178
+ # metrics = to_fp32(metrics)
1179
+
1180
+ return new_state, metrics
1181
+
1182
+ # Define eval fn
1183
+ def eval_step(params, batch):
1184
+ labels = batch.pop("labels")
1185
+ logits = model(**batch, params=params, train=False)[0]
1186
+ loss, num_labels = loss_fn(logits, labels)
1187
+
1188
+ total_samples = jax.lax.psum(num_labels, "batch")
1189
+ loss = jax.lax.psum(loss, "batch")
1190
+ loss = jax.tree_map(lambda l: l / total_samples, loss)
1191
+
1192
+ # summarize metrics
1193
+ metrics = {"loss": loss}
1194
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1195
+ # metrics = to_fp32(metrics)
1196
+ return metrics
1197
+
1198
+ # Define generation function
1199
+ gen_kwargs = {
1200
+ "max_length": training_args.generation_max_length,
1201
+ "num_beams": training_args.generation_num_beams,
1202
+ "length_penalty": training_args.generation_length_penalty,
1203
+ }
1204
+ final_gen_kwargs = {
1205
+ "max_length": training_args.final_generation_max_length,
1206
+ "num_beams": training_args.final_generation_num_beams,
1207
+ "length_penalty": training_args.generation_length_penalty,
1208
+ }
1209
+
1210
+ def generate_step(params, batch):
1211
+ model.params = params
1212
+ output_ids = model.generate(batch["inputs"], **gen_kwargs)
1213
+ return output_ids.sequences
1214
+
1215
+ def final_generate_step(params, batch):
1216
+ model.params = params
1217
+ output_ids = model.generate(batch["inputs"], **final_gen_kwargs)
1218
+ return output_ids.sequences
1219
+
1220
+ # Create parallel version of the train and eval step
1221
+ if training_args.do_train:
1222
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
1223
+
1224
+ if training_args.do_eval or training_args.do_predict:
1225
+ p_eval_step = jax.pmap(eval_step, "batch")
1226
+
1227
+ if training_args.predict_with_generate:
1228
+ p_generate_step = jax.pmap(generate_step, "batch")
1229
+ p_final_generate_step = jax.pmap(final_generate_step, "batch")
1230
+
1231
+ def run_evaluation(step, final_step=False):
1232
+ if training_args.do_eval:
1233
+ # ======================== Evaluating ==============================
1234
+ eval_metrics = []
1235
+ eval_preds = []
1236
+ eval_ids = []
1237
+ eval_labels = []
1238
+
1239
+ # Generate eval set by sequentially sampling indices from the eval dataset and grouping by length
1240
+ eval_samples_idx = get_grouped_indices(vectorized_datasets["eval"], eval_batch_size)
1241
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
1242
+
1243
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
1244
+ samples = [vectorized_datasets["eval"][int(idx)] for idx in batch_idx]
1245
+ batch = data_collator(samples)
1246
+ eval_ids.extend(batch.pop("input_ids"))
1247
+ batch = shard(batch.data)
1248
+ labels = batch["labels"]
1249
+
1250
+ metrics = p_eval_step(state.params, batch)
1251
+ eval_metrics.append(metrics)
1252
+
1253
+ # generation
1254
+ if training_args.predict_with_generate:
1255
+ if not final_step:
1256
+ generated_ids = p_generate_step(state.params, batch)
1257
+ eval_preds.extend(
1258
+ jax.device_get(generated_ids.reshape(-1, gen_kwargs["num_beams"], gen_kwargs["max_length"]))
1259
+ )
1260
+ else:
1261
+ generated_ids = p_final_generate_step(state.params, batch)
1262
+ eval_preds.extend(
1263
+ jax.device_get(
1264
+ generated_ids.reshape(-1, final_gen_kwargs["num_beams"], final_gen_kwargs["max_length"])
1265
+ )
1266
+ )
1267
+ eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
1268
+
1269
+ # normalize eval metrics
1270
+ eval_metrics = get_metrics(eval_metrics)
1271
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
1272
+ eval_metrics = to_fp32(eval_metrics)
1273
+
1274
+ # compute WER metric and get predicted string (for debugging)
1275
+ wer_desc = ""
1276
+ pred_str = []
1277
+ label_str = []
1278
+ if training_args.predict_with_generate:
1279
+ wer_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels)
1280
+ eval_metrics.update(wer_metric)
1281
+ wer_desc = " ".join([f"Eval {key}: {value} |" for key, value in wer_metric.items()])
1282
+
1283
+ # Print metrics and update progress bar
1284
+ desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {wer_desc})"
1285
+ epochs.write(desc)
1286
+ epochs.desc = desc
1287
+
1288
+ # Save metrics
1289
+ write_wandb_log(eval_metrics, step, prefix="eval")
1290
+ write_wandb_pred(
1291
+ pred_str,
1292
+ label_str,
1293
+ eval_ids,
1294
+ step,
1295
+ top_ids=vectorized_datasets["eval"]["input_id"] if data_args.log_first_ids else None,
1296
+ final_step=final_step,
1297
+ )
1298
+ # if has_tensorboard and jax.process_index() == 0:
1299
+ # write_eval_metric(summary_writer, eval_metrics, step, pred_str=pred_str)
1300
+
1301
+ def save_checkpoint(step):
1302
+ # save and push checkpoint to the hub
1303
+ if jax.process_index() == 0:
1304
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
1305
+ model.save_pretrained(training_args.output_dir, params=params)
1306
+ tokenizer.save_pretrained(training_args.output_dir)
1307
+ if training_args.push_to_hub:
1308
+ repo.push_to_hub(commit_message=f"Saving weights and logs of step {int(step / 1000)}k", blocking=False)
1309
+
1310
+ # Replicate the train state on each device
1311
+ state = state.replicate()
1312
+
1313
+ logger.info("***** Running training *****")
1314
+ logger.info(f" Num examples = {num_train_samples}")
1315
+ logger.info(f" Num Epochs = {num_epochs}")
1316
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
1317
+ logger.info(f" Num gradient accumulation steps = {gradient_accumulation_steps}")
1318
+ logger.info(f" Total train batch size (w. parallel & distributed) = {batch_size_per_update}")
1319
+ logger.info(f" Total optimization steps = {total_train_steps}")
1320
+ logger.info(f" Gradient checkpointing: {config.encoder.gradient_checkpointing}")
1321
+ logger.info(f" Use scan: {config.encoder.use_scan}")
1322
+ logger.info(f" Fuse matmuls: {config.encoder.fuse_matmuls}")
1323
+
1324
+ train_time = cur_step = 0
1325
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
1326
+ for epoch in epochs:
1327
+ if training_args.do_train:
1328
+ # ======================== Training ================================
1329
+ train_start = time.time()
1330
+
1331
+ # Create sampling rng
1332
+ rng, input_rng = jax.random.split(rng)
1333
+
1334
+ # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length
1335
+ train_samples_idx = get_grouped_indices(vectorized_datasets["train"], batch_size_per_update, input_rng)
1336
+ train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update)
1337
+
1338
+ # Gather the indices for creating the batch and do a training step
1339
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1), 1):
1340
+ samples = [vectorized_datasets["train"][int(idx)] for idx in batch_idx]
1341
+ batch = data_collator(samples)
1342
+ batch.pop("input_ids")
1343
+ batch = shard(batch.data)
1344
+ state, train_metric = p_train_step(state, batch)
1345
+
1346
+ cur_step = epoch * (num_train_samples // batch_size_per_update) + step
1347
+
1348
+ if cur_step % training_args.logging_steps == 0:
1349
+ # Save metrics
1350
+ train_metric = unreplicate(train_metric)
1351
+ train_time += time.time() - train_start
1352
+ # need to upcast all device arrays to fp32 for wandb logging (jnp.bfloat16 not supported) -> do this here OR in train_step
1353
+ write_wandb_log(to_fp32(train_metric), cur_step, prefix="train")
1354
+ # we won't log to tensorboard for now (it is fiddly logging param and grad norms on a layer-by-layer basis)
1355
+ # if has_tensorboard and jax.process_index() == 0:
1356
+ # write_train_metric(summary_writer, train_metrics, train_time, cur_step)
1357
+
1358
+ epochs.write(
1359
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']}, Gradient Norm: {train_metric['grad_norm']})"
1360
+ )
1361
+
1362
+ if cur_step % total_train_steps == 0:
1363
+ break
1364
+
1365
+ if cur_step % training_args.eval_steps == 0:
1366
+ run_evaluation(cur_step, final_step=False)
1367
+
1368
+ if cur_step % training_args.save_steps == 0:
1369
+ save_checkpoint(cur_step)
1370
+
1371
+ if training_args.do_train:
1372
+ save_checkpoint(cur_step)
1373
+
1374
+ if training_args.do_eval:
1375
+ run_evaluation(cur_step, final_step=True)
1376
+
1377
+ # TODO: collapse 'do_predict' into the run_evaluation function
1378
+ if training_args.do_predict:
1379
+ # ======================== Prediction ==============================
1380
+ for split in test_split:
1381
+ pred_metrics = []
1382
+ pred_generations = []
1383
+ pred_ids = []
1384
+ pred_labels = []
1385
+
1386
+ # Generate eval set by sequentially sampling indices from the eval dataset and grouping by length
1387
+ pred_samples_idx = get_grouped_indices(vectorized_datasets[split], eval_batch_size)
1388
+ pred_batch_idx = generate_batch_splits(pred_samples_idx, eval_batch_size)
1389
+
1390
+ for i, batch_idx in enumerate(tqdm(pred_batch_idx, desc=f"Predicting {split}...", position=2)):
1391
+ samples = [vectorized_datasets[split][int(idx)] for idx in batch_idx]
1392
+ batch = data_collator(samples)
1393
+ pred_ids.extend(batch.pop("input_ids"))
1394
+ batch = shard(batch.data)
1395
+ labels = batch["labels"]
1396
+
1397
+ metrics = p_eval_step(state.params, batch)
1398
+ pred_metrics.append(metrics)
1399
+
1400
+ # generation
1401
+ if training_args.predict_with_generate:
1402
+ generated_ids = p_final_generate_step(state.params, batch)
1403
+ pred_generations.extend(
1404
+ jax.device_get(
1405
+ generated_ids.reshape(-1, final_gen_kwargs["num_beams"], final_gen_kwargs["max_length"])
1406
+ )
1407
+ )
1408
+ pred_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
1409
+
1410
+ # normalize eval metrics
1411
+ pred_metrics = get_metrics(pred_metrics)
1412
+ pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
1413
+ pred_metrics = to_fp32(pred_metrics)
1414
+
1415
+ # compute WER metric and get predicted string (for debugging)
1416
+ wer_desc = ""
1417
+ pred_str = []
1418
+ label_str = []
1419
+ if training_args.predict_with_generate:
1420
+ wer_metric, pred_str, label_str = compute_metrics(pred_generations, pred_labels)
1421
+ pred_metrics.update(wer_metric)
1422
+ wer_desc = " ".join([f"{split} {key}: {value} |" for key, value in wer_metric.items()])
1423
+
1424
+ # Print metrics and update progress bar
1425
+ desc = f"Epoch... ({epoch + 1}/{num_epochs} | {split} Loss: {pred_metrics['loss']} | {wer_desc})"
1426
+ epochs.write(desc)
1427
+ epochs.desc = desc
1428
+
1429
+ # Save metrics
1430
+ write_wandb_log(pred_metrics, cur_step, prefix=split)
1431
+ write_wandb_pred(
1432
+ pred_str,
1433
+ label_str,
1434
+ pred_ids,
1435
+ cur_step,
1436
+ prefix=split,
1437
+ top_ids=vectorized_datasets[split]["input_id"] if data_args.log_first_ids else None,
1438
+ final_step=True,
1439
+ )
1440
+
1441
+
1442
+ if __name__ == "__main__":
1443
+ main()
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>", "sep_token": "</s>", "pad_token": "<pad>", "cls_token": "<s>", "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": false}}
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"errors": "replace", "bos_token": "<s>", "eos_token": "</s>", "sep_token": "</s>", "cls_token": "<s>", "unk_token": "<unk>", "pad_token": "<pad>", "mask_token": "<mask>", "add_prefix_space": false, "trim_offsets": true, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "sanchit-gandhi/flax-wav2vec2-2-bart-large-scan", "tokenizer_class": "BartTokenizer"}
vocab.json ADDED
The diff for this file is too large to render. See raw diff