sanchit-gandhi HF staff commited on
Commit
f74be82
1 Parent(s): 140399a

2hx8pk65: saving weights and logs of step 10k

Browse files
Files changed (46) hide show
  1. .gitattributes +1 -0
  2. config.json +291 -0
  3. flax_model.msgpack +3 -0
  4. merges.txt +0 -0
  5. models/__init__.py +6 -0
  6. models/__pycache__/__init__.cpython-38.pyc +0 -0
  7. models/__pycache__/configuration_bart.cpython-38.pyc +0 -0
  8. models/__pycache__/configuration_speech_encoder_decoder.cpython-38.pyc +0 -0
  9. models/__pycache__/configuration_wav2vec2.cpython-38.pyc +0 -0
  10. models/__pycache__/modeling_flax_bart.cpython-38.pyc +0 -0
  11. models/__pycache__/modeling_flax_speech_encoder_decoder.cpython-38.pyc +0 -0
  12. models/__pycache__/modeling_flax_wav2vec2.cpython-38.pyc +0 -0
  13. models/configuration_bart.py +183 -0
  14. models/configuration_speech_encoder_decoder.py +121 -0
  15. models/configuration_wav2vec2.py +344 -0
  16. models/modeling_flax_bart.py +816 -0
  17. models/modeling_flax_speech_encoder_decoder.py +1245 -0
  18. models/modeling_flax_wav2vec2.py +975 -0
  19. nohup.out +0 -0
  20. preprocessor_config.json +9 -0
  21. run_flax_speech_recognition_seq2seq.py +1572 -0
  22. run_librispeech.sh +39 -0
  23. special_tokens_map.json +15 -0
  24. tokenizer.json +0 -0
  25. tokenizer_config.json +16 -0
  26. vocab.json +0 -0
  27. wandb/debug-internal.log +1 -0
  28. wandb/debug.log +1 -0
  29. wandb/latest-run +1 -0
  30. wandb/run-20220828_084407-nbdgecc9/files/config.yaml +36 -0
  31. wandb/run-20220828_084407-nbdgecc9/files/output.log +110 -0
  32. wandb/run-20220828_084407-nbdgecc9/files/requirements.txt +167 -0
  33. wandb/run-20220828_084407-nbdgecc9/files/wandb-metadata.json +59 -0
  34. wandb/run-20220828_084407-nbdgecc9/files/wandb-summary.json +1 -0
  35. wandb/run-20220828_084407-nbdgecc9/logs/debug-internal.log +144 -0
  36. wandb/run-20220828_084407-nbdgecc9/logs/debug.log +131 -0
  37. wandb/run-20220828_084407-nbdgecc9/run-nbdgecc9.wandb +3 -0
  38. wandb/run-20220828_085247-2hx8pk65/files/config.yaml +28 -0
  39. wandb/run-20220828_085247-2hx8pk65/files/media/table/eval/step_10k_10000_8b44e8a00a036a18ffdf.table.json +1 -0
  40. wandb/run-20220828_085247-2hx8pk65/files/output.log +0 -0
  41. wandb/run-20220828_085247-2hx8pk65/files/requirements.txt +167 -0
  42. wandb/run-20220828_085247-2hx8pk65/files/wandb-metadata.json +59 -0
  43. wandb/run-20220828_085247-2hx8pk65/files/wandb-summary.json +1 -0
  44. wandb/run-20220828_085247-2hx8pk65/logs/debug-internal.log +0 -0
  45. wandb/run-20220828_085247-2hx8pk65/logs/debug.log +25 -0
  46. wandb/run-20220828_085247-2hx8pk65/run-2hx8pk65.wandb +3 -0
.gitattributes CHANGED
@@ -29,3 +29,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
29
  *.zip filter=lfs diff=lfs merge=lfs -text
30
  *.zst filter=lfs diff=lfs merge=lfs -text
31
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
29
  *.zip filter=lfs diff=lfs merge=lfs -text
30
  *.zst filter=lfs diff=lfs merge=lfs -text
31
  *tfevents* filter=lfs diff=lfs merge=lfs -text
32
+ *.wandb filter=lfs diff=lfs merge=lfs -text
config.json ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "sanchit-gandhi/flax-wav2vec2-2-bart-large-scan",
3
+ "architectures": [
4
+ "SpeechEncoderDecoderModel"
5
+ ],
6
+ "decoder": {
7
+ "_name_or_path": "",
8
+ "activation_dropout": 0.2,
9
+ "activation_function": "gelu",
10
+ "add_bias_logits": false,
11
+ "add_cross_attention": true,
12
+ "add_final_layer_norm": false,
13
+ "architectures": [
14
+ "BartModel"
15
+ ],
16
+ "attention_dropout": 0.1,
17
+ "bad_words_ids": null,
18
+ "bos_token_id": 0,
19
+ "chunk_size_feed_forward": 0,
20
+ "classif_dropout": 0.1,
21
+ "classifier_dropout": 0.0,
22
+ "cross_attention_hidden_size": null,
23
+ "d_model": 1024,
24
+ "decoder_attention_heads": 16,
25
+ "decoder_ffn_dim": 4096,
26
+ "decoder_layerdrop": 0.0,
27
+ "decoder_layers": 12,
28
+ "decoder_start_token_id": 2,
29
+ "diversity_penalty": 0.0,
30
+ "do_sample": false,
31
+ "dropout": 0.2,
32
+ "early_stopping": true,
33
+ "encoder_attention_heads": 16,
34
+ "encoder_ffn_dim": 4096,
35
+ "encoder_layerdrop": 0.0,
36
+ "encoder_layers": 12,
37
+ "encoder_no_repeat_ngram_size": 0,
38
+ "eos_token_id": 2,
39
+ "exponential_decay_length_penalty": null,
40
+ "finetuning_task": null,
41
+ "forced_bos_token_id": 0,
42
+ "forced_eos_token_id": 2,
43
+ "fuse_matmuls": false,
44
+ "gradient_checkpointing": true,
45
+ "id2label": {
46
+ "0": "LABEL_0",
47
+ "1": "LABEL_1",
48
+ "2": "LABEL_2"
49
+ },
50
+ "init_std": 0.02,
51
+ "is_decoder": true,
52
+ "is_encoder_decoder": false,
53
+ "label2id": {
54
+ "LABEL_0": 0,
55
+ "LABEL_1": 1,
56
+ "LABEL_2": 2
57
+ },
58
+ "length_penalty": 1.0,
59
+ "max_length": 20,
60
+ "max_position_embeddings": 1024,
61
+ "min_length": 0,
62
+ "model_type": "bart",
63
+ "no_repeat_ngram_size": 3,
64
+ "normalize_before": false,
65
+ "num_beam_groups": 1,
66
+ "num_beams": 4,
67
+ "num_hidden_layers": 12,
68
+ "num_return_sequences": 1,
69
+ "output_attentions": false,
70
+ "output_hidden_states": false,
71
+ "output_scores": false,
72
+ "pad_token_id": 1,
73
+ "prefix": null,
74
+ "problem_type": null,
75
+ "pruned_heads": {},
76
+ "remove_invalid_values": false,
77
+ "repetition_penalty": 1.0,
78
+ "return_dict": true,
79
+ "return_dict_in_generate": false,
80
+ "scale_embedding": false,
81
+ "sep_token_id": null,
82
+ "task_specific_params": {
83
+ "summarization": {
84
+ "length_penalty": 1.0,
85
+ "max_length": 128,
86
+ "min_length": 12,
87
+ "num_beams": 4
88
+ },
89
+ "summarization_cnn": {
90
+ "length_penalty": 2.0,
91
+ "max_length": 142,
92
+ "min_length": 56,
93
+ "num_beams": 4
94
+ },
95
+ "summarization_xsum": {
96
+ "length_penalty": 1.0,
97
+ "max_length": 62,
98
+ "min_length": 11,
99
+ "num_beams": 6
100
+ }
101
+ },
102
+ "temperature": 1.0,
103
+ "tf_legacy_loss": false,
104
+ "tie_encoder_decoder": false,
105
+ "tie_word_embeddings": true,
106
+ "tokenizer_class": null,
107
+ "top_k": 50,
108
+ "top_p": 1.0,
109
+ "torch_dtype": "float32",
110
+ "torchscript": false,
111
+ "transformers_version": "4.21.0.dev0",
112
+ "typical_p": 1.0,
113
+ "use_bfloat16": false,
114
+ "use_cache": true,
115
+ "use_scan": true,
116
+ "vocab_size": 50265
117
+ },
118
+ "decoder_start_token_id": 0,
119
+ "encoder": {
120
+ "_name_or_path": "",
121
+ "activation_dropout": 0.2,
122
+ "adapter_kernel_size": 3,
123
+ "adapter_stride": 2,
124
+ "add_adapter": true,
125
+ "add_cross_attention": false,
126
+ "apply_spec_augment": true,
127
+ "architectures": [
128
+ "Wav2Vec2ForPreTraining"
129
+ ],
130
+ "attention_dropout": 0.1,
131
+ "bad_words_ids": null,
132
+ "bos_token_id": 1,
133
+ "chunk_size_feed_forward": 0,
134
+ "classifier_proj_size": 256,
135
+ "codevector_dim": 768,
136
+ "contrastive_logits_temperature": 0.1,
137
+ "conv_bias": true,
138
+ "conv_dim": [
139
+ 512,
140
+ 512,
141
+ 512,
142
+ 512,
143
+ 512,
144
+ 512,
145
+ 512
146
+ ],
147
+ "conv_kernel": [
148
+ 10,
149
+ 3,
150
+ 3,
151
+ 3,
152
+ 3,
153
+ 2,
154
+ 2
155
+ ],
156
+ "conv_stride": [
157
+ 5,
158
+ 2,
159
+ 2,
160
+ 2,
161
+ 2,
162
+ 2,
163
+ 2
164
+ ],
165
+ "cross_attention_hidden_size": null,
166
+ "ctc_loss_reduction": "sum",
167
+ "ctc_zero_infinity": false,
168
+ "decoder_start_token_id": null,
169
+ "diversity_loss_weight": 0.1,
170
+ "diversity_penalty": 0.0,
171
+ "do_sample": false,
172
+ "do_stable_layer_norm": true,
173
+ "early_stopping": false,
174
+ "encoder_no_repeat_ngram_size": 0,
175
+ "eos_token_id": 2,
176
+ "exponential_decay_length_penalty": null,
177
+ "feat_extract_activation": "gelu",
178
+ "feat_extract_dropout": 0.0,
179
+ "feat_extract_norm": "layer",
180
+ "feat_proj_dropout": 0.2,
181
+ "feat_quantizer_dropout": 0.0,
182
+ "final_dropout": 0.0,
183
+ "finetuning_task": null,
184
+ "forced_bos_token_id": null,
185
+ "forced_eos_token_id": null,
186
+ "fuse_matmuls": false,
187
+ "gradient_checkpointing": true,
188
+ "hidden_act": "gelu",
189
+ "hidden_dropout": 0.2,
190
+ "hidden_dropout_prob": 0.1,
191
+ "hidden_size": 1024,
192
+ "id2label": {
193
+ "0": "LABEL_0",
194
+ "1": "LABEL_1"
195
+ },
196
+ "initializer_range": 0.02,
197
+ "intermediate_size": 4096,
198
+ "is_decoder": false,
199
+ "is_encoder_decoder": false,
200
+ "label2id": {
201
+ "LABEL_0": 0,
202
+ "LABEL_1": 1
203
+ },
204
+ "layer_norm_eps": 1e-05,
205
+ "layerdrop": 0.0,
206
+ "length_penalty": 1.0,
207
+ "mask_feature_length": 10,
208
+ "mask_feature_min_masks": 0,
209
+ "mask_feature_prob": 0.0,
210
+ "mask_time_length": 10,
211
+ "mask_time_min_masks": 2,
212
+ "mask_time_prob": 0.1,
213
+ "max_length": 20,
214
+ "min_length": 0,
215
+ "model_type": "wav2vec2",
216
+ "no_repeat_ngram_size": 0,
217
+ "num_adapter_layers": 3,
218
+ "num_attention_heads": 16,
219
+ "num_beam_groups": 1,
220
+ "num_beams": 1,
221
+ "num_codevector_groups": 2,
222
+ "num_codevectors_per_group": 320,
223
+ "num_conv_pos_embedding_groups": 16,
224
+ "num_conv_pos_embeddings": 128,
225
+ "num_feat_extract_layers": 7,
226
+ "num_hidden_layers": 24,
227
+ "num_negatives": 100,
228
+ "num_return_sequences": 1,
229
+ "output_attentions": false,
230
+ "output_hidden_size": 1024,
231
+ "output_hidden_states": false,
232
+ "output_scores": false,
233
+ "pad_token_id": 0,
234
+ "prefix": null,
235
+ "problem_type": null,
236
+ "proj_codevector_dim": 768,
237
+ "pruned_heads": {},
238
+ "remove_invalid_values": false,
239
+ "repetition_penalty": 1.0,
240
+ "return_dict": true,
241
+ "return_dict_in_generate": false,
242
+ "sep_token_id": null,
243
+ "task_specific_params": null,
244
+ "tdnn_dilation": [
245
+ 1,
246
+ 2,
247
+ 3,
248
+ 1,
249
+ 1
250
+ ],
251
+ "tdnn_dim": [
252
+ 512,
253
+ 512,
254
+ 512,
255
+ 512,
256
+ 1500
257
+ ],
258
+ "tdnn_kernel": [
259
+ 5,
260
+ 3,
261
+ 3,
262
+ 1,
263
+ 1
264
+ ],
265
+ "temperature": 1.0,
266
+ "tf_legacy_loss": false,
267
+ "tie_encoder_decoder": false,
268
+ "tie_word_embeddings": true,
269
+ "tokenizer_class": null,
270
+ "top_k": 50,
271
+ "top_p": 1.0,
272
+ "torch_dtype": null,
273
+ "torchscript": false,
274
+ "transformers_version": "4.21.0.dev0",
275
+ "typical_p": 1.0,
276
+ "use_bfloat16": false,
277
+ "use_scan": true,
278
+ "use_weighted_layer_sum": false,
279
+ "vocab_size": 32,
280
+ "xvector_output_dim": 512
281
+ },
282
+ "eos_token_id": 2,
283
+ "is_encoder_decoder": true,
284
+ "max_length": 40,
285
+ "model_type": "speech-encoder-decoder",
286
+ "pad_token_id": 1,
287
+ "processor_class": "Wav2Vec2Processor",
288
+ "tie_word_embeddings": false,
289
+ "transformers_version": null,
290
+ "use_cache": false
291
+ }
flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4bbb8026d3a4c9acb651189cbf65ab582eb2284bbcae68d0c6512395b962329
3
+ size 2353616717
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
models/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from models.configuration_bart import BartConfig
2
+ from models.configuration_wav2vec2 import Wav2Vec2Config
3
+ from models.configuration_speech_encoder_decoder import SpeechEncoderDecoderConfig
4
+ from models.modeling_flax_wav2vec2 import FlaxWav2Vec2Model, FlaxWav2Vec2Module, FlaxWav2Vec2ForCTC, FlaxWav2Vec2ForCTCModule
5
+ from models.modeling_flax_bart import FlaxBartForCausalLM, FlaxBartForCausalLMModule
6
+ from models.modeling_flax_speech_encoder_decoder import FlaxSpeechEncoderDecoderModel
models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (762 Bytes). View file
 
models/__pycache__/configuration_bart.cpython-38.pyc ADDED
Binary file (7.06 kB). View file
 
models/__pycache__/configuration_speech_encoder_decoder.cpython-38.pyc ADDED
Binary file (4.64 kB). View file
 
models/__pycache__/configuration_wav2vec2.cpython-38.pyc ADDED
Binary file (16.8 kB). View file
 
models/__pycache__/modeling_flax_bart.cpython-38.pyc ADDED
Binary file (21.1 kB). View file
 
models/__pycache__/modeling_flax_speech_encoder_decoder.cpython-38.pyc ADDED
Binary file (39.4 kB). View file
 
models/__pycache__/modeling_flax_wav2vec2.cpython-38.pyc ADDED
Binary file (30.7 kB). View file
 
models/configuration_bart.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ BART model configuration"""
16
+ import warnings
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+ BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
25
+ "facebook/bart-large": "https://huggingface.co/facebook/bart-large/resolve/main/config.json",
26
+ # See all BART models at https://huggingface.co/models?filter=bart
27
+ }
28
+
29
+
30
+ class BartConfig(PretrainedConfig):
31
+ r"""
32
+ This is the configuration class to store the configuration of a [`BartModel`]. It is used to instantiate a BART
33
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
34
+ defaults will yield a similar configuration to that of the BART
35
+ [facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture.
36
+
37
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
38
+ documentation from [`PretrainedConfig`] for more information.
39
+
40
+
41
+ Args:
42
+ vocab_size (`int`, *optional*, defaults to 50265):
43
+ Vocabulary size of the BART model. Defines the number of different tokens that can be represented by the
44
+ `inputs_ids` passed when calling [`BartModel`] or [`TFBartModel`].
45
+ d_model (`int`, *optional*, defaults to 1024):
46
+ Dimensionality of the layers and the pooler layer.
47
+ encoder_layers (`int`, *optional*, defaults to 12):
48
+ Number of encoder layers.
49
+ decoder_layers (`int`, *optional*, defaults to 12):
50
+ Number of decoder layers.
51
+ encoder_attention_heads (`int`, *optional*, defaults to 16):
52
+ Number of attention heads for each attention layer in the Transformer encoder.
53
+ decoder_attention_heads (`int`, *optional*, defaults to 16):
54
+ Number of attention heads for each attention layer in the Transformer decoder.
55
+ decoder_ffn_dim (`int`, *optional*, defaults to 4096):
56
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
57
+ encoder_ffn_dim (`int`, *optional*, defaults to 4096):
58
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
59
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
60
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
61
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
62
+ dropout (`float`, *optional*, defaults to 0.1):
63
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
64
+ attention_dropout (`float`, *optional*, defaults to 0.0):
65
+ The dropout ratio for the attention probabilities.
66
+ activation_dropout (`float`, *optional*, defaults to 0.0):
67
+ The dropout ratio for activations inside the fully connected layer.
68
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
69
+ The dropout ratio for classifier.
70
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
71
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
72
+ just in case (e.g., 512 or 1024 or 2048).
73
+ init_std (`float`, *optional*, defaults to 0.02):
74
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
75
+ encoder_layerdrop: (`float`, *optional*, defaults to 0.0):
76
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
77
+ for more details.
78
+ decoder_layerdrop: (`float`, *optional*, defaults to 0.0):
79
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
80
+ for more details.
81
+ scale_embedding (`bool`, *optional*, defaults to `False`):
82
+ Scale embeddings by diving by sqrt(d_model).
83
+ use_cache (`bool`, *optional*, defaults to `True`):
84
+ Whether or not the model should return the last key/values attentions (not used by all models).
85
+ num_labels: (`int`, *optional*, defaults to 3):
86
+ The number of labels to use in [`BartForSequenceClassification`].
87
+ forced_eos_token_id (`int`, *optional*, defaults to 2):
88
+ The id of the token to force as the last generated token when `max_length` is reached. Usually set to
89
+ `eos_token_id`.
90
+ use_scan (`bool`, *optional*, defaults to `False`):
91
+ Whether or not to use nn.scan in the Flax Bart attention layers.
92
+
93
+ Example:
94
+
95
+ ```python
96
+ >>> from transformers import BartModel, BartConfig
97
+
98
+ >>> # Initializing a BART facebook/bart-large style configuration
99
+ >>> configuration = BartConfig()
100
+
101
+ >>> # Initializing a model from the facebook/bart-large style configuration
102
+ >>> model = BartModel(configuration)
103
+
104
+ >>> # Accessing the model configuration
105
+ >>> configuration = model.config
106
+ ```"""
107
+ model_type = "bart"
108
+ keys_to_ignore_at_inference = ["past_key_values"]
109
+ attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
110
+
111
+ def __init__(
112
+ self,
113
+ vocab_size=50265,
114
+ max_position_embeddings=1024,
115
+ encoder_layers=12,
116
+ encoder_ffn_dim=4096,
117
+ encoder_attention_heads=16,
118
+ decoder_layers=12,
119
+ decoder_ffn_dim=4096,
120
+ decoder_attention_heads=16,
121
+ encoder_layerdrop=0.0,
122
+ decoder_layerdrop=0.0,
123
+ activation_function="gelu",
124
+ d_model=1024,
125
+ dropout=0.1,
126
+ attention_dropout=0.0,
127
+ activation_dropout=0.0,
128
+ init_std=0.02,
129
+ classifier_dropout=0.0,
130
+ scale_embedding=False,
131
+ use_cache=True,
132
+ use_scan=False,
133
+ fuse_matmuls=False,
134
+ num_labels=3,
135
+ pad_token_id=1,
136
+ bos_token_id=0,
137
+ eos_token_id=2,
138
+ is_encoder_decoder=True,
139
+ decoder_start_token_id=2,
140
+ forced_eos_token_id=2,
141
+ **kwargs
142
+ ):
143
+ self.vocab_size = vocab_size
144
+ self.max_position_embeddings = max_position_embeddings
145
+ self.d_model = d_model
146
+ self.encoder_ffn_dim = encoder_ffn_dim
147
+ self.encoder_layers = encoder_layers
148
+ self.encoder_attention_heads = encoder_attention_heads
149
+ self.decoder_ffn_dim = decoder_ffn_dim
150
+ self.decoder_layers = decoder_layers
151
+ self.decoder_attention_heads = decoder_attention_heads
152
+ self.dropout = dropout
153
+ self.attention_dropout = attention_dropout
154
+ self.activation_dropout = activation_dropout
155
+ self.activation_function = activation_function
156
+ self.init_std = init_std
157
+ self.encoder_layerdrop = encoder_layerdrop
158
+ self.decoder_layerdrop = decoder_layerdrop
159
+ self.classifier_dropout = classifier_dropout
160
+ self.use_cache = use_cache
161
+ self.use_scan = use_scan
162
+ self.fuse_matmuls = fuse_matmuls
163
+ self.num_hidden_layers = encoder_layers
164
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
165
+
166
+ super().__init__(
167
+ num_labels=num_labels,
168
+ pad_token_id=pad_token_id,
169
+ bos_token_id=bos_token_id,
170
+ eos_token_id=eos_token_id,
171
+ is_encoder_decoder=is_encoder_decoder,
172
+ decoder_start_token_id=decoder_start_token_id,
173
+ forced_eos_token_id=forced_eos_token_id,
174
+ **kwargs,
175
+ )
176
+
177
+ # ensure backward compatibility for BART CNN models
178
+ if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
179
+ self.forced_bos_token_id = self.bos_token_id
180
+ warnings.warn(
181
+ f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
182
+ "The config can simply be saved and uploaded again to be fixed."
183
+ )
models/configuration_speech_encoder_decoder.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import copy
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.utils import logging
21
+ from models.configuration_wav2vec2 import Wav2Vec2Config
22
+ from models.configuration_bart import BartConfig
23
+ from transformers import AutoConfig
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ class SpeechEncoderDecoderConfig(PretrainedConfig):
30
+ r"""
31
+ [`SpeechEncoderDecoderConfig`] is the configuration class to store the configuration of a
32
+ [`SpeechEncoderDecoderModel`]. It is used to instantiate an Encoder Decoder model according to the specified
33
+ arguments, defining the encoder and decoder configs.
34
+
35
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
+ documentation from [`PretrainedConfig`] for more information.
37
+
38
+ Args:
39
+ kwargs (*optional*):
40
+ Dictionary of keyword arguments. Notably:
41
+
42
+ - **encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
43
+ the encoder config.
44
+ - **decoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
45
+ the decoder config.
46
+
47
+ Examples:
48
+
49
+ ```python
50
+ >>> from transformers import BertConfig, Wav2Vec2Config, SpeechEncoderDecoderConfig, SpeechEncoderDecoderModel
51
+
52
+ >>> # Initializing a Wav2Vec2 & BERT style configuration
53
+ >>> config_encoder = Wav2Vec2Config()
54
+ >>> config_decoder = BertConfig()
55
+
56
+ >>> config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)
57
+
58
+ >>> # Initializing a Wav2Vec2Bert model from a Wav2Vec2 & bert-base-uncased style configurations
59
+ >>> model = SpeechEncoderDecoderModel(config=config)
60
+
61
+ >>> # Accessing the model configuration
62
+ >>> config_encoder = model.config.encoder
63
+ >>> config_decoder = model.config.decoder
64
+ >>> # set decoder config to causal lm
65
+ >>> config_decoder.is_decoder = True
66
+ >>> config_decoder.add_cross_attention = True
67
+
68
+ >>> # Saving the model, including its configuration
69
+ >>> model.save_pretrained("my-model")
70
+
71
+ >>> # loading model and config from pretrained folder
72
+ >>> encoder_decoder_config = SpeechEncoderDecoderConfig.from_pretrained("my-model")
73
+ >>> model = SpeechEncoderDecoderModel.from_pretrained("my-model", config=encoder_decoder_config)
74
+ ```"""
75
+ model_type = "speech-encoder-decoder"
76
+ is_composition = True
77
+
78
+ def __init__(self, **kwargs):
79
+ super().__init__(**kwargs)
80
+ if "encoder" not in kwargs or "decoder" not in kwargs:
81
+ raise ValueError(
82
+ f"A configuraton of type {self.model_type} cannot be instantiated because not both `encoder` and `decoder` sub-configurations are passed, but only {kwargs}"
83
+ )
84
+
85
+ encoder_config = kwargs.pop("encoder")
86
+ decoder_config = kwargs.pop("decoder")
87
+
88
+ # TODO: Load configs from AutoConfig (as done in Transformers 🤗)
89
+ self.encoder = Wav2Vec2Config(**encoder_config)
90
+ self.decoder = BartConfig(**decoder_config)
91
+ self.is_encoder_decoder = True
92
+
93
+ @classmethod
94
+ def from_encoder_decoder_configs(
95
+ cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
96
+ ) -> PretrainedConfig:
97
+ r"""
98
+ Instantiate a [`SpeechEncoderDecoderConfig`] (or a derived class) from a pre-trained encoder model
99
+ configuration and decoder model configuration.
100
+
101
+ Returns:
102
+ [`SpeechEncoderDecoderConfig`]: An instance of a configuration object
103
+ """
104
+ logger.info("Setting `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
105
+ decoder_config.is_decoder = True
106
+ decoder_config.add_cross_attention = True
107
+
108
+ return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)
109
+
110
+ def to_dict(self):
111
+ """
112
+ Serializes this instance to a Python dictionary. Override the default *to_dict()* from *PretrainedConfig*.
113
+
114
+ Returns:
115
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
116
+ """
117
+ output = copy.deepcopy(self.__dict__)
118
+ output["encoder"] = self.encoder.to_dict()
119
+ output["decoder"] = self.decoder.to_dict()
120
+ output["model_type"] = self.__class__.model_type
121
+ return output
models/configuration_wav2vec2.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Wav2Vec2 model configuration"""
16
+
17
+ import functools
18
+ import operator
19
+
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.utils import logging
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
27
+ "facebook/wav2vec2-base-960h": "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/config.json",
28
+ # See all Wav2Vec2 models at https://huggingface.co/models?filter=wav2vec2
29
+ }
30
+
31
+
32
+ class Wav2Vec2Config(PretrainedConfig):
33
+ r"""
34
+ This is the configuration class to store the configuration of a [`Wav2Vec2Model`]. It is used to instantiate an
35
+ Wav2Vec2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
36
+ with the defaults will yield a similar configuration to that of the Wav2Vec2
37
+ [facebook/wav2vec2-base-960h](https://huggingface.co/facebook/wav2vec2-base-960h) architecture.
38
+
39
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
40
+ documentation from [`PretrainedConfig`] for more information.
41
+
42
+
43
+ Args:
44
+ vocab_size (`int`, *optional*, defaults to 32):
45
+ Vocabulary size of the Wav2Vec2 model. Defines the number of different tokens that can be represented by
46
+ the `inputs_ids` passed when calling [`Wav2Vec2Model`] or [`TFWav2Vec2Model`]. Vocabulary size of the
47
+ model. Defines the different tokens that can be represented by the *inputs_ids* passed to the forward
48
+ method of [`Wav2Vec2Model`].
49
+ hidden_size (`int`, *optional*, defaults to 768):
50
+ Dimensionality of the encoder layers and the pooler layer.
51
+ num_hidden_layers (`int`, *optional*, defaults to 12):
52
+ Number of hidden layers in the Transformer encoder.
53
+ num_attention_heads (`int`, *optional*, defaults to 12):
54
+ Number of attention heads for each attention layer in the Transformer encoder.
55
+ intermediate_size (`int`, *optional*, defaults to 3072):
56
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
57
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
58
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
59
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
60
+ hidden_dropout (`float`, *optional*, defaults to 0.1):
61
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
62
+ attention_dropout (`float`, *optional*, defaults to 0.1):
63
+ The dropout ratio for the attention probabilities.
64
+ final_dropout (`float`, *optional*, defaults to 0.1):
65
+ The dropout probability for the final projection layer of [`Wav2Vec2ForCTC`].
66
+ initializer_range (`float`, *optional*, defaults to 0.02):
67
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
68
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
69
+ The epsilon used by the layer normalization layers.
70
+ feat_extract_norm (`str`, *optional*, defaults to `"group"`):
71
+ The norm to be applied to 1D convolutional layers in feature encoder. One of `"group"` for group
72
+ normalization of only the first 1D convolutional layer or `"layer"` for layer normalization of all 1D
73
+ convolutional layers.
74
+ feat_proj_dropout (`float`, *optional*, defaults to 0.0):
75
+ The dropout probability for output of the feature encoder.
76
+ feat_extract_activation (`str, `optional`, defaults to `"gelu"`):
77
+ The non-linear activation function (function or string) in the 1D convolutional layers of the feature
78
+ extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
79
+ feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
80
+ The dropout probabilitiy for quantized feature encoder states.
81
+ conv_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
82
+ A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
83
+ feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
84
+ conv_stride (`Tuple[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
85
+ A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
86
+ of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*.
87
+ conv_kernel (`Tuple[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
88
+ A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
89
+ length of *conv_kernel* defines the number of convolutional layers and has to match the length of
90
+ *conv_dim*.
91
+ conv_bias (`bool`, *optional*, defaults to `False`):
92
+ Whether the 1D convolutional layers have a bias.
93
+ num_conv_pos_embeddings (`int`, *optional*, defaults to 128):
94
+ Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional
95
+ embeddings layer.
96
+ num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
97
+ Number of groups of 1D convolutional positional embeddings layer.
98
+ do_stable_layer_norm (`bool`, *optional*, defaults to `False`):
99
+ Whether to apply *stable* layer norm architecture of the Transformer encoder. `do_stable_layer_norm is
100
+ True` corresponds to applying layer norm before the attention layer, whereas `do_stable_layer_norm is
101
+ False` corresponds to applying layer norm after the attention layer.
102
+ apply_spec_augment (`bool`, *optional*, defaults to `True`):
103
+ Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see
104
+ [SpecAugment: A Simple Data Augmentation Method for Automatic Speech
105
+ Recognition](https://arxiv.org/abs/1904.08779).
106
+ mask_time_prob (`float`, *optional*, defaults to 0.05):
107
+ Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking
108
+ procecure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If
109
+ reasoning from the propability of each feature vector to be chosen as the start of the vector span to be
110
+ masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the
111
+ actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`.
112
+ mask_time_length (`int`, *optional*, defaults to 10):
113
+ Length of vector span along the time axis.
114
+ mask_time_min_masks (`int`, *optional*, defaults to 2),:
115
+ The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step,
116
+ irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length <
117
+ mask_time_min_masks''
118
+ mask_feature_prob (`float`, *optional*, defaults to 0.0):
119
+ Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The
120
+ masking procecure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over
121
+ the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector
122
+ span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap
123
+ may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is
124
+ True`.
125
+ mask_feature_length (`int`, *optional*, defaults to 10):
126
+ Length of vector span along the feature axis.
127
+ mask_feature_min_masks (`int`, *optional*, defaults to 0),:
128
+ The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
129
+ step, irrespectively of `mask_feature_prob`. Only relevant if
130
+ ''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks''
131
+ num_codevectors_per_group (`int`, *optional*, defaults to 320):
132
+ Number of entries in each quantization codebook (group).
133
+ num_codevector_groups (`int`, *optional*, defaults to 2):
134
+ Number of codevector groups for product codevector quantization.
135
+ contrastive_logits_temperature (`float`, *optional*, defaults to 0.1):
136
+ The temperature *kappa* in the contrastive loss.
137
+ feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
138
+ The dropout probabilitiy for the output of the feature encoder that's used by the quantizer.
139
+ num_negatives (`int`, *optional*, defaults to 100):
140
+ Number of negative samples for the contrastive loss.
141
+ codevector_dim (`int`, *optional*, defaults to 256):
142
+ Dimensionality of the quantized feature vectors.
143
+ proj_codevector_dim (`int`, *optional*, defaults to 256):
144
+ Dimensionality of the final projection of both the quantized and the transformer features.
145
+ diversity_loss_weight (`int`, *optional*, defaults to 0.1):
146
+ The weight of the codebook diversity loss component.
147
+ ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`):
148
+ Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
149
+ instance of [`Wav2Vec2ForCTC`].
150
+ ctc_zero_infinity (`bool`, *optional*, defaults to `False`):
151
+ Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
152
+ occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
153
+ of [`Wav2Vec2ForCTC`].
154
+ use_weighted_layer_sum (`bool`, *optional*, defaults to `False`):
155
+ Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
156
+ instance of [`Wav2Vec2ForSequenceClassification`].
157
+ classifier_proj_size (`int`, *optional*, defaults to 256):
158
+ Dimensionality of the projection before token mean-pooling for classification.
159
+ tdnn_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
160
+ A tuple of integers defining the number of output channels of each 1D convolutional layer in the *TDNN*
161
+ module of the *XVector* model. The length of *tdnn_dim* defines the number of *TDNN* layers.
162
+ tdnn_kernel (`Tuple[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
163
+ A tuple of integers defining the kernel size of each 1D convolutional layer in the *TDNN* module of the
164
+ *XVector* model. The length of *tdnn_kernel* has to match the length of *tdnn_dim*.
165
+ tdnn_dilation (`Tuple[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
166
+ A tuple of integers defining the dilation factor of each 1D convolutional layer in *TDNN* module of the
167
+ *XVector* model. The length of *tdnn_dilation* has to match the length of *tdnn_dim*.
168
+ xvector_output_dim (`int`, *optional*, defaults to 512):
169
+ Dimensionality of the *XVector* embedding vectors.
170
+ add_adapter (`bool`, *optional*, defaults to `False`):
171
+ Whether a convolutional network should be stacked on top of the Wav2Vec2 Encoder. Can be very useful for
172
+ warm-starting Wav2Vec2 for SpeechEncoderDecoder models.
173
+ adapter_kernel_size (`int`, *optional*, defaults to 3):
174
+ Kernel size of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
175
+ adapter_stride (`int`, *optional*, defaults to 2):
176
+ Stride of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
177
+ num_adapter_layers (`int`, *optional*, defaults to 3):
178
+ Number of convolutional layers that should be used in the adapter network. Only relevant if `add_adapter is
179
+ True`.
180
+ output_hidden_size (`int`, *optional*):
181
+ Dimensionality of the encoder output layer. If not defined, this defaults to *hidden-size*. Only relevant
182
+ if `add_adapter is True`.
183
+ use_scan (`bool`, *optional*, defaults to `False`):
184
+ Whether or not to use nn.scan in the Flax Wav2Vec2 transformer layers.
185
+
186
+ Example:
187
+
188
+ ```python
189
+ >>> from transformers import Wav2Vec2Model, Wav2Vec2Config
190
+
191
+ >>> # Initializing a Wav2Vec2 facebook/wav2vec2-base-960h style configuration
192
+ >>> configuration = Wav2Vec2Config()
193
+
194
+ >>> # Initializing a model from the facebook/wav2vec2-base-960h style configuration
195
+ >>> model = Wav2Vec2Model(configuration)
196
+
197
+ >>> # Accessing the model configuration
198
+ >>> configuration = model.config
199
+ ```"""
200
+ model_type = "wav2vec2"
201
+
202
+ def __init__(
203
+ self,
204
+ vocab_size=32,
205
+ hidden_size=768,
206
+ num_hidden_layers=12,
207
+ num_attention_heads=12,
208
+ intermediate_size=3072,
209
+ hidden_act="gelu",
210
+ hidden_dropout=0.1,
211
+ activation_dropout=0.1,
212
+ attention_dropout=0.1,
213
+ feat_proj_dropout=0.0,
214
+ feat_quantizer_dropout=0.0,
215
+ final_dropout=0.1,
216
+ layerdrop=0.1,
217
+ initializer_range=0.02,
218
+ layer_norm_eps=1e-5,
219
+ feat_extract_norm="group",
220
+ feat_extract_activation="gelu",
221
+ conv_dim=(512, 512, 512, 512, 512, 512, 512),
222
+ conv_stride=(5, 2, 2, 2, 2, 2, 2),
223
+ conv_kernel=(10, 3, 3, 3, 3, 2, 2),
224
+ conv_bias=False,
225
+ num_conv_pos_embeddings=128,
226
+ num_conv_pos_embedding_groups=16,
227
+ do_stable_layer_norm=False,
228
+ apply_spec_augment=True,
229
+ mask_time_prob=0.05,
230
+ mask_time_length=10,
231
+ mask_time_min_masks=2,
232
+ mask_feature_prob=0.0,
233
+ mask_feature_length=10,
234
+ mask_feature_min_masks=0,
235
+ num_codevectors_per_group=320,
236
+ num_codevector_groups=2,
237
+ contrastive_logits_temperature=0.1,
238
+ num_negatives=100,
239
+ codevector_dim=256,
240
+ proj_codevector_dim=256,
241
+ diversity_loss_weight=0.1,
242
+ ctc_loss_reduction="sum",
243
+ ctc_zero_infinity=False,
244
+ use_weighted_layer_sum=False,
245
+ classifier_proj_size=256,
246
+ tdnn_dim=(512, 512, 512, 512, 1500),
247
+ tdnn_kernel=(5, 3, 3, 1, 1),
248
+ tdnn_dilation=(1, 2, 3, 1, 1),
249
+ xvector_output_dim=512,
250
+ pad_token_id=0,
251
+ bos_token_id=1,
252
+ eos_token_id=2,
253
+ add_adapter=False,
254
+ adapter_kernel_size=3,
255
+ adapter_stride=2,
256
+ num_adapter_layers=3,
257
+ output_hidden_size=None,
258
+ use_scan=False,
259
+ fuse_matmuls=False,
260
+ **kwargs
261
+ ):
262
+ super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
263
+ self.hidden_size = hidden_size
264
+ self.feat_extract_norm = feat_extract_norm
265
+ self.feat_extract_activation = feat_extract_activation
266
+ self.conv_dim = list(conv_dim)
267
+ self.conv_stride = list(conv_stride)
268
+ self.conv_kernel = list(conv_kernel)
269
+ self.conv_bias = conv_bias
270
+ self.num_conv_pos_embeddings = num_conv_pos_embeddings
271
+ self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
272
+ self.num_feat_extract_layers = len(self.conv_dim)
273
+ self.num_hidden_layers = num_hidden_layers
274
+ self.intermediate_size = intermediate_size
275
+ self.hidden_act = hidden_act
276
+ self.num_attention_heads = num_attention_heads
277
+ self.hidden_dropout = hidden_dropout
278
+ self.attention_dropout = attention_dropout
279
+ self.activation_dropout = activation_dropout
280
+ self.feat_proj_dropout = feat_proj_dropout
281
+ self.final_dropout = final_dropout
282
+ self.layerdrop = layerdrop
283
+ self.layer_norm_eps = layer_norm_eps
284
+ self.initializer_range = initializer_range
285
+ self.vocab_size = vocab_size
286
+ self.do_stable_layer_norm = do_stable_layer_norm
287
+ self.use_weighted_layer_sum = use_weighted_layer_sum
288
+ self.use_scan = use_scan
289
+ self.fuse_matmuls = fuse_matmuls
290
+
291
+ if (
292
+ (len(self.conv_stride) != self.num_feat_extract_layers)
293
+ or (len(self.conv_kernel) != self.num_feat_extract_layers)
294
+ or (len(self.conv_dim) != self.num_feat_extract_layers)
295
+ ):
296
+ raise ValueError(
297
+ "Configuration for convolutional layers is incorrect. "
298
+ "It is required that `len(config.conv_dim)` == `len(config.conv_stride)` == `len(config.conv_kernel)`, "
299
+ f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride) "
300
+ f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`."
301
+ )
302
+
303
+ # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
304
+ self.apply_spec_augment = apply_spec_augment
305
+ self.mask_time_prob = mask_time_prob
306
+ self.mask_time_length = mask_time_length
307
+ self.mask_time_min_masks = mask_time_min_masks
308
+ self.mask_feature_prob = mask_feature_prob
309
+ self.mask_feature_length = mask_feature_length
310
+ self.mask_feature_min_masks = mask_feature_min_masks
311
+
312
+ # parameters for pretraining with codevector quantized representations
313
+ self.num_codevectors_per_group = num_codevectors_per_group
314
+ self.num_codevector_groups = num_codevector_groups
315
+ self.contrastive_logits_temperature = contrastive_logits_temperature
316
+ self.feat_quantizer_dropout = feat_quantizer_dropout
317
+ self.num_negatives = num_negatives
318
+ self.codevector_dim = codevector_dim
319
+ self.proj_codevector_dim = proj_codevector_dim
320
+ self.diversity_loss_weight = diversity_loss_weight
321
+
322
+ # ctc loss
323
+ self.ctc_loss_reduction = ctc_loss_reduction
324
+ self.ctc_zero_infinity = ctc_zero_infinity
325
+
326
+ # adapter
327
+ self.add_adapter = add_adapter
328
+ self.adapter_kernel_size = adapter_kernel_size
329
+ self.adapter_stride = adapter_stride
330
+ self.num_adapter_layers = num_adapter_layers
331
+ self.output_hidden_size = output_hidden_size or hidden_size
332
+
333
+ # SequenceClassification-specific parameter. Feel free to ignore for other classes.
334
+ self.classifier_proj_size = classifier_proj_size
335
+
336
+ # XVector-specific parameters. Feel free to ignore for other classes.
337
+ self.tdnn_dim = list(tdnn_dim)
338
+ self.tdnn_kernel = list(tdnn_kernel)
339
+ self.tdnn_dilation = list(tdnn_dilation)
340
+ self.xvector_output_dim = xvector_output_dim
341
+
342
+ @property
343
+ def inputs_to_logits_ratio(self):
344
+ return functools.reduce(operator.mul, self.conv_stride, 1)
models/modeling_flax_bart.py ADDED
@@ -0,0 +1,816 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Flax Bart model."""
16
+
17
+ import math
18
+ import random
19
+ from functools import partial
20
+ from typing import Optional, Tuple
21
+
22
+ import numpy as np
23
+
24
+ import flax.linen as nn
25
+ import jax
26
+ import jax.numpy as jnp
27
+ from flax.core.frozen_dict import FrozenDict, unfreeze
28
+ from flax.linen import combine_masks, make_causal_mask
29
+ from flax.linen import partitioning as nn_partitioning
30
+ from flax.linen.attention import dot_product_attention_weights
31
+ from jax import lax
32
+ from jax.random import PRNGKey
33
+
34
+ from transformers.modeling_flax_outputs import (
35
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
36
+ FlaxCausalLMOutputWithCrossAttentions,
37
+ )
38
+ from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
39
+
40
+ from models import BartConfig
41
+
42
+
43
+ scan_with_axes = nn_partitioning.scan_with_axes
44
+ remat = nn_partitioning.remat
45
+
46
+
47
+ def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
48
+ """
49
+ Shift input ids one token to the right.
50
+ """
51
+ shifted_input_ids = np.zeros_like(input_ids)
52
+ shifted_input_ids[:, 1:] = input_ids[:, :-1]
53
+ shifted_input_ids[:, 0] = decoder_start_token_id
54
+
55
+ shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
56
+ return shifted_input_ids
57
+
58
+
59
+ class FlaxBartAttention(nn.Module):
60
+ config: BartConfig
61
+ embed_dim: int
62
+ num_heads: int
63
+ dropout: float = 0.0
64
+ causal: bool = False
65
+ bias: bool = True
66
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
67
+
68
+ def setup(self) -> None:
69
+ self.head_dim = self.embed_dim // self.num_heads
70
+ if self.head_dim * self.num_heads != self.embed_dim:
71
+ raise ValueError(
72
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
73
+ f" and `num_heads`: {self.num_heads})."
74
+ )
75
+
76
+ dense = partial(
77
+ nn.Dense,
78
+ self.embed_dim,
79
+ use_bias=self.bias,
80
+ dtype=self.dtype,
81
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
82
+ )
83
+
84
+ self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
85
+
86
+ self.fused_proj = nn.Dense(
87
+ self.embed_dim * 3,
88
+ use_bias=self.bias,
89
+ dtype=self.dtype,
90
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
91
+ )
92
+
93
+ self.fused_key_value = nn.Dense(
94
+ self.embed_dim * 2,
95
+ use_bias=self.bias,
96
+ dtype=self.dtype,
97
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
98
+ )
99
+
100
+ self.out_proj = dense()
101
+
102
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
103
+
104
+ if self.causal:
105
+ self.causal_mask = make_causal_mask(
106
+ jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
107
+ )
108
+
109
+ def _split_heads(self, hidden_states):
110
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
111
+
112
+ def _merge_heads(self, hidden_states):
113
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
114
+
115
+ @nn.compact
116
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
117
+ """
118
+ This function takes projected key, value states from a single input token and concatenates the states to cached
119
+ states from previous steps. This function is slighly adapted from the official Flax repository:
120
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
121
+ """
122
+ # detect if we're initializing by absence of existing cache data.
123
+ is_initialized = self.has_variable("cache", "cached_key")
124
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
125
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
126
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
127
+
128
+ if is_initialized:
129
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
130
+ # update key, value caches with our new 1d spatial slices
131
+ cur_index = cache_index.value
132
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
133
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
134
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
135
+ cached_key.value = key
136
+ cached_value.value = value
137
+ num_updated_cache_vectors = query.shape[1]
138
+ cache_index.value = cache_index.value + num_updated_cache_vectors
139
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
140
+ pad_mask = jnp.broadcast_to(
141
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
142
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
143
+ )
144
+ attention_mask = combine_masks(pad_mask, attention_mask)
145
+ return key, value, attention_mask
146
+
147
+ def __call__(
148
+ self,
149
+ hidden_states: jnp.ndarray,
150
+ key_value_states: Optional[jnp.ndarray] = None,
151
+ attention_mask: Optional[jnp.ndarray] = None,
152
+ init_cache: bool = False,
153
+ deterministic: bool = True,
154
+ ) -> Tuple[jnp.ndarray]:
155
+ """Input shape: Batch x Time x Channel"""
156
+
157
+ # if key_value_states are provided this layer is used as a cross-attention layer
158
+ # for the decoder
159
+ is_cross_attention = key_value_states is not None
160
+ batch_size = hidden_states.shape[0]
161
+
162
+ if self.config.fuse_matmuls:
163
+ # get key, value proj
164
+ if is_cross_attention:
165
+ # get query proj
166
+ query_states = self.q_proj(hidden_states)
167
+ # cross_attentions
168
+ attention_states = self.fused_key_value(key_value_states)
169
+ key_states, value_states = jnp.split(attention_states, 2, axis=-1)
170
+ else:
171
+ attention_states = self.fused_proj(hidden_states)
172
+ query_states, key_states, value_states = jnp.split(attention_states, 3, axis=-1)
173
+
174
+ else:
175
+ # get query proj
176
+ query_states = self.q_proj(hidden_states)
177
+ # get key, value proj
178
+ if is_cross_attention:
179
+ # cross_attentions
180
+ key_states = self.k_proj(key_value_states)
181
+ value_states = self.v_proj(key_value_states)
182
+ else:
183
+ # self_attention
184
+ key_states = self.k_proj(hidden_states)
185
+ value_states = self.v_proj(hidden_states)
186
+
187
+ query_states = self._split_heads(query_states)
188
+ key_states = self._split_heads(key_states)
189
+ value_states = self._split_heads(value_states)
190
+
191
+ # handle cache prepare causal attention mask
192
+ if self.causal:
193
+ query_length, key_length = query_states.shape[1], key_states.shape[1]
194
+ if self.has_variable("cache", "cached_key"):
195
+ mask_shift = self.variables["cache"]["cache_index"]
196
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
197
+ causal_mask = lax.dynamic_slice(
198
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
199
+ )
200
+ else:
201
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
202
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
203
+
204
+ # combine masks if needed
205
+ if attention_mask is not None and self.causal:
206
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
207
+ attention_mask = combine_masks(attention_mask, causal_mask)
208
+ elif self.causal:
209
+ attention_mask = causal_mask
210
+ elif attention_mask is not None:
211
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
212
+
213
+ # During fast autoregressive decoding, we feed one position at a time,
214
+ # and cache the keys and values step by step.
215
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
216
+ key_states, value_states, attention_mask = self._concatenate_to_cache(
217
+ key_states, value_states, query_states, attention_mask
218
+ )
219
+
220
+ # Convert the boolean attention mask to an attention bias.
221
+ if attention_mask is not None:
222
+ # attention mask in the form of attention bias
223
+ attention_bias = lax.select(
224
+ attention_mask > 0,
225
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
226
+ jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
227
+ )
228
+ else:
229
+ attention_bias = None
230
+
231
+ dropout_rng = None
232
+ if not deterministic and self.dropout > 0.0:
233
+ dropout_rng = self.make_rng("dropout")
234
+
235
+ attn_weights = dot_product_attention_weights(
236
+ query_states,
237
+ key_states,
238
+ bias=attention_bias,
239
+ dropout_rng=dropout_rng,
240
+ dropout_rate=self.dropout,
241
+ broadcast_dropout=True,
242
+ deterministic=deterministic,
243
+ dtype=self.dtype,
244
+ precision=None,
245
+ )
246
+
247
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
248
+ attn_output = self._merge_heads(attn_output)
249
+ attn_output = self.out_proj(attn_output)
250
+
251
+ return attn_output, attn_weights
252
+
253
+
254
+ class FlaxBartDecoderLayer(nn.Module):
255
+ config: BartConfig
256
+ dtype: jnp.dtype = jnp.float32
257
+
258
+ def setup(self) -> None:
259
+ self.embed_dim = self.config.d_model
260
+ self.self_attn = FlaxBartAttention(
261
+ config=self.config,
262
+ embed_dim=self.embed_dim,
263
+ num_heads=self.config.decoder_attention_heads,
264
+ dropout=self.config.attention_dropout,
265
+ causal=True,
266
+ dtype=self.dtype,
267
+ )
268
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
269
+ self.activation_fn = ACT2FN[self.config.activation_function]
270
+ self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
271
+
272
+ self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
273
+ self.encoder_attn = FlaxBartAttention(
274
+ config=self.config,
275
+ embed_dim=self.embed_dim,
276
+ num_heads=self.config.decoder_attention_heads,
277
+ dropout=self.config.attention_dropout,
278
+ dtype=self.dtype,
279
+ )
280
+ self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
281
+ self.fc1 = nn.Dense(
282
+ self.config.encoder_ffn_dim,
283
+ dtype=self.dtype,
284
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
285
+ )
286
+ self.fc2 = nn.Dense(
287
+ self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
288
+ )
289
+ self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
290
+
291
+ def __call__(
292
+ self,
293
+ hidden_states: jnp.ndarray,
294
+ attention_mask: jnp.ndarray,
295
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
296
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
297
+ init_cache: bool = False,
298
+ output_attentions: bool = True,
299
+ deterministic: bool = True,
300
+ ) -> Tuple[jnp.ndarray]:
301
+
302
+ if self.config.use_scan:
303
+ hidden_states = hidden_states[0]
304
+
305
+ residual = hidden_states
306
+
307
+ # Self Attention
308
+ hidden_states, self_attn_weights = self.self_attn(
309
+ hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache
310
+ )
311
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
312
+ hidden_states = residual + hidden_states
313
+ hidden_states = self.self_attn_layer_norm(hidden_states)
314
+
315
+ # Cross-Attention Block
316
+ cross_attn_weights = None
317
+ if encoder_hidden_states is not None:
318
+ residual = hidden_states
319
+
320
+ hidden_states, cross_attn_weights = self.encoder_attn(
321
+ hidden_states=hidden_states,
322
+ key_value_states=encoder_hidden_states,
323
+ attention_mask=encoder_attention_mask,
324
+ )
325
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
326
+ hidden_states = residual + hidden_states
327
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
328
+
329
+ # Fully Connected
330
+ residual = hidden_states
331
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
332
+ hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
333
+ hidden_states = self.fc2(hidden_states)
334
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
335
+ hidden_states = residual + hidden_states
336
+ hidden_states = self.final_layer_norm(hidden_states)
337
+
338
+ outputs = (hidden_states,)
339
+
340
+ if output_attentions:
341
+ outputs += (self_attn_weights, cross_attn_weights)
342
+
343
+ if self.config.use_scan:
344
+ outputs = (outputs, None)
345
+
346
+ return outputs
347
+
348
+
349
+ class FlaxBartDecoderLayerCollection(nn.Module):
350
+ config: BartConfig
351
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
352
+
353
+ @nn.compact
354
+ def __call__(
355
+ self,
356
+ hidden_states,
357
+ attention_mask,
358
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
359
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
360
+ deterministic: bool = True,
361
+ init_cache: bool = False,
362
+ output_attentions: bool = False,
363
+ output_hidden_states: bool = False,
364
+ return_dict: bool = True,
365
+ ):
366
+ # decoder layers
367
+ all_hidden_states = () if output_hidden_states else None
368
+ all_self_attns = () if output_attentions else None
369
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
370
+
371
+ num_decoder_layers = self.config.decoder_layers
372
+ BlockDecoderLayer = (
373
+ remat(
374
+ FlaxBartDecoderLayer,
375
+ static_argnums=(4, 5, 6),
376
+ prevent_cse=not self.config.use_scan,
377
+ )
378
+ if self.config.gradient_checkpointing
379
+ else FlaxBartDecoderLayer
380
+ )
381
+
382
+ if self.config.use_scan:
383
+ # since all decoder layers are the same, we use nn.scan directly
384
+ assert not output_attentions, "cannot use `scan` with `output_attentions` set to `True`"
385
+ assert not output_hidden_states, "cannot use `scan` with `output_hidden_states` set to `True`"
386
+ hidden_states = (hidden_states,)
387
+
388
+ # TODO: add layerdrop in checkpointed scan (note: default value for layerdrop in config is zero)
389
+ hidden_states, _ = scan_with_axes(
390
+ BlockDecoderLayer,
391
+ variable_axes={"params": 0, "cache": 0},
392
+ split_rngs={"params": True, "dropout": True},
393
+ in_axes=(nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast),
394
+ length=num_decoder_layers,
395
+ )(self.config, dtype=self.dtype, name="FlaxBartDecoderLayers")(
396
+ hidden_states,
397
+ attention_mask,
398
+ encoder_hidden_states,
399
+ encoder_attention_mask,
400
+ init_cache,
401
+ output_attentions,
402
+ deterministic,
403
+ )
404
+ hidden_states = hidden_states[0]
405
+
406
+ else:
407
+ for layer in range(num_decoder_layers):
408
+ if output_hidden_states:
409
+ all_hidden_states += (hidden_states,)
410
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
411
+ dropout_probability = random.uniform(0, 1)
412
+ if not deterministic and (dropout_probability < self.config.decoder_layerdrop):
413
+ layer_outputs = (None, None, None)
414
+ else:
415
+ layer_outputs = BlockDecoderLayer(self.config, dtype=self.dtype, name=str(layer),)(
416
+ hidden_states,
417
+ attention_mask,
418
+ encoder_hidden_states,
419
+ encoder_attention_mask,
420
+ init_cache,
421
+ output_attentions,
422
+ deterministic,
423
+ )
424
+
425
+ hidden_states = layer_outputs[0]
426
+ if output_attentions:
427
+ all_self_attns += (layer_outputs[1],)
428
+
429
+ if encoder_hidden_states is not None:
430
+ all_cross_attentions += (layer_outputs[2],)
431
+
432
+ # add hidden states from the last decoder layer
433
+ if output_hidden_states:
434
+ all_hidden_states += (hidden_states,)
435
+
436
+ outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions]
437
+
438
+ if not return_dict:
439
+ return tuple(v for v in outputs if v is not None)
440
+
441
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
442
+ last_hidden_state=hidden_states,
443
+ hidden_states=all_hidden_states,
444
+ attentions=all_self_attns,
445
+ cross_attentions=all_cross_attentions,
446
+ )
447
+
448
+
449
+ class FlaxBartDecoder(nn.Module):
450
+ config: BartConfig
451
+ embed_tokens: nn.Embed
452
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
453
+
454
+ def setup(self):
455
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
456
+
457
+ embed_dim = self.config.d_model
458
+ self.padding_idx = self.config.pad_token_id
459
+ self.max_target_positions = self.config.max_position_embeddings
460
+ self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
461
+
462
+ # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
463
+ # and adjust num_embeddings appropriately. Other models don't have this hack
464
+ self.offset = 2
465
+ self.embed_positions = nn.Embed(
466
+ self.config.max_position_embeddings + self.offset,
467
+ embed_dim,
468
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
469
+ )
470
+
471
+ self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
472
+ self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
473
+
474
+ def __call__(
475
+ self,
476
+ input_ids,
477
+ attention_mask,
478
+ position_ids,
479
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
480
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
481
+ init_cache: bool = False,
482
+ output_attentions: bool = False,
483
+ output_hidden_states: bool = False,
484
+ return_dict: bool = True,
485
+ deterministic: bool = True,
486
+ ):
487
+ input_shape = input_ids.shape
488
+ input_ids = input_ids.reshape(-1, input_shape[-1])
489
+
490
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
491
+
492
+ # embed positions
493
+ positions = self.embed_positions(position_ids + self.offset)
494
+
495
+ hidden_states = inputs_embeds + positions
496
+ hidden_states = self.layernorm_embedding(hidden_states)
497
+
498
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
499
+
500
+ outputs = self.layers(
501
+ hidden_states,
502
+ attention_mask,
503
+ encoder_hidden_states,
504
+ encoder_attention_mask,
505
+ deterministic=deterministic,
506
+ init_cache=init_cache,
507
+ output_attentions=output_attentions,
508
+ output_hidden_states=output_hidden_states,
509
+ return_dict=return_dict,
510
+ )
511
+
512
+ if not return_dict:
513
+ return outputs
514
+
515
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
516
+ last_hidden_state=outputs.last_hidden_state,
517
+ hidden_states=outputs.hidden_states,
518
+ attentions=outputs.attentions,
519
+ cross_attentions=outputs.cross_attentions,
520
+ )
521
+
522
+
523
+ class FlaxBartDecoderPreTrainedModel(FlaxPreTrainedModel):
524
+ config_class = BartConfig
525
+ base_model_prefix: str = "model"
526
+ module_class: nn.Module = None
527
+
528
+ def __init__(
529
+ self,
530
+ config: BartConfig,
531
+ input_shape: Tuple[int] = (1, 1),
532
+ seed: int = 0,
533
+ dtype: jnp.dtype = jnp.float32,
534
+ _do_init: bool = True,
535
+ **kwargs
536
+ ):
537
+ config.is_decoder = True
538
+ config.is_encoder_decoder = False
539
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
540
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
541
+
542
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
543
+ # init input tensors
544
+ input_ids = jnp.zeros(input_shape, dtype="i4")
545
+ attention_mask = jnp.ones_like(input_ids)
546
+
547
+ batch_size, sequence_length = input_ids.shape
548
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
549
+
550
+ params_rng, dropout_rng = jax.random.split(rng)
551
+ rngs = {"params": params_rng, "dropout": dropout_rng}
552
+ encoder_hidden_states = jnp.zeros(input_shape + (self.config.d_model,))
553
+ encoder_attention_mask = attention_mask
554
+ module_init_outputs = self.module.init(
555
+ rngs,
556
+ input_ids,
557
+ attention_mask,
558
+ position_ids,
559
+ encoder_hidden_states,
560
+ encoder_attention_mask,
561
+ return_dict=False,
562
+ )
563
+ return module_init_outputs["params"]
564
+
565
+ def init_cache(self, batch_size, max_length):
566
+ r"""
567
+ Args:
568
+ batch_size (`int`):
569
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
570
+ max_length (`int`):
571
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
572
+ cache.
573
+ """
574
+ # init input variables to retrieve cache
575
+ input_ids = jnp.ones((batch_size, max_length), dtype="i4")
576
+ attention_mask = jnp.ones_like(input_ids, dtype="i4")
577
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
578
+
579
+ init_variables = self.module.init(
580
+ jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
581
+ )
582
+ return unfreeze(init_variables["cache"])
583
+
584
+ def __call__(
585
+ self,
586
+ input_ids: jnp.ndarray,
587
+ attention_mask: Optional[jnp.ndarray] = None,
588
+ position_ids: Optional[jnp.ndarray] = None,
589
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
590
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
591
+ output_attentions: Optional[bool] = None,
592
+ output_hidden_states: Optional[bool] = None,
593
+ return_dict: Optional[bool] = None,
594
+ train: bool = False,
595
+ params: dict = None,
596
+ past_key_values: dict = None,
597
+ dropout_rng: PRNGKey = None,
598
+ ):
599
+ """
600
+ Args:
601
+ input_ids (`jnp.ndarray` of shape `(target_batch_size, target_sequence_length)`):
602
+ Indices of decoder input sequence tokens in the vocabulary.
603
+
604
+ Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
605
+ [`PreTrainedTokenizer.__call__`] for details.
606
+
607
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
608
+
609
+ For translation and summarization training, `decoder_input_ids` should be provided. If no
610
+ `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
611
+ for denoising pre-training following the paper.
612
+ attention_mask (`jnp.ndarray` of shape `(target_batch_size, target_sequence_length)`, *optional*):
613
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
614
+ be used by default.
615
+
616
+ If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the
617
+ paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
618
+ position_ids (`numpy.ndarray` of shape `(target_batch_size, sequence_length)`, *optional*):
619
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
620
+ range `[0, config.max_position_embeddings - 1]`.
621
+ encoder_hidden_states (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
622
+ A sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
623
+ encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
624
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
625
+
626
+ - 1 for tokens that are **not masked**,
627
+ - 0 for tokens that are **masked**.
628
+
629
+ [What are attention masks?](../glossary#attention-mask)
630
+ output_attentions (`bool`, *optional*):
631
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
632
+ tensors for more detail.
633
+ output_hidden_states (`bool`, *optional*):
634
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
635
+ more detail.
636
+ return_dict (`bool`, *optional*):
637
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
638
+ past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
639
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
640
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
641
+ """
642
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
643
+ output_hidden_states = (
644
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
645
+ )
646
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
647
+
648
+ if encoder_hidden_states is not None and encoder_attention_mask is None:
649
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
650
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
651
+
652
+ # prepare decoder inputs
653
+ if attention_mask is None:
654
+ attention_mask = jnp.ones_like(input_ids)
655
+ if position_ids is None:
656
+ batch_size, sequence_length = input_ids.shape
657
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
658
+
659
+ # Handle any PRNG if needed
660
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
661
+
662
+ inputs = {"params": params or self.params}
663
+
664
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
665
+ # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
666
+ # changed by FlaxBartAttention module
667
+ if past_key_values:
668
+ inputs["cache"] = past_key_values
669
+ mutable = ["cache"]
670
+ else:
671
+ mutable = False
672
+
673
+ outputs = self.module.apply(
674
+ inputs,
675
+ input_ids=jnp.array(input_ids, dtype="i4"),
676
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
677
+ position_ids=jnp.array(position_ids, dtype="i4"),
678
+ encoder_hidden_states=encoder_hidden_states,
679
+ encoder_attention_mask=encoder_attention_mask,
680
+ output_attentions=output_attentions,
681
+ output_hidden_states=output_hidden_states,
682
+ return_dict=return_dict,
683
+ deterministic=not train,
684
+ rngs=rngs,
685
+ mutable=mutable,
686
+ )
687
+
688
+ # add updated cache to model output
689
+ if past_key_values is not None and return_dict:
690
+ outputs, past_key_values = outputs
691
+ outputs["past_key_values"] = unfreeze(past_key_values["cache"])
692
+ return outputs
693
+ elif past_key_values is not None and not return_dict:
694
+ outputs, past_key_values = outputs
695
+ outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
696
+
697
+ return outputs
698
+
699
+
700
+ class FlaxBartDecoderWrapper(nn.Module):
701
+ """
702
+ This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
703
+ used in combination with the [`EncoderDecoderModel`] framework.
704
+ """
705
+
706
+ config: BartConfig
707
+ dtype: jnp.dtype = jnp.float32
708
+
709
+ def setup(self):
710
+ embed_dim = self.config.d_model
711
+ embed_tokens = nn.Embed(
712
+ self.config.vocab_size,
713
+ embed_dim,
714
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
715
+ )
716
+ self.decoder = FlaxBartDecoder(config=self.config, embed_tokens=embed_tokens, dtype=self.dtype)
717
+
718
+ def __call__(self, *args, **kwargs):
719
+ return self.decoder(*args, **kwargs)
720
+
721
+
722
+ class FlaxBartForCausalLMModule(nn.Module):
723
+ """Bart Decoder Module with a language modeling head on top (linear layer with weights tied to the input embeddings)
724
+ e.g. for autoregressive tasks.
725
+ """
726
+
727
+ config: BartConfig
728
+ dtype: jnp.dtype = jnp.float32
729
+
730
+ def setup(self):
731
+ self.model = FlaxBartDecoderWrapper(config=self.config, dtype=self.dtype)
732
+ self.lm_head = nn.Dense(
733
+ self.config.vocab_size,
734
+ use_bias=False,
735
+ dtype=self.dtype,
736
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
737
+ )
738
+
739
+ def __call__(
740
+ self,
741
+ input_ids,
742
+ attention_mask,
743
+ position_ids,
744
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
745
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
746
+ init_cache: bool = False,
747
+ output_attentions: bool = False,
748
+ output_hidden_states: bool = False,
749
+ return_dict: bool = True,
750
+ deterministic: bool = True,
751
+ ):
752
+
753
+ outputs = self.model(
754
+ input_ids,
755
+ attention_mask,
756
+ position_ids,
757
+ encoder_hidden_states,
758
+ encoder_attention_mask,
759
+ deterministic=deterministic,
760
+ init_cache=init_cache,
761
+ output_attentions=output_attentions,
762
+ output_hidden_states=output_hidden_states,
763
+ return_dict=return_dict,
764
+ )
765
+
766
+ hidden_states = outputs[0]
767
+
768
+ if self.config.tie_word_embeddings:
769
+ shared_embedding = self.model.variables["params"]["decoder"]["embed_tokens"]["embedding"]
770
+ lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
771
+ else:
772
+ lm_logits = self.lm_head(hidden_states)
773
+
774
+ if not return_dict:
775
+ return (lm_logits,) + outputs[1:]
776
+
777
+ return FlaxCausalLMOutputWithCrossAttentions(
778
+ logits=lm_logits,
779
+ hidden_states=outputs.hidden_states,
780
+ attentions=outputs.attentions,
781
+ cross_attentions=outputs.cross_attentions,
782
+ )
783
+
784
+
785
+ class FlaxBartForCausalLM(FlaxBartDecoderPreTrainedModel):
786
+ """Bart Decoder Model with a language modeling head on top (linear layer with weights tied to the input embeddings)
787
+ e.g. for autoregressive tasks.
788
+ """
789
+
790
+ module_class = FlaxBartForCausalLMModule
791
+
792
+ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
793
+ # initializing the cache
794
+ batch_size, seq_length = input_ids.shape
795
+
796
+ past_key_values = self.init_cache(batch_size, max_length)
797
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
798
+ # But since the decoder uses a causal mask, those positions are masked anyway.
799
+ # Thus, we can create a single static attention_mask here, which is more efficient for compilation
800
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
801
+ if attention_mask is not None:
802
+ position_ids = attention_mask.cumsum(axis=-1) - 1
803
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
804
+ else:
805
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
806
+
807
+ return {
808
+ "past_key_values": past_key_values,
809
+ "attention_mask": extended_attention_mask,
810
+ "position_ids": position_ids,
811
+ }
812
+
813
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
814
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
815
+ model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
816
+ return model_kwargs
models/modeling_flax_speech_encoder_decoder.py ADDED
@@ -0,0 +1,1245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Classes to support Flax Speech-Encoder-Decoder architectures"""
16
+
17
+ import os
18
+ from functools import partial
19
+ from typing import Optional, Tuple, Union, Dict
20
+
21
+ import flax
22
+ import flax.linen as nn
23
+ import jax
24
+ import jax.numpy as jnp
25
+ from flax.core.frozen_dict import FrozenDict, unfreeze
26
+ from jax import lax
27
+ from jax.random import PRNGKey
28
+ import numpy as np
29
+
30
+ from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput
31
+ from transformers.modeling_flax_utils import FlaxPreTrainedModel
32
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, ModelOutput
33
+ from transformers.generation_flax_utils import FlaxLogitsProcessorList
34
+ from models import (
35
+ FlaxWav2Vec2Model,
36
+ FlaxWav2Vec2Module,
37
+ FlaxBartForCausalLM,
38
+ FlaxBartForCausalLMModule,
39
+ BartConfig,
40
+ Wav2Vec2Config,
41
+ SpeechEncoderDecoderConfig,
42
+ )
43
+
44
+ logger = logging.get_logger(__name__)
45
+
46
+ _CONFIG_FOR_DOC = "SpeechEncoderDecoderConfig"
47
+
48
+ SPEECH_ENCODER_DECODER_START_DOCSTRING = r"""
49
+ This class can be used to initialize a speech-sequence-to-text-sequence model with any pretrained speech
50
+ autoencoding model as the encoder and any pretrained text autoregressive model as the decoder. The encoder is
51
+ loaded via [`~AutoModel.from_pretrained`] function and the decoder is loaded via
52
+ [`~AutoModelForCausalLM.from_pretrained`] function. Cross-attention layers are automatically added to the decoder
53
+ and should be fine-tuned on a downstream generative task, like summarization.
54
+
55
+ The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation
56
+ tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation
57
+ Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi
58
+ Zhou, Wei Li, Peter J. Liu.
59
+
60
+ Additionally, in [Large-Scale Self- and Semi-Supervised Learning for Speech
61
+ Translation](https://arxiv.org/abs/2104.06678) it is shown how leveraging large pretrained speech models for speech
62
+ translation yields a significant performance improvement.
63
+
64
+ After such an Speech-Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other
65
+ models (see the examples for more information).
66
+
67
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
68
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
69
+ etc.)
70
+
71
+ This model is also a Flax Linen
72
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
73
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
74
+
75
+ Parameters:
76
+ config ([`SpeechEncoderDecoderConfig`]): Model configuration class with all the parameters of the model.
77
+ Initializing with a config file does not load the weights associated with the model, only the
78
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
79
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
80
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
81
+ `jax.numpy.bfloat16` (on TPUs).
82
+
83
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
84
+ specified all the computation will be performed with the given `dtype`.
85
+
86
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
87
+ parameters.**
88
+
89
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
90
+ [`~FlaxPreTrainedModel.to_bf16`].
91
+ """
92
+
93
+ SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
94
+ Args:
95
+ inputs (`jnp.ndarray` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*):
96
+ Float values of input raw speech waveform or speech features. Values can be obtained by loading a *.flac*
97
+ or *.wav* audio file into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile
98
+ library (*pip install soundfile*). To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or
99
+ [`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type
100
+ *torch.FloatTensor*.
101
+ attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
102
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
103
+
104
+ - 1 for tokens that are **not masked**,
105
+ - 0 for tokens that are **masked**.
106
+
107
+ [What are attention masks?](../glossary#attention-mask)
108
+ decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
109
+ Indices of decoder input sequence tokens in the vocabulary.
110
+
111
+ Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
112
+ [`PreTrainedTokenizer.__call__`] for details.
113
+
114
+ [What are input IDs?](../glossary#input-ids)
115
+
116
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
117
+ `past_key_values`).
118
+
119
+ For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
120
+ created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
121
+ and prepending them with the `decoder_start_token_id`.
122
+ decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
123
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
124
+ be used by default.
125
+ decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
126
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
127
+ range `[0, config.decoder.max_position_embeddings - 1]`.
128
+ output_hidden_states (`bool`, *optional*):
129
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
130
+ more detail.
131
+ return_dict (`bool`, *optional*):
132
+ If set to `True`, the model will return a [`~utils.FlaxSeq2SeqLMOutput`] instead of a plain tuple.
133
+ """
134
+
135
+ SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING = r"""
136
+ Args:
137
+ inputs (`jnp.ndarray` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*):
138
+ Float values of input raw speech waveform or speech features. Values can be obtained by loading a *.flac*
139
+ or *.wav* audio file into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile
140
+ library (*pip install soundfile*). To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or
141
+ [`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type
142
+ *torch.FloatTensor*.
143
+ attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
144
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
145
+
146
+ - 1 for tokens that are **not masked**,
147
+ - 0 for tokens that are **masked**.
148
+
149
+ [What are attention masks?](../glossary#attention-mask)
150
+ output_attentions (`bool`, *optional*):
151
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
152
+ tensors for more detail.
153
+ output_hidden_states (`bool`, *optional*):
154
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
155
+ more detail.
156
+ return_dict (`bool`, *optional*):
157
+ If set to `True`, the model will return a [`~utils.FlaxBaseModelOutput`] instead of a plain tuple.
158
+ """
159
+
160
+ SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r"""
161
+ Args:
162
+ decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
163
+ Indices of decoder input sequence tokens in the vocabulary.
164
+
165
+ Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
166
+ [`PreTrainedTokenizer.__call__`] for details.
167
+
168
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
169
+
170
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
171
+ `past_key_values`).
172
+
173
+ For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
174
+ created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
175
+ and prepending them with the `decoder_start_token_id`.
176
+ encoder_outputs (`tuple(tuple(jnp.ndarray)`):
177
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
178
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
179
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
180
+ encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
181
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
182
+
183
+ - 1 for tokens that are **not masked**,
184
+ - 0 for tokens that are **masked**.
185
+
186
+ [What are attention masks?](../glossary#attention-mask)
187
+ decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
188
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
189
+ be used by default.
190
+ decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
191
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
192
+ range `[0, config.decoder.max_position_embeddings - 1]`.
193
+ past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
194
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
195
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
196
+ output_attentions (`bool`, *optional*):
197
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
198
+ tensors for more detail.
199
+ output_hidden_states (`bool`, *optional*):
200
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
201
+ more detail.
202
+ return_dict (`bool`, *optional*):
203
+ If set to `True`, the model will return a [`~utils.FlaxCausalLMOutputWithCrossAttentions`] instead of a
204
+ plain tuple.
205
+ """
206
+
207
+ @flax.struct.dataclass
208
+ class FlaxBeamSearchOutput(ModelOutput):
209
+ """
210
+ Flax Base class for outputs of decoder-only generation models using greedy search.
211
+
212
+
213
+ Args:
214
+ sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
215
+ The generated sequences.
216
+ scores (`jnp.ndarray` of shape `(batch_size,)`):
217
+ The scores (log probabilites) of the generated sequences.
218
+ """
219
+
220
+ sequences: jnp.ndarray = None
221
+ scores: jnp.ndarray = None
222
+
223
+
224
+ @flax.struct.dataclass
225
+ class BeamSearchState:
226
+ cur_len: jnp.ndarray
227
+ running_sequences: jnp.ndarray
228
+ running_scores: jnp.ndarray
229
+ sequences: jnp.ndarray
230
+ scores: jnp.ndarray
231
+ is_sent_finished: jnp.ndarray
232
+ model_kwargs: Dict[str, jnp.ndarray]
233
+
234
+
235
+
236
+
237
+ class FlaxSpeechEncoderDecoderModule(nn.Module):
238
+ config: SpeechEncoderDecoderConfig
239
+ dtype: jnp.dtype = jnp.float32
240
+
241
+ def setup(self):
242
+ encoder_config = self.config.encoder
243
+ decoder_config = self.config.decoder
244
+
245
+ # TODO: configure FlaxAutoModel mappings (required when trialling different encoder-decoder combinations)
246
+ encoder_module = FlaxWav2Vec2Module
247
+ decoder_module = FlaxBartForCausalLMModule
248
+
249
+ self.encoder = encoder_module(encoder_config, dtype=self.dtype)
250
+ self.decoder = decoder_module(decoder_config, dtype=self.dtype)
251
+
252
+ # encoder outputs might need to be projected to different dimension for decoder
253
+ if (
254
+ self.encoder.config.hidden_size != self.decoder.config.hidden_size
255
+ and self.decoder.config.cross_attention_hidden_size is None
256
+ ):
257
+ self.enc_to_dec_proj = nn.Dense(
258
+ self.decoder.config.hidden_size,
259
+ kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range),
260
+ dtype=self.dtype,
261
+ )
262
+ else:
263
+ self.enc_to_dec_proj = None
264
+
265
+ def _get_feat_extract_output_lengths(
266
+ self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
267
+ ):
268
+ """
269
+ Computes the output length of the convolutional layers
270
+ """
271
+
272
+ add_adapter = self.config.encoder.add_adapter if add_adapter is None else add_adapter
273
+
274
+ def _conv_out_length(input_length, kernel_size, stride):
275
+ # 1D convolutional layer output length formula taken
276
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
277
+ return (input_length - kernel_size) // stride + 1
278
+
279
+ for kernel_size, stride in zip(self.config.encoder.conv_kernel, self.config.encoder.conv_stride):
280
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
281
+
282
+ if add_adapter:
283
+ for _ in range(self.config.encoder.num_adapter_layers):
284
+ input_lengths = _conv_out_length(input_lengths, 1, self.config.encoder.adapter_stride)
285
+
286
+ return input_lengths
287
+
288
+ def _get_encoder_module(self):
289
+ return self.encoder
290
+
291
+ def _get_projection_module(self):
292
+ return self.enc_to_dec_proj
293
+
294
+ def _get_decoder_module(self):
295
+ return self.decoder
296
+
297
+ def __call__(
298
+ self,
299
+ inputs,
300
+ attention_mask,
301
+ decoder_input_ids,
302
+ decoder_attention_mask,
303
+ decoder_position_ids,
304
+ encoder_outputs=None,
305
+ extract_features=None,
306
+ output_attentions: bool = False,
307
+ output_hidden_states: bool = False,
308
+ output_features: bool = False,
309
+ return_dict: bool = True,
310
+ deterministic: bool = True,
311
+ freeze_feature_encoder: bool = False,
312
+ ):
313
+ if encoder_outputs is None:
314
+ encoder_outputs = self.encoder(
315
+ inputs,
316
+ attention_mask=attention_mask,
317
+ extract_features=extract_features,
318
+ output_attentions=output_attentions,
319
+ output_hidden_states=output_hidden_states,
320
+ output_features=output_features,
321
+ return_dict=return_dict,
322
+ deterministic=deterministic,
323
+ freeze_feature_encoder=freeze_feature_encoder,
324
+ )
325
+
326
+ if output_features:
327
+ return encoder_outputs
328
+
329
+ encoder_hidden_states = encoder_outputs[0]
330
+
331
+ # optionally project encoder_hidden_states
332
+ if self.enc_to_dec_proj is not None:
333
+ encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
334
+
335
+ # compute correct encoder attention mask
336
+ if attention_mask is not None:
337
+ encoder_attention_mask = self.encoder._get_feature_vector_attention_mask(
338
+ encoder_hidden_states.shape[1], attention_mask
339
+ )
340
+ else:
341
+ encoder_attention_mask = None
342
+
343
+ # flax script modeling_flax_wav2vec2.py
344
+ decoder_outputs = self.decoder(
345
+ input_ids=decoder_input_ids,
346
+ attention_mask=decoder_attention_mask,
347
+ position_ids=decoder_position_ids,
348
+ encoder_hidden_states=encoder_hidden_states,
349
+ encoder_attention_mask=encoder_attention_mask,
350
+ output_attentions=output_attentions,
351
+ output_hidden_states=output_hidden_states,
352
+ return_dict=return_dict,
353
+ deterministic=deterministic,
354
+ )
355
+
356
+ if not return_dict:
357
+ return decoder_outputs + encoder_outputs
358
+
359
+ return FlaxSeq2SeqLMOutput(
360
+ logits=decoder_outputs.logits,
361
+ decoder_hidden_states=decoder_outputs.hidden_states,
362
+ decoder_attentions=decoder_outputs.attentions,
363
+ cross_attentions=decoder_outputs.cross_attentions,
364
+ encoder_last_hidden_state=encoder_hidden_states,
365
+ encoder_hidden_states=encoder_outputs.hidden_states,
366
+ encoder_attentions=encoder_outputs.attentions,
367
+ )
368
+
369
+
370
+ @add_start_docstrings(SPEECH_ENCODER_DECODER_START_DOCSTRING)
371
+ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
372
+ r"""
373
+ [`FlaxSpeechEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture
374
+ with the module (flax.nn.Module) of one of the base model classes of the library as encoder module and another one
375
+ as decoder module when created with the :meth*~transformers.FlaxAutoModel.from_pretrained* class method for the
376
+ encoder and :meth*~transformers.FlaxAutoModelForCausalLM.from_pretrained* class method for the decoder.
377
+ """
378
+
379
+ config_class = SpeechEncoderDecoderConfig
380
+ base_model_prefix: str = "speech_encoder_decoder"
381
+ module_class = FlaxSpeechEncoderDecoderModule
382
+
383
+ def __init__(
384
+ self,
385
+ config: SpeechEncoderDecoderConfig,
386
+ input_shape: Optional[Tuple] = None,
387
+ seed: int = 0,
388
+ dtype: jnp.dtype = jnp.float32,
389
+ _do_init: bool = True,
390
+ **kwargs
391
+ ):
392
+
393
+ if not _do_init:
394
+ raise ValueError(
395
+ "`FlaxSpeechEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`."
396
+ )
397
+
398
+ if config.decoder.cross_attention_hidden_size is not None:
399
+ # Raise ValueError or option to project enc to dec hidden_size (eg EncAdapterLayer)
400
+ if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
401
+ raise ValueError(
402
+ "If `cross_attention_hidden_size` is specified in the decoder's configuration, "
403
+ "it has to be equal to the encoder's `hidden_size`. "
404
+ f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
405
+ f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
406
+ )
407
+
408
+ # make sure input & output embeddings are not tied
409
+ config.tie_word_embeddings = False
410
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
411
+
412
+ if input_shape is None:
413
+ # speech encoders almost always downsample the sequence length dimension
414
+ encoder_input_length = 1024
415
+ decoder_input_length = module._get_feat_extract_output_lengths(encoder_input_length)
416
+ input_shape = ((1, encoder_input_length), (1, decoder_input_length))
417
+
418
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
419
+
420
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
421
+ encoder_input_shape, decoder_input_shape = input_shape
422
+
423
+ # init input DeviceArrays
424
+ inputs = jnp.zeros(encoder_input_shape, dtype="f4")
425
+ attention_mask = jnp.ones_like(inputs, dtype="i4")
426
+ decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4")
427
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
428
+
429
+ batch_size, sequence_length = inputs.shape
430
+
431
+ decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape
432
+ if not decoder_batch_size == batch_size:
433
+ raise ValueError(
434
+ f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder and {decoder_batch_size} for decoder."
435
+ )
436
+ decoder_position_ids = jnp.broadcast_to(
437
+ jnp.arange(decoder_sequence_length)[None, :], (decoder_batch_size, decoder_sequence_length)
438
+ )
439
+
440
+ params_rng, dropout_rng = jax.random.split(rng)
441
+ rngs = {"params": params_rng, "dropout": dropout_rng}
442
+
443
+ return self.module.init(
444
+ rngs,
445
+ inputs,
446
+ attention_mask,
447
+ decoder_input_ids,
448
+ decoder_attention_mask,
449
+ decoder_position_ids,
450
+ )["params"]
451
+
452
+ def init_cache(self, batch_size, max_length, encoder_outputs):
453
+ r"""
454
+ Args:
455
+ batch_size (`int`):
456
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
457
+ max_length (`int`):
458
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
459
+ cache.
460
+ encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
461
+ `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
462
+ `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
463
+ is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
464
+ cross-attention of the decoder.
465
+ """
466
+ # init input variables to retrieve cache
467
+ decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
468
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
469
+ decoder_position_ids = jnp.broadcast_to(
470
+ jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
471
+ )
472
+
473
+ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
474
+ decoder_module = module._get_decoder_module()
475
+ return decoder_module(
476
+ input_ids=decoder_input_ids,
477
+ attention_mask=decoder_attention_mask,
478
+ position_ids=decoder_position_ids,
479
+ **kwargs,
480
+ )
481
+
482
+ init_variables = self.module.init(
483
+ jax.random.PRNGKey(0),
484
+ decoder_input_ids=decoder_input_ids,
485
+ decoder_attention_mask=decoder_attention_mask,
486
+ decoder_position_ids=decoder_position_ids,
487
+ encoder_hidden_states=encoder_outputs[0],
488
+ init_cache=True,
489
+ method=_decoder_forward, # we only need to call the decoder to init the cache
490
+ )
491
+ return unfreeze(init_variables["cache"])
492
+
493
+ def _get_feat_extract_output_lengths(
494
+ self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
495
+ ):
496
+ return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter)
497
+
498
+ @add_start_docstrings(SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING)
499
+ @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=_CONFIG_FOR_DOC)
500
+ def encode(
501
+ self,
502
+ inputs: jnp.ndarray,
503
+ attention_mask: Optional[jnp.ndarray] = None,
504
+ extract_features: Optional[jnp.ndarray] = None,
505
+ output_attentions: Optional[bool] = None,
506
+ output_hidden_states: Optional[bool] = None,
507
+ output_features: Optional[bool] = None,
508
+ return_dict: Optional[bool] = None,
509
+ train: bool = False,
510
+ freeze_feature_encoder: bool = False,
511
+ params: dict = None,
512
+ dropout_rng: PRNGKey = None,
513
+ ):
514
+ r"""
515
+ Returns:
516
+
517
+ Example:
518
+
519
+ ```python
520
+ >>> from transformers import FlaxSpeechEncoderDecoderModel
521
+
522
+ >>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized
523
+ >>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
524
+ ... "facebook/wav2vec2-large-lv60", "facebook/bart-large"
525
+ ... )
526
+
527
+ >>> inputs = jnp.ones((2, 5000), dtype=jnp.float32)
528
+ >>> encoder_outputs = model.encode(inputs)
529
+ ```"""
530
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
531
+ output_hidden_states = (
532
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
533
+ )
534
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
535
+
536
+ if attention_mask is None:
537
+ attention_mask = jnp.ones_like(inputs, dtype="i4")
538
+
539
+ if extract_features is not None:
540
+ extract_features = jnp.array(extract_features, dtype="f4")
541
+
542
+ # Handle any PRNG if needed
543
+ rngs = {}
544
+ if dropout_rng is not None:
545
+ rngs["dropout"] = dropout_rng
546
+
547
+ def _encoder_forward(module, inputs, attention_mask, **kwargs):
548
+ encode_module = module._get_encoder_module()
549
+ return encode_module(inputs, attention_mask, **kwargs)
550
+
551
+ outputs = self.module.apply(
552
+ {"params": params or self.params},
553
+ inputs=jnp.array(inputs, dtype="f4"),
554
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
555
+ extract_features=extract_features,
556
+ output_attentions=output_attentions,
557
+ output_hidden_states=output_hidden_states,
558
+ output_features=output_features,
559
+ return_dict=return_dict,
560
+ deterministic=not train,
561
+ freeze_feature_encoder=freeze_feature_encoder,
562
+ rngs=rngs,
563
+ method=_encoder_forward,
564
+ )
565
+
566
+ if return_dict and not output_features:
567
+ outputs = FlaxBaseModelOutput(
568
+ last_hidden_state=outputs.last_hidden_state,
569
+ hidden_states=outputs.hidden_states,
570
+ attentions=outputs.attentions,
571
+ )
572
+
573
+ return outputs
574
+
575
+ @add_start_docstrings(SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING)
576
+ @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
577
+ def decode(
578
+ self,
579
+ decoder_input_ids,
580
+ encoder_outputs,
581
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
582
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
583
+ decoder_position_ids: Optional[jnp.ndarray] = None,
584
+ past_key_values: dict = None,
585
+ output_attentions: Optional[bool] = None,
586
+ output_hidden_states: Optional[bool] = None,
587
+ return_dict: Optional[bool] = None,
588
+ train: bool = False,
589
+ params: dict = None,
590
+ dropout_rng: PRNGKey = None,
591
+ ):
592
+ r"""
593
+ Returns:
594
+
595
+ Example:
596
+
597
+ ```python
598
+ >>> from transformers import FlaxSpeechEncoderDecoderModel
599
+ >>> import jax.numpy as jnp
600
+
601
+ >>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized
602
+ >>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
603
+ ... "facebook/wav2vec2-large-lv60", "facebook/bart-large"
604
+ ... )
605
+
606
+ >>> inputs = jnp.ones((2, 5000), dtype=jnp.float32)
607
+ >>> encoder_outputs = model.encode(inputs)
608
+
609
+ >>> decoder_start_token_id = model.config.decoder.bos_token_id
610
+ >>> decoder_input_ids = jnp.ones((inputs.shape[0], 1), dtype="i4") * decoder_start_token_id
611
+
612
+ >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
613
+ >>> logits = outputs.logits
614
+ ```"""
615
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
616
+ output_hidden_states = (
617
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
618
+ )
619
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
620
+
621
+ encoder_hidden_states = encoder_outputs[0]
622
+ if encoder_attention_mask is None:
623
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
624
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
625
+
626
+ batch_size, sequence_length = decoder_input_ids.shape
627
+ if decoder_attention_mask is None:
628
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length))
629
+
630
+ if decoder_position_ids is None:
631
+ if past_key_values is not None:
632
+ raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
633
+
634
+ decoder_position_ids = jnp.broadcast_to(
635
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
636
+ )
637
+
638
+ # Handle any PRNG if needed
639
+ rngs = {}
640
+ if dropout_rng is not None:
641
+ rngs["dropout"] = dropout_rng
642
+
643
+ params = {"params": params or self.params}
644
+
645
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
646
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
647
+ # it can be changed by FlaxBartAttention module
648
+ if past_key_values:
649
+ params["cache"] = past_key_values
650
+ mutable = ["cache"]
651
+ else:
652
+ mutable = False
653
+
654
+ def _decoder_forward(
655
+ module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, encoder_hidden_states, **kwargs
656
+ ):
657
+
658
+ projection_module = module._get_projection_module()
659
+ decoder_module = module._get_decoder_module()
660
+
661
+ # optionally project encoder_hidden_states
662
+ if projection_module is not None:
663
+ encoder_hidden_states = projection_module(encoder_hidden_states)
664
+
665
+ return decoder_module(
666
+ decoder_input_ids,
667
+ decoder_attention_mask,
668
+ decoder_position_ids,
669
+ encoder_hidden_states,
670
+ **kwargs,
671
+ )
672
+
673
+ outputs = self.module.apply(
674
+ params,
675
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
676
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
677
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
678
+ encoder_hidden_states=encoder_hidden_states,
679
+ encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
680
+ output_attentions=output_attentions,
681
+ output_hidden_states=output_hidden_states,
682
+ return_dict=return_dict,
683
+ deterministic=not train,
684
+ rngs=rngs,
685
+ mutable=mutable,
686
+ method=_decoder_forward,
687
+ )
688
+
689
+ # add updated cache to model output
690
+ if past_key_values is not None and return_dict:
691
+ outputs, past = outputs
692
+ outputs["past_key_values"] = unfreeze(past["cache"])
693
+ return outputs
694
+ elif past_key_values is not None and not return_dict:
695
+ outputs, past = outputs
696
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
697
+
698
+ return outputs
699
+
700
+ @add_start_docstrings_to_model_forward(SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING)
701
+ @replace_return_docstrings(output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
702
+ def __call__(
703
+ self,
704
+ inputs: jnp.ndarray,
705
+ attention_mask: Optional[jnp.ndarray] = None,
706
+ extract_features: Optional[jnp.ndarray] = None,
707
+ decoder_input_ids: Optional[jnp.ndarray] = None,
708
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
709
+ decoder_position_ids: Optional[jnp.ndarray] = None,
710
+ output_attentions: Optional[bool] = None,
711
+ output_hidden_states: Optional[bool] = None,
712
+ output_features: Optional[bool] = None,
713
+ return_dict: Optional[bool] = None,
714
+ train: bool = False,
715
+ freeze_feature_encoder: bool = False,
716
+ params: dict = None,
717
+ dropout_rng: PRNGKey = None,
718
+ ):
719
+ r"""
720
+ Returns:
721
+
722
+ Examples:
723
+
724
+ ```python
725
+ >>> from transformers import FlaxSpeechEncoderDecoderModel, BartTokenizer
726
+
727
+ >>> # load a fine-tuned wav2vec2-2-bart model
728
+ >>> model = FlaxSpeechEncoderDecoderModel.from_pretrained("patrickvonplaten/wav2vec2-2-bart-large")
729
+ >>> # load output tokenizer
730
+ >>> tokenizer_output = BartTokenizer.from_pretrained("facebook/bart-large")
731
+
732
+ >>> inputs = jnp.ones((2, 5000), dtype=jnp.float32)
733
+
734
+ >>> # use bart's special bos, pad and eos tokens
735
+ >>> model.config.decoder_start_token_id = model.decoder.config.bos_token_id
736
+ >>> model.config.pad_token_id = model.decoder.config.pad_token_id
737
+ >>> model.config.eos_token_id = model.decoder.config.eos_token_id
738
+
739
+ >>> outputs = model.generate(inputs)
740
+ # Assert something? More interesting input? dtype correct?
741
+ ```
742
+ """
743
+
744
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
745
+ output_hidden_states = (
746
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
747
+ )
748
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
749
+
750
+ # prepare encoder inputs
751
+ if attention_mask is None:
752
+ attention_mask = jnp.ones_like(inputs, dtype="i4")
753
+
754
+ if extract_features is not None:
755
+ inputs = None # we can omit passing the inputs to the model to save memory
756
+ extract_features = jnp.array(extract_features, dtype="f4")
757
+ else:
758
+ inputs = jnp.array(inputs, dtype="f4")
759
+
760
+ # prepare decoder inputs
761
+ if decoder_input_ids is None:
762
+ raise ValueError(
763
+ "`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must be specified as an input argument."
764
+ )
765
+ if decoder_attention_mask is None:
766
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
767
+ if decoder_position_ids is None:
768
+ batch_size, sequence_length = decoder_input_ids.shape
769
+ decoder_position_ids = jnp.broadcast_to(
770
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
771
+ )
772
+
773
+ # Handle any PRNG if needed
774
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
775
+
776
+ return self.module.apply(
777
+ {"params": params or self.params},
778
+ inputs=inputs,
779
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
780
+ extract_features=extract_features,
781
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
782
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
783
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
784
+ output_attentions=output_attentions,
785
+ output_hidden_states=output_hidden_states,
786
+ output_features=output_features,
787
+ return_dict=return_dict,
788
+ deterministic=not train,
789
+ freeze_feature_encoder=freeze_feature_encoder,
790
+ rngs=rngs,
791
+ )
792
+
793
+ def prepare_inputs_for_generation(
794
+ self,
795
+ decoder_input_ids,
796
+ max_length,
797
+ attention_mask: Optional[jnp.DeviceArray] = None,
798
+ decoder_attention_mask: Optional[jnp.DeviceArray] = None,
799
+ encoder_outputs=None,
800
+ **kwargs
801
+ ):
802
+ # initializing the cache
803
+ batch_size, seq_length = decoder_input_ids.shape
804
+
805
+ past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
806
+ # Note that usually one would have to put 0's in the attention_mask for x > input.shape[-1] and x < cache_length.
807
+ # But since the decoder uses a causal mask, those positions are masked anyways.
808
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
809
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
810
+ if decoder_attention_mask is not None:
811
+ decoder_position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
812
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
813
+ else:
814
+ decoder_position_ids = jnp.broadcast_to(
815
+ jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
816
+ )
817
+
818
+ return {
819
+ "past_key_values": past_key_values,
820
+ "encoder_outputs": encoder_outputs,
821
+ "encoder_attention_mask": attention_mask,
822
+ "decoder_attention_mask": extended_attention_mask,
823
+ "decoder_position_ids": decoder_position_ids,
824
+ }
825
+
826
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
827
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
828
+ model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
829
+ return model_kwargs
830
+
831
+ @classmethod
832
+ def from_encoder_decoder_pretrained(
833
+ cls,
834
+ encoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
835
+ decoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
836
+ *model_args,
837
+ **kwargs
838
+ ) -> FlaxPreTrainedModel:
839
+ r"""
840
+ Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model
841
+ checkpoints.
842
+
843
+ Params:
844
+ encoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*):
845
+ Information necessary to initiate the encoder. Can be either:
846
+
847
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
848
+ Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
849
+ user or organization name, like `dbmdz/bert-base-german-cased`.
850
+ - A path to a *directory* containing model weights saved using
851
+ [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
852
+
853
+ decoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*, defaults to `None`):
854
+ Information necessary to initiate the decoder. Can be either:
855
+
856
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
857
+ Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
858
+ user or organization name, like `dbmdz/bert-base-german-cased`.
859
+ - A path to a *directory* containing model weights saved using
860
+ [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
861
+
862
+ model_args (remaining positional arguments, *optional*):
863
+ All remaning positional arguments will be passed to the underlying model's `__init__` method.
864
+
865
+ kwargs (remaining dictionary of keyword arguments, *optional*):
866
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
867
+ `output_attentions=True`).
868
+
869
+ - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter.
870
+ - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
871
+ - To update the parent model configuration, do not use a prefix for each configuration parameter.
872
+
873
+ Behaves differently depending on whether a `config` is provided or automatically loaded.
874
+
875
+ Example:
876
+
877
+ ```python
878
+ >>> from transformers import FlaxSpeechEncoderDecoderModel
879
+
880
+ >>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized
881
+ >>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
882
+ ... "facebook/wav2vec2-large-lv60", "facebook/bart-large"
883
+ ... )
884
+ >>> # saving model after fine-tuning
885
+ >>> model.save_pretrained("./wav2vec2-2-bart-large")
886
+ >>> # load fine-tuned model
887
+ >>> model = FlaxSpeechEncoderDecoderModel.from_pretrained("./wav2vec2-2-bart-large")
888
+ ```"""
889
+
890
+ kwargs_encoder = {
891
+ argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
892
+ }
893
+
894
+ kwargs_decoder = {
895
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
896
+ }
897
+
898
+ # remove encoder, decoder kwargs from kwargs
899
+ for key in kwargs_encoder.keys():
900
+ del kwargs["encoder_" + key]
901
+ for key in kwargs_decoder.keys():
902
+ del kwargs["decoder_" + key]
903
+
904
+ # Load and initialize the encoder and decoder
905
+ # The distinction between encoder and decoder at the model level is made
906
+ # by the value of the flag `is_decoder` that we need to set correctly.
907
+ encoder = kwargs_encoder.pop("model", None)
908
+ if encoder is None:
909
+ if encoder_pretrained_model_name_or_path is None:
910
+ raise ValueError(
911
+ "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
912
+ "to be defined."
913
+ )
914
+
915
+ if "config" not in kwargs_encoder:
916
+ # TODO: AutoConfig .from_pretrained
917
+ encoder_config, kwargs_encoder = Wav2Vec2Config.from_pretrained(
918
+ encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True
919
+ )
920
+ if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
921
+ logger.info(
922
+ f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
923
+ "from a decoder model. Cross-attention and casual mask are disabled."
924
+ )
925
+ encoder_config.is_decoder = False
926
+ encoder_config.add_cross_attention = False
927
+
928
+ kwargs_encoder["config"] = encoder_config
929
+
930
+ # TODO: FlaxAutoModel .from_pretrained
931
+ encoder = FlaxWav2Vec2Model.from_pretrained(
932
+ encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder
933
+ )
934
+
935
+ decoder = kwargs_decoder.pop("model", None)
936
+ if decoder is None:
937
+ if decoder_pretrained_model_name_or_path is None:
938
+ raise ValueError(
939
+ "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
940
+ "to be defined."
941
+ )
942
+
943
+ if "config" not in kwargs_decoder:
944
+ # TODO: AutoConfig .from_pretrained
945
+ decoder_config, kwargs_decoder = BartConfig.from_pretrained(
946
+ decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
947
+ )
948
+ if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
949
+ logger.info(
950
+ f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
951
+ f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
952
+ f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
953
+ "cross attention layers."
954
+ )
955
+ decoder_config.is_decoder = True
956
+ decoder_config.add_cross_attention = True
957
+
958
+ kwargs_decoder["config"] = decoder_config
959
+
960
+ if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
961
+ logger.warning(
962
+ f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
963
+ f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
964
+ "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
965
+ "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
966
+ "`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
967
+ )
968
+
969
+ # TODO: FlaxAutoModelForCausalLM .from_pretrained
970
+ decoder = FlaxBartForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
971
+
972
+ # instantiate config with corresponding kwargs
973
+ dtype = kwargs.pop("dtype", jnp.float32)
974
+ config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
975
+
976
+ # make sure input & output word embeddings are not tied
977
+ config.tie_word_embeddings = False
978
+
979
+ # init model
980
+ model = cls(config, dtype=dtype)
981
+ model.params["encoder"] = encoder.params
982
+ model.params["decoder"] = decoder.params
983
+
984
+ return model
985
+
986
+ def _beam_search(
987
+ self,
988
+ input_ids: None,
989
+ max_length: Optional[int] = None,
990
+ pad_token_id: Optional[int] = None,
991
+ eos_token_id: Optional[int] = None,
992
+ length_penalty: Optional[float] = None,
993
+ early_stopping: Optional[bool] = None,
994
+ logits_processor: Optional[FlaxLogitsProcessorList] = None,
995
+ trace: bool = True,
996
+ params: Optional[Dict[str, jnp.ndarray]] = None,
997
+ model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
998
+ ):
999
+ """
1000
+ This beam search function is heavily inspired by Flax's official example:
1001
+ https://github.com/google/flax/blob/master/examples/wmt/train.py#L254
1002
+ """
1003
+
1004
+ def flatten_beam_dim(tensor):
1005
+ """Flattens the first two dimensions of a non-scalar array."""
1006
+ # ignore scalars (e.g. cache index)
1007
+ if tensor.ndim == 0 or tensor.ndim == 1:
1008
+ return tensor
1009
+ elif tensor.ndim == 6:
1010
+ return tensor.reshape(tensor.shape[:1] + (tensor.shape[1] * tensor.shape[2],) + tensor.shape[3:])
1011
+ return tensor.reshape((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:])
1012
+
1013
+ def unflatten_beam_dim(tensor, batch_size, num_beams):
1014
+ """Unflattens the first, flat batch*beam dimension of a non-scalar array."""
1015
+ # ignore scalars (e.g. cache index)
1016
+ if tensor.ndim == 0 or tensor.ndim == 1:
1017
+ return tensor
1018
+ if tensor.ndim == 5:
1019
+ return tensor.reshape(tensor.shape[:1] + (batch_size, num_beams) + tensor.shape[2:])
1020
+ return tensor.reshape((batch_size, num_beams) + tensor.shape[1:])
1021
+
1022
+ def gather_beams(nested, beam_indices, batch_size, new_num_beams):
1023
+ """
1024
+ Gathers the beam slices indexed by beam_indices into new beam array.
1025
+ """
1026
+ batch_indices = jnp.reshape(
1027
+ jnp.arange(batch_size * new_num_beams) // new_num_beams, (batch_size, new_num_beams)
1028
+ )
1029
+
1030
+ def gather_fn(tensor):
1031
+ # ignore scalars (e.g. cache index)
1032
+ if tensor.ndim == 0 or tensor.ndim == 1:
1033
+ return tensor
1034
+ if tensor.ndim == 6:
1035
+ return tensor[:, batch_indices, beam_indices]
1036
+ return tensor[batch_indices, beam_indices]
1037
+
1038
+ return jax.tree_map(gather_fn, nested)
1039
+
1040
+ # init values
1041
+ max_length = max_length if max_length is not None else self.config.max_length
1042
+ pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
1043
+ eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
1044
+ length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
1045
+ early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
1046
+
1047
+ batch_size, num_beams, cur_len = input_ids.shape
1048
+
1049
+ eos_token_id = jnp.array(eos_token_id)
1050
+ pad_token_id = jnp.array(pad_token_id)
1051
+ cur_len = jnp.array(cur_len)
1052
+
1053
+ # per batch,beam-item holding current token in loop.
1054
+ sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
1055
+ running_sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
1056
+ running_sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0, 0))
1057
+
1058
+ # per batch,beam-item state bit indicating if sentence has finished.
1059
+ is_sent_finished = jnp.zeros((batch_size, num_beams), dtype=jnp.bool_)
1060
+
1061
+ # per batch,beam-item score, logprobs
1062
+ running_scores = jnp.tile(jnp.array([0.0] + [np.array(-1.0e7)] * (num_beams - 1)), [batch_size, 1])
1063
+ scores = jnp.ones((batch_size, num_beams)) * np.array(-1.0e7)
1064
+
1065
+ # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
1066
+ # and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
1067
+ model = self.decode if self.config.is_encoder_decoder else self
1068
+
1069
+ # flatten beam dim
1070
+ if "encoder_outputs" in model_kwargs:
1071
+ model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim(
1072
+ model_kwargs["encoder_outputs"]["last_hidden_state"]
1073
+ )
1074
+ if "attention_mask" in model_kwargs:
1075
+ model_kwargs["attention_mask"] = flatten_beam_dim(model_kwargs["attention_mask"])
1076
+
1077
+ # initialize model specific kwargs
1078
+ model_kwargs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), max_length, **model_kwargs)
1079
+
1080
+ # initialize state
1081
+ state = BeamSearchState(
1082
+ cur_len=cur_len,
1083
+ running_sequences=running_sequences,
1084
+ running_scores=running_scores,
1085
+ sequences=sequences,
1086
+ scores=scores,
1087
+ is_sent_finished=is_sent_finished,
1088
+ model_kwargs=model_kwargs,
1089
+ )
1090
+
1091
+ def beam_search_cond_fn(state):
1092
+ """beam search state termination condition fn."""
1093
+
1094
+ # 1. is less than max length?
1095
+ not_max_length_yet = state.cur_len < max_length
1096
+
1097
+ # 2. can the new beams still improve?
1098
+ best_running_score = state.running_scores[:, -1:] / (max_length**length_penalty)
1099
+ worst_finished_score = jnp.where(
1100
+ state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7)
1101
+ )
1102
+ improvement_still_possible = jnp.all(worst_finished_score < best_running_score)
1103
+
1104
+ # 3. is there still a beam that has not finished?
1105
+ still_open_beam = ~(jnp.all(state.is_sent_finished) & early_stopping)
1106
+
1107
+ return not_max_length_yet & still_open_beam & improvement_still_possible
1108
+
1109
+ def beam_search_body_fn(state, input_ids_length=1):
1110
+ """beam search state update fn."""
1111
+ # 1. Forward current tokens
1112
+ # Collect the current position slice along length to feed the fast
1113
+ # autoregressive decoder model. Flatten the beam dimension into batch
1114
+ # dimension for feeding into the model.
1115
+ # unflatten beam dimension
1116
+ # Unflatten beam dimension in attention cache arrays
1117
+ input_token = flatten_beam_dim(
1118
+ lax.dynamic_slice(
1119
+ state.running_sequences,
1120
+ (0, 0, state.cur_len - input_ids_length),
1121
+ (batch_size, num_beams, input_ids_length),
1122
+ )
1123
+ )
1124
+ model_outputs = model(input_token, params=params, **state.model_kwargs)
1125
+
1126
+ logits = unflatten_beam_dim(model_outputs.logits[:, -1], batch_size, num_beams)
1127
+ cache = jax.tree_map(
1128
+ lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams), model_outputs.past_key_values
1129
+ )
1130
+
1131
+ # adapt logits for FlaxMarianMTModel
1132
+ logits = self._adapt_logits_for_beam_search(logits)
1133
+
1134
+ # 2. Compute log probs
1135
+ # get log probabilities from logits,
1136
+ # process logits with processors (*e.g.* min_length, ...), and
1137
+ # add new logprobs to existing running logprobs scores.
1138
+ log_probs = jax.nn.log_softmax(logits)
1139
+ log_probs = logits_processor(
1140
+ flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), state.cur_len
1141
+ )
1142
+ log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
1143
+ log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2)
1144
+ vocab_size = log_probs.shape[2]
1145
+ log_probs = log_probs.reshape((batch_size, num_beams * vocab_size))
1146
+
1147
+ # 3. Retrieve top-K
1148
+ # Each item in batch has num_beams * vocab_size candidate sequences.
1149
+ # For each item, get the top 2*k candidates with the highest log-
1150
+ # probabilities. We gather the top 2*K beams here so that even if the best
1151
+ # K sequences reach EOS simultaneously, we have another K sequences
1152
+ # remaining to continue the live beam search.
1153
+ # Gather the top 2*K scores from _all_ beams.
1154
+ # Gather 2*k top beams.
1155
+ # Recover the beam index by floor division.
1156
+ # Recover token id by modulo division and expand Id array for broadcasting.
1157
+ # Update sequences for the 2*K top-k new sequences.
1158
+ beams_to_keep = 2 * num_beams
1159
+ topk_log_probs, topk_indices = lax.top_k(log_probs, k=beams_to_keep)
1160
+ topk_beam_indices = topk_indices // vocab_size
1161
+ topk_running_sequences = gather_beams(
1162
+ state.running_sequences, topk_beam_indices, batch_size, beams_to_keep
1163
+ )
1164
+ topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
1165
+ topk_sequences = lax.dynamic_update_slice(topk_running_sequences, topk_ids, (0, 0, state.cur_len))
1166
+
1167
+ # 4. Check which sequences have ended
1168
+ # Update current sequences:
1169
+ # Did any of these sequences reach an end marker?
1170
+ # To prevent these just finished sequences from being added to the current sequences
1171
+ # set of active beam search sequences, set their log probs to a very large
1172
+ # negative value.
1173
+ did_topk_just_finished = topk_sequences[:, :, state.cur_len] == eos_token_id
1174
+ running_topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(-1.0e7)
1175
+ # 5. Get running sequences scores for next
1176
+ # Determine the top k beam indices (from top 2*k beams) from log probs
1177
+ # and gather top k beams (from top 2*k beams).
1178
+ next_topk_indices = jnp.flip(lax.top_k(running_topk_log_probs, k=num_beams)[1], axis=1)
1179
+ next_running_sequences, next_running_scores = gather_beams(
1180
+ [topk_sequences, running_topk_log_probs], next_topk_indices, batch_size, num_beams
1181
+ )
1182
+
1183
+ # 6. Process topk logits
1184
+ # Further process log probs:
1185
+ # - add length penalty
1186
+ # - make sure no scores can be added anymore if beam is full
1187
+ # - make sure still running sequences cannot be chosen as finalized beam
1188
+ topk_log_probs = topk_log_probs / (state.cur_len**length_penalty)
1189
+ beams_in_batch_are_full = (
1190
+ jnp.broadcast_to(state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape)
1191
+ & early_stopping
1192
+ )
1193
+ add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
1194
+ topk_log_probs += add_penalty * np.array(-1.0e7)
1195
+
1196
+ # 7. Get scores, sequences, is sentence finished for next.
1197
+ # Combine sequences, scores, and flags along the beam dimension and compare
1198
+ # new finished sequence scores to existing finished scores and select the
1199
+ # best from the new set of beams
1200
+ merged_sequences = jnp.concatenate([state.sequences, topk_sequences], axis=1)
1201
+ merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1)
1202
+ merged_is_sent_finished = jnp.concatenate([state.is_sent_finished, did_topk_just_finished], axis=1)
1203
+ topk_merged_indices = jnp.flip(lax.top_k(merged_scores, k=num_beams)[1], axis=1)
1204
+ next_sequences, next_scores, next_is_sent_finished = gather_beams(
1205
+ [merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams
1206
+ )
1207
+
1208
+ # 8. Update model kwargs.
1209
+ # Determine the top k beam indices from the original set of all beams.
1210
+ # With these, gather the top k beam-associated caches.
1211
+ next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams)
1212
+ next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams)
1213
+ model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache)
1214
+ next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
1215
+
1216
+ return BeamSearchState(
1217
+ cur_len=state.cur_len + 1,
1218
+ running_scores=next_running_scores,
1219
+ running_sequences=next_running_sequences,
1220
+ scores=next_scores,
1221
+ sequences=next_sequences,
1222
+ is_sent_finished=next_is_sent_finished,
1223
+ model_kwargs=next_model_kwargs,
1224
+ )
1225
+
1226
+ # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
1227
+ if input_ids.shape[-1] > 1:
1228
+ state = partial(beam_search_body_fn, input_ids_length=input_ids.shape[-1])(state)
1229
+
1230
+ if not trace:
1231
+ state = self._run_loop_in_debug(beam_search_cond_fn, beam_search_body_fn, state)
1232
+ else:
1233
+ state = lax.while_loop(beam_search_cond_fn, beam_search_body_fn, state)
1234
+
1235
+ # Account for the edge-case where there are no finished sequences for a
1236
+ # particular batch item. If so, return running sequences for that batch item.
1237
+ none_finished = jnp.any(state.is_sent_finished, axis=1)
1238
+ sequences = jnp.where(none_finished[:, None, None], state.sequences, state.running_sequences)
1239
+ scores = jnp.where(none_finished[:, None], state.scores, state.running_scores)
1240
+
1241
+ # return all beams for each batch and the best score
1242
+ sequences = sequences[:, :]
1243
+ scores = scores[:, -1]
1244
+
1245
+ return FlaxBeamSearchOutput(sequences=sequences, scores=scores)
models/modeling_flax_wav2vec2.py ADDED
@@ -0,0 +1,975 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Flax Wav2Vec2 model."""
16
+
17
+ from functools import partial
18
+ from typing import Optional, Tuple, Union
19
+
20
+ import flax
21
+ import flax.linen as nn
22
+ import jax
23
+ import jax.numpy as jnp
24
+ from flax.core.frozen_dict import FrozenDict
25
+ from flax.linen import partitioning as nn_partitioning
26
+ from flax.linen.attention import dot_product_attention_weights
27
+ from jax import lax
28
+
29
+ from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
30
+ from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
31
+ from transformers.utils import ModelOutput
32
+
33
+ from models import Wav2Vec2Config
34
+
35
+ scan_with_axes = nn_partitioning.scan_with_axes
36
+ remat = nn_partitioning.remat
37
+
38
+
39
+ @flax.struct.dataclass
40
+ class FlaxWav2Vec2BaseModelOutput(ModelOutput):
41
+ """
42
+ Output type of [`FlaxWav2Vec2BaseModelOutput`], with potential hidden states and attentions.
43
+
44
+ Args:
45
+ last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
46
+ Sequence of hidden-states at the output of the last layer of the model.
47
+ extract_features (`jnp.ndarray` of shape `(batch_size, sequence_length, last_conv_dim)`):
48
+ Sequence of extracted feature vectors of the last convolutional layer of the model with `last_conv_dim`
49
+ being the dimension of the last convolutional layer.
50
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
51
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
52
+ `(batch_size, sequence_length, hidden_size)`.
53
+
54
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
55
+ attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
56
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
57
+ sequence_length)`.
58
+
59
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
60
+ heads.
61
+ """
62
+
63
+ last_hidden_state: jnp.ndarray = None
64
+ extract_features: jnp.ndarray = None
65
+ hidden_states: Optional[Tuple[jnp.ndarray]] = None
66
+ attentions: Optional[Tuple[jnp.ndarray]] = None
67
+
68
+
69
+ WAV_2_VEC_2_START_DOCSTRING = r"""
70
+ Wav2Vec2 was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
71
+ Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
72
+ Auli.
73
+
74
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
75
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
76
+ etc.)
77
+
78
+ This model is also a Flax Linen
79
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
80
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
81
+
82
+ Finally, this model supports inherent JAX features such as:
83
+
84
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
85
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
86
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
87
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
88
+
89
+ Parameters:
90
+ config ([`Wav2Vec2Config`]): Model configuration class with all the parameters of the model.
91
+ Initializing with a config file does not load the weights associated with the model, only the
92
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
93
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
94
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
95
+ `jax.numpy.bfloat16` (on TPUs).
96
+
97
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
98
+ specified all the computation will be performed with the given `dtype`.
99
+
100
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
101
+ parameters.**
102
+
103
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
104
+ [`~FlaxPreTrainedModel.to_bf16`].
105
+ """
106
+
107
+
108
+ WAV_2_VEC_2_INPUTS_DOCSTRING = r"""
109
+ Args:
110
+ input_values (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
111
+ Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file
112
+ into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile library (*pip install
113
+ soundfile*). To prepare the array into *input_values*, the [`Wav2Vec2Processor`] should be used for padding
114
+ and conversion into a tensor of type *jnp.ndarray*. See [`Wav2Vec2Processor.__call__`] for details.
115
+ attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
116
+ Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
117
+ 1]`:
118
+
119
+ - 1 for tokens that are **not masked**,
120
+ - 0 for tokens that are **masked**.
121
+
122
+ [What are attention masks?](../glossary#attention-mask) .. warning:: `attention_mask` should only be passed
123
+ if the corresponding processor has `config.return_attention_mask == True`. For all models whose processor
124
+ has `config.return_attention_mask == False`, such as
125
+ [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base-960h), `attention_mask` should **not** be
126
+ passed to avoid degraded performance when doing batched inference. For such models `input_values` should
127
+ simply be padded with 0 and passed without `attention_mask`. Be aware that these models also yield slightly
128
+ different results depending on whether `input_values` is padded or not.
129
+ mask_time_indices (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
130
+ Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
131
+ masked extracted features in *config.proj_codevector_dim* space.
132
+ output_attentions (`bool`, *optional*):
133
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
134
+ tensors for more detail.
135
+ output_hidden_states (`bool`, *optional*):
136
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
137
+ more detail.
138
+ return_dict (`bool`, *optional*):
139
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
140
+ """
141
+
142
+
143
+ class FlaxWav2Vec2LayerNormConvLayer(nn.Module):
144
+ config: Wav2Vec2Config
145
+ layer_id: int = 0
146
+ dtype: jnp.dtype = jnp.float32
147
+
148
+ def setup(self):
149
+ self.in_conv_dim = self.config.conv_dim[self.layer_id] if self.layer_id > 0 else 1
150
+ self.out_conv_dim = self.config.conv_dim[self.layer_id]
151
+
152
+ self.conv = nn.Conv(
153
+ features=self.config.conv_dim[self.layer_id],
154
+ kernel_size=(self.config.conv_kernel[self.layer_id],),
155
+ strides=(self.config.conv_stride[self.layer_id],),
156
+ use_bias=self.config.conv_bias,
157
+ kernel_init=jax.nn.initializers.he_normal(),
158
+ padding="VALID",
159
+ dtype=self.dtype,
160
+ )
161
+ self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
162
+ self.activation = ACT2FN[self.config.feat_extract_activation]
163
+
164
+ def __call__(self, hidden_states):
165
+ hidden_states = self.conv(hidden_states)
166
+ hidden_states = self.layer_norm(hidden_states)
167
+ hidden_states = self.activation(hidden_states)
168
+ return hidden_states
169
+
170
+
171
+ class FlaxConvWithWeightNorm(nn.Module):
172
+ config: Wav2Vec2Config
173
+ dtype: jnp.dtype = jnp.float32
174
+
175
+ def setup(self):
176
+ self.conv = nn.Conv(
177
+ features=self.config.hidden_size,
178
+ kernel_size=(self.config.num_conv_pos_embeddings,),
179
+ kernel_init=jax.nn.initializers.he_normal(),
180
+ padding="VALID",
181
+ feature_group_count=self.config.num_conv_pos_embedding_groups,
182
+ dtype=self.dtype,
183
+ )
184
+ weight_shape = (
185
+ self.conv.features,
186
+ self.conv.features // self.conv.feature_group_count,
187
+ self.conv.kernel_size[0],
188
+ )
189
+ self.weight_v = self.param("weight_v", jax.nn.initializers.he_normal(), weight_shape)
190
+ self.weight_g = self.param("weight_g", lambda _: jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :])
191
+ self.bias = self.param("bias", jax.nn.initializers.zeros, (self.conv.features,))
192
+ self.prev_padding = self.conv.kernel_size[0] // 2
193
+
194
+ def _get_normed_weights(self):
195
+ weight_v_norm = jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :]
196
+ normed_weight_v = jnp.divide(self.weight_v, weight_v_norm)
197
+ normed_kernel = jnp.multiply(normed_weight_v, self.weight_g)
198
+ return normed_kernel
199
+
200
+ def __call__(self, hidden_states):
201
+ kernel = self._get_normed_weights()
202
+ hidden_states = jnp.pad(hidden_states, ((0, 0), (self.prev_padding, self.prev_padding), (0, 0)))
203
+ hidden_states = self.conv.apply({"params": {"kernel": kernel.T, "bias": self.bias}}, hidden_states)
204
+ return hidden_states
205
+
206
+
207
+ class FlaxWav2Vec2PositionalConvEmbedding(nn.Module):
208
+ config: Wav2Vec2Config
209
+ dtype: jnp.dtype = jnp.float32
210
+
211
+ def setup(self):
212
+ self.conv = FlaxConvWithWeightNorm(self.config, dtype=self.dtype)
213
+ self.activation = ACT2FN[self.config.feat_extract_activation]
214
+ self.num_pad_remove = 1 if self.config.num_conv_pos_embeddings % 2 == 0 else 0
215
+
216
+ def __call__(self, hidden_states):
217
+ hidden_states = hidden_states.transpose((0, 1, 2))
218
+
219
+ hidden_states = self.conv(hidden_states)
220
+
221
+ if self.num_pad_remove > 0:
222
+ hidden_states = hidden_states[:, : -self.num_pad_remove, :]
223
+ hidden_states = self.activation(hidden_states)
224
+
225
+ hidden_states = hidden_states.transpose((0, 1, 2))
226
+ return hidden_states
227
+
228
+
229
+ class FlaxConvLayersCollection(nn.Module):
230
+ config: Wav2Vec2Config
231
+ dtype: jnp.dtype = jnp.float32
232
+
233
+ def setup(self):
234
+ if self.config.feat_extract_norm == "layer":
235
+ # note that we can't use scan on the conv layers as they differ on a layer-by-layer basis
236
+ BlockLayer = remat(FlaxWav2Vec2LayerNormConvLayer) if self.config.gradient_checkpointing else FlaxWav2Vec2LayerNormConvLayer
237
+ self.layers = [
238
+ BlockLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype)
239
+ for i in range(self.config.num_feat_extract_layers)
240
+ ]
241
+ elif self.config.feat_extract_norm == "group":
242
+ raise NotImplementedError("At the moment only ``config.feat_extact_norm == 'layer'`` is supported")
243
+ else:
244
+ raise ValueError(
245
+ f"`config.feat_extract_norm` is {self.config.feat_extract_norm}, but has to be one of ['group', 'layer']"
246
+ )
247
+
248
+ def __call__(self, hidden_states):
249
+ for i, conv_layer in enumerate(self.layers):
250
+ hidden_states = conv_layer(hidden_states)
251
+ return hidden_states
252
+
253
+
254
+ class FlaxWav2Vec2FeatureEncoder(nn.Module):
255
+ """Construct the features from raw audio waveform"""
256
+
257
+ config: Wav2Vec2Config
258
+ dtype: jnp.dtype = jnp.float32
259
+
260
+ def setup(self):
261
+ self.conv_layers = FlaxConvLayersCollection(self.config, dtype=self.dtype)
262
+
263
+ def __call__(self, input_values, freeze_feature_encoder=False):
264
+ hidden_states = input_values[:, :, None]
265
+ hidden_states = self.conv_layers(hidden_states)
266
+ if freeze_feature_encoder:
267
+ hidden_states = jax.lax.stop_gradient(hidden_states)
268
+ return hidden_states
269
+
270
+
271
+ class FlaxWav2Vec2FeatureProjection(nn.Module):
272
+ config: Wav2Vec2Config
273
+ dtype: jnp.dtype = jnp.float32
274
+
275
+ def setup(self):
276
+ self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
277
+ self.projection = nn.Dense(
278
+ self.config.hidden_size,
279
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
280
+ dtype=self.dtype,
281
+ )
282
+ self.dropout = nn.Dropout(rate=self.config.feat_proj_dropout)
283
+
284
+ def __call__(self, hidden_states, deterministic=True):
285
+ norm_hidden_states = self.layer_norm(hidden_states)
286
+ hidden_states = self.projection(norm_hidden_states)
287
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
288
+ return hidden_states, norm_hidden_states
289
+
290
+
291
+ class FlaxWav2Vec2Attention(nn.Module):
292
+ config: Wav2Vec2Config
293
+ embed_dim: int
294
+ num_heads: int
295
+ dropout: float = 0.0
296
+ bias: bool = True
297
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
298
+
299
+ def setup(self) -> None:
300
+ self.head_dim = self.embed_dim // self.num_heads
301
+ if self.head_dim * self.num_heads != self.embed_dim:
302
+ raise ValueError(
303
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
304
+ )
305
+
306
+ dense = partial(
307
+ nn.Dense,
308
+ self.embed_dim,
309
+ use_bias=self.bias,
310
+ dtype=self.dtype,
311
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
312
+ )
313
+
314
+ self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
315
+
316
+ self.fused_proj = nn.Dense(
317
+ self.embed_dim * 3,
318
+ use_bias=self.bias,
319
+ dtype=self.dtype,
320
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
321
+ )
322
+
323
+ self.out_proj = dense()
324
+
325
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
326
+
327
+ def _split_heads(self, hidden_states):
328
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
329
+
330
+ def _merge_heads(self, hidden_states):
331
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
332
+
333
+ def __call__(
334
+ self,
335
+ hidden_states: jnp.ndarray,
336
+ key_value_states: Optional[jnp.ndarray] = None,
337
+ attention_mask: Optional[jnp.ndarray] = None,
338
+ deterministic: bool = True,
339
+ ) -> Tuple[jnp.ndarray]:
340
+ """Input shape: Batch x Time x Channel"""
341
+
342
+ if self.config.fuse_matmuls:
343
+ attention_states = self.fused_proj(hidden_states)
344
+ query_states, key_states, value_states = jnp.split(attention_states, 3, axis=-1)
345
+
346
+ else:
347
+ # get query proj
348
+ query_states = self.q_proj(hidden_states)
349
+
350
+ key_states = self.k_proj(hidden_states)
351
+ value_states = self.v_proj(hidden_states)
352
+
353
+ query_states = self._split_heads(query_states)
354
+ key_states = self._split_heads(key_states)
355
+ value_states = self._split_heads(value_states)
356
+
357
+ if attention_mask is not None:
358
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
359
+
360
+ # Convert the boolean attention mask to an attention bias.
361
+ if attention_mask is not None:
362
+ # attention mask in the form of attention bias
363
+ attention_bias = lax.select(
364
+ attention_mask > 0,
365
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
366
+ jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
367
+ )
368
+ else:
369
+ attention_bias = None
370
+
371
+ dropout_rng = None
372
+ if not deterministic and self.dropout > 0.0:
373
+ dropout_rng = self.make_rng("dropout")
374
+
375
+ attn_weights = dot_product_attention_weights(
376
+ query_states,
377
+ key_states,
378
+ bias=attention_bias,
379
+ dropout_rng=dropout_rng,
380
+ dropout_rate=self.dropout,
381
+ broadcast_dropout=True,
382
+ deterministic=deterministic,
383
+ dtype=self.dtype,
384
+ precision=None,
385
+ )
386
+
387
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
388
+ attn_output = self._merge_heads(attn_output)
389
+ attn_output = self.out_proj(attn_output)
390
+
391
+ return attn_output, attn_weights
392
+
393
+
394
+ class FlaxWav2Vec2FeedForward(nn.Module):
395
+ config: Wav2Vec2Config
396
+ dtype: jnp.dtype = jnp.float32
397
+
398
+ def setup(self):
399
+ self.intermediate_dropout = nn.Dropout(rate=self.config.activation_dropout)
400
+
401
+ self.intermediate_dense = nn.Dense(
402
+ self.config.intermediate_size,
403
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
404
+ dtype=self.dtype,
405
+ )
406
+ if isinstance(self.config.hidden_act, str):
407
+ self.intermediate_act_fn = ACT2FN[self.config.hidden_act]
408
+ else:
409
+ self.intermediate_act_fn = self.config.hidden_act
410
+
411
+ self.output_dense = nn.Dense(
412
+ self.config.hidden_size,
413
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
414
+ dtype=self.dtype,
415
+ )
416
+ self.output_dropout = nn.Dropout(rate=self.config.hidden_dropout)
417
+
418
+ def __call__(self, hidden_states, deterministic=True):
419
+ hidden_states = self.intermediate_dense(hidden_states)
420
+ hidden_states = self.intermediate_act_fn(hidden_states)
421
+ hidden_states = self.intermediate_dropout(hidden_states, deterministic=deterministic)
422
+
423
+ hidden_states = self.output_dense(hidden_states)
424
+ hidden_states = self.output_dropout(hidden_states, deterministic=deterministic)
425
+ return hidden_states
426
+
427
+
428
+ class FlaxWav2Vec2EncoderLayerStableLayerNorm(nn.Module):
429
+ config: Wav2Vec2Config
430
+ dtype: jnp.dtype = jnp.float32
431
+
432
+ def setup(self):
433
+ self.attention = FlaxWav2Vec2Attention(
434
+ config=self.config,
435
+ embed_dim=self.config.hidden_size,
436
+ num_heads=self.config.num_attention_heads,
437
+ dropout=self.config.attention_dropout,
438
+ dtype=self.dtype,
439
+ )
440
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout)
441
+ self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
442
+ self.feed_forward = FlaxWav2Vec2FeedForward(self.config, dtype=self.dtype)
443
+ self.final_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
444
+
445
+ def __call__(self, hidden_states, attention_mask=None, deterministic=True, output_attentions=False):
446
+ if self.config.use_scan:
447
+ hidden_states = hidden_states[0]
448
+ attn_residual = hidden_states
449
+ hidden_states = self.layer_norm(hidden_states)
450
+ hidden_states, attn_weights = self.attention(
451
+ hidden_states, attention_mask=attention_mask, deterministic=deterministic
452
+ )
453
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
454
+ hidden_states = attn_residual + hidden_states
455
+ hidden_states = hidden_states + self.feed_forward(
456
+ self.final_layer_norm(hidden_states), deterministic=deterministic
457
+ )
458
+
459
+ outputs = (hidden_states,)
460
+
461
+ if output_attentions:
462
+ outputs += (attn_weights,)
463
+
464
+ if self.config.use_scan:
465
+ outputs = (outputs, None)
466
+
467
+ return outputs
468
+
469
+
470
+ class FlaxWav2Vec2EncoderLayerStableLayerNormCollection(nn.Module):
471
+ config: Wav2Vec2Config
472
+ dtype: jnp.dtype = jnp.float32
473
+
474
+ @nn.compact
475
+ def __call__(
476
+ self,
477
+ hidden_states,
478
+ attention_mask=None,
479
+ deterministic: bool = True,
480
+ output_attentions: bool = False,
481
+ output_hidden_states: bool = False,
482
+ return_dict: bool = True,
483
+ ):
484
+ all_attentions = () if output_attentions else None
485
+ all_hidden_states = () if output_hidden_states else None
486
+
487
+ num_layers = self.config.num_hidden_layers
488
+ BlockEncoderLayer = (
489
+ remat(
490
+ FlaxWav2Vec2EncoderLayerStableLayerNorm,
491
+ static_argnums=(2, 3),
492
+ prevent_cse=not self.config.use_scan,
493
+ )
494
+ if self.config.gradient_checkpointing
495
+ else FlaxWav2Vec2EncoderLayerStableLayerNorm
496
+ )
497
+
498
+ if self.config.use_scan:
499
+ # since all decoder layers are the same, we use nn.scan directly
500
+ assert not output_attentions, "cannot use `scan` with `output_attentions` set to `True`"
501
+ assert not output_hidden_states, "cannot use `scan` with `output_hidden_states` set to `True`"
502
+ hidden_states = (hidden_states,)
503
+
504
+ hidden_states, _ = scan_with_axes(
505
+ BlockEncoderLayer,
506
+ variable_axes={"params": 0, "cache": 0},
507
+ split_rngs={"params": True, "dropout": True},
508
+ in_axes=(nn.broadcast, nn.broadcast, nn.broadcast),
509
+ length=num_layers,
510
+ )(self.config, dtype=self.dtype, name="FlaxWav2Vec2EncoderLayers",)(
511
+ hidden_states, attention_mask, deterministic, output_attentions
512
+ )
513
+ hidden_states = hidden_states[0]
514
+
515
+ else:
516
+ for layer in range(num_layers):
517
+ if output_hidden_states:
518
+ all_hidden_states += (hidden_states,)
519
+
520
+ layer_outputs = BlockEncoderLayer(
521
+ self.config,
522
+ dtype=self.dtype,
523
+ name=str(layer),
524
+ )(hidden_states, attention_mask, deterministic, output_attentions)
525
+
526
+ hidden_states = layer_outputs[0]
527
+
528
+ if output_attentions:
529
+ all_attentions += (layer_outputs[1],)
530
+
531
+ if output_hidden_states:
532
+ all_hidden_states += (hidden_states,)
533
+
534
+ outputs = (hidden_states, all_hidden_states, all_attentions)
535
+
536
+ if not return_dict:
537
+ return tuple(v for v in outputs if v is not None)
538
+
539
+ return FlaxBaseModelOutput(
540
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
541
+ )
542
+
543
+
544
+ class FlaxWav2Vec2StableLayerNormEncoder(nn.Module):
545
+ config: Wav2Vec2Config
546
+ dtype: jnp.dtype = jnp.float32
547
+
548
+ def setup(self):
549
+ self.pos_conv_embed = FlaxWav2Vec2PositionalConvEmbedding(self.config, dtype=self.dtype)
550
+ self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
551
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout)
552
+ self.layers = FlaxWav2Vec2EncoderLayerStableLayerNormCollection(self.config, dtype=self.dtype)
553
+
554
+ def __call__(
555
+ self,
556
+ hidden_states,
557
+ attention_mask=None,
558
+ deterministic=True,
559
+ output_attentions=False,
560
+ output_hidden_states=False,
561
+ return_dict=True,
562
+ ):
563
+
564
+ if attention_mask is not None:
565
+ # make sure padded tokens are not attended to
566
+ hidden_states = jnp.where(
567
+ jnp.broadcast_to(attention_mask[:, :, None], hidden_states.shape), hidden_states, 0
568
+ )
569
+
570
+ position_embeddings = self.pos_conv_embed(hidden_states)
571
+
572
+ hidden_states = hidden_states + position_embeddings
573
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
574
+
575
+ outputs = self.layers(
576
+ hidden_states,
577
+ attention_mask,
578
+ output_attentions=output_attentions,
579
+ output_hidden_states=output_hidden_states,
580
+ return_dict=return_dict,
581
+ )
582
+
583
+ last_hidden_state = self.layer_norm(outputs[0])
584
+
585
+ # update the last element in `hidden_states` after applying `layernorm` above
586
+ hidden_states = None
587
+ if output_hidden_states:
588
+ hidden_states = outputs[1]
589
+ hidden_states = hidden_states[:-1] + (last_hidden_state,)
590
+
591
+ if not return_dict:
592
+ outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
593
+ return tuple(v for v in outputs if v is not None)
594
+
595
+ return FlaxBaseModelOutput(
596
+ last_hidden_state=last_hidden_state, hidden_states=hidden_states, attentions=outputs.attentions
597
+ )
598
+
599
+
600
+ class FlaxWav2Vec2Adapter(nn.Module):
601
+ config: Wav2Vec2Config
602
+ dtype: jnp.dtype = jnp.float32
603
+
604
+ def setup(self):
605
+ # hidden_states require down-projection if feature dims don't match
606
+ if self.config.output_hidden_size != self.config.hidden_size:
607
+ self.proj = nn.Dense(
608
+ self.config.output_hidden_size,
609
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
610
+ dtype=self.dtype,
611
+ )
612
+ self.proj_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
613
+ else:
614
+ self.proj = self.proj_layer_norm = None
615
+
616
+ self.layers = FlaxWav2Vec2AdapterLayersCollection(self.config, dtype=self.dtype)
617
+
618
+ def __call__(self, hidden_states, deterministic=True):
619
+ # down-project hidden_states if required
620
+ if self.proj is not None and self.proj_layer_norm is not None:
621
+ hidden_states = self.proj(hidden_states)
622
+ hidden_states = self.proj_layer_norm(hidden_states)
623
+
624
+ hidden_states = self.layers(hidden_states)
625
+
626
+ return hidden_states
627
+
628
+
629
+ class FlaxWav2Vec2AdapterLayer(nn.Module):
630
+ config: Wav2Vec2Config
631
+ dtype: jnp.dtype = jnp.float32
632
+
633
+ def setup(self):
634
+ self.conv = nn.Conv(
635
+ features=2 * self.config.output_hidden_size,
636
+ kernel_size=(self.config.adapter_kernel_size,),
637
+ strides=(self.config.adapter_stride,),
638
+ padding=((1, 1),),
639
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
640
+ dtype=self.dtype,
641
+ )
642
+
643
+ def __call__(self, hidden_states):
644
+ hidden_states = self.conv(hidden_states)
645
+ hidden_states = nn.glu(hidden_states, axis=2)
646
+
647
+ return hidden_states
648
+
649
+
650
+ class FlaxWav2Vec2AdapterLayersCollection(nn.Module):
651
+ config: Wav2Vec2Config
652
+ dtype: jnp.dtype = jnp.float32
653
+
654
+ def setup(self):
655
+ BlockAdapterLayer = remat(FlaxWav2Vec2AdapterLayer) if self.config.gradient_checkpointing else FlaxWav2Vec2AdapterLayer
656
+ self.layers = [
657
+ BlockAdapterLayer(self.config, name=str(i), dtype=self.dtype)
658
+ for i in range(self.config.num_adapter_layers)
659
+ ]
660
+
661
+ def __call__(self, hidden_states):
662
+ for conv_layer in self.layers:
663
+ hidden_states = conv_layer(hidden_states)
664
+
665
+ return hidden_states
666
+
667
+
668
+ class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
669
+ """
670
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
671
+ models.
672
+ """
673
+
674
+ config_class = Wav2Vec2Config
675
+ base_model_prefix: str = "wav2vec2"
676
+ main_input_name = "input_values"
677
+ module_class: nn.Module = None
678
+
679
+ def __init__(
680
+ self,
681
+ config: Wav2Vec2Config,
682
+ input_shape: Tuple = (1, 1024),
683
+ seed: int = 0,
684
+ dtype: jnp.dtype = jnp.float32,
685
+ _do_init: bool = True,
686
+ **kwargs,
687
+ ):
688
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
689
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
690
+
691
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
692
+ # init input tensors
693
+ input_values = jnp.zeros(input_shape, dtype="i4")
694
+ attention_mask = jnp.ones_like(input_values)
695
+ params_rng, dropout_rng = jax.random.split(rng, 2)
696
+ rngs = {"params": params_rng, "dropout": dropout_rng}
697
+
698
+ return self.module.init(rngs, input_values, attention_mask, return_dict=False)["params"]
699
+
700
+ def __call__(
701
+ self,
702
+ input_values,
703
+ attention_mask=None,
704
+ mask_time_indices=None,
705
+ extract_features=None,
706
+ params: dict = None,
707
+ dropout_rng: jax.random.PRNGKey = None,
708
+ train: bool = False,
709
+ output_attentions: Optional[bool] = None,
710
+ output_hidden_states: Optional[bool] = None,
711
+ output_features: Optional[bool] = None,
712
+ freeze_feature_encoder: bool = False,
713
+ return_dict: Optional[bool] = None,
714
+ ):
715
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
716
+ output_hidden_states = (
717
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
718
+ )
719
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
720
+
721
+ if attention_mask is None:
722
+ batch_size, sequence_length = input_values.shape
723
+ attention_mask = jnp.ones((batch_size, sequence_length))
724
+
725
+ if extract_features is not None:
726
+ extract_features = jnp.array(extract_features, dtype="f4")
727
+
728
+ # Handle any PRNG if needed
729
+ rngs = {}
730
+ if dropout_rng is not None:
731
+ rngs["dropout"] = dropout_rng
732
+
733
+ inputs = {"params": params or self.params}
734
+
735
+ return self.module.apply(
736
+ inputs,
737
+ jnp.array(input_values, dtype="f4"),
738
+ jnp.array(attention_mask, dtype="i4"),
739
+ mask_time_indices,
740
+ extract_features,
741
+ not train,
742
+ output_attentions,
743
+ output_hidden_states,
744
+ output_features,
745
+ freeze_feature_encoder,
746
+ return_dict,
747
+ rngs=rngs,
748
+ )
749
+
750
+ def _get_feat_extract_output_lengths(
751
+ self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
752
+ ):
753
+ return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter)
754
+
755
+ def _get_feature_vector_attention_mask(
756
+ self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
757
+ ):
758
+ return self.module._get_feature_vector_attention_mask(feature_vector_length, attention_mask, add_adapter=add_adapter)
759
+
760
+
761
+ class FlaxWav2Vec2Module(nn.Module):
762
+ config: Wav2Vec2Config
763
+ dtype: jnp.dtype = jnp.float32
764
+
765
+ def setup(self):
766
+ self.feature_extractor = FlaxWav2Vec2FeatureEncoder(self.config, dtype=self.dtype)
767
+ self.feature_projection = FlaxWav2Vec2FeatureProjection(self.config, dtype=self.dtype)
768
+ self.masked_spec_embed = self.param(
769
+ "masked_spec_embed", jax.nn.initializers.uniform(), (self.config.hidden_size,)
770
+ )
771
+
772
+ if self.config.do_stable_layer_norm:
773
+ self.encoder = FlaxWav2Vec2StableLayerNormEncoder(self.config, dtype=self.dtype)
774
+ else:
775
+ raise NotImplementedError("``config.do_stable_layer_norm is False`` is currently not supported.")
776
+
777
+ self.adapter = FlaxWav2Vec2Adapter(self.config, dtype=self.dtype) if self.config.add_adapter else None
778
+
779
+ def __call__(
780
+ self,
781
+ input_values,
782
+ attention_mask=None,
783
+ mask_time_indices=None,
784
+ extract_features=None,
785
+ deterministic=True,
786
+ output_attentions=None,
787
+ output_hidden_states=None,
788
+ output_features=False,
789
+ freeze_feature_encoder=False,
790
+ return_dict=None,
791
+ ):
792
+
793
+ # forward pass through the feature extractor if features not specified
794
+ if extract_features is None:
795
+ extract_features = self.feature_extractor(input_values, freeze_feature_encoder=freeze_feature_encoder)
796
+
797
+ if output_features:
798
+ return extract_features
799
+
800
+ # make sure that no loss is computed on padded inputs
801
+ if attention_mask is not None:
802
+ # compute reduced attention_mask corresponding to feature vectors
803
+ attention_mask = self._get_feature_vector_attention_mask(
804
+ extract_features.shape[1], attention_mask, add_adapter=False
805
+ )
806
+
807
+ hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic)
808
+ if mask_time_indices is not None: # apply SpecAugment along time axis with given indices
809
+ hidden_states = jnp.where(
810
+ jnp.broadcast_to(mask_time_indices[:, :, None], hidden_states.shape),
811
+ jnp.broadcast_to(self.masked_spec_embed[None, None, :], hidden_states.shape),
812
+ hidden_states,
813
+ )
814
+
815
+ encoder_outputs = self.encoder(
816
+ hidden_states,
817
+ attention_mask=attention_mask,
818
+ deterministic=deterministic,
819
+ output_attentions=output_attentions,
820
+ output_hidden_states=output_hidden_states,
821
+ return_dict=return_dict,
822
+ )
823
+
824
+ hidden_states = encoder_outputs[0]
825
+
826
+ if self.adapter is not None:
827
+ hidden_states = self.adapter(hidden_states)
828
+
829
+ if not return_dict:
830
+ return (hidden_states, extract_features) + encoder_outputs[1:]
831
+
832
+ return FlaxWav2Vec2BaseModelOutput(
833
+ last_hidden_state=hidden_states,
834
+ extract_features=extract_features,
835
+ hidden_states=encoder_outputs.hidden_states,
836
+ attentions=encoder_outputs.attentions,
837
+ )
838
+
839
+ def _get_feat_extract_output_lengths(
840
+ self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
841
+ ):
842
+ """
843
+ Computes the output length of the convolutional layers
844
+ """
845
+
846
+ add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
847
+
848
+ def _conv_out_length(input_length, kernel_size, stride):
849
+ # 1D convolutional layer output length formula taken
850
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
851
+ return (input_length - kernel_size) // stride + 1
852
+
853
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
854
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
855
+
856
+ if add_adapter:
857
+ for _ in range(self.config.num_adapter_layers):
858
+ input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
859
+
860
+ return input_lengths
861
+
862
+ def _get_feature_vector_attention_mask(
863
+ self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
864
+ ):
865
+ # Effectively attention_mask.sum(-1), but not inplace to be able to run
866
+ # on inference mode.
867
+ non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1]
868
+
869
+ output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
870
+
871
+ batch_size = attention_mask.shape[0]
872
+
873
+ attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
874
+ # these two operations makes sure that all values
875
+ # before the output lengths indices are attended to
876
+ attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1)
877
+ attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
878
+ return attention_mask
879
+
880
+
881
+ class FlaxWav2Vec2Model(FlaxWav2Vec2PreTrainedModel):
882
+ module_class = FlaxWav2Vec2Module
883
+
884
+
885
+ class FlaxWav2Vec2ForCTCModule(nn.Module):
886
+ config: Wav2Vec2Config
887
+ dtype: jnp.dtype = jnp.float32
888
+
889
+ def setup(self):
890
+ self.wav2vec2 = FlaxWav2Vec2Module(self.config, dtype=self.dtype)
891
+ self.dropout = nn.Dropout(rate=self.config.final_dropout)
892
+ self.lm_head = nn.Dense(
893
+ self.config.vocab_size,
894
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
895
+ dtype=self.dtype,
896
+ )
897
+
898
+ def __call__(
899
+ self,
900
+ input_values,
901
+ attention_mask=None,
902
+ mask_time_indices=None,
903
+ extract_features=None,
904
+ deterministic=True,
905
+ output_attentions=None,
906
+ output_hidden_states=None,
907
+ output_features=False,
908
+ freeze_feature_encoder=False,
909
+ return_dict=None,
910
+ ):
911
+ outputs = self.wav2vec2(
912
+ input_values,
913
+ attention_mask=attention_mask,
914
+ mask_time_indices=mask_time_indices,
915
+ deterministic=deterministic,
916
+ output_attentions=output_attentions,
917
+ output_hidden_states=output_hidden_states,
918
+ freeze_feature_encoder=freeze_feature_encoder,
919
+ return_dict=return_dict,
920
+ )
921
+
922
+ hidden_states = outputs[0]
923
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
924
+
925
+ logits = self.lm_head(hidden_states)
926
+
927
+ if not return_dict:
928
+ return (logits,) + outputs[2:]
929
+
930
+ return FlaxCausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
931
+
932
+ def _get_feat_extract_output_lengths(
933
+ self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
934
+ ):
935
+ """
936
+ Computes the output length of the convolutional layers
937
+ """
938
+
939
+ add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
940
+
941
+ def _conv_out_length(input_length, kernel_size, stride):
942
+ # 1D convolutional layer output length formula taken
943
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
944
+ return (input_length - kernel_size) // stride + 1
945
+
946
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
947
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
948
+
949
+ if add_adapter:
950
+ for _ in range(self.config.num_adapter_layers):
951
+ input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
952
+
953
+ return input_lengths
954
+
955
+ def _get_feature_vector_attention_mask(
956
+ self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
957
+ ):
958
+ # Effectively attention_mask.sum(-1), but not inplace to be able to run
959
+ # on inference mode.
960
+ non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1]
961
+
962
+ output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
963
+
964
+ batch_size = attention_mask.shape[0]
965
+
966
+ attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
967
+ # these two operations makes sure that all values
968
+ # before the output lengths indices are attended to
969
+ attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1)
970
+ attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
971
+ return attention_mask
972
+
973
+
974
+ class FlaxWav2Vec2ForCTC(FlaxWav2Vec2PreTrainedModel):
975
+ module_class = FlaxWav2Vec2ForCTCModule
nohup.out ADDED
The diff for this file is too large to render. See raw diff
 
preprocessor_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "feature_extractor_type": "Wav2Vec2FeatureExtractor",
4
+ "feature_size": 1,
5
+ "padding_side": "right",
6
+ "padding_value": 0.0,
7
+ "return_attention_mask": true,
8
+ "sampling_rate": 16000
9
+ }
run_flax_speech_recognition_seq2seq.py ADDED
@@ -0,0 +1,1572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2022 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the Flax library models for sequence to sequence speech recognition.
18
+ """
19
+ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
+
21
+ import logging
22
+ import math
23
+ import os
24
+ import re
25
+ import sys
26
+ import time
27
+ from dataclasses import dataclass, field
28
+ from pathlib import Path
29
+ from typing import Any, Callable, Dict, List, Optional, Union
30
+
31
+ import datasets
32
+ import numpy as np
33
+ from datasets import DatasetDict, load_dataset, load_metric
34
+ from tqdm import tqdm
35
+
36
+ import flax
37
+ import jax
38
+ import jax.numpy as jnp
39
+ import optax
40
+ import transformers
41
+ import wandb as wandb
42
+ from flax import core, jax_utils, struct, traverse_util
43
+ from flax.jax_utils import pad_shard_unpad, unreplicate
44
+ from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
45
+ from huggingface_hub import Repository
46
+ from models import FlaxSpeechEncoderDecoderModel
47
+ from optax._src import linear_algebra
48
+ from transformers import (
49
+ AutoConfig,
50
+ AutoFeatureExtractor,
51
+ AutoProcessor,
52
+ AutoTokenizer,
53
+ HfArgumentParser,
54
+ Seq2SeqTrainingArguments,
55
+ is_tensorboard_available,
56
+ )
57
+ from transformers.file_utils import get_full_repo_name
58
+ from transformers.trainer_utils import get_last_checkpoint
59
+ from transformers.utils import check_min_version
60
+ from transformers.utils.versions import require_version
61
+
62
+
63
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
64
+ check_min_version("4.17.0.dev0")
65
+
66
+ require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
67
+
68
+ logger = logging.getLogger(__name__)
69
+
70
+
71
+ @flax.struct.dataclass
72
+ class ModelArguments:
73
+ """
74
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
75
+ """
76
+
77
+ model_name_or_path: str = field(
78
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
79
+ )
80
+ config_name: Optional[str] = field(
81
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
82
+ )
83
+ tokenizer_name: Optional[str] = field(
84
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
85
+ )
86
+ feature_extractor_name: Optional[str] = field(
87
+ default=None, metadata={"help": "feature extractor name or path if not the same as model_name"}
88
+ )
89
+ cache_dir: Optional[str] = field(
90
+ default=None,
91
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
92
+ )
93
+ use_fast_tokenizer: bool = field(
94
+ default=True,
95
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
96
+ )
97
+ model_revision: str = field(
98
+ default="main",
99
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
100
+ )
101
+ use_auth_token: bool = field(
102
+ default=False,
103
+ metadata={
104
+ "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
105
+ "with private models)."
106
+ },
107
+ )
108
+ freeze_feature_encoder: bool = field(
109
+ default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
110
+ )
111
+ activation_dropout: float = field(
112
+ default=0.1,
113
+ metadata={
114
+ "help": "The hidden activation dropout probability in the embeddings, encoder, and pooler."
115
+ },
116
+ )
117
+ hidden_dropout: float = field(
118
+ default=0.1,
119
+ metadata={
120
+ "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler."
121
+ },
122
+ )
123
+ feat_proj_dropout: float = field(
124
+ default=0.0,
125
+ metadata={
126
+ "help": "The feat proj dropout probability for feature encoder representations."
127
+ },
128
+ )
129
+ mask_time_prob: float = field(
130
+ default=0.1,
131
+ metadata={
132
+ "help": "The spec aug dropout probability for feature encoder representations."
133
+ },
134
+ )
135
+ encoder_add_adapter: bool = field(
136
+ default=True, metadata={"help": "Whether to add an adapter layer between the encoder and decoder."}
137
+ )
138
+
139
+
140
+ @flax.struct.dataclass
141
+ class DataTrainingArguments:
142
+ """
143
+ Arguments pertaining to what data we are going to input our model for training and eval.
144
+ """
145
+
146
+ dataset_name: str = field(
147
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
148
+ )
149
+ dataset_config_name: Optional[str] = field(
150
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
151
+ )
152
+ text_column: Optional[str] = field(
153
+ default=None,
154
+ metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
155
+ )
156
+ dataset_cache_dir: Optional[str] = field(
157
+ default=None, metadata={"help": "Path to cache directory for saving and loading datasets"}
158
+ )
159
+ overwrite_cache: bool = field(
160
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
161
+ )
162
+ preprocessing_num_workers: Optional[int] = field(
163
+ default=None,
164
+ metadata={"help": "The number of processes to use for the preprocessing."},
165
+ )
166
+ max_train_samples: Optional[int] = field(
167
+ default=None,
168
+ metadata={
169
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
170
+ "value if set."
171
+ },
172
+ )
173
+ max_eval_samples: Optional[int] = field(
174
+ default=None,
175
+ metadata={
176
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
177
+ "value if set."
178
+ },
179
+ )
180
+ max_test_samples: Optional[int] = field(
181
+ default=None,
182
+ metadata={
183
+ "help": "For debugging purposes or quicker training, truncate the number of test examples to this "
184
+ "value if set."
185
+ },
186
+ )
187
+ audio_column_name: str = field(
188
+ default="audio",
189
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
190
+ )
191
+ text_column_name: str = field(
192
+ default="text",
193
+ metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
194
+ )
195
+ id_column_name: str = field(
196
+ default="id",
197
+ metadata={"help": "The name of the dataset column containing the id data. Defaults to 'id'"},
198
+ )
199
+ max_duration_in_seconds: float = field(
200
+ default=20.0,
201
+ metadata={
202
+ "help": "Filter audio files in the training set that are longer than `max_duration_in_seconds` seconds"
203
+ },
204
+ )
205
+ min_duration_in_seconds: float = field(
206
+ default=0.0, metadata={"help": "Filter audio files in the training set that are shorter than `min_duration_in_seconds` seconds"}
207
+ )
208
+ max_target_length: Optional[int] = field(
209
+ default=128,
210
+ metadata={
211
+ "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
212
+ "than this will be truncated, sequences shorter will be padded."
213
+ },
214
+ )
215
+ min_target_length: Optional[int] = field(
216
+ default=0,
217
+ metadata={
218
+ "help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
219
+ "than this will be filtered."
220
+ },
221
+ )
222
+ pad_input_to_multiple_of: Optional[int] = field(
223
+ default=24000,
224
+ metadata={
225
+ "help": "If set will pad the input sequence to a multiple of the provided value. "
226
+ "This is important to avoid triggering recompilations on TPU."
227
+ },
228
+ )
229
+ pad_target_to_multiple_of: Optional[int] = field(
230
+ default=None,
231
+ metadata={
232
+ "help": "If set will pad the target sequence to a multiple of the provided value. "
233
+ "This is important to avoid triggering recompilations on TPU. If unspecified, will default to `max_target_length`, "
234
+ " the equivalent of padding the targets to max length."
235
+ },
236
+ )
237
+ preprocessing_only: bool = field(
238
+ default=False,
239
+ metadata={
240
+ "help": "Whether to only do data preprocessing and skip training. "
241
+ "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
242
+ "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
243
+ "so that the cached datasets can consequently be loaded in distributed training"
244
+ },
245
+ )
246
+ train_split_name: str = field(
247
+ default="train",
248
+ metadata={
249
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
250
+ },
251
+ )
252
+ eval_split_name: str = field(
253
+ default="validation",
254
+ metadata={
255
+ "help": "The name of the evaluation data set split to use (via the datasets library). Defaults to 'validation'"
256
+ },
257
+ )
258
+ test_split_name: str = field(
259
+ default="test",
260
+ metadata={"help": "The name of the test data set split to use (via the datasets library). Defaults to 'test'"},
261
+ )
262
+ do_lower_case: bool = field(
263
+ default=True,
264
+ metadata={"help": "Whether the target text should be lower cased."},
265
+ )
266
+ wandb_project: str = field(
267
+ default="flax-speech-recognition-seq2seq",
268
+ metadata={"help": "The name of the wandb project."},
269
+ )
270
+ wandb_name: str = field(
271
+ default=None,
272
+ metadata={"help": "The name of the wandb run."},
273
+ )
274
+ wandb_job_type: str = field(
275
+ default="Seq2Seq",
276
+ metadata={"help": "The name of the wandb job type."},
277
+ )
278
+ log_first_ids: bool = field(
279
+ default=True,
280
+ metadata={
281
+ "help": "Whether to log the first id's from the dataset. Defaults to `True`. If `False`, will log the first id's returned by the grouped length sampler."
282
+ },
283
+ )
284
+
285
+
286
+ # @flax.struct.dataclass
287
+ @dataclass
288
+ class FlaxSeq2SeqTrainingArguments(Seq2SeqTrainingArguments):
289
+ precision: str = field(
290
+ default="full",
291
+ metadata={
292
+ "help": "Whether to enable mixed-precision training. If true, the optimizer is stored in half-precision (bfloat16) and computations are executed in half-precision"
293
+ "**Note that this only specifies the dtype of the computation and optimizer state. It does not influence the dtype of model parameters.**"
294
+ },
295
+ )
296
+ matmul_precision: str = field(
297
+ default="default",
298
+ metadata={
299
+ "help": "Default floating-point precision of internal computations used in TPU matrix multiplications and convolutions. "
300
+ "This configuration option controls the default precision for JAX operations that take an optional precision argument (e.g. `lax.conv_general_dilated` and `lax.dot`). "
301
+ "This configuration option does not change the behaviours of such calls with explicit precision arguments; "
302
+ "it only changes the behaviors of calls with no such argument provided. "
303
+ "One of `['highest', 'float32', 'high', 'bfloat16_3x', 'default', 'bfloat16', 'fastest', None]`."
304
+ },
305
+ )
306
+ generation_length_penalty: float = field(
307
+ default=1,
308
+ metadata={
309
+ "help": "Exponential penalty to the length. 1.0 (default) means no penalty. Set to values < 1.0 in order to encourage the model"
310
+ "to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer sequences."
311
+ },
312
+ )
313
+ final_generation_max_length: int = field(
314
+ default=None,
315
+ metadata={
316
+ "help": "The `max_length` to use on each evaluation loop when `predict_with_generate=True`. If unspecified, will default "
317
+ "to the `max_length` value of the model configuration."
318
+ },
319
+ )
320
+ final_generation_num_beams: int = field(
321
+ default=None,
322
+ metadata={
323
+ "help": "The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. If unspecified, will default "
324
+ "to the `num_beams` value of the model configuration."
325
+ },
326
+ )
327
+
328
+ def __post_init__(self):
329
+ if self.final_generation_max_length is None:
330
+ self.final_generation_max_length = self.generation_max_length
331
+ if self.final_generation_num_beams is None:
332
+ self.final_generation_num_beams = self.generation_num_beams
333
+
334
+
335
+ def to_fp32(t):
336
+ return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
337
+
338
+
339
+ def to_bf16(t):
340
+ return jax.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t)
341
+
342
+
343
+ class MixedPrecisionTrainState(struct.PyTreeNode):
344
+ """Train state for use with a single Optax optimizer.
345
+ Adapted from flax train_state https://github.com/google/flax/blob/main/flax/training/train_state.py
346
+
347
+ Synopsis::
348
+
349
+ state = TrainState.create(
350
+ apply_fn=model.apply,
351
+ params=variables['params'],
352
+ tx=tx)
353
+ grad_fn = jax.grad(make_loss_fn(state.apply_fn))
354
+ for batch in data:
355
+ grads = grad_fn(state.params, batch)
356
+ state = state.apply_gradients(grads=grads)
357
+
358
+ Args:
359
+ step: Counter starts at 0 and is incremented by every call to
360
+ `.apply_gradients()`.
361
+ apply_fn: Usually set to `model.apply()`. Kept in this dataclass for
362
+ convenience to have a shorter params list for the `train_step()` function
363
+ in your training loop.
364
+ params: The parameters to be updated by `tx` and used by `apply_fn`.
365
+ tx: An Optax gradient transformation.
366
+ opt_state: The state for `tx`.
367
+ dropout_rng: PRNG key for stochastic operations.
368
+ bf16: Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training.
369
+ """
370
+
371
+ step: int
372
+ apply_fn: Callable = struct.field(pytree_node=False)
373
+ params: core.FrozenDict[str, Any]
374
+ tx: optax.GradientTransformation = struct.field(pytree_node=False)
375
+ opt_state: optax.OptState
376
+ dropout_rng: jnp.ndarray
377
+ max_grad_norm: Optional[float] = 1.0
378
+
379
+ def apply_gradients(self, *, grads, to_dtype, **kwargs):
380
+ """Updates `step`, `params`, `opt_state` and `**kwargs` in return value.
381
+
382
+ Note that internally this function calls `.tx.update()` followed by a call
383
+ to `optax.apply_updates()` to update `params` and `opt_state`.
384
+
385
+ Args:
386
+ grads: Gradients that have the same pytree structure as `.params`.
387
+ **kwargs: Additional dataclass attributes that should be `.replace()`-ed.
388
+
389
+ Returns:
390
+ An updated instance of `self` with `step` incremented by one, `params`
391
+ and `opt_state` updated by applying `grads`, and additional attributes
392
+ replaced as specified by `kwargs`.
393
+ """
394
+
395
+ # clip gradients by global l2 norm
396
+ casted_max_grad_norm = to_dtype(self.max_grad_norm)
397
+ g_norm = linear_algebra.global_norm(grads)
398
+ g_norm = jnp.maximum(casted_max_grad_norm, g_norm)
399
+ grads = jax.tree_map(lambda t: (t / g_norm) * casted_max_grad_norm, grads)
400
+
401
+ # perform update step in fp32 and subsequently downcast optimizer states if mixed precision training
402
+ # grads and opt_state in bf16 (need to upcast), params in fp32 (leave as is)
403
+ updates, new_opt_state = self.tx.update(to_fp32(grads), to_fp32(self.opt_state), self.params)
404
+
405
+ new_params = optax.apply_updates(self.params, updates)
406
+ return self.replace(
407
+ step=self.step + 1,
408
+ params=new_params,
409
+ opt_state=to_dtype(new_opt_state),
410
+ **kwargs,
411
+ )
412
+
413
+ @classmethod
414
+ def create(cls, *, apply_fn, params, tx, to_dtype, **kwargs):
415
+ """Creates a new instance with `step=0` and initialized `opt_state`."""
416
+ # downcast optimizer state to bf16 if mixed-precision training
417
+ opt_state = tx.init(to_dtype(params)) if tx is not None else None
418
+ return cls(
419
+ step=0,
420
+ apply_fn=apply_fn,
421
+ params=params,
422
+ tx=tx,
423
+ opt_state=opt_state,
424
+ **kwargs,
425
+ )
426
+
427
+ def replicate(self):
428
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
429
+
430
+
431
+ def pad_to_max_length(data, tokenizer):
432
+ # Get lengths of each row of data
433
+ lens = np.array([len(i) for i in data])
434
+
435
+ # Mask of valid places in each row
436
+ mask = np.arange(lens.max()) < lens[:, None]
437
+
438
+ # Setup output array and put elements from data into masked positions
439
+ out = np.ones_like(mask, dtype=data.dtype) * tokenizer.pad_token_id
440
+ out[mask] = np.concatenate(data)
441
+ return out
442
+
443
+
444
+ def shift_tokens_right(label_ids: np.array, decoder_start_token_id: int) -> np.ndarray:
445
+ """
446
+ Shift label ids one token to the right.
447
+ """
448
+ shifted_label_ids = np.zeros_like(label_ids)
449
+ shifted_label_ids[:, 1:] = label_ids[:, :-1]
450
+ shifted_label_ids[:, 0] = decoder_start_token_id
451
+
452
+ return shifted_label_ids
453
+
454
+
455
+ @flax.struct.dataclass
456
+ class FlaxDataCollatorSpeechSeq2SeqWithPadding:
457
+ """
458
+ Data collator that will dynamically pad the inputs received.
459
+ Args:
460
+ processor ([`Wav2Vec2Processor`])
461
+ The processor used for proccessing the data.
462
+ decoder_start_token_id (:obj: `int`)
463
+ The begin-of-sentence of the decoder.
464
+ input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
465
+ Select a strategy to pad the returned input sequences (according to the model's padding side and padding index)
466
+ among:
467
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
468
+ sequence if provided).
469
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
470
+ maximum acceptable input length for the model if that argument is not provided.
471
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
472
+ different lengths).
473
+ target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
474
+ Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
475
+ See above for details.
476
+ max_input_length (:obj:`float`, `optional`):
477
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
478
+ max_target_length (:obj:`int`, `optional`):
479
+ Maximum length of the ``labels`` of the returned list and optionally padding length (see above).
480
+ pad_input_to_multiple_of (:obj:`int`, `optional`):
481
+ If set will pad the input sequence to a multiple of the provided value.
482
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
483
+ 7.5 (Volta).
484
+ pad_target_to_multiple_of (:obj:`int`, `optional`):
485
+ If set will pad the target sequence to a multiple of the provided value.
486
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
487
+ 7.5 (Volta).
488
+ """
489
+
490
+ processor: Any
491
+ decoder_start_token_id: int
492
+ input_padding: Union[bool, str] = "longest"
493
+ target_padding: Union[bool, str] = "max_length"
494
+ max_input_length: Optional[float] = None
495
+ max_target_length: Optional[int] = None
496
+ pad_input_to_multiple_of: Optional[int] = None
497
+ pad_target_to_multiple_of: Optional[int] = None
498
+
499
+ def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
500
+ # split inputs and labels since they have to be of different lengths and need
501
+ # different padding methods
502
+ input_features = [{"input_values": feature["input_values"]} for feature in features]
503
+ input_ids = [feature["input_id"] for feature in features]
504
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
505
+
506
+ # reformat list to dict and set to pytorch format
507
+ batch = self.processor.feature_extractor.pad(
508
+ input_features,
509
+ max_length=self.max_input_length,
510
+ padding=self.input_padding,
511
+ pad_to_multiple_of=self.pad_input_to_multiple_of,
512
+ return_tensors="np",
513
+ )
514
+
515
+ labels_batch = self.processor.tokenizer.pad(
516
+ label_features,
517
+ max_length=self.max_target_length,
518
+ padding=self.target_padding,
519
+ pad_to_multiple_of=self.pad_target_to_multiple_of,
520
+ return_tensors="np",
521
+ )
522
+
523
+ # if bos token is appended in previous tokenization step,
524
+ # cut bos token here as it's append later anyways
525
+ labels = labels_batch["input_ids"]
526
+ if (labels[:, 0] == self.decoder_start_token_id).all().item():
527
+ labels = labels[:, 1:]
528
+ labels_batch.attention_mask = labels_batch.attention_mask[:, 1:]
529
+
530
+ decoder_input_ids = shift_tokens_right(labels, self.decoder_start_token_id)
531
+
532
+ # replace padding with -100 to ignore correctly when computing the loss
533
+ labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1))
534
+ labels = labels.filled(fill_value=-100)
535
+
536
+ batch["inputs"] = batch.pop("input_values")
537
+ batch["input_ids"] = input_ids
538
+ batch["labels"] = labels
539
+ batch["decoder_input_ids"] = decoder_input_ids
540
+
541
+ return batch
542
+
543
+
544
+ def get_grouped_indices(
545
+ dataset, batch_size: int, rng: Optional[List[int]] = None, mega_batch_mult: Optional[int] = None
546
+ ) -> np.array:
547
+ """
548
+ Adapted from the `get_length_grouped_indices` function in the PyTorch Trainer utils file (https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_pt_utils.py#L486)
549
+ Function that returns a list of indices in which each slice of `batch_size` consecutive indices correspond to elements of similar
550
+ lengths. To do this, the indices are:
551
+
552
+ - randomly permuted (if a JAX rng is specified)
553
+ - grouped in mega-batches of size `mega_batch_mult * batch_size`
554
+ - sorted by length in each mega-batch
555
+
556
+ The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
557
+ maximum length placed first, so that an OOM happens sooner rather than later.
558
+ """
559
+ lengths = dataset["input_length"]
560
+
561
+ # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
562
+ if mega_batch_mult is None:
563
+ mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
564
+ # Just in case, for tiny datasets
565
+ if mega_batch_mult == 0:
566
+ mega_batch_mult = 1
567
+
568
+ # We need to use JAX for the random permutation as the PRNG key will be set based on the seed outside of the sampler.
569
+ num_samples = len(lengths)
570
+ indices = jax.random.permutation(rng, np.arange(num_samples)) if rng is not None else np.arange(num_samples)
571
+ indices = np.asarray(indices)
572
+
573
+ megabatch_size = mega_batch_mult * batch_size
574
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
575
+ megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches]
576
+
577
+ # The rest is to get the biggest batch first.
578
+ # Since each megabatch is sorted by descending length, the longest element is the first
579
+ megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
580
+ max_idx = np.argmax(megabatch_maximums).item()
581
+ # Switch to put the longest batch in first position
582
+ # (note that this is different to the PT grouped sampler in which we only put the longest element in the first position, and not its batch)
583
+ megabatches[0], megabatches[max_idx] = megabatches[max_idx], megabatches[0]
584
+
585
+ megabatches = np.array([i for megabatch in megabatches for i in megabatch])
586
+
587
+ return megabatches
588
+
589
+
590
+ def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last_batch=True) -> np.ndarray:
591
+ """Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
592
+ the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned."""
593
+ num_samples = len(samples_idx)
594
+ if drop_last_batch:
595
+ samples_to_remove = num_samples % batch_size
596
+ if samples_to_remove != 0:
597
+ samples_idx = samples_idx[:-samples_to_remove]
598
+ sections_split = num_samples // batch_size
599
+ samples_idx = samples_idx.reshape((sections_split, batch_size))
600
+ else:
601
+ sections_split = math.ceil(num_samples / batch_size)
602
+ samples_idx = np.array_split(samples_idx, sections_split)
603
+ return samples_idx
604
+
605
+
606
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
607
+ summary_writer.scalar("train_time", train_time, step)
608
+
609
+ train_metrics = get_metrics(train_metrics)
610
+ for key, vals in train_metrics.items():
611
+ tag = f"train_{key}"
612
+ for i, val in enumerate(vals):
613
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
614
+
615
+
616
+ def write_eval_metric(summary_writer, eval_metrics, step, pred_str=None):
617
+ for metric_name, value in eval_metrics.items():
618
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
619
+
620
+ if pred_str is not None:
621
+ # write output actual predictions for debugging
622
+ summary_writer.text("eval_predictions", "\n".join(pred_str), step)
623
+
624
+
625
+ def write_wandb_log(metrics, step, prefix=None):
626
+ if jax.process_index() == 0:
627
+ log_metrics = {}
628
+ for k, v in metrics.items():
629
+ if "layer" in k:
630
+ log_metrics[f"{k}/"] = v
631
+ elif prefix is not None:
632
+ log_metrics[f"{prefix}/{k}"] = v
633
+ else:
634
+ log_metrics[k] = v
635
+ wandb.log(log_metrics, step)
636
+
637
+
638
+ def write_wandb_pred(pred_str, label_str, eval_ids, step, prefix="eval", top_ids=None, final_step=True):
639
+ if jax.process_index() == 0:
640
+ top_ids = top_ids if top_ids else eval_ids
641
+ num_beams = len(pred_str)
642
+ # convert str data to a wandb compatible format
643
+ str_data = []
644
+ for id in top_ids:
645
+ if id in eval_ids:
646
+ idx = eval_ids.index(id)
647
+ str_data.append([eval_ids[idx], label_str[idx]] + [pred_str[beam][idx] for beam in range(num_beams)])
648
+ columns = ["id", "label_str"] + [f"beam_{i + 1}" for i in range(num_beams)]
649
+ wandb.log(
650
+ {f"{prefix}/step_{int(step / 1000)}k": wandb.Table(columns=columns, data=str_data[:50])},
651
+ step,
652
+ )
653
+ if final_step:
654
+ str_data = np.array(str_data)
655
+ wandb.log(
656
+ {f"{prefix}/step_{int(step / 1000)}k_all": wandb.Table(columns=columns, data=str_data[:200000])},
657
+ step,
658
+ )
659
+ str_data = str_data[str_data[:, 1] != str_data[:, 2]]
660
+ wandb.log(
661
+ {f"{prefix}/step_{int(step / 1000)}k_incorrect": wandb.Table(columns=columns, data=str_data[:200000])},
662
+ step,
663
+ )
664
+
665
+
666
+ def create_learning_rate_fn(
667
+ num_train_steps: int, num_warmup_steps: int, learning_rate: float
668
+ ) -> Callable[[int], jnp.array]:
669
+ """Returns a linear warmup, linear_decay learning rate function."""
670
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
671
+ decay_fn = optax.linear_schedule(
672
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
673
+ )
674
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
675
+ return schedule_fn
676
+
677
+
678
+ def main():
679
+ # 1. Parse input arguments
680
+ # See all possible arguments in src/transformers/training_args.py
681
+ # or by passing the --help flag to this script.
682
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
683
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FlaxSeq2SeqTrainingArguments))
684
+
685
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
686
+ # If we pass only one argument to the script and it's the path to a json file,
687
+ # let's parse it to get our arguments.
688
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
689
+ else:
690
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
691
+
692
+ # 2. Setup logging
693
+ # Make one log on every process with the configuration for debugging.
694
+ logging.basicConfig(
695
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
696
+ datefmt="%m/%d/%Y %H:%M:%S",
697
+ handlers=[logging.StreamHandler(sys.stdout)],
698
+ )
699
+ # Set the verbosity to info of the Transformers logger.
700
+ # We only want one process per machine to log things on the screen.
701
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
702
+ if jax.process_index() == 0:
703
+ datasets.utils.logging.set_verbosity_warning()
704
+ transformers.utils.logging.set_verbosity_info()
705
+ else:
706
+ datasets.utils.logging.set_verbosity_error()
707
+ transformers.utils.logging.set_verbosity_error()
708
+
709
+ # Set up wandb run
710
+ if jax.process_index() == 0:
711
+ wandb.init(project=data_args.wandb_project, name=data_args.wandb_name, job_type=data_args.wandb_job_type)
712
+
713
+ logger.info("Training/evaluation parameters %s", training_args)
714
+
715
+ # Set the default TPU matmul precision and display the number of devices
716
+ jax.config.update("jax_default_matmul_precision", training_args.matmul_precision)
717
+ logger.info(f"JAX devices: {jax.device_count()}, matmul precision: {training_args.matmul_precision}")
718
+
719
+ # TODO: 3. Detecting last checkpoint and eventually continue from last checkpoint
720
+ last_checkpoint = None
721
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
722
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
723
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
724
+ raise ValueError(
725
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
726
+ "Use --overwrite_output_dir to overcome."
727
+ )
728
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
729
+ logger.info(
730
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
731
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
732
+ )
733
+
734
+ # 4. Load dataset
735
+ raw_datasets = DatasetDict()
736
+
737
+ if training_args.do_train:
738
+ raw_datasets["train"] = load_dataset(
739
+ data_args.dataset_name,
740
+ data_args.dataset_config_name,
741
+ split=data_args.train_split_name,
742
+ cache_dir=data_args.dataset_cache_dir,
743
+ use_auth_token=True if model_args.use_auth_token else None,
744
+ )
745
+
746
+ if training_args.do_eval:
747
+ raw_datasets["eval"] = load_dataset(
748
+ data_args.dataset_name,
749
+ data_args.dataset_config_name,
750
+ split=data_args.eval_split_name,
751
+ cache_dir=data_args.dataset_cache_dir,
752
+ use_auth_token=True if model_args.use_auth_token else None,
753
+ )
754
+
755
+ if training_args.do_predict:
756
+ test_split = data_args.test_split_name.split("+")
757
+ for split in test_split:
758
+ raw_datasets[split] = load_dataset(
759
+ data_args.dataset_name,
760
+ data_args.dataset_config_name,
761
+ split=split,
762
+ cache_dir=data_args.dataset_cache_dir,
763
+ use_auth_token=True if model_args.use_auth_token else None,
764
+ )
765
+
766
+ if not training_args.do_train and not training_args.do_eval and not training_args.do_predict:
767
+ raise ValueError(
768
+ "Cannot not train, not do evaluation and not do prediction. At least one of "
769
+ "training, evaluation or prediction has to be done."
770
+ )
771
+
772
+ # if not training, there is no need to run multiple epochs
773
+ if not training_args.do_train:
774
+ training_args.num_train_epochs = 1
775
+
776
+ if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
777
+ raise ValueError(
778
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
779
+ "Make sure to set `--audio_column_name` to the correct audio column - one of "
780
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
781
+ )
782
+
783
+ if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names:
784
+ raise ValueError(
785
+ f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
786
+ "Make sure to set `--text_column_name` to the correct text column - one of "
787
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
788
+ )
789
+
790
+ if data_args.log_first_ids and data_args.id_column_name not in next(iter(raw_datasets.values())).column_names:
791
+ raise ValueError(
792
+ f"--id_column_name {data_args.id_column_name} not found in dataset '{data_args.dataset_name}'. "
793
+ "Make sure to set `--id_column_name` to the correct id column - one of "
794
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
795
+ )
796
+
797
+ # 5. Load pretrained model, tokenizer, and feature extractor
798
+ #
799
+ # Distributed training:
800
+ # The .from_pretrained methods guarantee that only one local process can concurrently
801
+ config = AutoConfig.from_pretrained(
802
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
803
+ cache_dir=model_args.cache_dir,
804
+ revision=model_args.model_revision,
805
+ use_auth_token=True if model_args.use_auth_token else None,
806
+ )
807
+
808
+ # update config according to training and model args
809
+ config.encoder.update(
810
+ {
811
+ "gradient_checkpointing": training_args.gradient_checkpointing,
812
+ "hidden_dropout": model_args.hidden_dropout,
813
+ "activation_dropout": model_args.activation_dropout,
814
+ "feat_proj_dropout": model_args.feat_proj_dropout,
815
+ "mask_time_prob": model_args.mask_time_prob,
816
+ "add_adapter": model_args.encoder_add_adapter,
817
+ }
818
+ )
819
+ config.decoder.update(
820
+ {
821
+ "gradient_checkpointing": training_args.gradient_checkpointing,
822
+ "dropout": model_args.hidden_dropout,
823
+ "activation_dropout": model_args.activation_dropout,
824
+ }
825
+ )
826
+
827
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
828
+ model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
829
+ cache_dir=model_args.cache_dir,
830
+ revision=model_args.model_revision,
831
+ use_auth_token=True if model_args.use_auth_token else None,
832
+ )
833
+ tokenizer = AutoTokenizer.from_pretrained(
834
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
835
+ cache_dir=model_args.cache_dir,
836
+ use_fast=model_args.use_fast_tokenizer,
837
+ revision=model_args.model_revision,
838
+ use_auth_token=True if model_args.use_auth_token else None,
839
+ )
840
+
841
+ if training_args.precision == "full_mixed":
842
+ dtype = jnp.bfloat16
843
+ training_args.mixed_precision = True
844
+ elif training_args.precision == "half_mixed":
845
+ dtype = jnp.bfloat16
846
+ training_args.mixed_precision = False
847
+ else:
848
+ dtype = jnp.float32
849
+ training_args.mixed_precision = False
850
+
851
+ model = FlaxSpeechEncoderDecoderModel.from_pretrained(
852
+ model_args.model_name_or_path,
853
+ config=config,
854
+ dtype=dtype,
855
+ cache_dir=model_args.cache_dir,
856
+ revision=model_args.model_revision,
857
+ use_auth_token=True if model_args.use_auth_token else None,
858
+ )
859
+
860
+ if model.config.decoder_start_token_id is None:
861
+ raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
862
+
863
+ # 6. Resample speech dataset ALWAYS
864
+ raw_datasets = raw_datasets.cast_column(
865
+ data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
866
+ )
867
+
868
+ # 7. Preprocessing the datasets.
869
+ # We need to read the audio files as arrays and tokenize the targets.
870
+ max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
871
+ min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate)
872
+ max_target_length = data_args.max_target_length
873
+ min_target_length = data_args.min_target_length
874
+ pad_input_to_multiple_of = data_args.pad_input_to_multiple_of
875
+ pad_target_to_multiple_of = data_args.pad_target_to_multiple_of
876
+ audio_column_name = data_args.audio_column_name
877
+ num_workers = data_args.preprocessing_num_workers
878
+ text_column_name = data_args.text_column_name
879
+ id_column_name = data_args.id_column_name
880
+ model_input_name = feature_extractor.model_input_names[0]
881
+ do_lower_case = data_args.do_lower_case
882
+ log_first_ids = data_args.log_first_ids
883
+ dataset_name = data_args.dataset_name
884
+ tedlium_contractions = [" 's", " 't", " 're", " 've", " 'm", " 'll", " 'd", " 'clock", " 'all"]
885
+ gigaspeech_punctuation = {" <comma>": ",", " <period>": ".", " <questionmark>": "?", " <exclamationpoint>": "!"}
886
+ gigaspeech_disfluencies = ["<other>", "<sil>"]
887
+ swb_disfluencies = ["[noise]", "[laughter]", "[silence]", "<a_aside>", "<b_aside>", "<e_aside>", "[laughter-",
888
+ "[vocalized-noise]", "_1"]
889
+ swb_punctuations = ["{", "}", "[", "]-", "]"]
890
+ earnings_disfluencies = ["<crosstalk>", "<affirmative>", "<inaudible>", "inaudible", "<laugh>"]
891
+ ignore_segments = ["ignore_time_segment_in_scoring", "<noise>", "<music>", "[noise]", "[laughter]", "[silence]",
892
+ "[vocalized-noise]", "<crosstalk>", "<affirmative>", "<inaudible>", "<laugh>", "<other>", "<sil>", ""]
893
+
894
+ if training_args.do_train and data_args.max_train_samples is not None:
895
+ raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
896
+
897
+ if training_args.do_eval and data_args.max_eval_samples is not None:
898
+ raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
899
+
900
+ if training_args.do_predict and data_args.max_test_samples is not None:
901
+ for split in test_split:
902
+ raw_datasets[split] = raw_datasets[split].select(range(data_args.max_eval_samples))
903
+
904
+ # filter data where the targets are ignored in scoring
905
+ def is_target_labels(input_str):
906
+ return input_str.lower() not in ignore_segments
907
+
908
+ raw_datasets = raw_datasets.filter(
909
+ is_target_labels,
910
+ num_proc=num_workers,
911
+ input_columns=[text_column_name],
912
+ desc="filtering data where the targets are ignored in scoring",
913
+ )
914
+
915
+ def prepare_dataset(batch):
916
+ # Pre-process audio
917
+ try:
918
+ sample = batch[audio_column_name]
919
+ except ValueError:
920
+ # E22: some samples are empty (no audio). Reading the empty audio array will trigger
921
+ # a soundfile ValueError. For now, we'll manually set these arrays to a zero array.
922
+ # They will be filtered in the subsequent filtering stage and so are
923
+ # explicitly ignored during training.
924
+ sample = {"array": np.array([0.]), "sampling_rate": feature_extractor.sampling_rate}
925
+
926
+ # normalise audio (mean, std) to (0, 1)
927
+ inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
928
+ # process audio length
929
+ batch[model_input_name] = inputs.input_values[0]
930
+ batch["input_length"] = len(batch["input_values"])
931
+ batch["input_id"] = batch[id_column_name] if log_first_ids else None
932
+
933
+ # 'Error correction' of targets
934
+ input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
935
+
936
+ # LibriSpeech ASR
937
+ if dataset_name == "librispeech_asr":
938
+ pass # no error correction necessary
939
+
940
+ # VoxPopuli
941
+ if dataset_name == "google/xtreme_s":
942
+ pass # no error correction necessary
943
+
944
+ # Common Voice 9
945
+ if dataset_name == "mozilla-foundation/common_voice_9_0":
946
+ if input_str.startswith('"') and input_str.endswith('"'):
947
+ # we can remove trailing quotation marks as they do not affect the transcription
948
+ input_str = input_str[1:-1]
949
+ # replace double quotation marks with single
950
+ input_str = input_str.replace('""', '"')
951
+
952
+ # TED-LIUM (Release 3)
953
+ if dataset_name == "LIUM/tedlium":
954
+ # delete the <unk> token from the text
955
+ input_str = input_str.replace("<unk>", "")
956
+ # replace spaced apostrophes with un-spaced (it 's -> it's)
957
+ for contraction in tedlium_contractions:
958
+ input_str = input_str.replace(contraction, contraction[1:])
959
+
960
+ # GigaSpeech
961
+ if dataset_name == "speechcolab/gigaspeech":
962
+ for disfluency in gigaspeech_disfluencies:
963
+ input_str = input_str.replace(disfluency, "")
964
+ # convert spelled out punctuation to symbolic form
965
+ for punctuation, replacement in gigaspeech_punctuation.items():
966
+ input_str = input_str.replace(punctuation, replacement)
967
+
968
+ # SWB: hide the path to the private HF dataset
969
+ if "switchboard" in dataset_name:
970
+ for disfluency in swb_disfluencies:
971
+ input_str = input_str.replace(disfluency, "")
972
+ # remove parenthesised text (test data only)
973
+ input_str = re.sub("[\(].*?[\)]", "", input_str)
974
+ for punctuation in swb_punctuations:
975
+ input_str = input_str.replace(punctuation, "")
976
+ # replace anomalous words with their correct transcriptions
977
+ split_str = input_str.split("/")
978
+ if len(split_str) > 1:
979
+ input_str = " ".join(
980
+ [" ".join([" ".join(i.split(" ")[:-1]) for i in split_str])] + [split_str[-1].split(" ")[-1]])
981
+
982
+ # Earnings 22: still figuring out best segmenting method. Thus, dataset name subject to change
983
+ if "earnings22" in dataset_name:
984
+ for disfluency in earnings_disfluencies:
985
+ input_str = input_str.replace(disfluency, "")
986
+
987
+ # SPGISpeech
988
+ if dataset_name == "kensho/spgispeech":
989
+ pass # no error correction necessary
990
+
991
+ # JIWER compliance (for WER/CER calc.)
992
+ # remove multiple spaces
993
+ input_str = re.sub(r"\s\s+", " ", input_str)
994
+ # strip trailing spaces
995
+ input_str = input_str.strip()
996
+
997
+ # Finally, we tokenize the processed text
998
+ batch["labels"] = tokenizer(input_str).input_ids
999
+ batch["labels_length"] = len(batch["labels"])
1000
+ return batch
1001
+
1002
+ vectorized_datasets = raw_datasets.map(
1003
+ prepare_dataset,
1004
+ remove_columns=next(iter(raw_datasets.values())).column_names,
1005
+ num_proc=num_workers,
1006
+ desc="preprocess train dataset",
1007
+ )
1008
+
1009
+ # filter training data with inputs longer than max_input_length
1010
+ def is_audio_in_length_range(length):
1011
+ return length > min_input_length and length < max_input_length
1012
+
1013
+ if training_args.do_train:
1014
+ vectorized_datasets["train"] = vectorized_datasets["train"].filter(
1015
+ is_audio_in_length_range,
1016
+ num_proc=num_workers,
1017
+ input_columns=["input_length"],
1018
+ )
1019
+
1020
+ # filter data with targets shorter than min_target_length or longer than max_target_length
1021
+ def is_labels_in_length_range(length):
1022
+ return length > min_target_length and length < max_target_length
1023
+
1024
+ if training_args.do_train:
1025
+ vectorized_datasets["train"] = vectorized_datasets["train"].filter(
1026
+ is_labels_in_length_range,
1027
+ num_proc=num_workers,
1028
+ input_columns=["labels_length"],
1029
+ )
1030
+
1031
+ # filter data with targets shorter than 2 tokens: <s></s> -> empty sentences
1032
+ def is_labels_greater_than_min(length):
1033
+ return length > 2
1034
+
1035
+ vectorized_datasets = vectorized_datasets.filter(
1036
+ is_labels_greater_than_min,
1037
+ num_proc=num_workers,
1038
+ input_columns=["labels_length"],
1039
+ )
1040
+
1041
+ # for large datasets it is advised to run the preprocessing on a
1042
+ # single machine first with `args.preprocessing_only` since there will mostly likely
1043
+ # be a timeout when running the script in distributed mode.
1044
+ # In a second step `args.preprocessing_only` can then be set to `False` to load the
1045
+ # cached dataset
1046
+ if data_args.preprocessing_only:
1047
+ cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
1048
+ logger.info(f"Data preprocessing finished. Files cached at {cache}.")
1049
+ return
1050
+
1051
+ # 8. Load Metrics
1052
+ wer_metric = load_metric("wer")
1053
+ cer_metric = load_metric("cer")
1054
+
1055
+ def compute_metrics(pred_ids: List[List[int]], label_ids: List[List[int]]):
1056
+ label_ids = (
1057
+ pad_to_max_length(np.array(label_ids, dtype="object"), tokenizer)
1058
+ if pad_target_to_multiple_of
1059
+ else label_ids
1060
+ )
1061
+
1062
+ padded_ids = np.where(np.asarray(label_ids) == -100, tokenizer.pad_token_id, np.asarray(label_ids))
1063
+ # we do not want to group tokens when computing the metrics
1064
+ label_str = tokenizer.batch_decode(padded_ids, skip_special_tokens=True)
1065
+
1066
+ pred_ids = np.array(pred_ids)
1067
+ num_beams = pred_ids.shape[1]
1068
+ # decode on a beam-by-beam basis
1069
+ pred_str = [
1070
+ tokenizer.batch_decode(pred_ids[:, beam, :], skip_special_tokens=True)
1071
+ for beam in reversed(range(num_beams))
1072
+ ]
1073
+ # compute word/character error rate for top beam
1074
+ wer = wer_metric.compute(predictions=pred_str[0], references=label_str)
1075
+ cer = cer_metric.compute(predictions=pred_str[0], references=label_str)
1076
+
1077
+ return {"wer": wer, "cer": cer}, pred_str, label_str
1078
+
1079
+ # 9. Save feature extractor, tokenizer and config
1080
+ feature_extractor.save_pretrained(training_args.output_dir)
1081
+ tokenizer.save_pretrained(training_args.output_dir)
1082
+ config.save_pretrained(training_args.output_dir)
1083
+
1084
+ processor = AutoProcessor.from_pretrained(training_args.output_dir)
1085
+
1086
+ data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding(
1087
+ processor=processor,
1088
+ decoder_start_token_id=model.config.decoder_start_token_id,
1089
+ input_padding="longest",
1090
+ target_padding="longest",
1091
+ max_target_length=max_target_length,
1092
+ pad_input_to_multiple_of=pad_input_to_multiple_of,
1093
+ pad_target_to_multiple_of=pad_target_to_multiple_of if pad_target_to_multiple_of else max_target_length,
1094
+ )
1095
+
1096
+ # Enable tensorboard only on the master node
1097
+ has_tensorboard = is_tensorboard_available()
1098
+ if has_tensorboard and jax.process_index() == 0:
1099
+ try:
1100
+ from flax.metrics.tensorboard import SummaryWriter
1101
+
1102
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
1103
+ except ImportError as ie:
1104
+ has_tensorboard = False
1105
+ logger.warning(
1106
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
1107
+ )
1108
+ else:
1109
+ logger.warning(
1110
+ "Unable to display metrics through TensorBoard because the package is not installed: "
1111
+ "Please run `pip install tensorboard` to enable."
1112
+ )
1113
+
1114
+ # 10. Handle the repository creation
1115
+ if training_args.push_to_hub:
1116
+ with open(os.path.join(training_args.output_dir, ".gitattributes"), "r+") as f:
1117
+ git_lfs_extensions = f.read()
1118
+ if "*.wandb" not in git_lfs_extensions:
1119
+ f.write("*.wandb filter=lfs diff=lfs merge=lfs -text")
1120
+ if training_args.hub_model_id is None:
1121
+ repo_name = get_full_repo_name(
1122
+ Path(training_args.output_dir).absolute().name, token=training_args.hub_token
1123
+ )
1124
+ else:
1125
+ repo_name = training_args.hub_model_id
1126
+ repo = Repository(training_args.output_dir, clone_from=repo_name)
1127
+
1128
+ # 11. Initialize our training
1129
+ rng = jax.random.PRNGKey(training_args.seed)
1130
+ rng, dropout_rng = jax.random.split(rng)
1131
+
1132
+ # Store some constants
1133
+ max_steps = int(training_args.max_steps)
1134
+ gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
1135
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
1136
+ batch_size_per_update = train_batch_size * gradient_accumulation_steps
1137
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
1138
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
1139
+ to_dtype = to_bf16 if training_args.mixed_precision else to_fp32
1140
+
1141
+ if training_args.do_train:
1142
+ num_train_samples = len(vectorized_datasets["train"])
1143
+ steps_per_epoch = num_train_samples // batch_size_per_update
1144
+ if max_steps > 0:
1145
+ num_epochs = -(training_args.max_steps // -steps_per_epoch)
1146
+ total_train_steps = max_steps
1147
+ else:
1148
+ num_epochs = int(training_args.num_train_epochs)
1149
+ total_train_steps = steps_per_epoch * num_epochs
1150
+
1151
+ # Create learning rate schedule
1152
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
1153
+ total_train_steps,
1154
+ training_args.warmup_steps,
1155
+ training_args.learning_rate,
1156
+ )
1157
+
1158
+ # We use Optax's "masking" functionality to not apply weight decay
1159
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
1160
+ # mask boolean with the same structure as the parameters.
1161
+ # The mask is True for parameters that should be decayed.
1162
+ # Note that this mask is specifically adapted for FlaxWav2Vec2 and FlaxBart.
1163
+ # For FlaxT5, one should correct the layer norm parameter naming
1164
+ # accordingly - see `run_t5_mlm_flax.py` e.g.
1165
+ def decay_mask_fn(params):
1166
+ flat_params = traverse_util.flatten_dict(params)
1167
+ layer_norm_params = [
1168
+ (name, "scale")
1169
+ for name in ["layer_norm", "self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
1170
+ ]
1171
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
1172
+ return traverse_util.unflatten_dict(flat_mask)
1173
+
1174
+ if training_args.adafactor:
1175
+ # Create Adafactor optimizer
1176
+ optim = optax.adafactor(
1177
+ learning_rate=linear_decay_lr_schedule_fn,
1178
+ dtype_momentum=jnp.bfloat16 if training_args.mixed_precision else jnp.float32,
1179
+ weight_decay_rate=training_args.weight_decay,
1180
+ weight_decay_mask=decay_mask_fn,
1181
+ )
1182
+ else:
1183
+ # Create AdamW optimizer
1184
+ optim = optax.adamw(
1185
+ learning_rate=linear_decay_lr_schedule_fn,
1186
+ b1=training_args.adam_beta1,
1187
+ b2=training_args.adam_beta2,
1188
+ eps=training_args.adam_epsilon,
1189
+ weight_decay=training_args.weight_decay,
1190
+ mask=decay_mask_fn,
1191
+ )
1192
+ else:
1193
+ num_epochs = 0
1194
+ total_train_steps = 0
1195
+ num_train_samples = 0
1196
+ optim = None
1197
+
1198
+ # Setup train state
1199
+ state = MixedPrecisionTrainState.create(
1200
+ apply_fn=model.__call__,
1201
+ params=model.params,
1202
+ tx=optim,
1203
+ to_dtype=to_dtype,
1204
+ dropout_rng=dropout_rng,
1205
+ max_grad_norm=training_args.max_grad_norm,
1206
+ )
1207
+
1208
+ # Cross entropy loss
1209
+ def loss_fn(logits, labels):
1210
+ vocab_size = logits.shape[-1]
1211
+ # optax onehot always returns a float32 device array, need to downcast if performing mixed precision training
1212
+ onehot_targets = to_dtype(onehot(labels, vocab_size))
1213
+ loss = optax.softmax_cross_entropy(logits, onehot_targets)
1214
+ # ignore padded tokens from loss, i.e. where labels are not set to -100
1215
+ padding = labels >= 0
1216
+ loss = loss * padding
1217
+ loss = loss.sum()
1218
+ num_labels = padding.sum()
1219
+ return loss, num_labels
1220
+
1221
+ # Define gradient update step fn
1222
+ def train_step(state, batch):
1223
+ # only one single rng per grad step, with or without accumulation, as the graph should be identical over one effective training batch
1224
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
1225
+
1226
+ def compute_loss(params, minibatch):
1227
+ labels = minibatch.pop("labels")
1228
+ logits = state.apply_fn(
1229
+ **minibatch,
1230
+ params=params,
1231
+ dropout_rng=dropout_rng,
1232
+ freeze_feature_encoder=model_args.freeze_feature_encoder,
1233
+ train=True,
1234
+ )[0]
1235
+ loss, num_labels = loss_fn(logits, labels)
1236
+ return loss, num_labels
1237
+
1238
+ grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
1239
+
1240
+ if gradient_accumulation_steps == 1:
1241
+ (loss, num_labels), grad = grad_fn(to_dtype(state.params), batch)
1242
+
1243
+ # Custom gradient accumulation
1244
+ else:
1245
+ # add a first dimension over gradient_accumulation_steps for minibatch slices
1246
+ batch = jax.tree_map(
1247
+ lambda x: x.reshape(
1248
+ gradient_accumulation_steps, training_args.per_device_train_batch_size, *x.shape[1::]
1249
+ ),
1250
+ batch,
1251
+ )
1252
+
1253
+ def accum_minibatch_step(accum_grad, minibatch):
1254
+ # compute loss, num labels and grad over minibatch and accumulate
1255
+ (loss, num_labels), grad = grad_fn(to_dtype(state.params), minibatch)
1256
+ return jax.tree_map(jnp.add, accum_grad, grad), (loss, num_labels)
1257
+
1258
+ # create an initial state for accumulating losses, num labels and gradients
1259
+ init_grad = jax.tree_map(jnp.zeros_like, to_dtype(state.params))
1260
+ # loop accum minibatch step over the number of gradient accumulation steps
1261
+ grad, (loss, num_labels) = jax.lax.scan(accum_minibatch_step, init_grad, batch)
1262
+
1263
+ grad = jax.lax.psum(grad, "batch")
1264
+ loss = jax.lax.psum(loss.sum(), "batch")
1265
+ total_samples = jax.lax.psum(num_labels.sum(), "batch")
1266
+ grad = jax.tree_map(lambda g: g / total_samples, grad)
1267
+ loss = jax.tree_map(lambda l: l / total_samples, loss)
1268
+
1269
+ # update state
1270
+ new_state = state.apply_gradients(
1271
+ grads=grad,
1272
+ dropout_rng=new_dropout_rng,
1273
+ to_dtype=to_dtype,
1274
+ )
1275
+
1276
+ # compute gradient norms over all layers, total encoder, total decoder and global for detailed monitoring
1277
+ layer_grad_norm = jax.tree_map(jnp.linalg.norm, grad)
1278
+ logs = {
1279
+ "layer_grad_norm": layer_grad_norm,
1280
+ "encoder_grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm["encoder"])),
1281
+ "decoder_grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm["decoder"])),
1282
+ }
1283
+ logs["grad_norm"] = jnp.linalg.norm([logs["encoder_grad_norm"], logs["decoder_grad_norm"]])
1284
+
1285
+ # compute parameter norms over all layers, total encoder, total decoder and global for detailed monitoring
1286
+ layer_param_norm = jax.tree_map(jnp.linalg.norm, new_state.params)
1287
+ logs["layer_param_norm"] = layer_param_norm
1288
+ logs["encoder_param_norm"] = jnp.linalg.norm(jax.tree_util.tree_leaves(layer_param_norm["encoder"]))
1289
+ logs["decoder_param_norm"] = jnp.linalg.norm(jax.tree_util.tree_leaves(layer_param_norm["decoder"]))
1290
+ logs["param_norm"] = jnp.linalg.norm([logs["encoder_param_norm"], logs["decoder_param_norm"]])
1291
+
1292
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
1293
+ metrics.update(logs)
1294
+
1295
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1296
+ # metrics = to_fp32(metrics)
1297
+
1298
+ return new_state, metrics
1299
+
1300
+ # Define eval fn
1301
+ def eval_step(params, batch):
1302
+ labels = batch.pop("labels")
1303
+ logits = model(**batch, params=params, train=False)[0]
1304
+ loss, num_labels = loss_fn(logits, labels)
1305
+
1306
+ total_samples = jax.lax.psum(num_labels, "batch")
1307
+ loss = jax.lax.psum(loss, "batch")
1308
+ loss = jax.tree_map(lambda l: l / total_samples, loss)
1309
+
1310
+ # summarize metrics
1311
+ metrics = {"loss": loss}
1312
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1313
+ # metrics = to_fp32(metrics)
1314
+ return metrics
1315
+
1316
+ # Define generation function
1317
+ gen_kwargs = {
1318
+ "max_length": training_args.generation_max_length,
1319
+ "num_beams": training_args.generation_num_beams,
1320
+ "length_penalty": training_args.generation_length_penalty,
1321
+ }
1322
+ final_gen_kwargs = {
1323
+ "max_length": training_args.final_generation_max_length,
1324
+ "num_beams": training_args.final_generation_num_beams,
1325
+ "length_penalty": training_args.generation_length_penalty,
1326
+ }
1327
+
1328
+ def generate_step(params, batch):
1329
+ model.params = params
1330
+ output_ids = model.generate(batch["inputs"], **gen_kwargs)
1331
+ return output_ids.sequences
1332
+
1333
+ def final_generate_step(params, batch):
1334
+ model.params = params
1335
+ output_ids = model.generate(batch["inputs"], **final_gen_kwargs)
1336
+ return output_ids.sequences
1337
+
1338
+ # Create parallel version of the train and eval step
1339
+ if training_args.do_train:
1340
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
1341
+
1342
+ if training_args.do_eval or training_args.do_predict:
1343
+ p_eval_step = jax.pmap(eval_step, "batch")
1344
+
1345
+ if training_args.predict_with_generate:
1346
+ p_generate_step = jax.pmap(generate_step, "batch")
1347
+ p_final_generate_step = jax.pmap(final_generate_step, "batch")
1348
+
1349
+ def run_evaluation(step, final_step=False):
1350
+ if training_args.do_eval:
1351
+ # ======================== Evaluating ==============================
1352
+ eval_metrics = []
1353
+ eval_preds = []
1354
+ eval_ids = []
1355
+ eval_labels = []
1356
+
1357
+ # Generate eval set by sequentially sampling indices from the eval dataset and grouping by length
1358
+ eval_samples_idx = get_grouped_indices(vectorized_datasets["eval"], eval_batch_size)
1359
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last_batch=False)
1360
+
1361
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
1362
+ samples = [vectorized_datasets["eval"][int(idx)] for idx in batch_idx]
1363
+ batch = data_collator(samples)
1364
+ eval_ids.extend(batch.pop("input_ids"))
1365
+ labels = batch["labels"]
1366
+
1367
+ metrics = pad_shard_unpad(p_eval_step, static_return=True)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1368
+ eval_metrics.append(metrics)
1369
+
1370
+ # generation
1371
+ if training_args.predict_with_generate:
1372
+ if not final_step:
1373
+ generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1374
+ eval_preds.extend(
1375
+ jax.device_get(
1376
+ generated_ids.reshape(-1, gen_kwargs["num_beams"], gen_kwargs["max_length"])
1377
+ )
1378
+ )
1379
+ else:
1380
+ generated_ids = pad_shard_unpad(p_final_generate_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1381
+ eval_preds.extend(
1382
+ jax.device_get(
1383
+ generated_ids.reshape(
1384
+ -1, final_gen_kwargs["num_beams"], final_gen_kwargs["max_length"]
1385
+ )
1386
+ )
1387
+ )
1388
+ eval_labels.extend(labels)
1389
+
1390
+ # normalize eval metrics
1391
+ eval_metrics = get_metrics(eval_metrics)
1392
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
1393
+ eval_metrics = to_fp32(eval_metrics)
1394
+
1395
+ # compute error rate metric and get predicted string (for debugging)
1396
+ error_rate_desc = ""
1397
+ pred_str = []
1398
+ label_str = []
1399
+ if training_args.predict_with_generate:
1400
+ error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels)
1401
+ eval_metrics.update(error_rate_metric)
1402
+ error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()])
1403
+
1404
+ # Print metrics and update progress bar
1405
+ desc = f"Step... ({step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})"
1406
+ epochs.write(desc)
1407
+ epochs.desc = desc
1408
+
1409
+ # Save metrics
1410
+ write_wandb_log(eval_metrics, step, prefix="eval")
1411
+ write_wandb_pred(
1412
+ pred_str,
1413
+ label_str,
1414
+ eval_ids,
1415
+ step,
1416
+ top_ids=vectorized_datasets["eval"]["input_id"] if data_args.log_first_ids else None,
1417
+ final_step=final_step,
1418
+ )
1419
+ # if has_tensorboard and jax.process_index() == 0:
1420
+ # write_eval_metric(summary_writer, eval_metrics, step, pred_str=pred_str)
1421
+
1422
+ def save_checkpoint(step):
1423
+ # save and push checkpoint to the hub
1424
+ if jax.process_index() == 0:
1425
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
1426
+ model.save_pretrained(training_args.output_dir, params=params)
1427
+ tokenizer.save_pretrained(training_args.output_dir)
1428
+ if training_args.push_to_hub:
1429
+ repo.push_to_hub(commit_message=f"{wandb.run.id}: saving weights and logs of step {int(step / 1000)}k", blocking=False)
1430
+
1431
+ # Replicate the train state on each device
1432
+ state = state.replicate()
1433
+
1434
+ logger.info("***** Running training *****")
1435
+ logger.info(f" Num examples = {num_train_samples}")
1436
+ logger.info(f" Num Epochs = {num_epochs}")
1437
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
1438
+ logger.info(f" Num gradient accumulation steps = {gradient_accumulation_steps}")
1439
+ logger.info(f" Total train batch size (w. parallel & distributed) = {batch_size_per_update}")
1440
+ logger.info(f" Total optimization steps = {total_train_steps}")
1441
+ logger.info(f" Gradient checkpointing: {config.encoder.gradient_checkpointing}")
1442
+ logger.info(f" Use scan: {config.encoder.use_scan}")
1443
+ logger.info(f" Fuse matmuls: {config.encoder.fuse_matmuls}")
1444
+
1445
+ train_time = cur_step = 0
1446
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
1447
+ for epoch in epochs:
1448
+ if training_args.do_train:
1449
+ # ======================== Training ================================
1450
+ train_start = time.time()
1451
+
1452
+ # Create sampling rng
1453
+ rng, input_rng = jax.random.split(rng)
1454
+
1455
+ # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length
1456
+ train_samples_idx = get_grouped_indices(vectorized_datasets["train"], batch_size_per_update, input_rng)
1457
+ train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update, drop_last_batch=True)
1458
+
1459
+ # Gather the indices for creating the batch and do a training step
1460
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1), 1):
1461
+ samples = [vectorized_datasets["train"][int(idx)] for idx in batch_idx]
1462
+ batch = data_collator(samples)
1463
+ batch.pop("input_ids")
1464
+ batch = shard(batch.data)
1465
+ state, train_metric = p_train_step(state, batch)
1466
+
1467
+ cur_step = epoch * (num_train_samples // batch_size_per_update) + step
1468
+
1469
+ if cur_step % training_args.logging_steps == 0:
1470
+ # Save metrics
1471
+ train_metric = unreplicate(train_metric)
1472
+ train_time += time.time() - train_start
1473
+ # need to upcast all device arrays to fp32 for wandb logging (jnp.bfloat16 not supported) -> do this here OR in train_step
1474
+ write_wandb_log(to_fp32(train_metric), cur_step, prefix="train")
1475
+ # we won't log to tensorboard for now (it is fiddly logging param and grad norms on a layer-by-layer basis)
1476
+ # if has_tensorboard and jax.process_index() == 0:
1477
+ # write_train_metric(summary_writer, train_metrics, train_time, cur_step)
1478
+
1479
+ epochs.write(
1480
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']}, Gradient Norm: {train_metric['grad_norm']})"
1481
+ )
1482
+
1483
+ if cur_step % total_train_steps == 0:
1484
+ break
1485
+
1486
+ if training_args.eval_steps and cur_step % training_args.eval_steps == 0:
1487
+ # run beam search at each eval step
1488
+ run_evaluation(cur_step, final_step=False)
1489
+
1490
+ if cur_step % training_args.save_steps == 0:
1491
+ save_checkpoint(cur_step)
1492
+
1493
+ if training_args.eval_steps == 0 and (epoch + 1) != num_epochs:
1494
+ # run evaluation at the end of the epoch if eval steps are not specified
1495
+ run_evaluation(cur_step, final_step=False)
1496
+ save_checkpoint(cur_step)
1497
+
1498
+ if training_args.do_train:
1499
+ save_checkpoint(cur_step)
1500
+
1501
+ cur_step = max_steps if max_steps > 0 else cur_step # set step to max steps so that eval happens in alignment with training
1502
+
1503
+ if training_args.do_eval:
1504
+ run_evaluation(cur_step, final_step=True)
1505
+
1506
+ # TODO: collapse 'do_predict' into the run_evaluation function
1507
+ if training_args.do_predict:
1508
+ # ======================== Prediction ==============================
1509
+ for split in test_split:
1510
+ pred_metrics = []
1511
+ pred_generations = []
1512
+ pred_ids = []
1513
+ pred_labels = []
1514
+
1515
+ # Generate eval set by sequentially sampling indices from the eval dataset and grouping by length
1516
+ pred_samples_idx = get_grouped_indices(vectorized_datasets[split], eval_batch_size)
1517
+ pred_batch_idx = generate_batch_splits(pred_samples_idx, eval_batch_size, drop_last_batch=False)
1518
+
1519
+ for i, batch_idx in enumerate(tqdm(pred_batch_idx, desc=f"Predicting {split}...", position=2)):
1520
+ samples = [vectorized_datasets[split][int(idx)] for idx in batch_idx]
1521
+ batch = data_collator(samples)
1522
+ pred_ids.extend(batch.pop("input_ids"))
1523
+ labels = batch["labels"]
1524
+
1525
+ metrics = pad_shard_unpad(p_eval_step, static_return=True)(state.params, batch.data,
1526
+ min_device_batch=per_device_eval_batch_size)
1527
+ pred_metrics.append(metrics)
1528
+
1529
+ # generation
1530
+ if training_args.predict_with_generate:
1531
+ generated_ids = pad_shard_unpad(p_final_generate_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1532
+ pred_generations.extend(
1533
+ jax.device_get(
1534
+ generated_ids.reshape(-1, final_gen_kwargs["num_beams"], final_gen_kwargs["max_length"])
1535
+ )
1536
+ )
1537
+ pred_labels.extend(labels)
1538
+
1539
+ # normalize eval metrics
1540
+ pred_metrics = get_metrics(pred_metrics)
1541
+ pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
1542
+ pred_metrics = to_fp32(pred_metrics)
1543
+
1544
+ # compute error rate metric and get predicted string (for debugging)
1545
+ error_rate_desc = ""
1546
+ pred_str = []
1547
+ label_str = []
1548
+ if training_args.predict_with_generate:
1549
+ error_rate_metric, pred_str, label_str = compute_metrics(pred_generations, pred_labels)
1550
+ pred_metrics.update(error_rate_metric)
1551
+ error_rate_desc = " ".join([f"{split} {key}: {value} |" for key, value in error_rate_metric.items()])
1552
+
1553
+ # Print metrics and update progress bar
1554
+ desc = f"Step... ({cur_step}/{total_train_steps} | {split} Loss: {pred_metrics['loss']} | {error_rate_desc})"
1555
+ epochs.write(desc)
1556
+ epochs.desc = desc
1557
+
1558
+ # Save metrics
1559
+ write_wandb_log(pred_metrics, cur_step, prefix=split)
1560
+ write_wandb_pred(
1561
+ pred_str,
1562
+ label_str,
1563
+ pred_ids,
1564
+ cur_step,
1565
+ prefix=split,
1566
+ top_ids=vectorized_datasets[split]["input_id"] if data_args.log_first_ids else None,
1567
+ final_step=True,
1568
+ )
1569
+
1570
+
1571
+ if __name__ == "__main__":
1572
+ main()
run_librispeech.sh ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ python run_flax_speech_recognition_seq2seq.py \
3
+ --dataset_name="librispeech_asr" \
4
+ --model_name_or_path="sanchit-gandhi/flax-wav2vec2-2-bart-large-scan" \
5
+ --dataset_config_name="all" \
6
+ --train_split_name="train.clean.100+train.clean.360+train.other.500" \
7
+ --eval_split_name="validation.clean" \
8
+ --test_split_name="validation.other+test.clean+test.other" \
9
+ --text_column_name="text" \
10
+ --id_column_name="id" \
11
+ --output_dir="./" \
12
+ --wandb_project="librispeech_960h" \
13
+ --wandb_name="flax-wav2vec2-2-bart-large-ls-960h-black-box" \
14
+ --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \
15
+ --per_device_train_batch_size="8" \
16
+ --per_device_eval_batch_size="4" \
17
+ --learning_rate="1e-4" \
18
+ --warmup_steps="500" \
19
+ --logging_steps="25" \
20
+ --max_steps="50000" \
21
+ --eval_steps="10000" \
22
+ --save_steps="10000" \
23
+ --generation_max_length="200" \
24
+ --generation_num_beams="5" \
25
+ --generation_length_penalty="1.2" \
26
+ --hidden_dropout="0.2" \
27
+ --activation_dropout="0.2" \
28
+ --feat_proj_dropout="0.2" \
29
+ --overwrite_output_dir \
30
+ --gradient_checkpointing \
31
+ --freeze_feature_encoder \
32
+ --predict_with_generate \
33
+ --do_lower_case \
34
+ --do_eval \
35
+ --do_train \
36
+ --do_predict \
37
+ --push_to_hub \
38
+ --use_auth_token
39
+
special_tokens_map.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "cls_token": "<s>",
4
+ "eos_token": "</s>",
5
+ "mask_token": {
6
+ "content": "<mask>",
7
+ "lstrip": true,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "pad_token": "<pad>",
13
+ "sep_token": "</s>",
14
+ "unk_token": "<unk>"
15
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": "<s>",
4
+ "cls_token": "<s>",
5
+ "eos_token": "</s>",
6
+ "errors": "replace",
7
+ "mask_token": "<mask>",
8
+ "model_max_length": 1024,
9
+ "name_or_path": "sanchit-gandhi/flax-wav2vec2-2-bart-large-scan",
10
+ "pad_token": "<pad>",
11
+ "sep_token": "</s>",
12
+ "special_tokens_map_file": null,
13
+ "tokenizer_class": "BartTokenizer",
14
+ "trim_offsets": true,
15
+ "unk_token": "<unk>"
16
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
wandb/debug-internal.log ADDED
@@ -0,0 +1 @@
 
 
1
+ run-20220828_085247-2hx8pk65/logs/debug-internal.log
wandb/debug.log ADDED
@@ -0,0 +1 @@
 
 
1
+ run-20220828_085247-2hx8pk65/logs/debug.log
wandb/latest-run ADDED
@@ -0,0 +1 @@
 
 
1
+ run-20220828_085247-2hx8pk65
wandb/run-20220828_084407-nbdgecc9/files/config.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_version: 1
2
+
3
+ _wandb:
4
+ desc: null
5
+ value:
6
+ cli_version: 0.12.15
7
+ framework: huggingface
8
+ huggingface_version: 4.21.0.dev0
9
+ is_jupyter_run: false
10
+ is_kaggle_kernel: false
11
+ python_version: 3.8.10
12
+ start_time: 1661676247
13
+ t:
14
+ 1:
15
+ - 1
16
+ - 11
17
+ - 12
18
+ - 45
19
+ - 49
20
+ - 51
21
+ - 55
22
+ 2:
23
+ - 1
24
+ - 11
25
+ - 12
26
+ - 45
27
+ - 49
28
+ - 51
29
+ - 55
30
+ 3:
31
+ - 13
32
+ 4: 3.8.10
33
+ 5: 0.12.15
34
+ 6: 4.21.0.dev0
35
+ 8:
36
+ - 5
wandb/run-20220828_084407-nbdgecc9/files/output.log ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ INFO:__main__:Training/evaluation parameters FlaxSeq2SeqTrainingArguments(
2
+ _n_gpu=-1,
3
+ adafactor=False,
4
+ adam_beta1=0.9,
5
+ adam_beta2=0.999,
6
+ adam_epsilon=1e-08,
7
+ auto_find_batch_size=False,
8
+ bf16=False,
9
+ bf16_full_eval=False,
10
+ data_seed=None,
11
+ dataloader_drop_last=False,
12
+ dataloader_num_workers=0,
13
+ dataloader_pin_memory=True,
14
+ ddp_bucket_cap_mb=None,
15
+ ddp_find_unused_parameters=None,
16
+ debug=,
17
+ deepspeed=None,
18
+ disable_tqdm=None,
19
+ do_eval=True,
20
+ do_predict=True,
21
+ do_train=True,
22
+ eval_accumulation_steps=None,
23
+ eval_delay=0,
24
+ eval_steps=10000,
25
+ evaluation_strategy=no,
26
+ final_generation_max_length=200,
27
+ final_generation_num_beams=5,
28
+ fp16=False,
29
+ fp16_backend=auto,
30
+ fp16_full_eval=False,
31
+ fp16_opt_level=O1,
32
+ fsdp=,
33
+ fsdp_min_num_params=0,
34
+ fsdp_transformer_layer_cls_to_wrap=None,
35
+ full_determinism=False,
36
+ generation_length_penalty=1.2,
37
+ generation_max_length=200,
38
+ generation_num_beams=5,
39
+ gradient_accumulation_steps=1,
40
+ gradient_checkpointing=True,
41
+ greater_is_better=None,
42
+ group_by_length=False,
43
+ half_precision_backend=auto,
44
+ hub_model_id=None,
45
+ hub_private_repo=False,
46
+ hub_strategy=every_save,
47
+ hub_token=<HUB_TOKEN>,
48
+ ignore_data_skip=False,
49
+ include_inputs_for_metrics=False,
50
+ jit_mode_eval=False,
51
+ label_names=None,
52
+ label_smoothing_factor=0.0,
53
+ learning_rate=0.0001,
54
+ length_column_name=length,
55
+ load_best_model_at_end=False,
56
+ local_rank=-1,
57
+ log_level=passive,
58
+ log_level_replica=passive,
59
+ log_on_each_node=True,
60
+ logging_dir=None,
61
+ logging_first_step=False,
62
+ logging_nan_inf_filter=True,
63
+ logging_steps=25,
64
+ logging_strategy=steps,
65
+ lr_scheduler_type=linear,
66
+ matmul_precision=default,
67
+ max_grad_norm=1.0,
68
+ max_steps=50000,
69
+ metric_for_best_model=None,
70
+ mp_parameters=,
71
+ no_cuda=False,
72
+ num_train_epochs=3.0,
73
+ optim=adamw_hf,
74
+ output_dir=./,
75
+ overwrite_output_dir=True,
76
+ past_index=-1,
77
+ per_device_eval_batch_size=4,
78
+ per_device_train_batch_size=8,
79
+ precision=full,
80
+ predict_with_generate=True,
81
+ prediction_loss_only=False,
82
+ push_to_hub=True,
83
+ push_to_hub_model_id=None,
84
+ push_to_hub_organization=None,
85
+ push_to_hub_token=<PUSH_TO_HUB_TOKEN>,
86
+ ray_scope=last,
87
+ remove_unused_columns=True,
88
+ report_to=None,
89
+ resume_from_checkpoint=None,
90
+ run_name=None,
91
+ save_on_each_node=False,
92
+ save_steps=10000,
93
+ save_strategy=steps,
94
+ save_total_limit=None,
95
+ seed=42,
96
+ sharded_ddp=,
97
+ skip_memory_metrics=True,
98
+ sortish_sampler=False,
99
+ tf32=None,
100
+ torchdynamo=None,
101
+ tpu_metrics_debug=False,
102
+ tpu_num_cores=None,
103
+ use_ipex=False,
104
+ use_legacy_prediction_loop=False,
105
+ warmup_ratio=0.0,
106
+ warmup_steps=500,
107
+ weight_decay=0.0,
108
+ xpu_backend=None,
109
+ )
110
+ INFO:__main__:JAX devices: 8, matmul precision: default
wandb/run-20220828_084407-nbdgecc9/files/requirements.txt ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.0.0
2
+ aiohttp==3.8.1
3
+ aiosignal==1.2.0
4
+ anyio==3.5.0
5
+ appdirs==1.4.4
6
+ argon2-cffi-bindings==21.2.0
7
+ argon2-cffi==21.3.0
8
+ asttokens==2.0.5
9
+ async-timeout==4.0.2
10
+ attrs==21.4.0
11
+ audioread==2.1.9
12
+ babel==2.10.1
13
+ backcall==0.2.0
14
+ beautifulsoup4==4.11.1
15
+ bleach==5.0.0
16
+ certifi==2021.10.8
17
+ cffi==1.15.0
18
+ charset-normalizer==2.0.12
19
+ chex==0.1.3
20
+ click==8.1.3
21
+ colorama==0.4.5
22
+ commonmark==0.9.1
23
+ cycler==0.11.0
24
+ datasets==2.4.1.dev0
25
+ debugpy==1.6.0
26
+ decorator==5.1.1
27
+ defusedxml==0.7.1
28
+ dill==0.3.4
29
+ dm-tree==0.1.7
30
+ docker-pycreds==0.4.0
31
+ entrypoints==0.4
32
+ etils==0.6.0
33
+ executing==0.8.3
34
+ fastjsonschema==2.15.3
35
+ filelock==3.6.0
36
+ flatbuffers==2.0
37
+ flax==0.5.3
38
+ fonttools==4.33.3
39
+ frozenlist==1.3.0
40
+ fsspec==2022.3.0
41
+ gitdb==4.0.9
42
+ gitpython==3.1.27
43
+ huggingface-hub==0.5.1
44
+ idna==3.3
45
+ ijson==3.1.4
46
+ importlib-metadata==4.11.3
47
+ importlib-resources==5.7.1
48
+ iniconfig==1.1.1
49
+ ipdb==0.13.9
50
+ ipykernel==6.13.0
51
+ ipython-genutils==0.2.0
52
+ ipython==8.3.0
53
+ jax==0.3.15
54
+ jaxlib==0.3.15
55
+ jedi==0.18.1
56
+ jinja2==3.1.2
57
+ jiwer==2.3.0
58
+ joblib==1.1.0
59
+ json5==0.9.6
60
+ jsonschema==4.4.0
61
+ jupyter-client==7.3.0
62
+ jupyter-core==4.10.0
63
+ jupyter-server==1.17.0
64
+ jupyterlab-pygments==0.2.2
65
+ jupyterlab-server==2.13.0
66
+ jupyterlab==3.4.0
67
+ kiwisolver==1.4.2
68
+ librosa==0.9.1
69
+ libtpu-nightly==0.1.dev20220722
70
+ llvmlite==0.38.0
71
+ markupsafe==2.1.1
72
+ matplotlib-inline==0.1.3
73
+ matplotlib==3.5.1
74
+ mistune==0.8.4
75
+ msgpack==1.0.3
76
+ multidict==6.0.2
77
+ multiprocess==0.70.12.2
78
+ nbclassic==0.3.7
79
+ nbclient==0.6.2
80
+ nbconvert==6.5.0
81
+ nbformat==5.4.0
82
+ nest-asyncio==1.5.5
83
+ nltk==3.7
84
+ notebook-shim==0.1.0
85
+ notebook==6.4.11
86
+ numba==0.55.1
87
+ numpy==1.21.0
88
+ opt-einsum==3.3.0
89
+ optax==0.1.2
90
+ packaging==21.3
91
+ pandas==1.4.2
92
+ pandocfilters==1.5.0
93
+ parso==0.8.3
94
+ pathtools==0.1.2
95
+ pexpect==4.8.0
96
+ pickleshare==0.7.5
97
+ pillow==9.1.0
98
+ pip==20.0.2
99
+ pkg-resources==0.0.0
100
+ pluggy==1.0.0
101
+ pooch==1.6.0
102
+ prometheus-client==0.14.1
103
+ promise==2.3
104
+ prompt-toolkit==3.0.29
105
+ protobuf==3.20.1
106
+ psutil==5.9.0
107
+ ptyprocess==0.7.0
108
+ pure-eval==0.2.2
109
+ py==1.11.0
110
+ pyarrow==7.0.0
111
+ pycparser==2.21
112
+ pycryptodome==3.14.1
113
+ pygments==2.12.0
114
+ pyparsing==3.0.8
115
+ pyrsistent==0.18.1
116
+ pytest==7.1.2
117
+ python-dateutil==2.8.2
118
+ python-levenshtein==0.12.2
119
+ pytz==2022.1
120
+ pyyaml==6.0
121
+ pyzmq==22.3.0
122
+ regex==2022.4.24
123
+ requests==2.27.1
124
+ resampy==0.2.2
125
+ responses==0.18.0
126
+ rich==11.1.0
127
+ rouge-score==0.1.2
128
+ sacremoses==0.0.49
129
+ scikit-learn==1.0.2
130
+ scipy==1.8.0
131
+ send2trash==1.8.0
132
+ sentry-sdk==1.5.10
133
+ seqeval==1.2.2
134
+ setproctitle==1.2.3
135
+ setuptools==44.0.0
136
+ shortuuid==1.0.8
137
+ six==1.16.0
138
+ smmap==5.0.0
139
+ sniffio==1.2.0
140
+ soundfile==0.10.3.post1
141
+ soupsieve==2.3.2.post1
142
+ speechcolab==0.0.6a0
143
+ stack-data==0.2.0
144
+ tensorstore==0.1.21
145
+ terminado==0.13.3
146
+ threadpoolctl==3.1.0
147
+ tinycss2==1.1.1
148
+ tokenizers==0.12.1
149
+ toml==0.10.2
150
+ tomli==2.0.1
151
+ toolz==0.11.2
152
+ torch==1.11.0+cpu
153
+ torchaudio==0.11.0+cpu
154
+ tornado==6.1
155
+ tqdm==4.64.0
156
+ traitlets==5.1.1
157
+ transformers==4.21.0.dev0
158
+ typing-extensions==4.2.0
159
+ urllib3==1.26.9
160
+ wandb==0.12.15
161
+ wcwidth==0.2.5
162
+ webencodings==0.5.1
163
+ websocket-client==1.3.2
164
+ wheel==0.37.1
165
+ xxhash==3.0.0
166
+ yarl==1.7.2
167
+ zipp==3.8.0
wandb/run-20220828_084407-nbdgecc9/files/wandb-metadata.json ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.11.0-1028-gcp-x86_64-with-glibc2.29",
3
+ "python": "3.8.10",
4
+ "heartbeatAt": "2022-08-28T08:44:08.435675",
5
+ "startedAt": "2022-08-28T08:44:07.234991",
6
+ "docker": null,
7
+ "cpu_count": 96,
8
+ "cuda": null,
9
+ "args": [
10
+ "--dataset_name=librispeech_asr",
11
+ "--model_name_or_path=sanchit-gandhi/flax-wav2vec2-2-bart-large-scan",
12
+ "--dataset_config_name=all",
13
+ "--train_split_name=train.clean.100+train.clean.360+train.other.500",
14
+ "--eval_split_name=validation.clean",
15
+ "--test_split_name=validation.other+test.clean+test.other",
16
+ "--text_column_name=text",
17
+ "--id_column_name=id",
18
+ "--output_dir=./",
19
+ "--wandb_project=librispeech_960h",
20
+ "--wandb_name=flax-wav2vec2-2-bart-large-ls-960h-black-box",
21
+ "--dataset_cache_dir=/home/sanchitgandhi/cache/huggingface/datasets",
22
+ "--per_device_train_batch_size=8",
23
+ "--per_device_eval_batch_size=4",
24
+ "--learning_rate=1e-4",
25
+ "--warmup_steps=500",
26
+ "--logging_steps=25",
27
+ "--max_steps=50000",
28
+ "--eval_steps=10000",
29
+ "--save_steps=10000",
30
+ "--generation_max_length=200",
31
+ "--generation_num_beams=5",
32
+ "--generation_length_penalty=1.2",
33
+ "--hidden_dropout=0.2",
34
+ "--activation_dropout=0.2",
35
+ "--feat_proj_dropout=0.2",
36
+ "--overwrite_output_dir",
37
+ "--gradient_checkpointing",
38
+ "--freeze_feature_encoder",
39
+ "--predict_with_generate",
40
+ "--do_lower_case",
41
+ "--do_eval",
42
+ "--do_train",
43
+ "--do_predict",
44
+ "--push_to_hub",
45
+ "--use_auth_token"
46
+ ],
47
+ "state": "running",
48
+ "program": "run_flax_speech_recognition_seq2seq.py",
49
+ "codePath": "run_flax_speech_recognition_seq2seq.py",
50
+ "git": {
51
+ "remote": "https://huggingface.co/sanchit-gandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box",
52
+ "commit": "140399a622e2a82685fa4b9727f3d970b8bef9e0"
53
+ },
54
+ "email": "sanchit@huggingface.co",
55
+ "root": "/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box",
56
+ "host": "t1v-n-5966b949-w-0",
57
+ "username": "sanchitgandhi",
58
+ "executable": "/home/sanchitgandhi/hf/bin/python"
59
+ }
wandb/run-20220828_084407-nbdgecc9/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"_wandb": {"runtime": 3}}
wandb/run-20220828_084407-nbdgecc9/logs/debug-internal.log ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2022-08-28 08:44:08,160 INFO MainThread:52894 [internal.py:wandb_internal():90] W&B internal server running at pid: 52894, started at: 2022-08-28 08:44:08.159804
2
+ 2022-08-28 08:44:08,162 INFO WriterThread:52894 [datastore.py:open_for_write():75] open: /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/run-nbdgecc9.wandb
3
+ 2022-08-28 08:44:08,163 DEBUG SenderThread:52894 [sender.py:send():232] send: header
4
+ 2022-08-28 08:44:08,163 DEBUG SenderThread:52894 [sender.py:send():232] send: run
5
+ 2022-08-28 08:44:08,326 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: check_version
6
+ 2022-08-28 08:44:08,390 INFO SenderThread:52894 [dir_watcher.py:__init__():166] watching files in: /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files
7
+ 2022-08-28 08:44:08,390 INFO SenderThread:52894 [sender.py:_start_run_threads():811] run started: nbdgecc9 with start time 1661676247
8
+ 2022-08-28 08:44:08,390 DEBUG SenderThread:52894 [sender.py:send():232] send: summary
9
+ 2022-08-28 08:44:08,391 INFO SenderThread:52894 [sender.py:_save_file():946] saving file wandb-summary.json with policy end
10
+ 2022-08-28 08:44:08,391 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: check_version
11
+ 2022-08-28 08:44:08,434 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: run_start
12
+ 2022-08-28 08:44:08,435 DEBUG HandlerThread:52894 [meta.py:__init__():35] meta init
13
+ 2022-08-28 08:44:08,435 DEBUG HandlerThread:52894 [meta.py:__init__():49] meta init done
14
+ 2022-08-28 08:44:08,435 DEBUG HandlerThread:52894 [meta.py:probe():209] probe
15
+ 2022-08-28 08:44:08,436 DEBUG HandlerThread:52894 [meta.py:_setup_git():199] setup git
16
+ 2022-08-28 08:44:08,456 DEBUG HandlerThread:52894 [meta.py:_setup_git():206] setup git done
17
+ 2022-08-28 08:44:08,456 DEBUG HandlerThread:52894 [meta.py:_save_pip():53] save pip
18
+ 2022-08-28 08:44:08,456 DEBUG HandlerThread:52894 [meta.py:_save_pip():67] save pip done
19
+ 2022-08-28 08:44:08,456 DEBUG HandlerThread:52894 [meta.py:probe():247] probe done
20
+ 2022-08-28 08:44:08,480 DEBUG SenderThread:52894 [sender.py:send():232] send: files
21
+ 2022-08-28 08:44:08,480 INFO SenderThread:52894 [sender.py:_save_file():946] saving file wandb-metadata.json with policy now
22
+ 2022-08-28 08:44:08,485 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: stop_status
23
+ 2022-08-28 08:44:08,485 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: stop_status
24
+ 2022-08-28 08:44:08,623 DEBUG SenderThread:52894 [sender.py:send():232] send: telemetry
25
+ 2022-08-28 08:44:08,935 INFO Thread-11 :52894 [upload_job.py:push():137] Uploaded file /tmp/tmpos_hhp45wandb/3f0zop6c-wandb-metadata.json
26
+ 2022-08-28 08:44:09,392 INFO Thread-7 :52894 [dir_watcher.py:_on_file_created():213] file/dir created: /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files/output.log
27
+ 2022-08-28 08:44:09,392 INFO Thread-7 :52894 [dir_watcher.py:_on_file_created():213] file/dir created: /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files/requirements.txt
28
+ 2022-08-28 08:44:09,392 INFO Thread-7 :52894 [dir_watcher.py:_on_file_created():213] file/dir created: /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files/wandb-summary.json
29
+ 2022-08-28 08:44:09,393 INFO Thread-7 :52894 [dir_watcher.py:_on_file_created():213] file/dir created: /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files/wandb-metadata.json
30
+ 2022-08-28 08:44:09,690 DEBUG SenderThread:52894 [sender.py:send():232] send: telemetry
31
+ 2022-08-28 08:44:11,392 INFO Thread-7 :52894 [dir_watcher.py:_on_file_modified():226] file/dir modified: /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files/output.log
32
+ 2022-08-28 08:44:12,001 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: poll_exit
33
+ 2022-08-28 08:44:12,002 DEBUG SenderThread:52894 [sender.py:send():232] send: exit
34
+ 2022-08-28 08:44:12,002 INFO SenderThread:52894 [sender.py:send_exit():368] handling exit code: 1
35
+ 2022-08-28 08:44:12,002 INFO SenderThread:52894 [sender.py:send_exit():370] handling runtime: 3
36
+ 2022-08-28 08:44:12,002 INFO SenderThread:52894 [sender.py:_save_file():946] saving file wandb-summary.json with policy end
37
+ 2022-08-28 08:44:12,002 INFO SenderThread:52894 [sender.py:send_exit():376] send defer
38
+ 2022-08-28 08:44:12,003 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: poll_exit
39
+ 2022-08-28 08:44:12,003 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: defer
40
+ 2022-08-28 08:44:12,003 INFO HandlerThread:52894 [handler.py:handle_request_defer():164] handle defer: 0
41
+ 2022-08-28 08:44:12,003 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: defer
42
+ 2022-08-28 08:44:12,003 INFO SenderThread:52894 [sender.py:send_request_defer():385] handle sender defer: 0
43
+ 2022-08-28 08:44:12,003 INFO SenderThread:52894 [sender.py:transition_state():389] send defer: 1
44
+ 2022-08-28 08:44:12,004 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: defer
45
+ 2022-08-28 08:44:12,004 INFO HandlerThread:52894 [handler.py:handle_request_defer():164] handle defer: 1
46
+ 2022-08-28 08:44:12,050 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: defer
47
+ 2022-08-28 08:44:12,050 INFO SenderThread:52894 [sender.py:send_request_defer():385] handle sender defer: 1
48
+ 2022-08-28 08:44:12,050 INFO SenderThread:52894 [sender.py:transition_state():389] send defer: 2
49
+ 2022-08-28 08:44:12,050 DEBUG SenderThread:52894 [sender.py:send():232] send: stats
50
+ 2022-08-28 08:44:12,051 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: defer
51
+ 2022-08-28 08:44:12,051 INFO HandlerThread:52894 [handler.py:handle_request_defer():164] handle defer: 2
52
+ 2022-08-28 08:44:12,051 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: defer
53
+ 2022-08-28 08:44:12,051 INFO SenderThread:52894 [sender.py:send_request_defer():385] handle sender defer: 2
54
+ 2022-08-28 08:44:12,051 INFO SenderThread:52894 [sender.py:transition_state():389] send defer: 3
55
+ 2022-08-28 08:44:12,051 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: defer
56
+ 2022-08-28 08:44:12,051 INFO HandlerThread:52894 [handler.py:handle_request_defer():164] handle defer: 3
57
+ 2022-08-28 08:44:12,051 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: defer
58
+ 2022-08-28 08:44:12,051 INFO SenderThread:52894 [sender.py:send_request_defer():385] handle sender defer: 3
59
+ 2022-08-28 08:44:12,051 INFO SenderThread:52894 [sender.py:transition_state():389] send defer: 4
60
+ 2022-08-28 08:44:12,051 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: defer
61
+ 2022-08-28 08:44:12,052 INFO HandlerThread:52894 [handler.py:handle_request_defer():164] handle defer: 4
62
+ 2022-08-28 08:44:12,052 DEBUG SenderThread:52894 [sender.py:send():232] send: summary
63
+ 2022-08-28 08:44:12,052 INFO SenderThread:52894 [sender.py:_save_file():946] saving file wandb-summary.json with policy end
64
+ 2022-08-28 08:44:12,052 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: defer
65
+ 2022-08-28 08:44:12,052 INFO SenderThread:52894 [sender.py:send_request_defer():385] handle sender defer: 4
66
+ 2022-08-28 08:44:12,052 INFO SenderThread:52894 [sender.py:transition_state():389] send defer: 5
67
+ 2022-08-28 08:44:12,052 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: defer
68
+ 2022-08-28 08:44:12,052 INFO HandlerThread:52894 [handler.py:handle_request_defer():164] handle defer: 5
69
+ 2022-08-28 08:44:12,052 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: defer
70
+ 2022-08-28 08:44:12,052 INFO SenderThread:52894 [sender.py:send_request_defer():385] handle sender defer: 5
71
+ 2022-08-28 08:44:12,104 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: poll_exit
72
+ 2022-08-28 08:44:12,199 INFO SenderThread:52894 [sender.py:transition_state():389] send defer: 6
73
+ 2022-08-28 08:44:12,200 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: poll_exit
74
+ 2022-08-28 08:44:12,200 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: defer
75
+ 2022-08-28 08:44:12,200 INFO HandlerThread:52894 [handler.py:handle_request_defer():164] handle defer: 6
76
+ 2022-08-28 08:44:12,200 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: defer
77
+ 2022-08-28 08:44:12,200 INFO SenderThread:52894 [sender.py:send_request_defer():385] handle sender defer: 6
78
+ 2022-08-28 08:44:12,200 INFO SenderThread:52894 [dir_watcher.py:finish():279] shutting down directory watcher
79
+ 2022-08-28 08:44:12,301 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: poll_exit
80
+ 2022-08-28 08:44:12,392 INFO Thread-7 :52894 [dir_watcher.py:_on_file_modified():226] file/dir modified: /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files/output.log
81
+ 2022-08-28 08:44:12,392 INFO SenderThread:52894 [dir_watcher.py:_on_file_modified():226] file/dir modified: /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files/wandb-summary.json
82
+ 2022-08-28 08:44:12,393 INFO SenderThread:52894 [dir_watcher.py:_on_file_modified():226] file/dir modified: /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files/config.yaml
83
+ 2022-08-28 08:44:12,393 INFO SenderThread:52894 [dir_watcher.py:finish():309] scan: /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files
84
+ 2022-08-28 08:44:12,393 INFO SenderThread:52894 [dir_watcher.py:finish():323] scan save: /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files/output.log output.log
85
+ 2022-08-28 08:44:12,393 INFO SenderThread:52894 [dir_watcher.py:finish():323] scan save: /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files/wandb-metadata.json wandb-metadata.json
86
+ 2022-08-28 08:44:12,393 INFO SenderThread:52894 [dir_watcher.py:finish():323] scan save: /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files/config.yaml config.yaml
87
+ 2022-08-28 08:44:12,393 INFO SenderThread:52894 [dir_watcher.py:finish():323] scan save: /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files/requirements.txt requirements.txt
88
+ 2022-08-28 08:44:12,393 INFO SenderThread:52894 [dir_watcher.py:finish():323] scan save: /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files/wandb-summary.json wandb-summary.json
89
+ 2022-08-28 08:44:12,396 INFO SenderThread:52894 [sender.py:transition_state():389] send defer: 7
90
+ 2022-08-28 08:44:12,396 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: poll_exit
91
+ 2022-08-28 08:44:12,404 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: defer
92
+ 2022-08-28 08:44:12,404 INFO HandlerThread:52894 [handler.py:handle_request_defer():164] handle defer: 7
93
+ 2022-08-28 08:44:12,404 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: defer
94
+ 2022-08-28 08:44:12,405 INFO SenderThread:52894 [sender.py:send_request_defer():385] handle sender defer: 7
95
+ 2022-08-28 08:44:12,405 INFO SenderThread:52894 [file_pusher.py:finish():145] shutting down file pusher
96
+ 2022-08-28 08:44:12,501 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: poll_exit
97
+ 2022-08-28 08:44:12,501 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: poll_exit
98
+ 2022-08-28 08:44:12,603 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: poll_exit
99
+ 2022-08-28 08:44:12,603 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: poll_exit
100
+ 2022-08-28 08:44:12,704 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: poll_exit
101
+ 2022-08-28 08:44:12,705 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: poll_exit
102
+ 2022-08-28 08:44:12,806 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: poll_exit
103
+ 2022-08-28 08:44:12,806 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: poll_exit
104
+ 2022-08-28 08:44:12,860 INFO Thread-14 :52894 [upload_job.py:push():137] Uploaded file /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files/requirements.txt
105
+ 2022-08-28 08:44:12,865 INFO Thread-12 :52894 [upload_job.py:push():137] Uploaded file /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files/output.log
106
+ 2022-08-28 08:44:12,866 INFO Thread-15 :52894 [upload_job.py:push():137] Uploaded file /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files/wandb-summary.json
107
+ 2022-08-28 08:44:12,908 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: poll_exit
108
+ 2022-08-28 08:44:12,908 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: poll_exit
109
+ 2022-08-28 08:44:12,949 INFO Thread-13 :52894 [upload_job.py:push():137] Uploaded file /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/files/config.yaml
110
+ 2022-08-28 08:44:13,009 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: poll_exit
111
+ 2022-08-28 08:44:13,009 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: poll_exit
112
+ 2022-08-28 08:44:13,111 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: poll_exit
113
+ 2022-08-28 08:44:13,111 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: poll_exit
114
+ 2022-08-28 08:44:13,149 INFO Thread-6 :52894 [sender.py:transition_state():389] send defer: 8
115
+ 2022-08-28 08:44:13,149 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: defer
116
+ 2022-08-28 08:44:13,149 INFO HandlerThread:52894 [handler.py:handle_request_defer():164] handle defer: 8
117
+ 2022-08-28 08:44:13,150 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: defer
118
+ 2022-08-28 08:44:13,150 INFO SenderThread:52894 [sender.py:send_request_defer():385] handle sender defer: 8
119
+ 2022-08-28 08:44:13,213 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: poll_exit
120
+ 2022-08-28 08:44:13,272 INFO SenderThread:52894 [sender.py:transition_state():389] send defer: 9
121
+ 2022-08-28 08:44:13,272 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: poll_exit
122
+ 2022-08-28 08:44:13,273 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: defer
123
+ 2022-08-28 08:44:13,273 INFO HandlerThread:52894 [handler.py:handle_request_defer():164] handle defer: 9
124
+ 2022-08-28 08:44:13,273 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: defer
125
+ 2022-08-28 08:44:13,273 INFO SenderThread:52894 [sender.py:send_request_defer():385] handle sender defer: 9
126
+ 2022-08-28 08:44:13,273 INFO SenderThread:52894 [sender.py:transition_state():389] send defer: 10
127
+ 2022-08-28 08:44:13,274 DEBUG SenderThread:52894 [sender.py:send():232] send: final
128
+ 2022-08-28 08:44:13,274 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: defer
129
+ 2022-08-28 08:44:13,274 INFO HandlerThread:52894 [handler.py:handle_request_defer():164] handle defer: 10
130
+ 2022-08-28 08:44:13,274 DEBUG SenderThread:52894 [sender.py:send():232] send: footer
131
+ 2022-08-28 08:44:13,274 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: defer
132
+ 2022-08-28 08:44:13,274 INFO SenderThread:52894 [sender.py:send_request_defer():385] handle sender defer: 10
133
+ 2022-08-28 08:44:13,374 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: poll_exit
134
+ 2022-08-28 08:44:13,374 DEBUG SenderThread:52894 [sender.py:send_request():246] send_request: poll_exit
135
+ 2022-08-28 08:44:13,375 INFO SenderThread:52894 [file_pusher.py:join():150] waiting for file pusher
136
+ 2022-08-28 08:44:13,731 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: sampled_history
137
+ 2022-08-28 08:44:13,732 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: get_summary
138
+ 2022-08-28 08:44:13,732 DEBUG HandlerThread:52894 [handler.py:handle_request():141] handle_request: shutdown
139
+ 2022-08-28 08:44:13,732 INFO HandlerThread:52894 [handler.py:finish():806] shutting down handler
140
+ 2022-08-28 08:44:14,274 INFO WriterThread:52894 [datastore.py:close():279] close: /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/run-nbdgecc9.wandb
141
+ 2022-08-28 08:44:14,629 INFO SenderThread:52894 [sender.py:finish():1106] shutting down sender
142
+ 2022-08-28 08:44:14,629 INFO SenderThread:52894 [file_pusher.py:finish():145] shutting down file pusher
143
+ 2022-08-28 08:44:14,630 INFO SenderThread:52894 [file_pusher.py:join():150] waiting for file pusher
144
+ 2022-08-28 08:44:14,632 INFO MainThread:52894 [internal.py:handle_exit():80] Internal process exited
wandb/run-20220828_084407-nbdgecc9/logs/debug.log ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2022-08-28 08:44:07,236 INFO MainThread:51732 [wandb_setup.py:_flush():75] Loading settings from /home/sanchitgandhi/.config/wandb/settings
2
+ 2022-08-28 08:44:07,236 INFO MainThread:51732 [wandb_setup.py:_flush():75] Loading settings from /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/settings
3
+ 2022-08-28 08:44:07,236 INFO MainThread:51732 [wandb_setup.py:_flush():75] Loading settings from environment variables: {}
4
+ 2022-08-28 08:44:07,236 INFO MainThread:51732 [wandb_setup.py:_flush():75] Inferring run settings from compute environment: {'program_relpath': 'run_flax_speech_recognition_seq2seq.py', 'program': 'run_flax_speech_recognition_seq2seq.py'}
5
+ 2022-08-28 08:44:07,236 INFO MainThread:51732 [wandb_init.py:_log_setup():437] Logging user logs to /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/logs/debug.log
6
+ 2022-08-28 08:44:07,236 INFO MainThread:51732 [wandb_init.py:_log_setup():438] Logging internal logs to /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_084407-nbdgecc9/logs/debug-internal.log
7
+ 2022-08-28 08:44:07,237 INFO MainThread:51732 [wandb_init.py:init():471] calling init triggers
8
+ 2022-08-28 08:44:07,237 INFO MainThread:51732 [wandb_init.py:init():474] wandb.init called with sweep_config: {}
9
+ config: {}
10
+ 2022-08-28 08:44:07,237 INFO MainThread:51732 [wandb_init.py:init():524] starting backend
11
+ 2022-08-28 08:44:07,237 INFO MainThread:51732 [backend.py:_multiprocessing_setup():97] multiprocessing start_methods=fork,spawn,forkserver, using: spawn
12
+ 2022-08-28 08:44:07,348 INFO MainThread:51732 [backend.py:ensure_launched():217] starting backend process...
13
+ 2022-08-28 08:44:07,379 INFO MainThread:51732 [backend.py:ensure_launched():222] started backend process with pid: 52894
14
+ 2022-08-28 08:44:07,381 INFO MainThread:51732 [wandb_init.py:init():533] backend started and connected
15
+ 2022-08-28 08:44:07,392 INFO MainThread:51732 [wandb_init.py:init():597] updated telemetry
16
+ 2022-08-28 08:44:07,454 INFO MainThread:51732 [wandb_init.py:init():628] communicating run to backend with 30 second timeout
17
+ 2022-08-28 08:44:08,326 INFO MainThread:51732 [wandb_run.py:_on_init():1923] communicating current version
18
+ 2022-08-28 08:44:08,426 INFO MainThread:51732 [wandb_run.py:_on_init():1927] got version response upgrade_message: "wandb version 0.13.2 is available! To upgrade, please run:\n $ pip install wandb --upgrade"
19
+
20
+ 2022-08-28 08:44:08,426 INFO MainThread:51732 [wandb_init.py:init():659] starting run threads in backend
21
+ 2022-08-28 08:44:08,485 INFO MainThread:51732 [wandb_run.py:_console_start():1897] atexit reg
22
+ 2022-08-28 08:44:08,485 INFO MainThread:51732 [wandb_run.py:_redirect():1770] redirect: SettingsConsole.REDIRECT
23
+ 2022-08-28 08:44:08,485 INFO MainThread:51732 [wandb_run.py:_redirect():1775] Redirecting console.
24
+ 2022-08-28 08:44:08,487 INFO MainThread:51732 [wandb_run.py:_redirect():1831] Redirects installed.
25
+ 2022-08-28 08:44:08,488 INFO MainThread:51732 [wandb_init.py:init():684] run started, returning control to user process
26
+ 2022-08-28 08:44:09,687 INFO MainThread:51732 [wandb_run.py:_atexit_cleanup():1866] got exitcode: 1
27
+ 2022-08-28 08:44:09,689 INFO MainThread:51732 [wandb_run.py:_restore():1838] restore
28
+ 2022-08-28 08:44:12,003 INFO MainThread:51732 [wandb_run.py:_on_finish():1995] got exit ret: file_counts {
29
+ wandb_count: 1
30
+ }
31
+ pusher_stats {
32
+ uploaded_bytes: 2233
33
+ total_bytes: 2233
34
+ }
35
+
36
+ 2022-08-28 08:44:12,200 INFO MainThread:51732 [wandb_run.py:_on_finish():1995] got exit ret: file_counts {
37
+ wandb_count: 1
38
+ }
39
+ pusher_stats {
40
+ uploaded_bytes: 2233
41
+ total_bytes: 2233
42
+ }
43
+
44
+ 2022-08-28 08:44:12,400 INFO MainThread:51732 [wandb_run.py:_on_finish():1995] got exit ret: file_counts {
45
+ wandb_count: 4
46
+ }
47
+ pusher_stats {
48
+ uploaded_bytes: 2233
49
+ total_bytes: 8131
50
+ }
51
+
52
+ 2022-08-28 08:44:12,501 INFO MainThread:51732 [wandb_run.py:_on_finish():1995] got exit ret: file_counts {
53
+ wandb_count: 5
54
+ }
55
+ pusher_stats {
56
+ uploaded_bytes: 2233
57
+ total_bytes: 8157
58
+ }
59
+
60
+ 2022-08-28 08:44:12,603 INFO MainThread:51732 [wandb_run.py:_on_finish():1995] got exit ret: file_counts {
61
+ wandb_count: 5
62
+ }
63
+ pusher_stats {
64
+ uploaded_bytes: 8157
65
+ total_bytes: 8157
66
+ }
67
+
68
+ 2022-08-28 08:44:12,705 INFO MainThread:51732 [wandb_run.py:_on_finish():1995] got exit ret: file_counts {
69
+ wandb_count: 5
70
+ }
71
+ pusher_stats {
72
+ uploaded_bytes: 8157
73
+ total_bytes: 8157
74
+ }
75
+
76
+ 2022-08-28 08:44:12,807 INFO MainThread:51732 [wandb_run.py:_on_finish():1995] got exit ret: file_counts {
77
+ wandb_count: 5
78
+ }
79
+ pusher_stats {
80
+ uploaded_bytes: 8157
81
+ total_bytes: 8157
82
+ }
83
+
84
+ 2022-08-28 08:44:12,908 INFO MainThread:51732 [wandb_run.py:_on_finish():1995] got exit ret: file_counts {
85
+ wandb_count: 5
86
+ }
87
+ pusher_stats {
88
+ uploaded_bytes: 8157
89
+ total_bytes: 8157
90
+ }
91
+
92
+ 2022-08-28 08:44:13,010 INFO MainThread:51732 [wandb_run.py:_on_finish():1995] got exit ret: file_counts {
93
+ wandb_count: 5
94
+ }
95
+ pusher_stats {
96
+ uploaded_bytes: 8157
97
+ total_bytes: 8157
98
+ }
99
+
100
+ 2022-08-28 08:44:13,112 INFO MainThread:51732 [wandb_run.py:_on_finish():1995] got exit ret: file_counts {
101
+ wandb_count: 5
102
+ }
103
+ pusher_stats {
104
+ uploaded_bytes: 8157
105
+ total_bytes: 8157
106
+ }
107
+
108
+ 2022-08-28 08:44:13,273 INFO MainThread:51732 [wandb_run.py:_on_finish():1995] got exit ret: file_counts {
109
+ wandb_count: 5
110
+ }
111
+ pusher_stats {
112
+ uploaded_bytes: 8157
113
+ total_bytes: 8157
114
+ }
115
+
116
+ 2022-08-28 08:44:13,630 INFO MainThread:51732 [wandb_run.py:_on_finish():1995] got exit ret: done: true
117
+ exit_result {
118
+ }
119
+ file_counts {
120
+ wandb_count: 5
121
+ }
122
+ pusher_stats {
123
+ uploaded_bytes: 8157
124
+ total_bytes: 8157
125
+ }
126
+ local_info {
127
+ }
128
+
129
+ 2022-08-28 08:44:14,787 INFO MainThread:51732 [wandb_run.py:_footer_history_summary_info():3102] rendering history
130
+ 2022-08-28 08:44:14,787 INFO MainThread:51732 [wandb_run.py:_footer_history_summary_info():3134] rendering summary
131
+ 2022-08-28 08:44:14,789 INFO MainThread:51732 [wandb_run.py:_footer_sync_info():3057] logging synced files
wandb/run-20220828_084407-nbdgecc9/run-nbdgecc9.wandb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cbe55913f815bc0f117700949409b7d3cb181dfad65b7969044117f11f40af4d
3
+ size 3379
wandb/run-20220828_085247-2hx8pk65/files/config.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_version: 1
2
+
3
+ _wandb:
4
+ desc: null
5
+ value:
6
+ cli_version: 0.12.15
7
+ framework: huggingface
8
+ huggingface_version: 4.21.0.dev0
9
+ is_jupyter_run: false
10
+ is_kaggle_kernel: false
11
+ python_version: 3.8.10
12
+ start_time: 1661676767
13
+ t:
14
+ 1:
15
+ - 1
16
+ - 11
17
+ - 12
18
+ - 45
19
+ - 49
20
+ - 51
21
+ - 55
22
+ 3:
23
+ - 13
24
+ 4: 3.8.10
25
+ 5: 0.12.15
26
+ 6: 4.21.0.dev0
27
+ 8:
28
+ - 5
wandb/run-20220828_085247-2hx8pk65/files/media/table/eval/step_10k_10000_8b44e8a00a036a18ffdf.table.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"columns": ["id", "label_str", "beam_1", "beam_2", "beam_3", "beam_4", "beam_5"], "data": [["2277-149896-0000", "he was in a fevered state of mind owing to the blight his wife's action threatened to cast upon his entire future", "he was in a fevered state of mind owing to the blight his action threatened to cast upon his entire future", "he was in a fevered state of mind owing to the blight his action threatened to cast upon his entire future", "he was in a fevered state of mind owing to the blight his action threatened to cast upon this entire future", "he was in a fevered state of mind owing to the blight his action threatened to cast upon his entire future", "he was in a fevered state of mind owing to the bight his action threatened to cast upon his entire future"], ["2277-149896-0001", "he would have to pay her the money which she would now regularly demand or there would be trouble it did not matter what he did", "he would have to pay her the money which she would now regularly demand or there would be trouble it did not matter what he did", "he would have to pay her the money which he would now regularly demand or there would be trouble it did not matter what he did", " he would have to pay her the money which she would now regularly demand or there would be trouble it did not matter what he did", "he would have to pay her the money which she'd now regularly demand or there would be trouble it did not matter what he did", "he would have to pay her the money which she could now regularly demand or there would be trouble it did not matter what he did"], ["2277-149896-0002", "hurstwood walked the floor mentally arranging the chief points of his situation", "hurstwood walked to the floor mentally arranging the chief points of his situation", "hirschwood walked to the floor mentally arranging the chief points of his situation", "herstwood walked to the floor mentally arranging the chief points of his situation", "hurstwood walked the floor mentally arranging the chief points of his situation", "hilstwood walked to the floor mentally arranging the chief points of his situation"], ["2277-149896-0003", "he also thought of his managerial position", "he also thought of his managerial position", "he also thought of this managerial position", " he also thought of his managerial position", "he also thought his managerial position", "here also thought of his managerial position"], ["2277-149896-0004", "how would the papers talk about it", "how would the papers talk about it", "how could the papers talk about it", "how'd the papers talk about it", "how did the papers talk about it", "how would the papers talk about it yes"], ["2277-149896-0005", "many little wrinkles gathered between his eyes as he contemplated this and his brow moistened", "many little wrinkles gathered between his eyes as he contemplated this and his brow moistened", "many little wrinkles gathered between his eyes as he contemplated this this and his brow moistened", "many little wrinkles gathered between his eyes as he contemplated this his brow moistened", "many little wrinkles gathered between his eyes as he contemplated this in his brow moistened", "many little wrinkles gathered between his eyes as he contemplated this and this brow moistened"], ["2277-149896-0006", "he could arrange that satisfactorily for carrie would be glad to wait if necessary", "he could arrange that satisfactorily for carrie would be glad to wait if necessary", "he could arrange that satisfactorily for carey would be glad to wait if necessary", "he could arrange that satisfactorily for carry would be glad to wait if necessary", "he could arrange this satisfactorily for carrie would be glad to wait if necessary", "he could arrange the satisfactorily for carrie would be glad to wait if necessary"], ["2277-149896-0007", "he would see how things turned out to morrow and then he would talk to her they were going to meet as usual", "he would see how things turned out tomorrow and then he would talk to her they were going to meet as usual", "he would see how things turned out to morrow and then he would talk to her they were going to meet as usual", "he would see how things turned out today and then he would talk to her they were going to meet as usual", "he would see how things turn out tomorrow and then he would talk to her they were going to meet as usual", "he could see how things turned out tomorrow and then he would talk to her they were going to meet as usual"], ["2277-149896-0008", "for some reason he felt as if something might come that way and was relieved when all the envelopes had been scanned and nothing suspicious noticed", "for some reason he felt as if something might come that way and was relieved when all the envelopes had been scanned and nothing suspicious noticed", "for some reason he felt as if something might come that way and was relieved when all the envelops had been scanned and nothing suspicious noticed", "for some reason he felt as if something might come this way and was relieved when all the envelopes had been scanned and nothing suspicious noticed", "from some reason he felt as if something might come that way and was relieved when all the envelopes had been scanned and nothing suspicious noticed", "for some reason he felt as if nothing might come that way and was relieved when all the envelopes had been scanned and nothing suspicious noticed"], ["2277-149896-0009", "while the danger had not lessened it had not as yet materialised and with him no news was good news", "while the danger had not lessened it had not as yet materialized and with him no news was good news", "while the danger had not lessened it had not as yet materialised and with him no news was good news", "while danger had not lessened it had not as yet materialized and with him no news was good news", "whilst the danger had not lessened it had not as yet materialized and with him no news was good news", "while the danger had not lessen it had not as yet materialized and with him no news was good news"], ["2277-149896-0010", "so little did he consider drouet that it never once occurred to him to worry about his finding out", "so little did he consider drouet that it never once occurred to him to worry about his finding out", "so little did he consider drue that it never once occurred to him to worry about his finding out", "so little did he consider drua that it never once occurred to him to worry about his finding out", "so little did he consider drura that it never once occurred to him to worry about his finding out", "so little did he consider druecca that it never once occurred to him to worry about his finding out"], ["2277-149896-0011", "he grew restless as he ruminated and then decided that perhaps it was nothing", "he grew restless as he ruminated and then decided that perhaps it was nothing", "he grew restless as he ruminated and then decided that perhaps it was nothing", "he grew restless as he ruminated and then decide that perhaps it was nothing", " he grew restless as he ruminated and then decided that perhaps it was nothing", "he grew restless as he ruminating and then decided that perhaps it was nothing"], ["2277-149896-0012", "she had not been able to get away this morning", "she had not been able to get away this morning", "he had not been able to get away this morning", " she had not been able to get away this morning", "the she had not been able to get away this morning", "the had not been able to get away this morning"], ["2277-149896-0013", "he would get one to day it would probably be on his desk when he got back he would look for it at once", "he would get one today it would probably be on his desk when he got back he would look for it at once", "he would get one to day it would probably be on his desk when he got back he would look for it at once", "he would get one tomorrow it would probably be on his desk when he got back he would look for it at once", "he could get one to day it would probably be on his desk when he got back he would look for it at once", "he could get one today it would probably be on his desk when he got back he would look for it at once"], ["2277-149896-0014", "after a time he gave up waiting and drearily headed for the madison car", "after a time he gave up waiting and drearily headed for the madison car", "after a time he gave up waiting and drearily headed for the mattinson car", "after a time he gave up waiting and drearily headed for the mattison car", "after a time he gave up waiting and drearily headed for the madeison car", "after a time he gave up waiting and drearily headed for the madezons car"], ["2277-149896-0015", "he went in and examined his letters but there was nothing from carrie", "he went in and examined his letters but there was nothing from carrie", "he went in and examined his letters but there was nothing from carey", "he went in and examined his letters but there was nothing from carry", "he went in and examined his letters but there was nothing from kerry", "he went in and examined his letters but there was nothing from cary"], ["2277-149896-0016", "fortunately there was nothing from his wife either", "fortunately there was nothing from his wife either", "fortunately there was nothing from his wife either", " fortunately there was nothing from his wife either", "fortunately there was nothing from this wife either", "forfortunately there was nothing from his wife either"], ["2277-149896-0017", "at one thirty he went to rector's for lunch and when he returned a messenger was waiting for him", "at one thirty he went to rector's for lunch and when he returned a messenger was waiting for him", "at one thirty he went to rector's for lunch and when he returned a messenger was waiting for him", "at one thirty he went to rectors for lunch and when he returned a messenger was waiting for him", "at one thirty he went to rectory's for lunch and when he returned a messenger was waiting for him", " at one thirty he went to rector's for lunch and when he returned a messenger was waiting for him"], ["2277-149896-0018", "his first impulse was to write but four words in reply go to the devil", "his first impulse was to write but four words in reply go to the devil", " his first impulse was to write but four words in reply go to the devil", "his first impulse was to write but four words in reply go to the devil", "his first impulse was to write but four words and reply go to the devil", "his first impulses was to write but four words in reply go to the devil"], ["2277-149896-0019", "but he compromised by telling the boy that there would be no reply", "but he compromised by telling the boy that there would be no reply", "but hecompromised by telling the boy that there would be no reply", "but he comprised by telling the boy that there would be no reply", "but he compromise by telling the boy that there would be no reply", " but he compromised by telling the boy that there would be no reply"], ["2277-149896-0020", "then he sat down in his chair and gazed without seeing contemplating the result of his work", "then he sat down in his chair and gazed without seeing contemplating the result of his work", "then he sat down in his chair and gazed without seeing contemplating the result of this work", " then he sat down in his chair and gazed without seeing contemplating the result of his work", "than he sat down in his chair and gazed without seeing contemplating the result of his work", "then he sat down in his chair and gazed without seeing contemplating the results of his work"], ["2277-149896-0021", "what would she do about that the confounded wretch", "what would she do about that the confounded wretch", "what would you do about that the confounded wretch", " what would she do about that the confounded wretch", "what could she do about that the confounded wretch", "but what would she do about that the confounded wretch"], ["2277-149896-0022", "later however his old discretion asserted itself", "later however his old discretion asserted itself", " later however his old discretion asserted itself", "later however his old discretion ascertained itself", "later however this old discretion asserted itself", "late however his old discretion asserted itself"], ["2277-149896-0023", "something had to be done a climax was near and she would not sit idle", "something had to be done a climax was near and she would not sit idle", " something had to be done a climax was near and she would not sit idle", "something had to be done the climax was near and she would not sit idle", "anything had to be done a climax was near and she would not sit idle", "nothing had to be done a climax was near and she would not sit idle"], ["2277-149896-0024", "he knew her well enough to know that when she had decided upon a plan she would follow it up", "he knew her well enough to know that when she had decided upon a plan she would follow it up", "he knew her well enough to know that when she had decided upon the plan she would follow it up", "he knew her well enough to know that when she decided upon a plan she would follow it up", "he knew her well enough to know that when she had decided upon a plan she should follow it up", "he knew her well enough to know when she had decided upon a plan she would follow it up"], ["2277-149896-0025", "he arose from his chair and went and looked out into the street", "he arose from his chair and went and looked out into the street", "he arose from his chair and went to looked out into the street", "he arose from his chair and went into the street", "he arose from his chair and went and looked out into this street", "he arose from his chair and went and looked out in the street"], ["2277-149896-0026", "the long drizzle had begun pedestrians had turned up collars and trousers at the bottom", "the long drizzle had begun pedestrians had turned up collars and trousers at the bottom", "the long drizzle had begun petersians had turned up collars and trousers at the bottom", "the long drizzle had begun pedestrians had turned up collars and trousers at the bottom", " the long drizzle had begun pedestrians had turned up collars and trousers at the bottom", "long drizzle had begun pedestrians had turned up collars and trousers at the bottom"], ["2277-149896-0027", "hurstwood almost exclaimed out loud at the insistency of this thing", "hurstwood almost exclaimed out loud at the insistency of this thing", "hirschwood almost exclaimed out loud at the insistency of this thing", "hurstwood almost exclaimed aloud at the insistency of this thing", "hilstwood almost exclaimed out loud at the insistency of this thing", "hurstwood almost exclaimed out loud at the insincerity of this thing"], ["2277-149896-0028", "he put on his hat and looked around for his umbrella", "he put on his hat and looked around for his umbrella", "he put on his hat and looked round for his umbrella", "he put on his hat and looked around for this umbrella", " he put on his hat and looked around for his umbrella", "he put on his hat and looked about for his umbrella"], ["2277-149896-0029", "he would have some arrangement of this thing", "he would have some arrangement of this thing", "he will have some arrangement of this thing", "he'd have some arrangement of this thing", "he could have some arrangement of this thing", " he would have some arrangement of this thing"], ["2277-149896-0030", "he began to wish that he had compromised in some way or other that he had sent the money perhaps he could do it up here", "he began to wish that he had compromised in some way or other that he had sent the money perhaps he could do it up here", "he began to wish that he had compromised in some way or another that he had sent the money perhaps he could do it up here", "he began to wish that he had comprised in some way or other that he had sent the money perhaps he could do it up here", "he began to wish he had compromised in some way or other that he had sent the money perhaps he could do it up here", "he began to wish that he had coordinated in some way or other that he had sent the money perhaps he could do it up here"], ["2277-149896-0031", "he would go in and see anyhow he would have no row", "he would go in and see anyhow he would have no row", "he would go in and see anyhow he would have no rowl", "he would go in and see anyhow he would have no rue", "he could go in and see anyhow he would have no row", "he would go in and see anyhow he would have no raoul"], ["2277-149896-0032", "by the time he reached his own street he was keenly alive to the difficulties of his situation and wished over and over that some solution would offer itself that he could see his way out", "by the time he reached his own street he was keenly alive to the difficulties of his situation and wished over and over that some solution would offer itself that he could see his way out", "by the time he reached his own street he was keenly alive to the difficulties of this situation and wished over and over that some solution would offer itself that he could see his way out", " by the time he reached his own street he was keenly alive to the difficulties of his situation and wished over and over that some solution would offer itself that he could see his way out", "by the time he reached his own street he was keenly alive to the troubles of his situation and wished over and over that some solution would offer itself that he could see his way out", "by the time he reached his own street he was keenly alive to the difficulties of his situation and wished over and over that some solution would offer itself that he could see this way out"], ["2277-149896-0033", "then he rang the bell no answer", "then he rang the bell no answer", "than he rang the bell no answer", " then he rang the bell no answer", "this he rang the bell no answer", "there he rang the bell no answer"], ["2277-149896-0034", "he rang again this time harder still no answer", "he rang again this time harder still no answer", "he wrang again this time harder still no answer", "he ring again this time harder still no answer", "he rang again this time harder still no answer", "he ringed again this time harder still no answer"], ["2277-149897-0000", "when hurstwood got back to his office again he was in a greater quandary than ever", "when hurstwood got back to his office again he was in a greater quandary than ever", "when hurstwood got back to his office again he was in a greater quondary than ever", "when hurstwood got back to his office again he was in a greater quandy than ever", "when hurstwood got back to his office again he was in a greater quadry than ever", "when hurstwood got back to his office again he was in a greater quorum than ever"], ["2277-149897-0001", "he could hardly realise how it had all come about", "he could hardly realize how it had all come about", "he could hardly realise how it had all come about", "he could hardly realize how it has all come about", "he could hardly realize how it had also come about", " he could hardly realize how it had all come about"], ["2277-149897-0002", "no letter had come no word of any kind and yet here it was late in the evening and she had agreed to meet him that morning", "no letter had come no word of any kind and yet here it was late in the evening and she had agreed to meet him that morning", "no letter had come no word of any kind and yet here it was late in the evening and she had agreed to meet him this morning", " no letter had come no word of any kind and yet here it was late in the evening and she had agreed to meet him that morning", "no letter had come no word of any kind and yet here it was late in the evening and she had agreed to meet him at morning", "no letter had come no word of any kind and yet here it was late in the evening and she had agreed to meet him the morning"], ["2277-149897-0003", "he saw that in the excitement of recent events he had not formulated a plan upon that score", "he saw that in the excitement of recent events he had not formulated a plan upon that score", "he saw that in the excitement of recent events he had not formulated a plan upon the score", "he saw that in the excitement of recent events he had not formulated a plan upon this score", "he saw that in the excitement of recent events he had not communicated a plan upon that score", "he saw that in the excited of recent events he had not formulated a plan upon that score"], ["2277-149897-0004", "he was getting some vague comfort out of a good cigar but it was no panacea for the ill which affected him", "he was getting some vague comfort out of a good cigar but it was no panacea for the ill which affected him", "he was getting some vague comfort out of a good cigar but it was no panegas for the ill which affected him", "he was getting some vague comfort out of a good cigar but it was no panatia for the ill which affected him", "he was getting some vague comfort out of a good cigar but it was no pennesia for the ill which affected him", "he was getting some vague comfort out of a good cigar but it was no panacea for the illness which affected him"], ["2277-149897-0005", "it was with great opposition after two or three hours of the most urgent mental affirmation and denial that at last he got an envelope placed in it the requested amount and slowly sealed it up", "it was with great opposition after two or three hours of the most urgent mental affirmation and denial that at last he got an envelope placed in it the requested amount and slowly sealed it up", "it was with great opposition after two or three hours of the most urgent mental affirmation and denial that at last he got an envelope placed in its requested amount and slowly sealed it up", "it was with great opposition after two or three hours of the most urgent mental affirmative and denial that at last he got an envelope placed in it the requested amount and slowly sealed it up", " it was with great opposition after two or three hours of the most urgent mental affirmation and denial that at last he got an envelope placed in it the requested amount and slowly sealed it up", "it was with great opposition after two or three hours of the most urgent mental affirmation and declaration that at last he got an envelope placed in it the requested amount and slowly sealed it up"], ["2277-149897-0006", "then he called harry the boy of all work around the place", "then he called harry the boy of all work around the place", "then he called harry the boy of all work round the place", " then he called harry the boy of all work around the place", "now he called harry the boy of all work around the place", "this he called harry the boy of all work around the place"], ["2277-149897-0007", "you take this to this address he said handing him the envelope and give it to missus hurstwood yes sir said the boy", "you take this to this address he said handing him the envelope and give it to missus hurstwood yes sir said the boy", " you take this to this address he said handing him the envelope and give it to missus hurstwood yes sir said the boy", "you take this to this addressed he said handing him the envelope and give it to missus hurstwood yes sir said the boy", "you take this to his address he said handing him the envelope and give it to missus hurstwood yes sir said the boy", "you take this to this address he said handing him the envelope and give it to missus hurstwood yes sir said the boy"], ["2277-149897-0008", "any answer i guess not", "any answer i guess not", "any answer i guess not i guess not", "any answer i guess not a guess not", "any answer i guessed not", "any answer i guess not the"], ["2277-149897-0009", "the boy hastened away and the manager fell to his musings", "the boy hastened away and the manager fell to his musings", "the boy hasted away and the manager fell to his musings", "the boy hastily away and the manager fell to his musings", " the boy hastened away and the manager fell to his musings", "the boy hastened away and the manager fell into his musings"], ["2277-149897-0010", "he was beaten for to night and he might just as well make the best of it", "he was beaten for to night and he might just as well make the best of it", "he was beaten for tonight and he might just as well make the best of it", "he was beaten for tomorrow and he might just as well make the best of it", "he was beaten for today and he might just as well make the best of it", "he was beaten for to night and he might just as well make the best of it"], ["2277-149897-0011", "she would take the envelope and know that she had triumphed", "she would take the envelope and know that she had triumphed", " she would take the envelope and know that she had triumphed", "he would take the envelope and know that she had triumphed", "the would take the envelope and know that she had triumphed", "we would take the envelope and know that she had triumphed"], ["2277-149897-0012", "if he only had that letter back he wouldn't send it", "if he only had that letter back he wouldn't send it", "if he only had that letter back he won't send it", "if he only had that letter back he couldn't send it", "if he only had that letter back he didn't send it", "if he only had that letter back he wuzn't send it"], ["2277-149897-0013", "for relief he arose and joined in conversation with a few friends who were drinking", "for relief he arose and joined in the conversation with a few friends who were drinking", "for relief he arose and joined in the conversation with the few friends who were drinking", "for relief he arose in the conversation with a few friends who were drinking", "for relief he arose in the conversation with the few friends who were drinking", "for relief he arose and joined in a conversation with a few friends who were drinking"], ["2277-149897-0014", "all the time his thoughts would run out to his home and see the scene being therein enacted", "all the time his thoughts would run out to his home and see the scene being therein enacted", "all this time his thoughts would run out to his home and see the scene being therein enacted", "all the time his thoughts would run out to his home and see the scene being therein enacted", " all the time his thoughts would run out to his home and see the scene being therein enacted", "all the time his thoughts would run out to his home and see the scene being herein enacted"]]}
wandb/run-20220828_085247-2hx8pk65/files/output.log ADDED
The diff for this file is too large to render. See raw diff
 
wandb/run-20220828_085247-2hx8pk65/files/requirements.txt ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.0.0
2
+ aiohttp==3.8.1
3
+ aiosignal==1.2.0
4
+ anyio==3.5.0
5
+ appdirs==1.4.4
6
+ argon2-cffi-bindings==21.2.0
7
+ argon2-cffi==21.3.0
8
+ asttokens==2.0.5
9
+ async-timeout==4.0.2
10
+ attrs==21.4.0
11
+ audioread==2.1.9
12
+ babel==2.10.1
13
+ backcall==0.2.0
14
+ beautifulsoup4==4.11.1
15
+ bleach==5.0.0
16
+ certifi==2021.10.8
17
+ cffi==1.15.0
18
+ charset-normalizer==2.0.12
19
+ chex==0.1.3
20
+ click==8.1.3
21
+ colorama==0.4.5
22
+ commonmark==0.9.1
23
+ cycler==0.11.0
24
+ datasets==2.4.1.dev0
25
+ debugpy==1.6.0
26
+ decorator==5.1.1
27
+ defusedxml==0.7.1
28
+ dill==0.3.4
29
+ dm-tree==0.1.7
30
+ docker-pycreds==0.4.0
31
+ entrypoints==0.4
32
+ etils==0.6.0
33
+ executing==0.8.3
34
+ fastjsonschema==2.15.3
35
+ filelock==3.6.0
36
+ flatbuffers==2.0
37
+ flax==0.5.3
38
+ fonttools==4.33.3
39
+ frozenlist==1.3.0
40
+ fsspec==2022.3.0
41
+ gitdb==4.0.9
42
+ gitpython==3.1.27
43
+ huggingface-hub==0.5.1
44
+ idna==3.3
45
+ ijson==3.1.4
46
+ importlib-metadata==4.11.3
47
+ importlib-resources==5.7.1
48
+ iniconfig==1.1.1
49
+ ipdb==0.13.9
50
+ ipykernel==6.13.0
51
+ ipython-genutils==0.2.0
52
+ ipython==8.3.0
53
+ jax==0.3.15
54
+ jaxlib==0.3.15
55
+ jedi==0.18.1
56
+ jinja2==3.1.2
57
+ jiwer==2.3.0
58
+ joblib==1.1.0
59
+ json5==0.9.6
60
+ jsonschema==4.4.0
61
+ jupyter-client==7.3.0
62
+ jupyter-core==4.10.0
63
+ jupyter-server==1.17.0
64
+ jupyterlab-pygments==0.2.2
65
+ jupyterlab-server==2.13.0
66
+ jupyterlab==3.4.0
67
+ kiwisolver==1.4.2
68
+ librosa==0.9.1
69
+ libtpu-nightly==0.1.dev20220722
70
+ llvmlite==0.38.0
71
+ markupsafe==2.1.1
72
+ matplotlib-inline==0.1.3
73
+ matplotlib==3.5.1
74
+ mistune==0.8.4
75
+ msgpack==1.0.3
76
+ multidict==6.0.2
77
+ multiprocess==0.70.12.2
78
+ nbclassic==0.3.7
79
+ nbclient==0.6.2
80
+ nbconvert==6.5.0
81
+ nbformat==5.4.0
82
+ nest-asyncio==1.5.5
83
+ nltk==3.7
84
+ notebook-shim==0.1.0
85
+ notebook==6.4.11
86
+ numba==0.55.1
87
+ numpy==1.21.0
88
+ opt-einsum==3.3.0
89
+ optax==0.1.2
90
+ packaging==21.3
91
+ pandas==1.4.2
92
+ pandocfilters==1.5.0
93
+ parso==0.8.3
94
+ pathtools==0.1.2
95
+ pexpect==4.8.0
96
+ pickleshare==0.7.5
97
+ pillow==9.1.0
98
+ pip==20.0.2
99
+ pkg-resources==0.0.0
100
+ pluggy==1.0.0
101
+ pooch==1.6.0
102
+ prometheus-client==0.14.1
103
+ promise==2.3
104
+ prompt-toolkit==3.0.29
105
+ protobuf==3.20.1
106
+ psutil==5.9.0
107
+ ptyprocess==0.7.0
108
+ pure-eval==0.2.2
109
+ py==1.11.0
110
+ pyarrow==7.0.0
111
+ pycparser==2.21
112
+ pycryptodome==3.14.1
113
+ pygments==2.12.0
114
+ pyparsing==3.0.8
115
+ pyrsistent==0.18.1
116
+ pytest==7.1.2
117
+ python-dateutil==2.8.2
118
+ python-levenshtein==0.12.2
119
+ pytz==2022.1
120
+ pyyaml==6.0
121
+ pyzmq==22.3.0
122
+ regex==2022.4.24
123
+ requests==2.27.1
124
+ resampy==0.2.2
125
+ responses==0.18.0
126
+ rich==11.1.0
127
+ rouge-score==0.1.2
128
+ sacremoses==0.0.49
129
+ scikit-learn==1.0.2
130
+ scipy==1.8.0
131
+ send2trash==1.8.0
132
+ sentry-sdk==1.5.10
133
+ seqeval==1.2.2
134
+ setproctitle==1.2.3
135
+ setuptools==44.0.0
136
+ shortuuid==1.0.8
137
+ six==1.16.0
138
+ smmap==5.0.0
139
+ sniffio==1.2.0
140
+ soundfile==0.10.3.post1
141
+ soupsieve==2.3.2.post1
142
+ speechcolab==0.0.6a0
143
+ stack-data==0.2.0
144
+ tensorstore==0.1.21
145
+ terminado==0.13.3
146
+ threadpoolctl==3.1.0
147
+ tinycss2==1.1.1
148
+ tokenizers==0.12.1
149
+ toml==0.10.2
150
+ tomli==2.0.1
151
+ toolz==0.11.2
152
+ torch==1.11.0+cpu
153
+ torchaudio==0.11.0+cpu
154
+ tornado==6.1
155
+ tqdm==4.64.0
156
+ traitlets==5.1.1
157
+ transformers==4.21.0.dev0
158
+ typing-extensions==4.2.0
159
+ urllib3==1.26.9
160
+ wandb==0.12.15
161
+ wcwidth==0.2.5
162
+ webencodings==0.5.1
163
+ websocket-client==1.3.2
164
+ wheel==0.37.1
165
+ xxhash==3.0.0
166
+ yarl==1.7.2
167
+ zipp==3.8.0
wandb/run-20220828_085247-2hx8pk65/files/wandb-metadata.json ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.11.0-1028-gcp-x86_64-with-glibc2.29",
3
+ "python": "3.8.10",
4
+ "heartbeatAt": "2022-08-28T08:52:48.553677",
5
+ "startedAt": "2022-08-28T08:52:47.513374",
6
+ "docker": null,
7
+ "cpu_count": 96,
8
+ "cuda": null,
9
+ "args": [
10
+ "--dataset_name=librispeech_asr",
11
+ "--model_name_or_path=sanchit-gandhi/flax-wav2vec2-2-bart-large-scan",
12
+ "--dataset_config_name=all",
13
+ "--train_split_name=train.clean.100+train.clean.360+train.other.500",
14
+ "--eval_split_name=validation.clean",
15
+ "--test_split_name=validation.other+test.clean+test.other",
16
+ "--text_column_name=text",
17
+ "--id_column_name=id",
18
+ "--output_dir=./",
19
+ "--wandb_project=librispeech_960h",
20
+ "--wandb_name=flax-wav2vec2-2-bart-large-ls-960h-black-box",
21
+ "--dataset_cache_dir=/home/sanchitgandhi/cache/huggingface/datasets",
22
+ "--per_device_train_batch_size=8",
23
+ "--per_device_eval_batch_size=4",
24
+ "--learning_rate=1e-4",
25
+ "--warmup_steps=500",
26
+ "--logging_steps=25",
27
+ "--max_steps=50000",
28
+ "--eval_steps=10000",
29
+ "--save_steps=10000",
30
+ "--generation_max_length=200",
31
+ "--generation_num_beams=5",
32
+ "--generation_length_penalty=1.2",
33
+ "--hidden_dropout=0.2",
34
+ "--activation_dropout=0.2",
35
+ "--feat_proj_dropout=0.2",
36
+ "--overwrite_output_dir",
37
+ "--gradient_checkpointing",
38
+ "--freeze_feature_encoder",
39
+ "--predict_with_generate",
40
+ "--do_lower_case",
41
+ "--do_eval",
42
+ "--do_train",
43
+ "--do_predict",
44
+ "--push_to_hub",
45
+ "--use_auth_token"
46
+ ],
47
+ "state": "running",
48
+ "program": "run_flax_speech_recognition_seq2seq.py",
49
+ "codePath": "run_flax_speech_recognition_seq2seq.py",
50
+ "git": {
51
+ "remote": "https://huggingface.co/sanchit-gandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box",
52
+ "commit": "140399a622e2a82685fa4b9727f3d970b8bef9e0"
53
+ },
54
+ "email": "sanchit@huggingface.co",
55
+ "root": "/home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box",
56
+ "host": "t1v-n-5966b949-w-0",
57
+ "username": "sanchitgandhi",
58
+ "executable": "/home/sanchitgandhi/hf/bin/python"
59
+ }
wandb/run-20220828_085247-2hx8pk65/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"train/decoder_grad_norm": 0.5876523852348328, "train/decoder_param_norm": 1057.45703125, "train/encoder_grad_norm": 0.38440409302711487, "train/encoder_param_norm": 2316.3564453125, "train/grad_norm": 0.7022120952606201, "layer_grad_norm/": {"decoder": {"model": {"decoder": {"embed_positions": {"embedding": 0.10323784500360489}, "embed_tokens": {"embedding": 0.16808316111564636}, "layernorm_embedding": {"bias": 0.03703528642654419, "scale": 0.060806743800640106}, "layers": {"FlaxBartDecoderLayers": {"encoder_attn": {"k_proj": {"bias": 1.75027107616188e-05, "kernel": 0.030463965609669685}, "out_proj": {"bias": 0.024376848712563515, "kernel": 0.08760593086481094}, "q_proj": {"bias": 0.0016024636570364237, "kernel": 0.034829143434762955}, "v_proj": {"bias": 0.04787713289260864, "kernel": 0.07169140875339508}}, "encoder_attn_layer_norm": {"bias": 0.03529948368668556, "scale": 0.0380270853638649}, "fc1": {"bias": 0.013248836621642113, "kernel": 0.33658137917518616}, "fc2": {"bias": 0.030859898775815964, "kernel": 0.2677602767944336}, "final_layer_norm": {"bias": 0.1120176762342453, "scale": 0.05825764685869217}, "self_attn": {"k_proj": {"bias": 6.563532224390656e-06, "kernel": 0.047542572021484375}, "out_proj": {"bias": 0.068998321890831, "kernel": 0.15063460171222687}, "q_proj": {"bias": 0.003958633169531822, "kernel": 0.05425203591585159}, "v_proj": {"bias": 0.07329808175563812, "kernel": 0.198069229722023}}, "self_attn_layer_norm": {"bias": 0.023308640345931053, "scale": 0.030806636437773705}}}}}}, "encoder": {"adapter": {"layers": {"0": {"conv": {"bias": 0.04864540696144104, "kernel": 0.133722722530365}}, "1": {"conv": {"bias": 0.04470941796898842, "kernel": 0.09400613605976105}}, "2": {"conv": {"bias": 0.05692768096923828, "kernel": 0.1417897492647171}}}}, "encoder": {"layer_norm": {"bias": 0.16896693408489227, "scale": 0.08190205693244934}, "layers": {"FlaxWav2Vec2EncoderLayers": {"attention": {"k_proj": {"bias": 5.699832854588749e-06, "kernel": 0.03451818600296974}, "out_proj": {"bias": 0.004949449095875025, "kernel": 0.0711507499217987}, "q_proj": {"bias": 0.006232084706425667, "kernel": 0.03630899265408516}, "v_proj": {"bias": 0.021894006058573723, "kernel": 0.0699479877948761}}, "feed_forward": {"intermediate_dense": {"bias": 0.010628663003444672, "kernel": 0.08824677765369415}, "output_dense": {"bias": 0.0046577295288443565, "kernel": 0.07864432781934738}}, "final_layer_norm": {"bias": 0.053700175136327744, "scale": 0.06233147531747818}, "layer_norm": {"bias": 0.09289932250976562, "scale": 0.07505689561367035}}}, "pos_conv_embed": {"conv": {"bias": 0.001811191556043923, "weight_g": 0.04629991203546524, "weight_v": 0.05902065336704254}}}, "feature_extractor": {"conv_layers": {"0": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "1": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "2": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "3": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "4": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "5": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "6": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}}}, "feature_projection": {"layer_norm": {"bias": 0.01040860079228878, "scale": 0.009696024470031261}, "projection": {"bias": 0.002452271291986108, "kernel": 0.06397733092308044}}, "masked_spec_embed": 0.0}}, "layer_param_norm/": {"decoder": {"model": {"decoder": {"embed_positions": {"embedding": 58.57985305786133}, "embed_tokens": {"embedding": 628.9428100585938}, "layernorm_embedding": {"bias": 2.4099645614624023, "scale": 13.944293022155762}, "layers": {"FlaxBartDecoderLayers": {"encoder_attn": {"k_proj": {"bias": 47.96258544921875, "kernel": 330.1817932128906}, "out_proj": {"bias": 6.197176456451416, "kernel": 226.72259521484375}, "q_proj": {"bias": 20.796918869018555, "kernel": 337.1412658691406}, "v_proj": {"bias": 3.727905035018921, "kernel": 230.9994354248047}}, "encoder_attn_layer_norm": {"bias": 10.427277565002441, "scale": 56.38846206665039}, "fc1": {"bias": 25.47351837158203, "kernel": 339.21954345703125}, "fc2": {"bias": 7.897115707397461, "kernel": 243.82398986816406}, "final_layer_norm": {"bias": 4.000784873962402, "scale": 63.70562744140625}, "self_attn": {"k_proj": {"bias": 59.513954162597656, "kernel": 278.91595458984375}, "out_proj": {"bias": 3.8339650630950928, "kernel": 131.7364501953125}, "q_proj": {"bias": 32.09528732299805, "kernel": 282.0332336425781}, "v_proj": {"bias": 2.626418352127075, "kernel": 140.15884399414062}}, "self_attn_layer_norm": {"bias": 8.851421356201172, "scale": 84.72929382324219}}}}}}, "encoder": {"adapter": {"layers": {"0": {"conv": {"bias": 0.5224539637565613, "kernel": 58.06698226928711}}, "1": {"conv": {"bias": 0.6238547563552856, "kernel": 55.76792907714844}}, "2": {"conv": {"bias": 0.8834269046783447, "kernel": 55.83806610107422}}}}, "encoder": {"layer_norm": {"bias": 0.2885725498199463, "scale": 4.501636505126953}, "layers": {"FlaxWav2Vec2EncoderLayers": {"attention": {"k_proj": {"bias": 19.359642028808594, "kernel": 551.2367553710938}, "out_proj": {"bias": 16.819419860839844, "kernel": 703.838134765625}, "q_proj": {"bias": 40.78517532348633, "kernel": 543.7529907226562}, "v_proj": {"bias": 15.60958194732666, "kernel": 695.4569091796875}}, "feed_forward": {"intermediate_dense": {"bias": 24.515138626098633, "kernel": 1373.99365234375}, "output_dense": {"bias": 20.76974868774414, "kernel": 1299.6435546875}}, "final_layer_norm": {"bias": 32.476783752441406, "scale": 141.65736389160156}, "layer_norm": {"bias": 7.329699516296387, "scale": 45.53441619873047}}}, "pos_conv_embed": {"conv": {"bias": 15.283638954162598, "weight_g": 21.029205322265625, "weight_v": 212.9462127685547}}}, "feature_extractor": {"conv_layers": {"0": {"conv": {"bias": 0.5982058644294739, "kernel": 8.08896541595459}, "layer_norm": {"bias": 10.069783210754395, "scale": 10.451257705688477}}, "1": {"conv": {"bias": 4.74075174331665, "kernel": 90.8435287475586}, "layer_norm": {"bias": 6.922820091247559, "scale": 19.5467586517334}}, "2": {"conv": {"bias": 6.7732415199279785, "kernel": 146.13897705078125}, "layer_norm": {"bias": 9.044225692749023, "scale": 19.424888610839844}}, "3": {"conv": {"bias": 5.224758148193359, "kernel": 159.10508728027344}, "layer_norm": {"bias": 8.319666862487793, "scale": 17.64743423461914}}, "4": {"conv": {"bias": 4.434978008270264, "kernel": 157.35813903808594}, "layer_norm": {"bias": 9.193974494934082, "scale": 15.562357902526855}}, "5": {"conv": {"bias": 5.297643661499023, "kernel": 131.1835174560547}, "layer_norm": {"bias": 10.735219955444336, "scale": 13.812533378601074}}, "6": {"conv": {"bias": 5.615579128265381, "kernel": 136.41822814941406}, "layer_norm": {"bias": 12.515308380126953, "scale": 11.152680397033691}}}}, "feature_projection": {"layer_norm": {"bias": 9.422893524169922, "scale": 27.84585189819336}, "projection": {"bias": 4.289161682128906, "kernel": 88.30554962158203}}, "masked_spec_embed": 26.247730255126953}}, "train/learning_rate": 8.086059824563563e-05, "train/loss": 0.1043805480003357, "train/param_norm": 2546.3154296875, "_timestamp": 1661727380, "_runtime": 50613, "_step": 9975}
wandb/run-20220828_085247-2hx8pk65/logs/debug-internal.log ADDED
The diff for this file is too large to render. See raw diff
 
wandb/run-20220828_085247-2hx8pk65/logs/debug.log ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2022-08-28 08:52:47,515 INFO MainThread:53859 [wandb_setup.py:_flush():75] Loading settings from /home/sanchitgandhi/.config/wandb/settings
2
+ 2022-08-28 08:52:47,515 INFO MainThread:53859 [wandb_setup.py:_flush():75] Loading settings from /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/settings
3
+ 2022-08-28 08:52:47,515 INFO MainThread:53859 [wandb_setup.py:_flush():75] Loading settings from environment variables: {}
4
+ 2022-08-28 08:52:47,515 INFO MainThread:53859 [wandb_setup.py:_flush():75] Inferring run settings from compute environment: {'program_relpath': 'run_flax_speech_recognition_seq2seq.py', 'program': 'run_flax_speech_recognition_seq2seq.py'}
5
+ 2022-08-28 08:52:47,515 INFO MainThread:53859 [wandb_init.py:_log_setup():437] Logging user logs to /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_085247-2hx8pk65/logs/debug.log
6
+ 2022-08-28 08:52:47,515 INFO MainThread:53859 [wandb_init.py:_log_setup():438] Logging internal logs to /home/sanchitgandhi/flax-wav2vec2-2-bart-large-ls-960h-black-box/wandb/run-20220828_085247-2hx8pk65/logs/debug-internal.log
7
+ 2022-08-28 08:52:47,515 INFO MainThread:53859 [wandb_init.py:init():471] calling init triggers
8
+ 2022-08-28 08:52:47,515 INFO MainThread:53859 [wandb_init.py:init():474] wandb.init called with sweep_config: {}
9
+ config: {}
10
+ 2022-08-28 08:52:47,515 INFO MainThread:53859 [wandb_init.py:init():524] starting backend
11
+ 2022-08-28 08:52:47,515 INFO MainThread:53859 [backend.py:_multiprocessing_setup():97] multiprocessing start_methods=fork,spawn,forkserver, using: spawn
12
+ 2022-08-28 08:52:47,546 INFO MainThread:53859 [backend.py:ensure_launched():217] starting backend process...
13
+ 2022-08-28 08:52:47,572 INFO MainThread:53859 [backend.py:ensure_launched():222] started backend process with pid: 54989
14
+ 2022-08-28 08:52:47,574 INFO MainThread:53859 [wandb_init.py:init():533] backend started and connected
15
+ 2022-08-28 08:52:47,585 INFO MainThread:53859 [wandb_init.py:init():597] updated telemetry
16
+ 2022-08-28 08:52:47,649 INFO MainThread:53859 [wandb_init.py:init():628] communicating run to backend with 30 second timeout
17
+ 2022-08-28 08:52:48,479 INFO MainThread:53859 [wandb_run.py:_on_init():1923] communicating current version
18
+ 2022-08-28 08:52:48,543 INFO MainThread:53859 [wandb_run.py:_on_init():1927] got version response upgrade_message: "wandb version 0.13.2 is available! To upgrade, please run:\n $ pip install wandb --upgrade"
19
+
20
+ 2022-08-28 08:52:48,543 INFO MainThread:53859 [wandb_init.py:init():659] starting run threads in backend
21
+ 2022-08-28 08:52:48,582 INFO MainThread:53859 [wandb_run.py:_console_start():1897] atexit reg
22
+ 2022-08-28 08:52:48,582 INFO MainThread:53859 [wandb_run.py:_redirect():1770] redirect: SettingsConsole.REDIRECT
23
+ 2022-08-28 08:52:48,583 INFO MainThread:53859 [wandb_run.py:_redirect():1775] Redirecting console.
24
+ 2022-08-28 08:52:48,585 INFO MainThread:53859 [wandb_run.py:_redirect():1831] Redirects installed.
25
+ 2022-08-28 08:52:48,585 INFO MainThread:53859 [wandb_init.py:init():684] run started, returning control to user process
wandb/run-20220828_085247-2hx8pk65/run-2hx8pk65.wandb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:094e92de49c7288ddfac32754880e9359cb30d1406e2d3bdff46b108a8c651aa
3
+ size 4469804