esc-bencher commited on
Commit
da05288
1 Parent(s): 4d5a92d

Add training scripts and weights

Browse files
README.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ tags:
5
+ - esc
6
+ datasets:
7
+ - spgispeech
8
+ ---
9
+ To reproduce this run, execute:
10
+ ```python
11
+ #!/usr/bin/env bash
12
+ python run_flax_speech_recognition_seq2seq.py \
13
+ --dataset_name="esc-benchmark/esc-datasets" \
14
+ --model_name_or_path="esc-benchmark/wav2vec2-aed-pretrained" \
15
+ --dataset_config_name="spgispeech" \
16
+ --output_dir="./" \
17
+ --wandb_name="wav2vec2-aed-spgispeech" \
18
+ --wandb_project="wav2vec2-aed" \
19
+ --per_device_train_batch_size="8" \
20
+ --per_device_eval_batch_size="2" \
21
+ --learning_rate="1e-4" \
22
+ --warmup_steps="500" \
23
+ --logging_steps="25" \
24
+ --max_steps="50001" \
25
+ --eval_steps="10000" \
26
+ --save_steps="10000" \
27
+ --generation_max_length="40" \
28
+ --generation_num_beams="1" \
29
+ --final_generation_max_length="225" \
30
+ --final_generation_num_beams="14" \
31
+ --generation_length_penalty="1.6" \
32
+ --overwrite_output_dir \
33
+ --gradient_checkpointing \
34
+ --freeze_feature_encoder \
35
+ --predict_with_generate \
36
+ --do_eval \
37
+ --do_train \
38
+ --do_predict \
39
+ --push_to_hub \
40
+ --use_auth_token
41
+ ```
config.json ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "sanchit-gandhi/flax-wav2vec2-2-bart-large-scan",
3
+ "architectures": [
4
+ "SpeechEncoderDecoderModel"
5
+ ],
6
+ "decoder": {
7
+ "_name_or_path": "",
8
+ "activation_dropout": 0.1,
9
+ "activation_function": "gelu",
10
+ "add_bias_logits": false,
11
+ "add_cross_attention": true,
12
+ "add_final_layer_norm": false,
13
+ "architectures": [
14
+ "BartModel"
15
+ ],
16
+ "attention_dropout": 0.1,
17
+ "bad_words_ids": null,
18
+ "bos_token_id": 0,
19
+ "chunk_size_feed_forward": 0,
20
+ "classif_dropout": 0.1,
21
+ "classifier_dropout": 0.0,
22
+ "cross_attention_hidden_size": null,
23
+ "d_model": 1024,
24
+ "decoder_attention_heads": 16,
25
+ "decoder_ffn_dim": 4096,
26
+ "decoder_layerdrop": 0.0,
27
+ "decoder_layers": 12,
28
+ "decoder_start_token_id": 2,
29
+ "diversity_penalty": 0.0,
30
+ "do_sample": false,
31
+ "dropout": 0.1,
32
+ "early_stopping": true,
33
+ "encoder_attention_heads": 16,
34
+ "encoder_ffn_dim": 4096,
35
+ "encoder_layerdrop": 0.0,
36
+ "encoder_layers": 12,
37
+ "encoder_no_repeat_ngram_size": 0,
38
+ "eos_token_id": 2,
39
+ "exponential_decay_length_penalty": null,
40
+ "finetuning_task": null,
41
+ "forced_bos_token_id": 0,
42
+ "forced_eos_token_id": 2,
43
+ "fuse_matmuls": false,
44
+ "gradient_checkpointing": true,
45
+ "id2label": {
46
+ "0": "LABEL_0",
47
+ "1": "LABEL_1",
48
+ "2": "LABEL_2"
49
+ },
50
+ "init_std": 0.02,
51
+ "is_decoder": true,
52
+ "is_encoder_decoder": false,
53
+ "label2id": {
54
+ "LABEL_0": 0,
55
+ "LABEL_1": 1,
56
+ "LABEL_2": 2
57
+ },
58
+ "length_penalty": 1.0,
59
+ "max_length": 20,
60
+ "max_position_embeddings": 1024,
61
+ "min_length": 0,
62
+ "model_type": "bart",
63
+ "no_repeat_ngram_size": 3,
64
+ "normalize_before": false,
65
+ "num_beam_groups": 1,
66
+ "num_beams": 4,
67
+ "num_hidden_layers": 12,
68
+ "num_return_sequences": 1,
69
+ "output_attentions": false,
70
+ "output_hidden_states": false,
71
+ "output_scores": false,
72
+ "pad_token_id": 1,
73
+ "prefix": null,
74
+ "problem_type": null,
75
+ "pruned_heads": {},
76
+ "remove_invalid_values": false,
77
+ "repetition_penalty": 1.0,
78
+ "return_dict": true,
79
+ "return_dict_in_generate": false,
80
+ "scale_embedding": false,
81
+ "sep_token_id": null,
82
+ "task_specific_params": {
83
+ "summarization": {
84
+ "length_penalty": 1.0,
85
+ "max_length": 128,
86
+ "min_length": 12,
87
+ "num_beams": 4
88
+ },
89
+ "summarization_cnn": {
90
+ "length_penalty": 2.0,
91
+ "max_length": 142,
92
+ "min_length": 56,
93
+ "num_beams": 4
94
+ },
95
+ "summarization_xsum": {
96
+ "length_penalty": 1.0,
97
+ "max_length": 62,
98
+ "min_length": 11,
99
+ "num_beams": 6
100
+ }
101
+ },
102
+ "temperature": 1.0,
103
+ "tie_encoder_decoder": false,
104
+ "tie_word_embeddings": true,
105
+ "tokenizer_class": null,
106
+ "top_k": 50,
107
+ "top_p": 1.0,
108
+ "torch_dtype": "float32",
109
+ "torchscript": false,
110
+ "transformers_version": "4.21.0.dev0",
111
+ "typical_p": 1.0,
112
+ "use_bfloat16": false,
113
+ "use_cache": true,
114
+ "use_scan": true,
115
+ "vocab_size": 50265
116
+ },
117
+ "decoder_start_token_id": 0,
118
+ "encoder": {
119
+ "_name_or_path": "",
120
+ "activation_dropout": 0.1,
121
+ "adapter_kernel_size": 3,
122
+ "adapter_stride": 2,
123
+ "add_adapter": true,
124
+ "add_cross_attention": false,
125
+ "apply_spec_augment": true,
126
+ "architectures": [
127
+ "Wav2Vec2ForPreTraining"
128
+ ],
129
+ "attention_dropout": 0.1,
130
+ "bad_words_ids": null,
131
+ "bos_token_id": 1,
132
+ "chunk_size_feed_forward": 0,
133
+ "classifier_proj_size": 256,
134
+ "codevector_dim": 768,
135
+ "contrastive_logits_temperature": 0.1,
136
+ "conv_bias": true,
137
+ "conv_dim": [
138
+ 512,
139
+ 512,
140
+ 512,
141
+ 512,
142
+ 512,
143
+ 512,
144
+ 512
145
+ ],
146
+ "conv_kernel": [
147
+ 10,
148
+ 3,
149
+ 3,
150
+ 3,
151
+ 3,
152
+ 2,
153
+ 2
154
+ ],
155
+ "conv_stride": [
156
+ 5,
157
+ 2,
158
+ 2,
159
+ 2,
160
+ 2,
161
+ 2,
162
+ 2
163
+ ],
164
+ "cross_attention_hidden_size": null,
165
+ "ctc_loss_reduction": "sum",
166
+ "ctc_zero_infinity": false,
167
+ "decoder_start_token_id": null,
168
+ "diversity_loss_weight": 0.1,
169
+ "diversity_penalty": 0.0,
170
+ "do_sample": false,
171
+ "do_stable_layer_norm": true,
172
+ "early_stopping": false,
173
+ "encoder_no_repeat_ngram_size": 0,
174
+ "eos_token_id": 2,
175
+ "exponential_decay_length_penalty": null,
176
+ "feat_extract_activation": "gelu",
177
+ "feat_extract_dropout": 0.0,
178
+ "feat_extract_norm": "layer",
179
+ "feat_proj_dropout": 0.0,
180
+ "feat_quantizer_dropout": 0.0,
181
+ "final_dropout": 0.0,
182
+ "finetuning_task": null,
183
+ "forced_bos_token_id": null,
184
+ "forced_eos_token_id": null,
185
+ "fuse_matmuls": false,
186
+ "gradient_checkpointing": true,
187
+ "hidden_act": "gelu",
188
+ "hidden_dropout": 0.1,
189
+ "hidden_dropout_prob": 0.1,
190
+ "hidden_size": 1024,
191
+ "id2label": {
192
+ "0": "LABEL_0",
193
+ "1": "LABEL_1"
194
+ },
195
+ "initializer_range": 0.02,
196
+ "intermediate_size": 4096,
197
+ "is_decoder": false,
198
+ "is_encoder_decoder": false,
199
+ "label2id": {
200
+ "LABEL_0": 0,
201
+ "LABEL_1": 1
202
+ },
203
+ "layer_norm_eps": 1e-05,
204
+ "layerdrop": 0.0,
205
+ "length_penalty": 1.0,
206
+ "mask_feature_length": 10,
207
+ "mask_feature_min_masks": 0,
208
+ "mask_feature_prob": 0.0,
209
+ "mask_time_length": 10,
210
+ "mask_time_min_masks": 2,
211
+ "mask_time_prob": 0.1,
212
+ "max_length": 20,
213
+ "min_length": 0,
214
+ "model_type": "wav2vec2",
215
+ "no_repeat_ngram_size": 0,
216
+ "num_adapter_layers": 3,
217
+ "num_attention_heads": 16,
218
+ "num_beam_groups": 1,
219
+ "num_beams": 1,
220
+ "num_codevector_groups": 2,
221
+ "num_codevectors_per_group": 320,
222
+ "num_conv_pos_embedding_groups": 16,
223
+ "num_conv_pos_embeddings": 128,
224
+ "num_feat_extract_layers": 7,
225
+ "num_hidden_layers": 24,
226
+ "num_negatives": 100,
227
+ "num_return_sequences": 1,
228
+ "output_attentions": false,
229
+ "output_hidden_size": 1024,
230
+ "output_hidden_states": false,
231
+ "output_scores": false,
232
+ "pad_token_id": 0,
233
+ "prefix": null,
234
+ "problem_type": null,
235
+ "proj_codevector_dim": 768,
236
+ "pruned_heads": {},
237
+ "remove_invalid_values": false,
238
+ "repetition_penalty": 1.0,
239
+ "return_dict": true,
240
+ "return_dict_in_generate": false,
241
+ "sep_token_id": null,
242
+ "task_specific_params": null,
243
+ "tdnn_dilation": [
244
+ 1,
245
+ 2,
246
+ 3,
247
+ 1,
248
+ 1
249
+ ],
250
+ "tdnn_dim": [
251
+ 512,
252
+ 512,
253
+ 512,
254
+ 512,
255
+ 1500
256
+ ],
257
+ "tdnn_kernel": [
258
+ 5,
259
+ 3,
260
+ 3,
261
+ 1,
262
+ 1
263
+ ],
264
+ "temperature": 1.0,
265
+ "tie_encoder_decoder": false,
266
+ "tie_word_embeddings": true,
267
+ "tokenizer_class": null,
268
+ "top_k": 50,
269
+ "top_p": 1.0,
270
+ "torch_dtype": null,
271
+ "torchscript": false,
272
+ "transformers_version": "4.21.0.dev0",
273
+ "typical_p": 1.0,
274
+ "use_bfloat16": false,
275
+ "use_scan": true,
276
+ "use_weighted_layer_sum": false,
277
+ "vocab_size": 32,
278
+ "xvector_output_dim": 512
279
+ },
280
+ "eos_token_id": 2,
281
+ "is_encoder_decoder": true,
282
+ "max_length": 40,
283
+ "model_type": "speech-encoder-decoder",
284
+ "pad_token_id": 1,
285
+ "processor_class": "Wav2Vec2Processor",
286
+ "tie_word_embeddings": false,
287
+ "transformers_version": null,
288
+ "use_cache": false
289
+ }
flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7af9662fb94f987aa7df7c25c8cbc68679a0c709e18f799cd6f85e3c3db6fd22
3
+ size 2353616717
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
models/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from models.configuration_bart import BartConfig
2
+ from models.configuration_wav2vec2 import Wav2Vec2Config
3
+ from models.configuration_speech_encoder_decoder import SpeechEncoderDecoderConfig
4
+ from models.modeling_flax_wav2vec2 import FlaxWav2Vec2Model, FlaxWav2Vec2Module, FlaxWav2Vec2ForCTC, FlaxWav2Vec2ForCTCModule
5
+ from models.modeling_flax_bart import FlaxBartForCausalLM, FlaxBartForCausalLMModule
6
+ from models.modeling_flax_speech_encoder_decoder import FlaxSpeechEncoderDecoderModel
models/configuration_bart.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ BART model configuration"""
16
+ import warnings
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+ BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
25
+ "facebook/bart-large": "https://huggingface.co/facebook/bart-large/resolve/main/config.json",
26
+ # See all BART models at https://huggingface.co/models?filter=bart
27
+ }
28
+
29
+
30
+ class BartConfig(PretrainedConfig):
31
+ r"""
32
+ This is the configuration class to store the configuration of a [`BartModel`]. It is used to instantiate a BART
33
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
34
+ defaults will yield a similar configuration to that of the BART
35
+ [facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture.
36
+
37
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
38
+ documentation from [`PretrainedConfig`] for more information.
39
+
40
+
41
+ Args:
42
+ vocab_size (`int`, *optional*, defaults to 50265):
43
+ Vocabulary size of the BART model. Defines the number of different tokens that can be represented by the
44
+ `inputs_ids` passed when calling [`BartModel`] or [`TFBartModel`].
45
+ d_model (`int`, *optional*, defaults to 1024):
46
+ Dimensionality of the layers and the pooler layer.
47
+ encoder_layers (`int`, *optional*, defaults to 12):
48
+ Number of encoder layers.
49
+ decoder_layers (`int`, *optional*, defaults to 12):
50
+ Number of decoder layers.
51
+ encoder_attention_heads (`int`, *optional*, defaults to 16):
52
+ Number of attention heads for each attention layer in the Transformer encoder.
53
+ decoder_attention_heads (`int`, *optional*, defaults to 16):
54
+ Number of attention heads for each attention layer in the Transformer decoder.
55
+ decoder_ffn_dim (`int`, *optional*, defaults to 4096):
56
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
57
+ encoder_ffn_dim (`int`, *optional*, defaults to 4096):
58
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
59
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
60
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
61
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
62
+ dropout (`float`, *optional*, defaults to 0.1):
63
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
64
+ attention_dropout (`float`, *optional*, defaults to 0.0):
65
+ The dropout ratio for the attention probabilities.
66
+ activation_dropout (`float`, *optional*, defaults to 0.0):
67
+ The dropout ratio for activations inside the fully connected layer.
68
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
69
+ The dropout ratio for classifier.
70
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
71
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
72
+ just in case (e.g., 512 or 1024 or 2048).
73
+ init_std (`float`, *optional*, defaults to 0.02):
74
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
75
+ encoder_layerdrop: (`float`, *optional*, defaults to 0.0):
76
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
77
+ for more details.
78
+ decoder_layerdrop: (`float`, *optional*, defaults to 0.0):
79
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
80
+ for more details.
81
+ scale_embedding (`bool`, *optional*, defaults to `False`):
82
+ Scale embeddings by diving by sqrt(d_model).
83
+ use_cache (`bool`, *optional*, defaults to `True`):
84
+ Whether or not the model should return the last key/values attentions (not used by all models).
85
+ num_labels: (`int`, *optional*, defaults to 3):
86
+ The number of labels to use in [`BartForSequenceClassification`].
87
+ forced_eos_token_id (`int`, *optional*, defaults to 2):
88
+ The id of the token to force as the last generated token when `max_length` is reached. Usually set to
89
+ `eos_token_id`.
90
+ use_scan (`bool`, *optional*, defaults to `False`):
91
+ Whether or not to use nn.scan in the Flax Bart attention layers.
92
+
93
+ Example:
94
+
95
+ ```python
96
+ >>> from transformers import BartModel, BartConfig
97
+
98
+ >>> # Initializing a BART facebook/bart-large style configuration
99
+ >>> configuration = BartConfig()
100
+
101
+ >>> # Initializing a model from the facebook/bart-large style configuration
102
+ >>> model = BartModel(configuration)
103
+
104
+ >>> # Accessing the model configuration
105
+ >>> configuration = model.config
106
+ ```"""
107
+ model_type = "bart"
108
+ keys_to_ignore_at_inference = ["past_key_values"]
109
+ attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
110
+
111
+ def __init__(
112
+ self,
113
+ vocab_size=50265,
114
+ max_position_embeddings=1024,
115
+ encoder_layers=12,
116
+ encoder_ffn_dim=4096,
117
+ encoder_attention_heads=16,
118
+ decoder_layers=12,
119
+ decoder_ffn_dim=4096,
120
+ decoder_attention_heads=16,
121
+ encoder_layerdrop=0.0,
122
+ decoder_layerdrop=0.0,
123
+ activation_function="gelu",
124
+ d_model=1024,
125
+ dropout=0.1,
126
+ attention_dropout=0.0,
127
+ activation_dropout=0.0,
128
+ init_std=0.02,
129
+ classifier_dropout=0.0,
130
+ scale_embedding=False,
131
+ use_cache=True,
132
+ use_scan=False,
133
+ fuse_matmuls=False,
134
+ num_labels=3,
135
+ pad_token_id=1,
136
+ bos_token_id=0,
137
+ eos_token_id=2,
138
+ is_encoder_decoder=True,
139
+ decoder_start_token_id=2,
140
+ forced_eos_token_id=2,
141
+ **kwargs
142
+ ):
143
+ self.vocab_size = vocab_size
144
+ self.max_position_embeddings = max_position_embeddings
145
+ self.d_model = d_model
146
+ self.encoder_ffn_dim = encoder_ffn_dim
147
+ self.encoder_layers = encoder_layers
148
+ self.encoder_attention_heads = encoder_attention_heads
149
+ self.decoder_ffn_dim = decoder_ffn_dim
150
+ self.decoder_layers = decoder_layers
151
+ self.decoder_attention_heads = decoder_attention_heads
152
+ self.dropout = dropout
153
+ self.attention_dropout = attention_dropout
154
+ self.activation_dropout = activation_dropout
155
+ self.activation_function = activation_function
156
+ self.init_std = init_std
157
+ self.encoder_layerdrop = encoder_layerdrop
158
+ self.decoder_layerdrop = decoder_layerdrop
159
+ self.classifier_dropout = classifier_dropout
160
+ self.use_cache = use_cache
161
+ self.use_scan = use_scan
162
+ self.fuse_matmuls = fuse_matmuls
163
+ self.num_hidden_layers = encoder_layers
164
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
165
+
166
+ super().__init__(
167
+ num_labels=num_labels,
168
+ pad_token_id=pad_token_id,
169
+ bos_token_id=bos_token_id,
170
+ eos_token_id=eos_token_id,
171
+ is_encoder_decoder=is_encoder_decoder,
172
+ decoder_start_token_id=decoder_start_token_id,
173
+ forced_eos_token_id=forced_eos_token_id,
174
+ **kwargs,
175
+ )
176
+
177
+ # ensure backward compatibility for BART CNN models
178
+ if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
179
+ self.forced_bos_token_id = self.bos_token_id
180
+ warnings.warn(
181
+ f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
182
+ "The config can simply be saved and uploaded again to be fixed."
183
+ )
models/configuration_speech_encoder_decoder.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import copy
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.utils import logging
21
+ from models.configuration_wav2vec2 import Wav2Vec2Config
22
+ from models.configuration_bart import BartConfig
23
+ from transformers import AutoConfig
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ class SpeechEncoderDecoderConfig(PretrainedConfig):
30
+ r"""
31
+ [`SpeechEncoderDecoderConfig`] is the configuration class to store the configuration of a
32
+ [`SpeechEncoderDecoderModel`]. It is used to instantiate an Encoder Decoder model according to the specified
33
+ arguments, defining the encoder and decoder configs.
34
+
35
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
+ documentation from [`PretrainedConfig`] for more information.
37
+
38
+ Args:
39
+ kwargs (*optional*):
40
+ Dictionary of keyword arguments. Notably:
41
+
42
+ - **encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
43
+ the encoder config.
44
+ - **decoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
45
+ the decoder config.
46
+
47
+ Examples:
48
+
49
+ ```python
50
+ >>> from transformers import BertConfig, Wav2Vec2Config, SpeechEncoderDecoderConfig, SpeechEncoderDecoderModel
51
+
52
+ >>> # Initializing a Wav2Vec2 & BERT style configuration
53
+ >>> config_encoder = Wav2Vec2Config()
54
+ >>> config_decoder = BertConfig()
55
+
56
+ >>> config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)
57
+
58
+ >>> # Initializing a Wav2Vec2Bert model from a Wav2Vec2 & bert-base-uncased style configurations
59
+ >>> model = SpeechEncoderDecoderModel(config=config)
60
+
61
+ >>> # Accessing the model configuration
62
+ >>> config_encoder = model.config.encoder
63
+ >>> config_decoder = model.config.decoder
64
+ >>> # set decoder config to causal lm
65
+ >>> config_decoder.is_decoder = True
66
+ >>> config_decoder.add_cross_attention = True
67
+
68
+ >>> # Saving the model, including its configuration
69
+ >>> model.save_pretrained("my-model")
70
+
71
+ >>> # loading model and config from pretrained folder
72
+ >>> encoder_decoder_config = SpeechEncoderDecoderConfig.from_pretrained("my-model")
73
+ >>> model = SpeechEncoderDecoderModel.from_pretrained("my-model", config=encoder_decoder_config)
74
+ ```"""
75
+ model_type = "speech-encoder-decoder"
76
+ is_composition = True
77
+
78
+ def __init__(self, **kwargs):
79
+ super().__init__(**kwargs)
80
+ if "encoder" not in kwargs or "decoder" not in kwargs:
81
+ raise ValueError(
82
+ f"A configuraton of type {self.model_type} cannot be instantiated because not both `encoder` and `decoder` sub-configurations are passed, but only {kwargs}"
83
+ )
84
+
85
+ encoder_config = kwargs.pop("encoder")
86
+ decoder_config = kwargs.pop("decoder")
87
+
88
+ # TODO: Load configs from AutoConfig (as done in Transformers 🤗)
89
+ self.encoder = Wav2Vec2Config(**encoder_config)
90
+ self.decoder = BartConfig(**decoder_config)
91
+ self.is_encoder_decoder = True
92
+
93
+ @classmethod
94
+ def from_encoder_decoder_configs(
95
+ cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
96
+ ) -> PretrainedConfig:
97
+ r"""
98
+ Instantiate a [`SpeechEncoderDecoderConfig`] (or a derived class) from a pre-trained encoder model
99
+ configuration and decoder model configuration.
100
+
101
+ Returns:
102
+ [`SpeechEncoderDecoderConfig`]: An instance of a configuration object
103
+ """
104
+ logger.info("Setting `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
105
+ decoder_config.is_decoder = True
106
+ decoder_config.add_cross_attention = True
107
+
108
+ return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)
109
+
110
+ def to_dict(self):
111
+ """
112
+ Serializes this instance to a Python dictionary. Override the default *to_dict()* from *PretrainedConfig*.
113
+
114
+ Returns:
115
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
116
+ """
117
+ output = copy.deepcopy(self.__dict__)
118
+ output["encoder"] = self.encoder.to_dict()
119
+ output["decoder"] = self.decoder.to_dict()
120
+ output["model_type"] = self.__class__.model_type
121
+ return output
models/configuration_wav2vec2.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Wav2Vec2 model configuration"""
16
+
17
+ import functools
18
+ import operator
19
+
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.utils import logging
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
27
+ "facebook/wav2vec2-base-960h": "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/config.json",
28
+ # See all Wav2Vec2 models at https://huggingface.co/models?filter=wav2vec2
29
+ }
30
+
31
+
32
+ class Wav2Vec2Config(PretrainedConfig):
33
+ r"""
34
+ This is the configuration class to store the configuration of a [`Wav2Vec2Model`]. It is used to instantiate an
35
+ Wav2Vec2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
36
+ with the defaults will yield a similar configuration to that of the Wav2Vec2
37
+ [facebook/wav2vec2-base-960h](https://huggingface.co/facebook/wav2vec2-base-960h) architecture.
38
+
39
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
40
+ documentation from [`PretrainedConfig`] for more information.
41
+
42
+
43
+ Args:
44
+ vocab_size (`int`, *optional*, defaults to 32):
45
+ Vocabulary size of the Wav2Vec2 model. Defines the number of different tokens that can be represented by
46
+ the `inputs_ids` passed when calling [`Wav2Vec2Model`] or [`TFWav2Vec2Model`]. Vocabulary size of the
47
+ model. Defines the different tokens that can be represented by the *inputs_ids* passed to the forward
48
+ method of [`Wav2Vec2Model`].
49
+ hidden_size (`int`, *optional*, defaults to 768):
50
+ Dimensionality of the encoder layers and the pooler layer.
51
+ num_hidden_layers (`int`, *optional*, defaults to 12):
52
+ Number of hidden layers in the Transformer encoder.
53
+ num_attention_heads (`int`, *optional*, defaults to 12):
54
+ Number of attention heads for each attention layer in the Transformer encoder.
55
+ intermediate_size (`int`, *optional*, defaults to 3072):
56
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
57
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
58
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
59
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
60
+ hidden_dropout (`float`, *optional*, defaults to 0.1):
61
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
62
+ attention_dropout (`float`, *optional*, defaults to 0.1):
63
+ The dropout ratio for the attention probabilities.
64
+ final_dropout (`float`, *optional*, defaults to 0.1):
65
+ The dropout probability for the final projection layer of [`Wav2Vec2ForCTC`].
66
+ initializer_range (`float`, *optional*, defaults to 0.02):
67
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
68
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
69
+ The epsilon used by the layer normalization layers.
70
+ feat_extract_norm (`str`, *optional*, defaults to `"group"`):
71
+ The norm to be applied to 1D convolutional layers in feature encoder. One of `"group"` for group
72
+ normalization of only the first 1D convolutional layer or `"layer"` for layer normalization of all 1D
73
+ convolutional layers.
74
+ feat_proj_dropout (`float`, *optional*, defaults to 0.0):
75
+ The dropout probability for output of the feature encoder.
76
+ feat_extract_activation (`str, `optional`, defaults to `"gelu"`):
77
+ The non-linear activation function (function or string) in the 1D convolutional layers of the feature
78
+ extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
79
+ feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
80
+ The dropout probabilitiy for quantized feature encoder states.
81
+ conv_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
82
+ A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
83
+ feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
84
+ conv_stride (`Tuple[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
85
+ A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
86
+ of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*.
87
+ conv_kernel (`Tuple[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
88
+ A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
89
+ length of *conv_kernel* defines the number of convolutional layers and has to match the length of
90
+ *conv_dim*.
91
+ conv_bias (`bool`, *optional*, defaults to `False`):
92
+ Whether the 1D convolutional layers have a bias.
93
+ num_conv_pos_embeddings (`int`, *optional*, defaults to 128):
94
+ Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional
95
+ embeddings layer.
96
+ num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
97
+ Number of groups of 1D convolutional positional embeddings layer.
98
+ do_stable_layer_norm (`bool`, *optional*, defaults to `False`):
99
+ Whether to apply *stable* layer norm architecture of the Transformer encoder. `do_stable_layer_norm is
100
+ True` corresponds to applying layer norm before the attention layer, whereas `do_stable_layer_norm is
101
+ False` corresponds to applying layer norm after the attention layer.
102
+ apply_spec_augment (`bool`, *optional*, defaults to `True`):
103
+ Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see
104
+ [SpecAugment: A Simple Data Augmentation Method for Automatic Speech
105
+ Recognition](https://arxiv.org/abs/1904.08779).
106
+ mask_time_prob (`float`, *optional*, defaults to 0.05):
107
+ Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking
108
+ procecure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If
109
+ reasoning from the propability of each feature vector to be chosen as the start of the vector span to be
110
+ masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the
111
+ actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`.
112
+ mask_time_length (`int`, *optional*, defaults to 10):
113
+ Length of vector span along the time axis.
114
+ mask_time_min_masks (`int`, *optional*, defaults to 2),:
115
+ The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step,
116
+ irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length <
117
+ mask_time_min_masks''
118
+ mask_feature_prob (`float`, *optional*, defaults to 0.0):
119
+ Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The
120
+ masking procecure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over
121
+ the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector
122
+ span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap
123
+ may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is
124
+ True`.
125
+ mask_feature_length (`int`, *optional*, defaults to 10):
126
+ Length of vector span along the feature axis.
127
+ mask_feature_min_masks (`int`, *optional*, defaults to 0),:
128
+ The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
129
+ step, irrespectively of `mask_feature_prob`. Only relevant if
130
+ ''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks''
131
+ num_codevectors_per_group (`int`, *optional*, defaults to 320):
132
+ Number of entries in each quantization codebook (group).
133
+ num_codevector_groups (`int`, *optional*, defaults to 2):
134
+ Number of codevector groups for product codevector quantization.
135
+ contrastive_logits_temperature (`float`, *optional*, defaults to 0.1):
136
+ The temperature *kappa* in the contrastive loss.
137
+ feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
138
+ The dropout probabilitiy for the output of the feature encoder that's used by the quantizer.
139
+ num_negatives (`int`, *optional*, defaults to 100):
140
+ Number of negative samples for the contrastive loss.
141
+ codevector_dim (`int`, *optional*, defaults to 256):
142
+ Dimensionality of the quantized feature vectors.
143
+ proj_codevector_dim (`int`, *optional*, defaults to 256):
144
+ Dimensionality of the final projection of both the quantized and the transformer features.
145
+ diversity_loss_weight (`int`, *optional*, defaults to 0.1):
146
+ The weight of the codebook diversity loss component.
147
+ ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`):
148
+ Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
149
+ instance of [`Wav2Vec2ForCTC`].
150
+ ctc_zero_infinity (`bool`, *optional*, defaults to `False`):
151
+ Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
152
+ occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
153
+ of [`Wav2Vec2ForCTC`].
154
+ use_weighted_layer_sum (`bool`, *optional*, defaults to `False`):
155
+ Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
156
+ instance of [`Wav2Vec2ForSequenceClassification`].
157
+ classifier_proj_size (`int`, *optional*, defaults to 256):
158
+ Dimensionality of the projection before token mean-pooling for classification.
159
+ tdnn_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
160
+ A tuple of integers defining the number of output channels of each 1D convolutional layer in the *TDNN*
161
+ module of the *XVector* model. The length of *tdnn_dim* defines the number of *TDNN* layers.
162
+ tdnn_kernel (`Tuple[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
163
+ A tuple of integers defining the kernel size of each 1D convolutional layer in the *TDNN* module of the
164
+ *XVector* model. The length of *tdnn_kernel* has to match the length of *tdnn_dim*.
165
+ tdnn_dilation (`Tuple[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
166
+ A tuple of integers defining the dilation factor of each 1D convolutional layer in *TDNN* module of the
167
+ *XVector* model. The length of *tdnn_dilation* has to match the length of *tdnn_dim*.
168
+ xvector_output_dim (`int`, *optional*, defaults to 512):
169
+ Dimensionality of the *XVector* embedding vectors.
170
+ add_adapter (`bool`, *optional*, defaults to `False`):
171
+ Whether a convolutional network should be stacked on top of the Wav2Vec2 Encoder. Can be very useful for
172
+ warm-starting Wav2Vec2 for SpeechEncoderDecoder models.
173
+ adapter_kernel_size (`int`, *optional*, defaults to 3):
174
+ Kernel size of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
175
+ adapter_stride (`int`, *optional*, defaults to 2):
176
+ Stride of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
177
+ num_adapter_layers (`int`, *optional*, defaults to 3):
178
+ Number of convolutional layers that should be used in the adapter network. Only relevant if `add_adapter is
179
+ True`.
180
+ output_hidden_size (`int`, *optional*):
181
+ Dimensionality of the encoder output layer. If not defined, this defaults to *hidden-size*. Only relevant
182
+ if `add_adapter is True`.
183
+ use_scan (`bool`, *optional*, defaults to `False`):
184
+ Whether or not to use nn.scan in the Flax Wav2Vec2 transformer layers.
185
+
186
+ Example:
187
+
188
+ ```python
189
+ >>> from transformers import Wav2Vec2Model, Wav2Vec2Config
190
+
191
+ >>> # Initializing a Wav2Vec2 facebook/wav2vec2-base-960h style configuration
192
+ >>> configuration = Wav2Vec2Config()
193
+
194
+ >>> # Initializing a model from the facebook/wav2vec2-base-960h style configuration
195
+ >>> model = Wav2Vec2Model(configuration)
196
+
197
+ >>> # Accessing the model configuration
198
+ >>> configuration = model.config
199
+ ```"""
200
+ model_type = "wav2vec2"
201
+
202
+ def __init__(
203
+ self,
204
+ vocab_size=32,
205
+ hidden_size=768,
206
+ num_hidden_layers=12,
207
+ num_attention_heads=12,
208
+ intermediate_size=3072,
209
+ hidden_act="gelu",
210
+ hidden_dropout=0.1,
211
+ activation_dropout=0.1,
212
+ attention_dropout=0.1,
213
+ feat_proj_dropout=0.0,
214
+ feat_quantizer_dropout=0.0,
215
+ final_dropout=0.1,
216
+ layerdrop=0.1,
217
+ initializer_range=0.02,
218
+ layer_norm_eps=1e-5,
219
+ feat_extract_norm="group",
220
+ feat_extract_activation="gelu",
221
+ conv_dim=(512, 512, 512, 512, 512, 512, 512),
222
+ conv_stride=(5, 2, 2, 2, 2, 2, 2),
223
+ conv_kernel=(10, 3, 3, 3, 3, 2, 2),
224
+ conv_bias=False,
225
+ num_conv_pos_embeddings=128,
226
+ num_conv_pos_embedding_groups=16,
227
+ do_stable_layer_norm=False,
228
+ apply_spec_augment=True,
229
+ mask_time_prob=0.05,
230
+ mask_time_length=10,
231
+ mask_time_min_masks=2,
232
+ mask_feature_prob=0.0,
233
+ mask_feature_length=10,
234
+ mask_feature_min_masks=0,
235
+ num_codevectors_per_group=320,
236
+ num_codevector_groups=2,
237
+ contrastive_logits_temperature=0.1,
238
+ num_negatives=100,
239
+ codevector_dim=256,
240
+ proj_codevector_dim=256,
241
+ diversity_loss_weight=0.1,
242
+ ctc_loss_reduction="sum",
243
+ ctc_zero_infinity=False,
244
+ use_weighted_layer_sum=False,
245
+ classifier_proj_size=256,
246
+ tdnn_dim=(512, 512, 512, 512, 1500),
247
+ tdnn_kernel=(5, 3, 3, 1, 1),
248
+ tdnn_dilation=(1, 2, 3, 1, 1),
249
+ xvector_output_dim=512,
250
+ pad_token_id=0,
251
+ bos_token_id=1,
252
+ eos_token_id=2,
253
+ add_adapter=False,
254
+ adapter_kernel_size=3,
255
+ adapter_stride=2,
256
+ num_adapter_layers=3,
257
+ output_hidden_size=None,
258
+ use_scan=False,
259
+ fuse_matmuls=False,
260
+ **kwargs
261
+ ):
262
+ super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
263
+ self.hidden_size = hidden_size
264
+ self.feat_extract_norm = feat_extract_norm
265
+ self.feat_extract_activation = feat_extract_activation
266
+ self.conv_dim = list(conv_dim)
267
+ self.conv_stride = list(conv_stride)
268
+ self.conv_kernel = list(conv_kernel)
269
+ self.conv_bias = conv_bias
270
+ self.num_conv_pos_embeddings = num_conv_pos_embeddings
271
+ self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
272
+ self.num_feat_extract_layers = len(self.conv_dim)
273
+ self.num_hidden_layers = num_hidden_layers
274
+ self.intermediate_size = intermediate_size
275
+ self.hidden_act = hidden_act
276
+ self.num_attention_heads = num_attention_heads
277
+ self.hidden_dropout = hidden_dropout
278
+ self.attention_dropout = attention_dropout
279
+ self.activation_dropout = activation_dropout
280
+ self.feat_proj_dropout = feat_proj_dropout
281
+ self.final_dropout = final_dropout
282
+ self.layerdrop = layerdrop
283
+ self.layer_norm_eps = layer_norm_eps
284
+ self.initializer_range = initializer_range
285
+ self.vocab_size = vocab_size
286
+ self.do_stable_layer_norm = do_stable_layer_norm
287
+ self.use_weighted_layer_sum = use_weighted_layer_sum
288
+ self.use_scan = use_scan
289
+ self.fuse_matmuls = fuse_matmuls
290
+
291
+ if (
292
+ (len(self.conv_stride) != self.num_feat_extract_layers)
293
+ or (len(self.conv_kernel) != self.num_feat_extract_layers)
294
+ or (len(self.conv_dim) != self.num_feat_extract_layers)
295
+ ):
296
+ raise ValueError(
297
+ "Configuration for convolutional layers is incorrect. "
298
+ "It is required that `len(config.conv_dim)` == `len(config.conv_stride)` == `len(config.conv_kernel)`, "
299
+ f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride) "
300
+ f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`."
301
+ )
302
+
303
+ # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
304
+ self.apply_spec_augment = apply_spec_augment
305
+ self.mask_time_prob = mask_time_prob
306
+ self.mask_time_length = mask_time_length
307
+ self.mask_time_min_masks = mask_time_min_masks
308
+ self.mask_feature_prob = mask_feature_prob
309
+ self.mask_feature_length = mask_feature_length
310
+ self.mask_feature_min_masks = mask_feature_min_masks
311
+
312
+ # parameters for pretraining with codevector quantized representations
313
+ self.num_codevectors_per_group = num_codevectors_per_group
314
+ self.num_codevector_groups = num_codevector_groups
315
+ self.contrastive_logits_temperature = contrastive_logits_temperature
316
+ self.feat_quantizer_dropout = feat_quantizer_dropout
317
+ self.num_negatives = num_negatives
318
+ self.codevector_dim = codevector_dim
319
+ self.proj_codevector_dim = proj_codevector_dim
320
+ self.diversity_loss_weight = diversity_loss_weight
321
+
322
+ # ctc loss
323
+ self.ctc_loss_reduction = ctc_loss_reduction
324
+ self.ctc_zero_infinity = ctc_zero_infinity
325
+
326
+ # adapter
327
+ self.add_adapter = add_adapter
328
+ self.adapter_kernel_size = adapter_kernel_size
329
+ self.adapter_stride = adapter_stride
330
+ self.num_adapter_layers = num_adapter_layers
331
+ self.output_hidden_size = output_hidden_size or hidden_size
332
+
333
+ # SequenceClassification-specific parameter. Feel free to ignore for other classes.
334
+ self.classifier_proj_size = classifier_proj_size
335
+
336
+ # XVector-specific parameters. Feel free to ignore for other classes.
337
+ self.tdnn_dim = list(tdnn_dim)
338
+ self.tdnn_kernel = list(tdnn_kernel)
339
+ self.tdnn_dilation = list(tdnn_dilation)
340
+ self.xvector_output_dim = xvector_output_dim
341
+
342
+ @property
343
+ def inputs_to_logits_ratio(self):
344
+ return functools.reduce(operator.mul, self.conv_stride, 1)
models/modeling_flax_bart.py ADDED
@@ -0,0 +1,816 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Flax Bart model."""
16
+
17
+ import math
18
+ import random
19
+ from functools import partial
20
+ from typing import Optional, Tuple
21
+
22
+ import numpy as np
23
+
24
+ import flax.linen as nn
25
+ import jax
26
+ import jax.numpy as jnp
27
+ from flax.core.frozen_dict import FrozenDict, unfreeze
28
+ from flax.linen import combine_masks, make_causal_mask
29
+ from flax.linen import partitioning as nn_partitioning
30
+ from flax.linen.attention import dot_product_attention_weights
31
+ from jax import lax
32
+ from jax.random import PRNGKey
33
+
34
+ from transformers.modeling_flax_outputs import (
35
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
36
+ FlaxCausalLMOutputWithCrossAttentions,
37
+ )
38
+ from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
39
+
40
+ from models import BartConfig
41
+
42
+
43
+ scan_with_axes = nn_partitioning.scan_with_axes
44
+ remat = nn_partitioning.remat
45
+
46
+
47
+ def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
48
+ """
49
+ Shift input ids one token to the right.
50
+ """
51
+ shifted_input_ids = np.zeros_like(input_ids)
52
+ shifted_input_ids[:, 1:] = input_ids[:, :-1]
53
+ shifted_input_ids[:, 0] = decoder_start_token_id
54
+
55
+ shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
56
+ return shifted_input_ids
57
+
58
+
59
+ class FlaxBartAttention(nn.Module):
60
+ config: BartConfig
61
+ embed_dim: int
62
+ num_heads: int
63
+ dropout: float = 0.0
64
+ causal: bool = False
65
+ bias: bool = True
66
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
67
+
68
+ def setup(self) -> None:
69
+ self.head_dim = self.embed_dim // self.num_heads
70
+ if self.head_dim * self.num_heads != self.embed_dim:
71
+ raise ValueError(
72
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
73
+ f" and `num_heads`: {self.num_heads})."
74
+ )
75
+
76
+ dense = partial(
77
+ nn.Dense,
78
+ self.embed_dim,
79
+ use_bias=self.bias,
80
+ dtype=self.dtype,
81
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
82
+ )
83
+
84
+ self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
85
+
86
+ self.fused_proj = nn.Dense(
87
+ self.embed_dim * 3,
88
+ use_bias=self.bias,
89
+ dtype=self.dtype,
90
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
91
+ )
92
+
93
+ self.fused_key_value = nn.Dense(
94
+ self.embed_dim * 2,
95
+ use_bias=self.bias,
96
+ dtype=self.dtype,
97
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
98
+ )
99
+
100
+ self.out_proj = dense()
101
+
102
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
103
+
104
+ if self.causal:
105
+ self.causal_mask = make_causal_mask(
106
+ jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
107
+ )
108
+
109
+ def _split_heads(self, hidden_states):
110
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
111
+
112
+ def _merge_heads(self, hidden_states):
113
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
114
+
115
+ @nn.compact
116
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
117
+ """
118
+ This function takes projected key, value states from a single input token and concatenates the states to cached
119
+ states from previous steps. This function is slighly adapted from the official Flax repository:
120
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
121
+ """
122
+ # detect if we're initializing by absence of existing cache data.
123
+ is_initialized = self.has_variable("cache", "cached_key")
124
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
125
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
126
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
127
+
128
+ if is_initialized:
129
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
130
+ # update key, value caches with our new 1d spatial slices
131
+ cur_index = cache_index.value
132
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
133
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
134
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
135
+ cached_key.value = key
136
+ cached_value.value = value
137
+ num_updated_cache_vectors = query.shape[1]
138
+ cache_index.value = cache_index.value + num_updated_cache_vectors
139
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
140
+ pad_mask = jnp.broadcast_to(
141
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
142
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
143
+ )
144
+ attention_mask = combine_masks(pad_mask, attention_mask)
145
+ return key, value, attention_mask
146
+
147
+ def __call__(
148
+ self,
149
+ hidden_states: jnp.ndarray,
150
+ key_value_states: Optional[jnp.ndarray] = None,
151
+ attention_mask: Optional[jnp.ndarray] = None,
152
+ init_cache: bool = False,
153
+ deterministic: bool = True,
154
+ ) -> Tuple[jnp.ndarray]:
155
+ """Input shape: Batch x Time x Channel"""
156
+
157
+ # if key_value_states are provided this layer is used as a cross-attention layer
158
+ # for the decoder
159
+ is_cross_attention = key_value_states is not None
160
+ batch_size = hidden_states.shape[0]
161
+
162
+ if self.config.fuse_matmuls:
163
+ # get key, value proj
164
+ if is_cross_attention:
165
+ # get query proj
166
+ query_states = self.q_proj(hidden_states)
167
+ # cross_attentions
168
+ attention_states = self.fused_key_value(key_value_states)
169
+ key_states, value_states = jnp.split(attention_states, 2, axis=-1)
170
+ else:
171
+ attention_states = self.fused_proj(hidden_states)
172
+ query_states, key_states, value_states = jnp.split(attention_states, 3, axis=-1)
173
+
174
+ else:
175
+ # get query proj
176
+ query_states = self.q_proj(hidden_states)
177
+ # get key, value proj
178
+ if is_cross_attention:
179
+ # cross_attentions
180
+ key_states = self.k_proj(key_value_states)
181
+ value_states = self.v_proj(key_value_states)
182
+ else:
183
+ # self_attention
184
+ key_states = self.k_proj(hidden_states)
185
+ value_states = self.v_proj(hidden_states)
186
+
187
+ query_states = self._split_heads(query_states)
188
+ key_states = self._split_heads(key_states)
189
+ value_states = self._split_heads(value_states)
190
+
191
+ # handle cache prepare causal attention mask
192
+ if self.causal:
193
+ query_length, key_length = query_states.shape[1], key_states.shape[1]
194
+ if self.has_variable("cache", "cached_key"):
195
+ mask_shift = self.variables["cache"]["cache_index"]
196
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
197
+ causal_mask = lax.dynamic_slice(
198
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
199
+ )
200
+ else:
201
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
202
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
203
+
204
+ # combine masks if needed
205
+ if attention_mask is not None and self.causal:
206
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
207
+ attention_mask = combine_masks(attention_mask, causal_mask)
208
+ elif self.causal:
209
+ attention_mask = causal_mask
210
+ elif attention_mask is not None:
211
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
212
+
213
+ # During fast autoregressive decoding, we feed one position at a time,
214
+ # and cache the keys and values step by step.
215
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
216
+ key_states, value_states, attention_mask = self._concatenate_to_cache(
217
+ key_states, value_states, query_states, attention_mask
218
+ )
219
+
220
+ # Convert the boolean attention mask to an attention bias.
221
+ if attention_mask is not None:
222
+ # attention mask in the form of attention bias
223
+ attention_bias = lax.select(
224
+ attention_mask > 0,
225
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
226
+ jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
227
+ )
228
+ else:
229
+ attention_bias = None
230
+
231
+ dropout_rng = None
232
+ if not deterministic and self.dropout > 0.0:
233
+ dropout_rng = self.make_rng("dropout")
234
+
235
+ attn_weights = dot_product_attention_weights(
236
+ query_states,
237
+ key_states,
238
+ bias=attention_bias,
239
+ dropout_rng=dropout_rng,
240
+ dropout_rate=self.dropout,
241
+ broadcast_dropout=True,
242
+ deterministic=deterministic,
243
+ dtype=self.dtype,
244
+ precision=None,
245
+ )
246
+
247
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
248
+ attn_output = self._merge_heads(attn_output)
249
+ attn_output = self.out_proj(attn_output)
250
+
251
+ return attn_output, attn_weights
252
+
253
+
254
+ class FlaxBartDecoderLayer(nn.Module):
255
+ config: BartConfig
256
+ dtype: jnp.dtype = jnp.float32
257
+
258
+ def setup(self) -> None:
259
+ self.embed_dim = self.config.d_model
260
+ self.self_attn = FlaxBartAttention(
261
+ config=self.config,
262
+ embed_dim=self.embed_dim,
263
+ num_heads=self.config.decoder_attention_heads,
264
+ dropout=self.config.attention_dropout,
265
+ causal=True,
266
+ dtype=self.dtype,
267
+ )
268
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
269
+ self.activation_fn = ACT2FN[self.config.activation_function]
270
+ self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
271
+
272
+ self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
273
+ self.encoder_attn = FlaxBartAttention(
274
+ config=self.config,
275
+ embed_dim=self.embed_dim,
276
+ num_heads=self.config.decoder_attention_heads,
277
+ dropout=self.config.attention_dropout,
278
+ dtype=self.dtype,
279
+ )
280
+ self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
281
+ self.fc1 = nn.Dense(
282
+ self.config.encoder_ffn_dim,
283
+ dtype=self.dtype,
284
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
285
+ )
286
+ self.fc2 = nn.Dense(
287
+ self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
288
+ )
289
+ self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
290
+
291
+ def __call__(
292
+ self,
293
+ hidden_states: jnp.ndarray,
294
+ attention_mask: jnp.ndarray,
295
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
296
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
297
+ init_cache: bool = False,
298
+ output_attentions: bool = True,
299
+ deterministic: bool = True,
300
+ ) -> Tuple[jnp.ndarray]:
301
+
302
+ if self.config.use_scan:
303
+ hidden_states = hidden_states[0]
304
+
305
+ residual = hidden_states
306
+
307
+ # Self Attention
308
+ hidden_states, self_attn_weights = self.self_attn(
309
+ hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache
310
+ )
311
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
312
+ hidden_states = residual + hidden_states
313
+ hidden_states = self.self_attn_layer_norm(hidden_states)
314
+
315
+ # Cross-Attention Block
316
+ cross_attn_weights = None
317
+ if encoder_hidden_states is not None:
318
+ residual = hidden_states
319
+
320
+ hidden_states, cross_attn_weights = self.encoder_attn(
321
+ hidden_states=hidden_states,
322
+ key_value_states=encoder_hidden_states,
323
+ attention_mask=encoder_attention_mask,
324
+ )
325
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
326
+ hidden_states = residual + hidden_states
327
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
328
+
329
+ # Fully Connected
330
+ residual = hidden_states
331
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
332
+ hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
333
+ hidden_states = self.fc2(hidden_states)
334
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
335
+ hidden_states = residual + hidden_states
336
+ hidden_states = self.final_layer_norm(hidden_states)
337
+
338
+ outputs = (hidden_states,)
339
+
340
+ if output_attentions:
341
+ outputs += (self_attn_weights, cross_attn_weights)
342
+
343
+ if self.config.use_scan:
344
+ outputs = (outputs, None)
345
+
346
+ return outputs
347
+
348
+
349
+ class FlaxBartDecoderLayerCollection(nn.Module):
350
+ config: BartConfig
351
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
352
+
353
+ @nn.compact
354
+ def __call__(
355
+ self,
356
+ hidden_states,
357
+ attention_mask,
358
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
359
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
360
+ deterministic: bool = True,
361
+ init_cache: bool = False,
362
+ output_attentions: bool = False,
363
+ output_hidden_states: bool = False,
364
+ return_dict: bool = True,
365
+ ):
366
+ # decoder layers
367
+ all_hidden_states = () if output_hidden_states else None
368
+ all_self_attns = () if output_attentions else None
369
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
370
+
371
+ num_decoder_layers = self.config.decoder_layers
372
+ BlockDecoderLayer = (
373
+ remat(
374
+ FlaxBartDecoderLayer,
375
+ static_argnums=(4, 5, 6),
376
+ prevent_cse=not self.config.use_scan,
377
+ )
378
+ if self.config.gradient_checkpointing
379
+ else FlaxBartDecoderLayer
380
+ )
381
+
382
+ if self.config.use_scan:
383
+ # since all decoder layers are the same, we use nn.scan directly
384
+ assert not output_attentions, "cannot use `scan` with `output_attentions` set to `True`"
385
+ assert not output_hidden_states, "cannot use `scan` with `output_hidden_states` set to `True`"
386
+ hidden_states = (hidden_states,)
387
+
388
+ # TODO: add layerdrop in checkpointed scan (note: default value for layerdrop in config is zero)
389
+ hidden_states, _ = scan_with_axes(
390
+ BlockDecoderLayer,
391
+ variable_axes={"params": 0, "cache": 0},
392
+ split_rngs={"params": True, "dropout": True},
393
+ in_axes=(nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast),
394
+ length=num_decoder_layers,
395
+ )(self.config, dtype=self.dtype, name="FlaxBartDecoderLayers")(
396
+ hidden_states,
397
+ attention_mask,
398
+ encoder_hidden_states,
399
+ encoder_attention_mask,
400
+ init_cache,
401
+ output_attentions,
402
+ deterministic,
403
+ )
404
+ hidden_states = hidden_states[0]
405
+
406
+ else:
407
+ for layer in range(num_decoder_layers):
408
+ if output_hidden_states:
409
+ all_hidden_states += (hidden_states,)
410
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
411
+ dropout_probability = random.uniform(0, 1)
412
+ if not deterministic and (dropout_probability < self.config.decoder_layerdrop):
413
+ layer_outputs = (None, None, None)
414
+ else:
415
+ layer_outputs = BlockDecoderLayer(self.config, dtype=self.dtype, name=str(layer),)(
416
+ hidden_states,
417
+ attention_mask,
418
+ encoder_hidden_states,
419
+ encoder_attention_mask,
420
+ init_cache,
421
+ output_attentions,
422
+ deterministic,
423
+ )
424
+
425
+ hidden_states = layer_outputs[0]
426
+ if output_attentions:
427
+ all_self_attns += (layer_outputs[1],)
428
+
429
+ if encoder_hidden_states is not None:
430
+ all_cross_attentions += (layer_outputs[2],)
431
+
432
+ # add hidden states from the last decoder layer
433
+ if output_hidden_states:
434
+ all_hidden_states += (hidden_states,)
435
+
436
+ outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions]
437
+
438
+ if not return_dict:
439
+ return tuple(v for v in outputs if v is not None)
440
+
441
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
442
+ last_hidden_state=hidden_states,
443
+ hidden_states=all_hidden_states,
444
+ attentions=all_self_attns,
445
+ cross_attentions=all_cross_attentions,
446
+ )
447
+
448
+
449
+ class FlaxBartDecoder(nn.Module):
450
+ config: BartConfig
451
+ embed_tokens: nn.Embed
452
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
453
+
454
+ def setup(self):
455
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
456
+
457
+ embed_dim = self.config.d_model
458
+ self.padding_idx = self.config.pad_token_id
459
+ self.max_target_positions = self.config.max_position_embeddings
460
+ self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
461
+
462
+ # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
463
+ # and adjust num_embeddings appropriately. Other models don't have this hack
464
+ self.offset = 2
465
+ self.embed_positions = nn.Embed(
466
+ self.config.max_position_embeddings + self.offset,
467
+ embed_dim,
468
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
469
+ )
470
+
471
+ self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
472
+ self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
473
+
474
+ def __call__(
475
+ self,
476
+ input_ids,
477
+ attention_mask,
478
+ position_ids,
479
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
480
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
481
+ init_cache: bool = False,
482
+ output_attentions: bool = False,
483
+ output_hidden_states: bool = False,
484
+ return_dict: bool = True,
485
+ deterministic: bool = True,
486
+ ):
487
+ input_shape = input_ids.shape
488
+ input_ids = input_ids.reshape(-1, input_shape[-1])
489
+
490
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
491
+
492
+ # embed positions
493
+ positions = self.embed_positions(position_ids + self.offset)
494
+
495
+ hidden_states = inputs_embeds + positions
496
+ hidden_states = self.layernorm_embedding(hidden_states)
497
+
498
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
499
+
500
+ outputs = self.layers(
501
+ hidden_states,
502
+ attention_mask,
503
+ encoder_hidden_states,
504
+ encoder_attention_mask,
505
+ deterministic=deterministic,
506
+ init_cache=init_cache,
507
+ output_attentions=output_attentions,
508
+ output_hidden_states=output_hidden_states,
509
+ return_dict=return_dict,
510
+ )
511
+
512
+ if not return_dict:
513
+ return outputs
514
+
515
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
516
+ last_hidden_state=outputs.last_hidden_state,
517
+ hidden_states=outputs.hidden_states,
518
+ attentions=outputs.attentions,
519
+ cross_attentions=outputs.cross_attentions,
520
+ )
521
+
522
+
523
+ class FlaxBartDecoderPreTrainedModel(FlaxPreTrainedModel):
524
+ config_class = BartConfig
525
+ base_model_prefix: str = "model"
526
+ module_class: nn.Module = None
527
+
528
+ def __init__(
529
+ self,
530
+ config: BartConfig,
531
+ input_shape: Tuple[int] = (1, 1),
532
+ seed: int = 0,
533
+ dtype: jnp.dtype = jnp.float32,
534
+ _do_init: bool = True,
535
+ **kwargs
536
+ ):
537
+ config.is_decoder = True
538
+ config.is_encoder_decoder = False
539
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
540
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
541
+
542
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
543
+ # init input tensors
544
+ input_ids = jnp.zeros(input_shape, dtype="i4")
545
+ attention_mask = jnp.ones_like(input_ids)
546
+
547
+ batch_size, sequence_length = input_ids.shape
548
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
549
+
550
+ params_rng, dropout_rng = jax.random.split(rng)
551
+ rngs = {"params": params_rng, "dropout": dropout_rng}
552
+ encoder_hidden_states = jnp.zeros(input_shape + (self.config.d_model,))
553
+ encoder_attention_mask = attention_mask
554
+ module_init_outputs = self.module.init(
555
+ rngs,
556
+ input_ids,
557
+ attention_mask,
558
+ position_ids,
559
+ encoder_hidden_states,
560
+ encoder_attention_mask,
561
+ return_dict=False,
562
+ )
563
+ return module_init_outputs["params"]
564
+
565
+ def init_cache(self, batch_size, max_length):
566
+ r"""
567
+ Args:
568
+ batch_size (`int`):
569
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
570
+ max_length (`int`):
571
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
572
+ cache.
573
+ """
574
+ # init input variables to retrieve cache
575
+ input_ids = jnp.ones((batch_size, max_length), dtype="i4")
576
+ attention_mask = jnp.ones_like(input_ids, dtype="i4")
577
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
578
+
579
+ init_variables = self.module.init(
580
+ jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
581
+ )
582
+ return unfreeze(init_variables["cache"])
583
+
584
+ def __call__(
585
+ self,
586
+ input_ids: jnp.ndarray,
587
+ attention_mask: Optional[jnp.ndarray] = None,
588
+ position_ids: Optional[jnp.ndarray] = None,
589
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
590
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
591
+ output_attentions: Optional[bool] = None,
592
+ output_hidden_states: Optional[bool] = None,
593
+ return_dict: Optional[bool] = None,
594
+ train: bool = False,
595
+ params: dict = None,
596
+ past_key_values: dict = None,
597
+ dropout_rng: PRNGKey = None,
598
+ ):
599
+ """
600
+ Args:
601
+ input_ids (`jnp.ndarray` of shape `(target_batch_size, target_sequence_length)`):
602
+ Indices of decoder input sequence tokens in the vocabulary.
603
+
604
+ Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
605
+ [`PreTrainedTokenizer.__call__`] for details.
606
+
607
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
608
+
609
+ For translation and summarization training, `decoder_input_ids` should be provided. If no
610
+ `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
611
+ for denoising pre-training following the paper.
612
+ attention_mask (`jnp.ndarray` of shape `(target_batch_size, target_sequence_length)`, *optional*):
613
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
614
+ be used by default.
615
+
616
+ If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the
617
+ paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
618
+ position_ids (`numpy.ndarray` of shape `(target_batch_size, sequence_length)`, *optional*):
619
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
620
+ range `[0, config.max_position_embeddings - 1]`.
621
+ encoder_hidden_states (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
622
+ A sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
623
+ encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
624
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
625
+
626
+ - 1 for tokens that are **not masked**,
627
+ - 0 for tokens that are **masked**.
628
+
629
+ [What are attention masks?](../glossary#attention-mask)
630
+ output_attentions (`bool`, *optional*):
631
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
632
+ tensors for more detail.
633
+ output_hidden_states (`bool`, *optional*):
634
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
635
+ more detail.
636
+ return_dict (`bool`, *optional*):
637
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
638
+ past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
639
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
640
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
641
+ """
642
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
643
+ output_hidden_states = (
644
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
645
+ )
646
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
647
+
648
+ if encoder_hidden_states is not None and encoder_attention_mask is None:
649
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
650
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
651
+
652
+ # prepare decoder inputs
653
+ if attention_mask is None:
654
+ attention_mask = jnp.ones_like(input_ids)
655
+ if position_ids is None:
656
+ batch_size, sequence_length = input_ids.shape
657
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
658
+
659
+ # Handle any PRNG if needed
660
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
661
+
662
+ inputs = {"params": params or self.params}
663
+
664
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
665
+ # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
666
+ # changed by FlaxBartAttention module
667
+ if past_key_values:
668
+ inputs["cache"] = past_key_values
669
+ mutable = ["cache"]
670
+ else:
671
+ mutable = False
672
+
673
+ outputs = self.module.apply(
674
+ inputs,
675
+ input_ids=jnp.array(input_ids, dtype="i4"),
676
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
677
+ position_ids=jnp.array(position_ids, dtype="i4"),
678
+ encoder_hidden_states=encoder_hidden_states,
679
+ encoder_attention_mask=encoder_attention_mask,
680
+ output_attentions=output_attentions,
681
+ output_hidden_states=output_hidden_states,
682
+ return_dict=return_dict,
683
+ deterministic=not train,
684
+ rngs=rngs,
685
+ mutable=mutable,
686
+ )
687
+
688
+ # add updated cache to model output
689
+ if past_key_values is not None and return_dict:
690
+ outputs, past_key_values = outputs
691
+ outputs["past_key_values"] = unfreeze(past_key_values["cache"])
692
+ return outputs
693
+ elif past_key_values is not None and not return_dict:
694
+ outputs, past_key_values = outputs
695
+ outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
696
+
697
+ return outputs
698
+
699
+
700
+ class FlaxBartDecoderWrapper(nn.Module):
701
+ """
702
+ This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
703
+ used in combination with the [`EncoderDecoderModel`] framework.
704
+ """
705
+
706
+ config: BartConfig
707
+ dtype: jnp.dtype = jnp.float32
708
+
709
+ def setup(self):
710
+ embed_dim = self.config.d_model
711
+ embed_tokens = nn.Embed(
712
+ self.config.vocab_size,
713
+ embed_dim,
714
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
715
+ )
716
+ self.decoder = FlaxBartDecoder(config=self.config, embed_tokens=embed_tokens, dtype=self.dtype)
717
+
718
+ def __call__(self, *args, **kwargs):
719
+ return self.decoder(*args, **kwargs)
720
+
721
+
722
+ class FlaxBartForCausalLMModule(nn.Module):
723
+ """Bart Decoder Module with a language modeling head on top (linear layer with weights tied to the input embeddings)
724
+ e.g. for autoregressive tasks.
725
+ """
726
+
727
+ config: BartConfig
728
+ dtype: jnp.dtype = jnp.float32
729
+
730
+ def setup(self):
731
+ self.model = FlaxBartDecoderWrapper(config=self.config, dtype=self.dtype)
732
+ self.lm_head = nn.Dense(
733
+ self.config.vocab_size,
734
+ use_bias=False,
735
+ dtype=self.dtype,
736
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
737
+ )
738
+
739
+ def __call__(
740
+ self,
741
+ input_ids,
742
+ attention_mask,
743
+ position_ids,
744
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
745
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
746
+ init_cache: bool = False,
747
+ output_attentions: bool = False,
748
+ output_hidden_states: bool = False,
749
+ return_dict: bool = True,
750
+ deterministic: bool = True,
751
+ ):
752
+
753
+ outputs = self.model(
754
+ input_ids,
755
+ attention_mask,
756
+ position_ids,
757
+ encoder_hidden_states,
758
+ encoder_attention_mask,
759
+ deterministic=deterministic,
760
+ init_cache=init_cache,
761
+ output_attentions=output_attentions,
762
+ output_hidden_states=output_hidden_states,
763
+ return_dict=return_dict,
764
+ )
765
+
766
+ hidden_states = outputs[0]
767
+
768
+ if self.config.tie_word_embeddings:
769
+ shared_embedding = self.model.variables["params"]["decoder"]["embed_tokens"]["embedding"]
770
+ lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
771
+ else:
772
+ lm_logits = self.lm_head(hidden_states)
773
+
774
+ if not return_dict:
775
+ return (lm_logits,) + outputs[1:]
776
+
777
+ return FlaxCausalLMOutputWithCrossAttentions(
778
+ logits=lm_logits,
779
+ hidden_states=outputs.hidden_states,
780
+ attentions=outputs.attentions,
781
+ cross_attentions=outputs.cross_attentions,
782
+ )
783
+
784
+
785
+ class FlaxBartForCausalLM(FlaxBartDecoderPreTrainedModel):
786
+ """Bart Decoder Model with a language modeling head on top (linear layer with weights tied to the input embeddings)
787
+ e.g. for autoregressive tasks.
788
+ """
789
+
790
+ module_class = FlaxBartForCausalLMModule
791
+
792
+ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
793
+ # initializing the cache
794
+ batch_size, seq_length = input_ids.shape
795
+
796
+ past_key_values = self.init_cache(batch_size, max_length)
797
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
798
+ # But since the decoder uses a causal mask, those positions are masked anyway.
799
+ # Thus, we can create a single static attention_mask here, which is more efficient for compilation
800
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
801
+ if attention_mask is not None:
802
+ position_ids = attention_mask.cumsum(axis=-1) - 1
803
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
804
+ else:
805
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
806
+
807
+ return {
808
+ "past_key_values": past_key_values,
809
+ "attention_mask": extended_attention_mask,
810
+ "position_ids": position_ids,
811
+ }
812
+
813
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
814
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
815
+ model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
816
+ return model_kwargs
models/modeling_flax_speech_encoder_decoder.py ADDED
@@ -0,0 +1,1245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Classes to support Flax Speech-Encoder-Decoder architectures"""
16
+
17
+ import os
18
+ from functools import partial
19
+ from typing import Optional, Tuple, Union, Dict
20
+
21
+ import flax
22
+ import flax.linen as nn
23
+ import jax
24
+ import jax.numpy as jnp
25
+ from flax.core.frozen_dict import FrozenDict, unfreeze
26
+ from jax import lax
27
+ from jax.random import PRNGKey
28
+ import numpy as np
29
+
30
+ from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput
31
+ from transformers.modeling_flax_utils import FlaxPreTrainedModel
32
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, ModelOutput
33
+ from transformers.generation_flax_utils import FlaxLogitsProcessorList
34
+ from models import (
35
+ FlaxWav2Vec2Model,
36
+ FlaxWav2Vec2Module,
37
+ FlaxBartForCausalLM,
38
+ FlaxBartForCausalLMModule,
39
+ BartConfig,
40
+ Wav2Vec2Config,
41
+ SpeechEncoderDecoderConfig,
42
+ )
43
+
44
+ logger = logging.get_logger(__name__)
45
+
46
+ _CONFIG_FOR_DOC = "SpeechEncoderDecoderConfig"
47
+
48
+ SPEECH_ENCODER_DECODER_START_DOCSTRING = r"""
49
+ This class can be used to initialize a speech-sequence-to-text-sequence model with any pretrained speech
50
+ autoencoding model as the encoder and any pretrained text autoregressive model as the decoder. The encoder is
51
+ loaded via [`~AutoModel.from_pretrained`] function and the decoder is loaded via
52
+ [`~AutoModelForCausalLM.from_pretrained`] function. Cross-attention layers are automatically added to the decoder
53
+ and should be fine-tuned on a downstream generative task, like summarization.
54
+
55
+ The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation
56
+ tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation
57
+ Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi
58
+ Zhou, Wei Li, Peter J. Liu.
59
+
60
+ Additionally, in [Large-Scale Self- and Semi-Supervised Learning for Speech
61
+ Translation](https://arxiv.org/abs/2104.06678) it is shown how leveraging large pretrained speech models for speech
62
+ translation yields a significant performance improvement.
63
+
64
+ After such an Speech-Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other
65
+ models (see the examples for more information).
66
+
67
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
68
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
69
+ etc.)
70
+
71
+ This model is also a Flax Linen
72
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
73
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
74
+
75
+ Parameters:
76
+ config ([`SpeechEncoderDecoderConfig`]): Model configuration class with all the parameters of the model.
77
+ Initializing with a config file does not load the weights associated with the model, only the
78
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
79
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
80
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
81
+ `jax.numpy.bfloat16` (on TPUs).
82
+
83
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
84
+ specified all the computation will be performed with the given `dtype`.
85
+
86
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
87
+ parameters.**
88
+
89
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
90
+ [`~FlaxPreTrainedModel.to_bf16`].
91
+ """
92
+
93
+ SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
94
+ Args:
95
+ inputs (`jnp.ndarray` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*):
96
+ Float values of input raw speech waveform or speech features. Values can be obtained by loading a *.flac*
97
+ or *.wav* audio file into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile
98
+ library (*pip install soundfile*). To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or
99
+ [`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type
100
+ *torch.FloatTensor*.
101
+ attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
102
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
103
+
104
+ - 1 for tokens that are **not masked**,
105
+ - 0 for tokens that are **masked**.
106
+
107
+ [What are attention masks?](../glossary#attention-mask)
108
+ decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
109
+ Indices of decoder input sequence tokens in the vocabulary.
110
+
111
+ Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
112
+ [`PreTrainedTokenizer.__call__`] for details.
113
+
114
+ [What are input IDs?](../glossary#input-ids)
115
+
116
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
117
+ `past_key_values`).
118
+
119
+ For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
120
+ created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
121
+ and prepending them with the `decoder_start_token_id`.
122
+ decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
123
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
124
+ be used by default.
125
+ decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
126
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
127
+ range `[0, config.decoder.max_position_embeddings - 1]`.
128
+ output_hidden_states (`bool`, *optional*):
129
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
130
+ more detail.
131
+ return_dict (`bool`, *optional*):
132
+ If set to `True`, the model will return a [`~utils.FlaxSeq2SeqLMOutput`] instead of a plain tuple.
133
+ """
134
+
135
+ SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING = r"""
136
+ Args:
137
+ inputs (`jnp.ndarray` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*):
138
+ Float values of input raw speech waveform or speech features. Values can be obtained by loading a *.flac*
139
+ or *.wav* audio file into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile
140
+ library (*pip install soundfile*). To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or
141
+ [`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type
142
+ *torch.FloatTensor*.
143
+ attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
144
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
145
+
146
+ - 1 for tokens that are **not masked**,
147
+ - 0 for tokens that are **masked**.
148
+
149
+ [What are attention masks?](../glossary#attention-mask)
150
+ output_attentions (`bool`, *optional*):
151
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
152
+ tensors for more detail.
153
+ output_hidden_states (`bool`, *optional*):
154
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
155
+ more detail.
156
+ return_dict (`bool`, *optional*):
157
+ If set to `True`, the model will return a [`~utils.FlaxBaseModelOutput`] instead of a plain tuple.
158
+ """
159
+
160
+ SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r"""
161
+ Args:
162
+ decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
163
+ Indices of decoder input sequence tokens in the vocabulary.
164
+
165
+ Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
166
+ [`PreTrainedTokenizer.__call__`] for details.
167
+
168
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
169
+
170
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
171
+ `past_key_values`).
172
+
173
+ For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
174
+ created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
175
+ and prepending them with the `decoder_start_token_id`.
176
+ encoder_outputs (`tuple(tuple(jnp.ndarray)`):
177
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
178
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
179
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
180
+ encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
181
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
182
+
183
+ - 1 for tokens that are **not masked**,
184
+ - 0 for tokens that are **masked**.
185
+
186
+ [What are attention masks?](../glossary#attention-mask)
187
+ decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
188
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
189
+ be used by default.
190
+ decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
191
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
192
+ range `[0, config.decoder.max_position_embeddings - 1]`.
193
+ past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
194
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
195
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
196
+ output_attentions (`bool`, *optional*):
197
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
198
+ tensors for more detail.
199
+ output_hidden_states (`bool`, *optional*):
200
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
201
+ more detail.
202
+ return_dict (`bool`, *optional*):
203
+ If set to `True`, the model will return a [`~utils.FlaxCausalLMOutputWithCrossAttentions`] instead of a
204
+ plain tuple.
205
+ """
206
+
207
+ @flax.struct.dataclass
208
+ class FlaxBeamSearchOutput(ModelOutput):
209
+ """
210
+ Flax Base class for outputs of decoder-only generation models using greedy search.
211
+
212
+
213
+ Args:
214
+ sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
215
+ The generated sequences.
216
+ scores (`jnp.ndarray` of shape `(batch_size,)`):
217
+ The scores (log probabilites) of the generated sequences.
218
+ """
219
+
220
+ sequences: jnp.ndarray = None
221
+ scores: jnp.ndarray = None
222
+
223
+
224
+ @flax.struct.dataclass
225
+ class BeamSearchState:
226
+ cur_len: jnp.ndarray
227
+ running_sequences: jnp.ndarray
228
+ running_scores: jnp.ndarray
229
+ sequences: jnp.ndarray
230
+ scores: jnp.ndarray
231
+ is_sent_finished: jnp.ndarray
232
+ model_kwargs: Dict[str, jnp.ndarray]
233
+
234
+
235
+
236
+
237
+ class FlaxSpeechEncoderDecoderModule(nn.Module):
238
+ config: SpeechEncoderDecoderConfig
239
+ dtype: jnp.dtype = jnp.float32
240
+
241
+ def setup(self):
242
+ encoder_config = self.config.encoder
243
+ decoder_config = self.config.decoder
244
+
245
+ # TODO: configure FlaxAutoModel mappings (required when trialling different encoder-decoder combinations)
246
+ encoder_module = FlaxWav2Vec2Module
247
+ decoder_module = FlaxBartForCausalLMModule
248
+
249
+ self.encoder = encoder_module(encoder_config, dtype=self.dtype)
250
+ self.decoder = decoder_module(decoder_config, dtype=self.dtype)
251
+
252
+ # encoder outputs might need to be projected to different dimension for decoder
253
+ if (
254
+ self.encoder.config.hidden_size != self.decoder.config.hidden_size
255
+ and self.decoder.config.cross_attention_hidden_size is None
256
+ ):
257
+ self.enc_to_dec_proj = nn.Dense(
258
+ self.decoder.config.hidden_size,
259
+ kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range),
260
+ dtype=self.dtype,
261
+ )
262
+ else:
263
+ self.enc_to_dec_proj = None
264
+
265
+ def _get_feat_extract_output_lengths(
266
+ self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
267
+ ):
268
+ """
269
+ Computes the output length of the convolutional layers
270
+ """
271
+
272
+ add_adapter = self.config.encoder.add_adapter if add_adapter is None else add_adapter
273
+
274
+ def _conv_out_length(input_length, kernel_size, stride):
275
+ # 1D convolutional layer output length formula taken
276
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
277
+ return (input_length - kernel_size) // stride + 1
278
+
279
+ for kernel_size, stride in zip(self.config.encoder.conv_kernel, self.config.encoder.conv_stride):
280
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
281
+
282
+ if add_adapter:
283
+ for _ in range(self.config.encoder.num_adapter_layers):
284
+ input_lengths = _conv_out_length(input_lengths, 1, self.config.encoder.adapter_stride)
285
+
286
+ return input_lengths
287
+
288
+ def _get_encoder_module(self):
289
+ return self.encoder
290
+
291
+ def _get_projection_module(self):
292
+ return self.enc_to_dec_proj
293
+
294
+ def _get_decoder_module(self):
295
+ return self.decoder
296
+
297
+ def __call__(
298
+ self,
299
+ inputs,
300
+ attention_mask,
301
+ decoder_input_ids,
302
+ decoder_attention_mask,
303
+ decoder_position_ids,
304
+ encoder_outputs=None,
305
+ extract_features=None,
306
+ output_attentions: bool = False,
307
+ output_hidden_states: bool = False,
308
+ output_features: bool = False,
309
+ return_dict: bool = True,
310
+ deterministic: bool = True,
311
+ freeze_feature_encoder: bool = False,
312
+ ):
313
+ if encoder_outputs is None:
314
+ encoder_outputs = self.encoder(
315
+ inputs,
316
+ attention_mask=attention_mask,
317
+ extract_features=extract_features,
318
+ output_attentions=output_attentions,
319
+ output_hidden_states=output_hidden_states,
320
+ output_features=output_features,
321
+ return_dict=return_dict,
322
+ deterministic=deterministic,
323
+ freeze_feature_encoder=freeze_feature_encoder,
324
+ )
325
+
326
+ if output_features:
327
+ return encoder_outputs
328
+
329
+ encoder_hidden_states = encoder_outputs[0]
330
+
331
+ # optionally project encoder_hidden_states
332
+ if self.enc_to_dec_proj is not None:
333
+ encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
334
+
335
+ # compute correct encoder attention mask
336
+ if attention_mask is not None:
337
+ encoder_attention_mask = self.encoder._get_feature_vector_attention_mask(
338
+ encoder_hidden_states.shape[1], attention_mask
339
+ )
340
+ else:
341
+ encoder_attention_mask = None
342
+
343
+ # flax script modeling_flax_wav2vec2.py
344
+ decoder_outputs = self.decoder(
345
+ input_ids=decoder_input_ids,
346
+ attention_mask=decoder_attention_mask,
347
+ position_ids=decoder_position_ids,
348
+ encoder_hidden_states=encoder_hidden_states,
349
+ encoder_attention_mask=encoder_attention_mask,
350
+ output_attentions=output_attentions,
351
+ output_hidden_states=output_hidden_states,
352
+ return_dict=return_dict,
353
+ deterministic=deterministic,
354
+ )
355
+
356
+ if not return_dict:
357
+ return decoder_outputs + encoder_outputs
358
+
359
+ return FlaxSeq2SeqLMOutput(
360
+ logits=decoder_outputs.logits,
361
+ decoder_hidden_states=decoder_outputs.hidden_states,
362
+ decoder_attentions=decoder_outputs.attentions,
363
+ cross_attentions=decoder_outputs.cross_attentions,
364
+ encoder_last_hidden_state=encoder_hidden_states,
365
+ encoder_hidden_states=encoder_outputs.hidden_states,
366
+ encoder_attentions=encoder_outputs.attentions,
367
+ )
368
+
369
+
370
+ @add_start_docstrings(SPEECH_ENCODER_DECODER_START_DOCSTRING)
371
+ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
372
+ r"""
373
+ [`FlaxSpeechEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture
374
+ with the module (flax.nn.Module) of one of the base model classes of the library as encoder module and another one
375
+ as decoder module when created with the :meth*~transformers.FlaxAutoModel.from_pretrained* class method for the
376
+ encoder and :meth*~transformers.FlaxAutoModelForCausalLM.from_pretrained* class method for the decoder.
377
+ """
378
+
379
+ config_class = SpeechEncoderDecoderConfig
380
+ base_model_prefix: str = "speech_encoder_decoder"
381
+ module_class = FlaxSpeechEncoderDecoderModule
382
+
383
+ def __init__(
384
+ self,
385
+ config: SpeechEncoderDecoderConfig,
386
+ input_shape: Optional[Tuple] = None,
387
+ seed: int = 0,
388
+ dtype: jnp.dtype = jnp.float32,
389
+ _do_init: bool = True,
390
+ **kwargs
391
+ ):
392
+
393
+ if not _do_init:
394
+ raise ValueError(
395
+ "`FlaxSpeechEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`."
396
+ )
397
+
398
+ if config.decoder.cross_attention_hidden_size is not None:
399
+ # Raise ValueError or option to project enc to dec hidden_size (eg EncAdapterLayer)
400
+ if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
401
+ raise ValueError(
402
+ "If `cross_attention_hidden_size` is specified in the decoder's configuration, "
403
+ "it has to be equal to the encoder's `hidden_size`. "
404
+ f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
405
+ f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
406
+ )
407
+
408
+ # make sure input & output embeddings are not tied
409
+ config.tie_word_embeddings = False
410
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
411
+
412
+ if input_shape is None:
413
+ # speech encoders almost always downsample the sequence length dimension
414
+ encoder_input_length = 1024
415
+ decoder_input_length = module._get_feat_extract_output_lengths(encoder_input_length)
416
+ input_shape = ((1, encoder_input_length), (1, decoder_input_length))
417
+
418
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
419
+
420
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
421
+ encoder_input_shape, decoder_input_shape = input_shape
422
+
423
+ # init input DeviceArrays
424
+ inputs = jnp.zeros(encoder_input_shape, dtype="f4")
425
+ attention_mask = jnp.ones_like(inputs, dtype="i4")
426
+ decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4")
427
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
428
+
429
+ batch_size, sequence_length = inputs.shape
430
+
431
+ decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape
432
+ if not decoder_batch_size == batch_size:
433
+ raise ValueError(
434
+ f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder and {decoder_batch_size} for decoder."
435
+ )
436
+ decoder_position_ids = jnp.broadcast_to(
437
+ jnp.arange(decoder_sequence_length)[None, :], (decoder_batch_size, decoder_sequence_length)
438
+ )
439
+
440
+ params_rng, dropout_rng = jax.random.split(rng)
441
+ rngs = {"params": params_rng, "dropout": dropout_rng}
442
+
443
+ return self.module.init(
444
+ rngs,
445
+ inputs,
446
+ attention_mask,
447
+ decoder_input_ids,
448
+ decoder_attention_mask,
449
+ decoder_position_ids,
450
+ )["params"]
451
+
452
+ def init_cache(self, batch_size, max_length, encoder_outputs):
453
+ r"""
454
+ Args:
455
+ batch_size (`int`):
456
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
457
+ max_length (`int`):
458
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
459
+ cache.
460
+ encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
461
+ `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
462
+ `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
463
+ is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
464
+ cross-attention of the decoder.
465
+ """
466
+ # init input variables to retrieve cache
467
+ decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
468
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
469
+ decoder_position_ids = jnp.broadcast_to(
470
+ jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
471
+ )
472
+
473
+ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
474
+ decoder_module = module._get_decoder_module()
475
+ return decoder_module(
476
+ input_ids=decoder_input_ids,
477
+ attention_mask=decoder_attention_mask,
478
+ position_ids=decoder_position_ids,
479
+ **kwargs,
480
+ )
481
+
482
+ init_variables = self.module.init(
483
+ jax.random.PRNGKey(0),
484
+ decoder_input_ids=decoder_input_ids,
485
+ decoder_attention_mask=decoder_attention_mask,
486
+ decoder_position_ids=decoder_position_ids,
487
+ encoder_hidden_states=encoder_outputs[0],
488
+ init_cache=True,
489
+ method=_decoder_forward, # we only need to call the decoder to init the cache
490
+ )
491
+ return unfreeze(init_variables["cache"])
492
+
493
+ def _get_feat_extract_output_lengths(
494
+ self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
495
+ ):
496
+ return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter)
497
+
498
+ @add_start_docstrings(SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING)
499
+ @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=_CONFIG_FOR_DOC)
500
+ def encode(
501
+ self,
502
+ inputs: jnp.ndarray,
503
+ attention_mask: Optional[jnp.ndarray] = None,
504
+ extract_features: Optional[jnp.ndarray] = None,
505
+ output_attentions: Optional[bool] = None,
506
+ output_hidden_states: Optional[bool] = None,
507
+ output_features: Optional[bool] = None,
508
+ return_dict: Optional[bool] = None,
509
+ train: bool = False,
510
+ freeze_feature_encoder: bool = False,
511
+ params: dict = None,
512
+ dropout_rng: PRNGKey = None,
513
+ ):
514
+ r"""
515
+ Returns:
516
+
517
+ Example:
518
+
519
+ ```python
520
+ >>> from transformers import FlaxSpeechEncoderDecoderModel
521
+
522
+ >>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized
523
+ >>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
524
+ ... "facebook/wav2vec2-large-lv60", "facebook/bart-large"
525
+ ... )
526
+
527
+ >>> inputs = jnp.ones((2, 5000), dtype=jnp.float32)
528
+ >>> encoder_outputs = model.encode(inputs)
529
+ ```"""
530
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
531
+ output_hidden_states = (
532
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
533
+ )
534
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
535
+
536
+ if attention_mask is None:
537
+ attention_mask = jnp.ones_like(inputs, dtype="i4")
538
+
539
+ if extract_features is not None:
540
+ extract_features = jnp.array(extract_features, dtype="f4")
541
+
542
+ # Handle any PRNG if needed
543
+ rngs = {}
544
+ if dropout_rng is not None:
545
+ rngs["dropout"] = dropout_rng
546
+
547
+ def _encoder_forward(module, inputs, attention_mask, **kwargs):
548
+ encode_module = module._get_encoder_module()
549
+ return encode_module(inputs, attention_mask, **kwargs)
550
+
551
+ outputs = self.module.apply(
552
+ {"params": params or self.params},
553
+ inputs=jnp.array(inputs, dtype="f4"),
554
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
555
+ extract_features=extract_features,
556
+ output_attentions=output_attentions,
557
+ output_hidden_states=output_hidden_states,
558
+ output_features=output_features,
559
+ return_dict=return_dict,
560
+ deterministic=not train,
561
+ freeze_feature_encoder=freeze_feature_encoder,
562
+ rngs=rngs,
563
+ method=_encoder_forward,
564
+ )
565
+
566
+ if return_dict and not output_features:
567
+ outputs = FlaxBaseModelOutput(
568
+ last_hidden_state=outputs.last_hidden_state,
569
+ hidden_states=outputs.hidden_states,
570
+ attentions=outputs.attentions,
571
+ )
572
+
573
+ return outputs
574
+
575
+ @add_start_docstrings(SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING)
576
+ @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
577
+ def decode(
578
+ self,
579
+ decoder_input_ids,
580
+ encoder_outputs,
581
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
582
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
583
+ decoder_position_ids: Optional[jnp.ndarray] = None,
584
+ past_key_values: dict = None,
585
+ output_attentions: Optional[bool] = None,
586
+ output_hidden_states: Optional[bool] = None,
587
+ return_dict: Optional[bool] = None,
588
+ train: bool = False,
589
+ params: dict = None,
590
+ dropout_rng: PRNGKey = None,
591
+ ):
592
+ r"""
593
+ Returns:
594
+
595
+ Example:
596
+
597
+ ```python
598
+ >>> from transformers import FlaxSpeechEncoderDecoderModel
599
+ >>> import jax.numpy as jnp
600
+
601
+ >>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized
602
+ >>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
603
+ ... "facebook/wav2vec2-large-lv60", "facebook/bart-large"
604
+ ... )
605
+
606
+ >>> inputs = jnp.ones((2, 5000), dtype=jnp.float32)
607
+ >>> encoder_outputs = model.encode(inputs)
608
+
609
+ >>> decoder_start_token_id = model.config.decoder.bos_token_id
610
+ >>> decoder_input_ids = jnp.ones((inputs.shape[0], 1), dtype="i4") * decoder_start_token_id
611
+
612
+ >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
613
+ >>> logits = outputs.logits
614
+ ```"""
615
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
616
+ output_hidden_states = (
617
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
618
+ )
619
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
620
+
621
+ encoder_hidden_states = encoder_outputs[0]
622
+ if encoder_attention_mask is None:
623
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
624
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
625
+
626
+ batch_size, sequence_length = decoder_input_ids.shape
627
+ if decoder_attention_mask is None:
628
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length))
629
+
630
+ if decoder_position_ids is None:
631
+ if past_key_values is not None:
632
+ raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
633
+
634
+ decoder_position_ids = jnp.broadcast_to(
635
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
636
+ )
637
+
638
+ # Handle any PRNG if needed
639
+ rngs = {}
640
+ if dropout_rng is not None:
641
+ rngs["dropout"] = dropout_rng
642
+
643
+ params = {"params": params or self.params}
644
+
645
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
646
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
647
+ # it can be changed by FlaxBartAttention module
648
+ if past_key_values:
649
+ params["cache"] = past_key_values
650
+ mutable = ["cache"]
651
+ else:
652
+ mutable = False
653
+
654
+ def _decoder_forward(
655
+ module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, encoder_hidden_states, **kwargs
656
+ ):
657
+
658
+ projection_module = module._get_projection_module()
659
+ decoder_module = module._get_decoder_module()
660
+
661
+ # optionally project encoder_hidden_states
662
+ if projection_module is not None:
663
+ encoder_hidden_states = projection_module(encoder_hidden_states)
664
+
665
+ return decoder_module(
666
+ decoder_input_ids,
667
+ decoder_attention_mask,
668
+ decoder_position_ids,
669
+ encoder_hidden_states,
670
+ **kwargs,
671
+ )
672
+
673
+ outputs = self.module.apply(
674
+ params,
675
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
676
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
677
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
678
+ encoder_hidden_states=encoder_hidden_states,
679
+ encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
680
+ output_attentions=output_attentions,
681
+ output_hidden_states=output_hidden_states,
682
+ return_dict=return_dict,
683
+ deterministic=not train,
684
+ rngs=rngs,
685
+ mutable=mutable,
686
+ method=_decoder_forward,
687
+ )
688
+
689
+ # add updated cache to model output
690
+ if past_key_values is not None and return_dict:
691
+ outputs, past = outputs
692
+ outputs["past_key_values"] = unfreeze(past["cache"])
693
+ return outputs
694
+ elif past_key_values is not None and not return_dict:
695
+ outputs, past = outputs
696
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
697
+
698
+ return outputs
699
+
700
+ @add_start_docstrings_to_model_forward(SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING)
701
+ @replace_return_docstrings(output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
702
+ def __call__(
703
+ self,
704
+ inputs: jnp.ndarray,
705
+ attention_mask: Optional[jnp.ndarray] = None,
706
+ extract_features: Optional[jnp.ndarray] = None,
707
+ decoder_input_ids: Optional[jnp.ndarray] = None,
708
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
709
+ decoder_position_ids: Optional[jnp.ndarray] = None,
710
+ output_attentions: Optional[bool] = None,
711
+ output_hidden_states: Optional[bool] = None,
712
+ output_features: Optional[bool] = None,
713
+ return_dict: Optional[bool] = None,
714
+ train: bool = False,
715
+ freeze_feature_encoder: bool = False,
716
+ params: dict = None,
717
+ dropout_rng: PRNGKey = None,
718
+ ):
719
+ r"""
720
+ Returns:
721
+
722
+ Examples:
723
+
724
+ ```python
725
+ >>> from transformers import FlaxSpeechEncoderDecoderModel, BartTokenizer
726
+
727
+ >>> # load a fine-tuned wav2vec2-2-bart model
728
+ >>> model = FlaxSpeechEncoderDecoderModel.from_pretrained("patrickvonplaten/wav2vec2-2-bart-large")
729
+ >>> # load output tokenizer
730
+ >>> tokenizer_output = BartTokenizer.from_pretrained("facebook/bart-large")
731
+
732
+ >>> inputs = jnp.ones((2, 5000), dtype=jnp.float32)
733
+
734
+ >>> # use bart's special bos, pad and eos tokens
735
+ >>> model.config.decoder_start_token_id = model.decoder.config.bos_token_id
736
+ >>> model.config.pad_token_id = model.decoder.config.pad_token_id
737
+ >>> model.config.eos_token_id = model.decoder.config.eos_token_id
738
+
739
+ >>> outputs = model.generate(inputs)
740
+ # Assert something? More interesting input? dtype correct?
741
+ ```
742
+ """
743
+
744
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
745
+ output_hidden_states = (
746
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
747
+ )
748
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
749
+
750
+ # prepare encoder inputs
751
+ if attention_mask is None:
752
+ attention_mask = jnp.ones_like(inputs, dtype="i4")
753
+
754
+ if extract_features is not None:
755
+ inputs = None # we can omit passing the inputs to the model to save memory
756
+ extract_features = jnp.array(extract_features, dtype="f4")
757
+ else:
758
+ inputs = jnp.array(inputs, dtype="f4")
759
+
760
+ # prepare decoder inputs
761
+ if decoder_input_ids is None:
762
+ raise ValueError(
763
+ "`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must be specified as an input argument."
764
+ )
765
+ if decoder_attention_mask is None:
766
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
767
+ if decoder_position_ids is None:
768
+ batch_size, sequence_length = decoder_input_ids.shape
769
+ decoder_position_ids = jnp.broadcast_to(
770
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
771
+ )
772
+
773
+ # Handle any PRNG if needed
774
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
775
+
776
+ return self.module.apply(
777
+ {"params": params or self.params},
778
+ inputs=inputs,
779
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
780
+ extract_features=extract_features,
781
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
782
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
783
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
784
+ output_attentions=output_attentions,
785
+ output_hidden_states=output_hidden_states,
786
+ output_features=output_features,
787
+ return_dict=return_dict,
788
+ deterministic=not train,
789
+ freeze_feature_encoder=freeze_feature_encoder,
790
+ rngs=rngs,
791
+ )
792
+
793
+ def prepare_inputs_for_generation(
794
+ self,
795
+ decoder_input_ids,
796
+ max_length,
797
+ attention_mask: Optional[jnp.DeviceArray] = None,
798
+ decoder_attention_mask: Optional[jnp.DeviceArray] = None,
799
+ encoder_outputs=None,
800
+ **kwargs
801
+ ):
802
+ # initializing the cache
803
+ batch_size, seq_length = decoder_input_ids.shape
804
+
805
+ past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
806
+ # Note that usually one would have to put 0's in the attention_mask for x > input.shape[-1] and x < cache_length.
807
+ # But since the decoder uses a causal mask, those positions are masked anyways.
808
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
809
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
810
+ if decoder_attention_mask is not None:
811
+ decoder_position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
812
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
813
+ else:
814
+ decoder_position_ids = jnp.broadcast_to(
815
+ jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
816
+ )
817
+
818
+ return {
819
+ "past_key_values": past_key_values,
820
+ "encoder_outputs": encoder_outputs,
821
+ "encoder_attention_mask": attention_mask,
822
+ "decoder_attention_mask": extended_attention_mask,
823
+ "decoder_position_ids": decoder_position_ids,
824
+ }
825
+
826
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
827
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
828
+ model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
829
+ return model_kwargs
830
+
831
+ @classmethod
832
+ def from_encoder_decoder_pretrained(
833
+ cls,
834
+ encoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
835
+ decoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
836
+ *model_args,
837
+ **kwargs
838
+ ) -> FlaxPreTrainedModel:
839
+ r"""
840
+ Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model
841
+ checkpoints.
842
+
843
+ Params:
844
+ encoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*):
845
+ Information necessary to initiate the encoder. Can be either:
846
+
847
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
848
+ Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
849
+ user or organization name, like `dbmdz/bert-base-german-cased`.
850
+ - A path to a *directory* containing model weights saved using
851
+ [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
852
+
853
+ decoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*, defaults to `None`):
854
+ Information necessary to initiate the decoder. Can be either:
855
+
856
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
857
+ Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
858
+ user or organization name, like `dbmdz/bert-base-german-cased`.
859
+ - A path to a *directory* containing model weights saved using
860
+ [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
861
+
862
+ model_args (remaining positional arguments, *optional*):
863
+ All remaning positional arguments will be passed to the underlying model's `__init__` method.
864
+
865
+ kwargs (remaining dictionary of keyword arguments, *optional*):
866
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
867
+ `output_attentions=True`).
868
+
869
+ - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter.
870
+ - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
871
+ - To update the parent model configuration, do not use a prefix for each configuration parameter.
872
+
873
+ Behaves differently depending on whether a `config` is provided or automatically loaded.
874
+
875
+ Example:
876
+
877
+ ```python
878
+ >>> from transformers import FlaxSpeechEncoderDecoderModel
879
+
880
+ >>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized
881
+ >>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
882
+ ... "facebook/wav2vec2-large-lv60", "facebook/bart-large"
883
+ ... )
884
+ >>> # saving model after fine-tuning
885
+ >>> model.save_pretrained("./wav2vec2-2-bart-large")
886
+ >>> # load fine-tuned model
887
+ >>> model = FlaxSpeechEncoderDecoderModel.from_pretrained("./wav2vec2-2-bart-large")
888
+ ```"""
889
+
890
+ kwargs_encoder = {
891
+ argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
892
+ }
893
+
894
+ kwargs_decoder = {
895
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
896
+ }
897
+
898
+ # remove encoder, decoder kwargs from kwargs
899
+ for key in kwargs_encoder.keys():
900
+ del kwargs["encoder_" + key]
901
+ for key in kwargs_decoder.keys():
902
+ del kwargs["decoder_" + key]
903
+
904
+ # Load and initialize the encoder and decoder
905
+ # The distinction between encoder and decoder at the model level is made
906
+ # by the value of the flag `is_decoder` that we need to set correctly.
907
+ encoder = kwargs_encoder.pop("model", None)
908
+ if encoder is None:
909
+ if encoder_pretrained_model_name_or_path is None:
910
+ raise ValueError(
911
+ "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
912
+ "to be defined."
913
+ )
914
+
915
+ if "config" not in kwargs_encoder:
916
+ # TODO: AutoConfig .from_pretrained
917
+ encoder_config, kwargs_encoder = Wav2Vec2Config.from_pretrained(
918
+ encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True
919
+ )
920
+ if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
921
+ logger.info(
922
+ f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
923
+ "from a decoder model. Cross-attention and casual mask are disabled."
924
+ )
925
+ encoder_config.is_decoder = False
926
+ encoder_config.add_cross_attention = False
927
+
928
+ kwargs_encoder["config"] = encoder_config
929
+
930
+ # TODO: FlaxAutoModel .from_pretrained
931
+ encoder = FlaxWav2Vec2Model.from_pretrained(
932
+ encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder
933
+ )
934
+
935
+ decoder = kwargs_decoder.pop("model", None)
936
+ if decoder is None:
937
+ if decoder_pretrained_model_name_or_path is None:
938
+ raise ValueError(
939
+ "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
940
+ "to be defined."
941
+ )
942
+
943
+ if "config" not in kwargs_decoder:
944
+ # TODO: AutoConfig .from_pretrained
945
+ decoder_config, kwargs_decoder = BartConfig.from_pretrained(
946
+ decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
947
+ )
948
+ if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
949
+ logger.info(
950
+ f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
951
+ f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
952
+ f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
953
+ "cross attention layers."
954
+ )
955
+ decoder_config.is_decoder = True
956
+ decoder_config.add_cross_attention = True
957
+
958
+ kwargs_decoder["config"] = decoder_config
959
+
960
+ if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
961
+ logger.warning(
962
+ f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
963
+ f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
964
+ "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
965
+ "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
966
+ "`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
967
+ )
968
+
969
+ # TODO: FlaxAutoModelForCausalLM .from_pretrained
970
+ decoder = FlaxBartForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
971
+
972
+ # instantiate config with corresponding kwargs
973
+ dtype = kwargs.pop("dtype", jnp.float32)
974
+ config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
975
+
976
+ # make sure input & output word embeddings are not tied
977
+ config.tie_word_embeddings = False
978
+
979
+ # init model
980
+ model = cls(config, dtype=dtype)
981
+ model.params["encoder"] = encoder.params
982
+ model.params["decoder"] = decoder.params
983
+
984
+ return model
985
+
986
+ def _beam_search(
987
+ self,
988
+ input_ids: None,
989
+ max_length: Optional[int] = None,
990
+ pad_token_id: Optional[int] = None,
991
+ eos_token_id: Optional[int] = None,
992
+ length_penalty: Optional[float] = None,
993
+ early_stopping: Optional[bool] = None,
994
+ logits_processor: Optional[FlaxLogitsProcessorList] = None,
995
+ trace: bool = True,
996
+ params: Optional[Dict[str, jnp.ndarray]] = None,
997
+ model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
998
+ ):
999
+ """
1000
+ This beam search function is heavily inspired by Flax's official example:
1001
+ https://github.com/google/flax/blob/master/examples/wmt/train.py#L254
1002
+ """
1003
+
1004
+ def flatten_beam_dim(tensor):
1005
+ """Flattens the first two dimensions of a non-scalar array."""
1006
+ # ignore scalars (e.g. cache index)
1007
+ if tensor.ndim == 0 or tensor.ndim == 1:
1008
+ return tensor
1009
+ elif tensor.ndim == 6:
1010
+ return tensor.reshape(tensor.shape[:1] + (tensor.shape[1] * tensor.shape[2],) + tensor.shape[3:])
1011
+ return tensor.reshape((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:])
1012
+
1013
+ def unflatten_beam_dim(tensor, batch_size, num_beams):
1014
+ """Unflattens the first, flat batch*beam dimension of a non-scalar array."""
1015
+ # ignore scalars (e.g. cache index)
1016
+ if tensor.ndim == 0 or tensor.ndim == 1:
1017
+ return tensor
1018
+ if tensor.ndim == 5:
1019
+ return tensor.reshape(tensor.shape[:1] + (batch_size, num_beams) + tensor.shape[2:])
1020
+ return tensor.reshape((batch_size, num_beams) + tensor.shape[1:])
1021
+
1022
+ def gather_beams(nested, beam_indices, batch_size, new_num_beams):
1023
+ """
1024
+ Gathers the beam slices indexed by beam_indices into new beam array.
1025
+ """
1026
+ batch_indices = jnp.reshape(
1027
+ jnp.arange(batch_size * new_num_beams) // new_num_beams, (batch_size, new_num_beams)
1028
+ )
1029
+
1030
+ def gather_fn(tensor):
1031
+ # ignore scalars (e.g. cache index)
1032
+ if tensor.ndim == 0 or tensor.ndim == 1:
1033
+ return tensor
1034
+ if tensor.ndim == 6:
1035
+ return tensor[:, batch_indices, beam_indices]
1036
+ return tensor[batch_indices, beam_indices]
1037
+
1038
+ return jax.tree_map(gather_fn, nested)
1039
+
1040
+ # init values
1041
+ max_length = max_length if max_length is not None else self.config.max_length
1042
+ pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
1043
+ eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
1044
+ length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
1045
+ early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
1046
+
1047
+ batch_size, num_beams, cur_len = input_ids.shape
1048
+
1049
+ eos_token_id = jnp.array(eos_token_id)
1050
+ pad_token_id = jnp.array(pad_token_id)
1051
+ cur_len = jnp.array(cur_len)
1052
+
1053
+ # per batch,beam-item holding current token in loop.
1054
+ sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
1055
+ running_sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
1056
+ running_sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0, 0))
1057
+
1058
+ # per batch,beam-item state bit indicating if sentence has finished.
1059
+ is_sent_finished = jnp.zeros((batch_size, num_beams), dtype=jnp.bool_)
1060
+
1061
+ # per batch,beam-item score, logprobs
1062
+ running_scores = jnp.tile(jnp.array([0.0] + [np.array(-1.0e7)] * (num_beams - 1)), [batch_size, 1])
1063
+ scores = jnp.ones((batch_size, num_beams)) * np.array(-1.0e7)
1064
+
1065
+ # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
1066
+ # and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
1067
+ model = self.decode if self.config.is_encoder_decoder else self
1068
+
1069
+ # flatten beam dim
1070
+ if "encoder_outputs" in model_kwargs:
1071
+ model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim(
1072
+ model_kwargs["encoder_outputs"]["last_hidden_state"]
1073
+ )
1074
+ if "attention_mask" in model_kwargs:
1075
+ model_kwargs["attention_mask"] = flatten_beam_dim(model_kwargs["attention_mask"])
1076
+
1077
+ # initialize model specific kwargs
1078
+ model_kwargs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), max_length, **model_kwargs)
1079
+
1080
+ # initialize state
1081
+ state = BeamSearchState(
1082
+ cur_len=cur_len,
1083
+ running_sequences=running_sequences,
1084
+ running_scores=running_scores,
1085
+ sequences=sequences,
1086
+ scores=scores,
1087
+ is_sent_finished=is_sent_finished,
1088
+ model_kwargs=model_kwargs,
1089
+ )
1090
+
1091
+ def beam_search_cond_fn(state):
1092
+ """beam search state termination condition fn."""
1093
+
1094
+ # 1. is less than max length?
1095
+ not_max_length_yet = state.cur_len < max_length
1096
+
1097
+ # 2. can the new beams still improve?
1098
+ best_running_score = state.running_scores[:, -1:] / (max_length**length_penalty)
1099
+ worst_finished_score = jnp.where(
1100
+ state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7)
1101
+ )
1102
+ improvement_still_possible = jnp.all(worst_finished_score < best_running_score)
1103
+
1104
+ # 3. is there still a beam that has not finished?
1105
+ still_open_beam = ~(jnp.all(state.is_sent_finished) & early_stopping)
1106
+
1107
+ return not_max_length_yet & still_open_beam & improvement_still_possible
1108
+
1109
+ def beam_search_body_fn(state, input_ids_length=1):
1110
+ """beam search state update fn."""
1111
+ # 1. Forward current tokens
1112
+ # Collect the current position slice along length to feed the fast
1113
+ # autoregressive decoder model. Flatten the beam dimension into batch
1114
+ # dimension for feeding into the model.
1115
+ # unflatten beam dimension
1116
+ # Unflatten beam dimension in attention cache arrays
1117
+ input_token = flatten_beam_dim(
1118
+ lax.dynamic_slice(
1119
+ state.running_sequences,
1120
+ (0, 0, state.cur_len - input_ids_length),
1121
+ (batch_size, num_beams, input_ids_length),
1122
+ )
1123
+ )
1124
+ model_outputs = model(input_token, params=params, **state.model_kwargs)
1125
+
1126
+ logits = unflatten_beam_dim(model_outputs.logits[:, -1], batch_size, num_beams)
1127
+ cache = jax.tree_map(
1128
+ lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams), model_outputs.past_key_values
1129
+ )
1130
+
1131
+ # adapt logits for FlaxMarianMTModel
1132
+ logits = self._adapt_logits_for_beam_search(logits)
1133
+
1134
+ # 2. Compute log probs
1135
+ # get log probabilities from logits,
1136
+ # process logits with processors (*e.g.* min_length, ...), and
1137
+ # add new logprobs to existing running logprobs scores.
1138
+ log_probs = jax.nn.log_softmax(logits)
1139
+ log_probs = logits_processor(
1140
+ flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), state.cur_len
1141
+ )
1142
+ log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
1143
+ log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2)
1144
+ vocab_size = log_probs.shape[2]
1145
+ log_probs = log_probs.reshape((batch_size, num_beams * vocab_size))
1146
+
1147
+ # 3. Retrieve top-K
1148
+ # Each item in batch has num_beams * vocab_size candidate sequences.
1149
+ # For each item, get the top 2*k candidates with the highest log-
1150
+ # probabilities. We gather the top 2*K beams here so that even if the best
1151
+ # K sequences reach EOS simultaneously, we have another K sequences
1152
+ # remaining to continue the live beam search.
1153
+ # Gather the top 2*K scores from _all_ beams.
1154
+ # Gather 2*k top beams.
1155
+ # Recover the beam index by floor division.
1156
+ # Recover token id by modulo division and expand Id array for broadcasting.
1157
+ # Update sequences for the 2*K top-k new sequences.
1158
+ beams_to_keep = 2 * num_beams
1159
+ topk_log_probs, topk_indices = lax.top_k(log_probs, k=beams_to_keep)
1160
+ topk_beam_indices = topk_indices // vocab_size
1161
+ topk_running_sequences = gather_beams(
1162
+ state.running_sequences, topk_beam_indices, batch_size, beams_to_keep
1163
+ )
1164
+ topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
1165
+ topk_sequences = lax.dynamic_update_slice(topk_running_sequences, topk_ids, (0, 0, state.cur_len))
1166
+
1167
+ # 4. Check which sequences have ended
1168
+ # Update current sequences:
1169
+ # Did any of these sequences reach an end marker?
1170
+ # To prevent these just finished sequences from being added to the current sequences
1171
+ # set of active beam search sequences, set their log probs to a very large
1172
+ # negative value.
1173
+ did_topk_just_finished = topk_sequences[:, :, state.cur_len] == eos_token_id
1174
+ running_topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(-1.0e7)
1175
+ # 5. Get running sequences scores for next
1176
+ # Determine the top k beam indices (from top 2*k beams) from log probs
1177
+ # and gather top k beams (from top 2*k beams).
1178
+ next_topk_indices = jnp.flip(lax.top_k(running_topk_log_probs, k=num_beams)[1], axis=1)
1179
+ next_running_sequences, next_running_scores = gather_beams(
1180
+ [topk_sequences, running_topk_log_probs], next_topk_indices, batch_size, num_beams
1181
+ )
1182
+
1183
+ # 6. Process topk logits
1184
+ # Further process log probs:
1185
+ # - add length penalty
1186
+ # - make sure no scores can be added anymore if beam is full
1187
+ # - make sure still running sequences cannot be chosen as finalized beam
1188
+ topk_log_probs = topk_log_probs / (state.cur_len**length_penalty)
1189
+ beams_in_batch_are_full = (
1190
+ jnp.broadcast_to(state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape)
1191
+ & early_stopping
1192
+ )
1193
+ add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
1194
+ topk_log_probs += add_penalty * np.array(-1.0e7)
1195
+
1196
+ # 7. Get scores, sequences, is sentence finished for next.
1197
+ # Combine sequences, scores, and flags along the beam dimension and compare
1198
+ # new finished sequence scores to existing finished scores and select the
1199
+ # best from the new set of beams
1200
+ merged_sequences = jnp.concatenate([state.sequences, topk_sequences], axis=1)
1201
+ merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1)
1202
+ merged_is_sent_finished = jnp.concatenate([state.is_sent_finished, did_topk_just_finished], axis=1)
1203
+ topk_merged_indices = jnp.flip(lax.top_k(merged_scores, k=num_beams)[1], axis=1)
1204
+ next_sequences, next_scores, next_is_sent_finished = gather_beams(
1205
+ [merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams
1206
+ )
1207
+
1208
+ # 8. Update model kwargs.
1209
+ # Determine the top k beam indices from the original set of all beams.
1210
+ # With these, gather the top k beam-associated caches.
1211
+ next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams)
1212
+ next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams)
1213
+ model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache)
1214
+ next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
1215
+
1216
+ return BeamSearchState(
1217
+ cur_len=state.cur_len + 1,
1218
+ running_scores=next_running_scores,
1219
+ running_sequences=next_running_sequences,
1220
+ scores=next_scores,
1221
+ sequences=next_sequences,
1222
+ is_sent_finished=next_is_sent_finished,
1223
+ model_kwargs=next_model_kwargs,
1224
+ )
1225
+
1226
+ # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
1227
+ if input_ids.shape[-1] > 1:
1228
+ state = partial(beam_search_body_fn, input_ids_length=input_ids.shape[-1])(state)
1229
+
1230
+ if not trace:
1231
+ state = self._run_loop_in_debug(beam_search_cond_fn, beam_search_body_fn, state)
1232
+ else:
1233
+ state = lax.while_loop(beam_search_cond_fn, beam_search_body_fn, state)
1234
+
1235
+ # Account for the edge-case where there are no finished sequences for a
1236
+ # particular batch item. If so, return running sequences for that batch item.
1237
+ none_finished = jnp.any(state.is_sent_finished, axis=1)
1238
+ sequences = jnp.where(none_finished[:, None, None], state.sequences, state.running_sequences)
1239
+ scores = jnp.where(none_finished[:, None], state.scores, state.running_scores)
1240
+
1241
+ # return all beams for each batch and the best score
1242
+ sequences = sequences[:, :]
1243
+ scores = scores[:, -1]
1244
+
1245
+ return FlaxBeamSearchOutput(sequences=sequences, scores=scores)
models/modeling_flax_wav2vec2.py ADDED
@@ -0,0 +1,975 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Flax Wav2Vec2 model."""
16
+
17
+ from functools import partial
18
+ from typing import Optional, Tuple, Union
19
+
20
+ import flax
21
+ import flax.linen as nn
22
+ import jax
23
+ import jax.numpy as jnp
24
+ from flax.core.frozen_dict import FrozenDict
25
+ from flax.linen import partitioning as nn_partitioning
26
+ from flax.linen.attention import dot_product_attention_weights
27
+ from jax import lax
28
+
29
+ from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
30
+ from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
31
+ from transformers.utils import ModelOutput
32
+
33
+ from models import Wav2Vec2Config
34
+
35
+ scan_with_axes = nn_partitioning.scan_with_axes
36
+ remat = nn_partitioning.remat
37
+
38
+
39
+ @flax.struct.dataclass
40
+ class FlaxWav2Vec2BaseModelOutput(ModelOutput):
41
+ """
42
+ Output type of [`FlaxWav2Vec2BaseModelOutput`], with potential hidden states and attentions.
43
+
44
+ Args:
45
+ last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
46
+ Sequence of hidden-states at the output of the last layer of the model.
47
+ extract_features (`jnp.ndarray` of shape `(batch_size, sequence_length, last_conv_dim)`):
48
+ Sequence of extracted feature vectors of the last convolutional layer of the model with `last_conv_dim`
49
+ being the dimension of the last convolutional layer.
50
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
51
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
52
+ `(batch_size, sequence_length, hidden_size)`.
53
+
54
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
55
+ attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
56
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
57
+ sequence_length)`.
58
+
59
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
60
+ heads.
61
+ """
62
+
63
+ last_hidden_state: jnp.ndarray = None
64
+ extract_features: jnp.ndarray = None
65
+ hidden_states: Optional[Tuple[jnp.ndarray]] = None
66
+ attentions: Optional[Tuple[jnp.ndarray]] = None
67
+
68
+
69
+ WAV_2_VEC_2_START_DOCSTRING = r"""
70
+ Wav2Vec2 was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
71
+ Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
72
+ Auli.
73
+
74
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
75
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
76
+ etc.)
77
+
78
+ This model is also a Flax Linen
79
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
80
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
81
+
82
+ Finally, this model supports inherent JAX features such as:
83
+
84
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
85
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
86
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
87
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
88
+
89
+ Parameters:
90
+ config ([`Wav2Vec2Config`]): Model configuration class with all the parameters of the model.
91
+ Initializing with a config file does not load the weights associated with the model, only the
92
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
93
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
94
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
95
+ `jax.numpy.bfloat16` (on TPUs).
96
+
97
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
98
+ specified all the computation will be performed with the given `dtype`.
99
+
100
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
101
+ parameters.**
102
+
103
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
104
+ [`~FlaxPreTrainedModel.to_bf16`].
105
+ """
106
+
107
+
108
+ WAV_2_VEC_2_INPUTS_DOCSTRING = r"""
109
+ Args:
110
+ input_values (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
111
+ Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file
112
+ into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile library (*pip install
113
+ soundfile*). To prepare the array into *input_values*, the [`Wav2Vec2Processor`] should be used for padding
114
+ and conversion into a tensor of type *jnp.ndarray*. See [`Wav2Vec2Processor.__call__`] for details.
115
+ attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
116
+ Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
117
+ 1]`:
118
+
119
+ - 1 for tokens that are **not masked**,
120
+ - 0 for tokens that are **masked**.
121
+
122
+ [What are attention masks?](../glossary#attention-mask) .. warning:: `attention_mask` should only be passed
123
+ if the corresponding processor has `config.return_attention_mask == True`. For all models whose processor
124
+ has `config.return_attention_mask == False`, such as
125
+ [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base-960h), `attention_mask` should **not** be
126
+ passed to avoid degraded performance when doing batched inference. For such models `input_values` should
127
+ simply be padded with 0 and passed without `attention_mask`. Be aware that these models also yield slightly
128
+ different results depending on whether `input_values` is padded or not.
129
+ mask_time_indices (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
130
+ Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
131
+ masked extracted features in *config.proj_codevector_dim* space.
132
+ output_attentions (`bool`, *optional*):
133
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
134
+ tensors for more detail.
135
+ output_hidden_states (`bool`, *optional*):
136
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
137
+ more detail.
138
+ return_dict (`bool`, *optional*):
139
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
140
+ """
141
+
142
+
143
+ class FlaxWav2Vec2LayerNormConvLayer(nn.Module):
144
+ config: Wav2Vec2Config
145
+ layer_id: int = 0
146
+ dtype: jnp.dtype = jnp.float32
147
+
148
+ def setup(self):
149
+ self.in_conv_dim = self.config.conv_dim[self.layer_id] if self.layer_id > 0 else 1
150
+ self.out_conv_dim = self.config.conv_dim[self.layer_id]
151
+
152
+ self.conv = nn.Conv(
153
+ features=self.config.conv_dim[self.layer_id],
154
+ kernel_size=(self.config.conv_kernel[self.layer_id],),
155
+ strides=(self.config.conv_stride[self.layer_id],),
156
+ use_bias=self.config.conv_bias,
157
+ kernel_init=jax.nn.initializers.he_normal(),
158
+ padding="VALID",
159
+ dtype=self.dtype,
160
+ )
161
+ self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
162
+ self.activation = ACT2FN[self.config.feat_extract_activation]
163
+
164
+ def __call__(self, hidden_states):
165
+ hidden_states = self.conv(hidden_states)
166
+ hidden_states = self.layer_norm(hidden_states)
167
+ hidden_states = self.activation(hidden_states)
168
+ return hidden_states
169
+
170
+
171
+ class FlaxConvWithWeightNorm(nn.Module):
172
+ config: Wav2Vec2Config
173
+ dtype: jnp.dtype = jnp.float32
174
+
175
+ def setup(self):
176
+ self.conv = nn.Conv(
177
+ features=self.config.hidden_size,
178
+ kernel_size=(self.config.num_conv_pos_embeddings,),
179
+ kernel_init=jax.nn.initializers.he_normal(),
180
+ padding="VALID",
181
+ feature_group_count=self.config.num_conv_pos_embedding_groups,
182
+ dtype=self.dtype,
183
+ )
184
+ weight_shape = (
185
+ self.conv.features,
186
+ self.conv.features // self.conv.feature_group_count,
187
+ self.conv.kernel_size[0],
188
+ )
189
+ self.weight_v = self.param("weight_v", jax.nn.initializers.he_normal(), weight_shape)
190
+ self.weight_g = self.param("weight_g", lambda _: jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :])
191
+ self.bias = self.param("bias", jax.nn.initializers.zeros, (self.conv.features,))
192
+ self.prev_padding = self.conv.kernel_size[0] // 2
193
+
194
+ def _get_normed_weights(self):
195
+ weight_v_norm = jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :]
196
+ normed_weight_v = jnp.divide(self.weight_v, weight_v_norm)
197
+ normed_kernel = jnp.multiply(normed_weight_v, self.weight_g)
198
+ return normed_kernel
199
+
200
+ def __call__(self, hidden_states):
201
+ kernel = self._get_normed_weights()
202
+ hidden_states = jnp.pad(hidden_states, ((0, 0), (self.prev_padding, self.prev_padding), (0, 0)))
203
+ hidden_states = self.conv.apply({"params": {"kernel": kernel.T, "bias": self.bias}}, hidden_states)
204
+ return hidden_states
205
+
206
+
207
+ class FlaxWav2Vec2PositionalConvEmbedding(nn.Module):
208
+ config: Wav2Vec2Config
209
+ dtype: jnp.dtype = jnp.float32
210
+
211
+ def setup(self):
212
+ self.conv = FlaxConvWithWeightNorm(self.config, dtype=self.dtype)
213
+ self.activation = ACT2FN[self.config.feat_extract_activation]
214
+ self.num_pad_remove = 1 if self.config.num_conv_pos_embeddings % 2 == 0 else 0
215
+
216
+ def __call__(self, hidden_states):
217
+ hidden_states = hidden_states.transpose((0, 1, 2))
218
+
219
+ hidden_states = self.conv(hidden_states)
220
+
221
+ if self.num_pad_remove > 0:
222
+ hidden_states = hidden_states[:, : -self.num_pad_remove, :]
223
+ hidden_states = self.activation(hidden_states)
224
+
225
+ hidden_states = hidden_states.transpose((0, 1, 2))
226
+ return hidden_states
227
+
228
+
229
+ class FlaxConvLayersCollection(nn.Module):
230
+ config: Wav2Vec2Config
231
+ dtype: jnp.dtype = jnp.float32
232
+
233
+ def setup(self):
234
+ if self.config.feat_extract_norm == "layer":
235
+ # note that we can't use scan on the conv layers as they differ on a layer-by-layer basis
236
+ BlockLayer = remat(FlaxWav2Vec2LayerNormConvLayer) if self.config.gradient_checkpointing else FlaxWav2Vec2LayerNormConvLayer
237
+ self.layers = [
238
+ BlockLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype)
239
+ for i in range(self.config.num_feat_extract_layers)
240
+ ]
241
+ elif self.config.feat_extract_norm == "group":
242
+ raise NotImplementedError("At the moment only ``config.feat_extact_norm == 'layer'`` is supported")
243
+ else:
244
+ raise ValueError(
245
+ f"`config.feat_extract_norm` is {self.config.feat_extract_norm}, but has to be one of ['group', 'layer']"
246
+ )
247
+
248
+ def __call__(self, hidden_states):
249
+ for i, conv_layer in enumerate(self.layers):
250
+ hidden_states = conv_layer(hidden_states)
251
+ return hidden_states
252
+
253
+
254
+ class FlaxWav2Vec2FeatureEncoder(nn.Module):
255
+ """Construct the features from raw audio waveform"""
256
+
257
+ config: Wav2Vec2Config
258
+ dtype: jnp.dtype = jnp.float32
259
+
260
+ def setup(self):
261
+ self.conv_layers = FlaxConvLayersCollection(self.config, dtype=self.dtype)
262
+
263
+ def __call__(self, input_values, freeze_feature_encoder=False):
264
+ hidden_states = input_values[:, :, None]
265
+ hidden_states = self.conv_layers(hidden_states)
266
+ if freeze_feature_encoder:
267
+ hidden_states = jax.lax.stop_gradient(hidden_states)
268
+ return hidden_states
269
+
270
+
271
+ class FlaxWav2Vec2FeatureProjection(nn.Module):
272
+ config: Wav2Vec2Config
273
+ dtype: jnp.dtype = jnp.float32
274
+
275
+ def setup(self):
276
+ self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
277
+ self.projection = nn.Dense(
278
+ self.config.hidden_size,
279
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
280
+ dtype=self.dtype,
281
+ )
282
+ self.dropout = nn.Dropout(rate=self.config.feat_proj_dropout)
283
+
284
+ def __call__(self, hidden_states, deterministic=True):
285
+ norm_hidden_states = self.layer_norm(hidden_states)
286
+ hidden_states = self.projection(norm_hidden_states)
287
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
288
+ return hidden_states, norm_hidden_states
289
+
290
+
291
+ class FlaxWav2Vec2Attention(nn.Module):
292
+ config: Wav2Vec2Config
293
+ embed_dim: int
294
+ num_heads: int
295
+ dropout: float = 0.0
296
+ bias: bool = True
297
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
298
+
299
+ def setup(self) -> None:
300
+ self.head_dim = self.embed_dim // self.num_heads
301
+ if self.head_dim * self.num_heads != self.embed_dim:
302
+ raise ValueError(
303
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
304
+ )
305
+
306
+ dense = partial(
307
+ nn.Dense,
308
+ self.embed_dim,
309
+ use_bias=self.bias,
310
+ dtype=self.dtype,
311
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
312
+ )
313
+
314
+ self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
315
+
316
+ self.fused_proj = nn.Dense(
317
+ self.embed_dim * 3,
318
+ use_bias=self.bias,
319
+ dtype=self.dtype,
320
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
321
+ )
322
+
323
+ self.out_proj = dense()
324
+
325
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
326
+
327
+ def _split_heads(self, hidden_states):
328
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
329
+
330
+ def _merge_heads(self, hidden_states):
331
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
332
+
333
+ def __call__(
334
+ self,
335
+ hidden_states: jnp.ndarray,
336
+ key_value_states: Optional[jnp.ndarray] = None,
337
+ attention_mask: Optional[jnp.ndarray] = None,
338
+ deterministic: bool = True,
339
+ ) -> Tuple[jnp.ndarray]:
340
+ """Input shape: Batch x Time x Channel"""
341
+
342
+ if self.config.fuse_matmuls:
343
+ attention_states = self.fused_proj(hidden_states)
344
+ query_states, key_states, value_states = jnp.split(attention_states, 3, axis=-1)
345
+
346
+ else:
347
+ # get query proj
348
+ query_states = self.q_proj(hidden_states)
349
+
350
+ key_states = self.k_proj(hidden_states)
351
+ value_states = self.v_proj(hidden_states)
352
+
353
+ query_states = self._split_heads(query_states)
354
+ key_states = self._split_heads(key_states)
355
+ value_states = self._split_heads(value_states)
356
+
357
+ if attention_mask is not None:
358
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
359
+
360
+ # Convert the boolean attention mask to an attention bias.
361
+ if attention_mask is not None:
362
+ # attention mask in the form of attention bias
363
+ attention_bias = lax.select(
364
+ attention_mask > 0,
365
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
366
+ jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
367
+ )
368
+ else:
369
+ attention_bias = None
370
+
371
+ dropout_rng = None
372
+ if not deterministic and self.dropout > 0.0:
373
+ dropout_rng = self.make_rng("dropout")
374
+
375
+ attn_weights = dot_product_attention_weights(
376
+ query_states,
377
+ key_states,
378
+ bias=attention_bias,
379
+ dropout_rng=dropout_rng,
380
+ dropout_rate=self.dropout,
381
+ broadcast_dropout=True,
382
+ deterministic=deterministic,
383
+ dtype=self.dtype,
384
+ precision=None,
385
+ )
386
+
387
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
388
+ attn_output = self._merge_heads(attn_output)
389
+ attn_output = self.out_proj(attn_output)
390
+
391
+ return attn_output, attn_weights
392
+
393
+
394
+ class FlaxWav2Vec2FeedForward(nn.Module):
395
+ config: Wav2Vec2Config
396
+ dtype: jnp.dtype = jnp.float32
397
+
398
+ def setup(self):
399
+ self.intermediate_dropout = nn.Dropout(rate=self.config.activation_dropout)
400
+
401
+ self.intermediate_dense = nn.Dense(
402
+ self.config.intermediate_size,
403
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
404
+ dtype=self.dtype,
405
+ )
406
+ if isinstance(self.config.hidden_act, str):
407
+ self.intermediate_act_fn = ACT2FN[self.config.hidden_act]
408
+ else:
409
+ self.intermediate_act_fn = self.config.hidden_act
410
+
411
+ self.output_dense = nn.Dense(
412
+ self.config.hidden_size,
413
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
414
+ dtype=self.dtype,
415
+ )
416
+ self.output_dropout = nn.Dropout(rate=self.config.hidden_dropout)
417
+
418
+ def __call__(self, hidden_states, deterministic=True):
419
+ hidden_states = self.intermediate_dense(hidden_states)
420
+ hidden_states = self.intermediate_act_fn(hidden_states)
421
+ hidden_states = self.intermediate_dropout(hidden_states, deterministic=deterministic)
422
+
423
+ hidden_states = self.output_dense(hidden_states)
424
+ hidden_states = self.output_dropout(hidden_states, deterministic=deterministic)
425
+ return hidden_states
426
+
427
+
428
+ class FlaxWav2Vec2EncoderLayerStableLayerNorm(nn.Module):
429
+ config: Wav2Vec2Config
430
+ dtype: jnp.dtype = jnp.float32
431
+
432
+ def setup(self):
433
+ self.attention = FlaxWav2Vec2Attention(
434
+ config=self.config,
435
+ embed_dim=self.config.hidden_size,
436
+ num_heads=self.config.num_attention_heads,
437
+ dropout=self.config.attention_dropout,
438
+ dtype=self.dtype,
439
+ )
440
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout)
441
+ self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
442
+ self.feed_forward = FlaxWav2Vec2FeedForward(self.config, dtype=self.dtype)
443
+ self.final_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
444
+
445
+ def __call__(self, hidden_states, attention_mask=None, deterministic=True, output_attentions=False):
446
+ if self.config.use_scan:
447
+ hidden_states = hidden_states[0]
448
+ attn_residual = hidden_states
449
+ hidden_states = self.layer_norm(hidden_states)
450
+ hidden_states, attn_weights = self.attention(
451
+ hidden_states, attention_mask=attention_mask, deterministic=deterministic
452
+ )
453
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
454
+ hidden_states = attn_residual + hidden_states
455
+ hidden_states = hidden_states + self.feed_forward(
456
+ self.final_layer_norm(hidden_states), deterministic=deterministic
457
+ )
458
+
459
+ outputs = (hidden_states,)
460
+
461
+ if output_attentions:
462
+ outputs += (attn_weights,)
463
+
464
+ if self.config.use_scan:
465
+ outputs = (outputs, None)
466
+
467
+ return outputs
468
+
469
+
470
+ class FlaxWav2Vec2EncoderLayerStableLayerNormCollection(nn.Module):
471
+ config: Wav2Vec2Config
472
+ dtype: jnp.dtype = jnp.float32
473
+
474
+ @nn.compact
475
+ def __call__(
476
+ self,
477
+ hidden_states,
478
+ attention_mask=None,
479
+ deterministic: bool = True,
480
+ output_attentions: bool = False,
481
+ output_hidden_states: bool = False,
482
+ return_dict: bool = True,
483
+ ):
484
+ all_attentions = () if output_attentions else None
485
+ all_hidden_states = () if output_hidden_states else None
486
+
487
+ num_layers = self.config.num_hidden_layers
488
+ BlockEncoderLayer = (
489
+ remat(
490
+ FlaxWav2Vec2EncoderLayerStableLayerNorm,
491
+ static_argnums=(2, 3),
492
+ prevent_cse=not self.config.use_scan,
493
+ )
494
+ if self.config.gradient_checkpointing
495
+ else FlaxWav2Vec2EncoderLayerStableLayerNorm
496
+ )
497
+
498
+ if self.config.use_scan:
499
+ # since all decoder layers are the same, we use nn.scan directly
500
+ assert not output_attentions, "cannot use `scan` with `output_attentions` set to `True`"
501
+ assert not output_hidden_states, "cannot use `scan` with `output_hidden_states` set to `True`"
502
+ hidden_states = (hidden_states,)
503
+
504
+ hidden_states, _ = scan_with_axes(
505
+ BlockEncoderLayer,
506
+ variable_axes={"params": 0, "cache": 0},
507
+ split_rngs={"params": True, "dropout": True},
508
+ in_axes=(nn.broadcast, nn.broadcast, nn.broadcast),
509
+ length=num_layers,
510
+ )(self.config, dtype=self.dtype, name="FlaxWav2Vec2EncoderLayers",)(
511
+ hidden_states, attention_mask, deterministic, output_attentions
512
+ )
513
+ hidden_states = hidden_states[0]
514
+
515
+ else:
516
+ for layer in range(num_layers):
517
+ if output_hidden_states:
518
+ all_hidden_states += (hidden_states,)
519
+
520
+ layer_outputs = BlockEncoderLayer(
521
+ self.config,
522
+ dtype=self.dtype,
523
+ name=str(layer),
524
+ )(hidden_states, attention_mask, deterministic, output_attentions)
525
+
526
+ hidden_states = layer_outputs[0]
527
+
528
+ if output_attentions:
529
+ all_attentions += (layer_outputs[1],)
530
+
531
+ if output_hidden_states:
532
+ all_hidden_states += (hidden_states,)
533
+
534
+ outputs = (hidden_states, all_hidden_states, all_attentions)
535
+
536
+ if not return_dict:
537
+ return tuple(v for v in outputs if v is not None)
538
+
539
+ return FlaxBaseModelOutput(
540
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
541
+ )
542
+
543
+
544
+ class FlaxWav2Vec2StableLayerNormEncoder(nn.Module):
545
+ config: Wav2Vec2Config
546
+ dtype: jnp.dtype = jnp.float32
547
+
548
+ def setup(self):
549
+ self.pos_conv_embed = FlaxWav2Vec2PositionalConvEmbedding(self.config, dtype=self.dtype)
550
+ self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
551
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout)
552
+ self.layers = FlaxWav2Vec2EncoderLayerStableLayerNormCollection(self.config, dtype=self.dtype)
553
+
554
+ def __call__(
555
+ self,
556
+ hidden_states,
557
+ attention_mask=None,
558
+ deterministic=True,
559
+ output_attentions=False,
560
+ output_hidden_states=False,
561
+ return_dict=True,
562
+ ):
563
+
564
+ if attention_mask is not None:
565
+ # make sure padded tokens are not attended to
566
+ hidden_states = jnp.where(
567
+ jnp.broadcast_to(attention_mask[:, :, None], hidden_states.shape), hidden_states, 0
568
+ )
569
+
570
+ position_embeddings = self.pos_conv_embed(hidden_states)
571
+
572
+ hidden_states = hidden_states + position_embeddings
573
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
574
+
575
+ outputs = self.layers(
576
+ hidden_states,
577
+ attention_mask,
578
+ output_attentions=output_attentions,
579
+ output_hidden_states=output_hidden_states,
580
+ return_dict=return_dict,
581
+ )
582
+
583
+ last_hidden_state = self.layer_norm(outputs[0])
584
+
585
+ # update the last element in `hidden_states` after applying `layernorm` above
586
+ hidden_states = None
587
+ if output_hidden_states:
588
+ hidden_states = outputs[1]
589
+ hidden_states = hidden_states[:-1] + (last_hidden_state,)
590
+
591
+ if not return_dict:
592
+ outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
593
+ return tuple(v for v in outputs if v is not None)
594
+
595
+ return FlaxBaseModelOutput(
596
+ last_hidden_state=last_hidden_state, hidden_states=hidden_states, attentions=outputs.attentions
597
+ )
598
+
599
+
600
+ class FlaxWav2Vec2Adapter(nn.Module):
601
+ config: Wav2Vec2Config
602
+ dtype: jnp.dtype = jnp.float32
603
+
604
+ def setup(self):
605
+ # hidden_states require down-projection if feature dims don't match
606
+ if self.config.output_hidden_size != self.config.hidden_size:
607
+ self.proj = nn.Dense(
608
+ self.config.output_hidden_size,
609
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
610
+ dtype=self.dtype,
611
+ )
612
+ self.proj_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
613
+ else:
614
+ self.proj = self.proj_layer_norm = None
615
+
616
+ self.layers = FlaxWav2Vec2AdapterLayersCollection(self.config, dtype=self.dtype)
617
+
618
+ def __call__(self, hidden_states, deterministic=True):
619
+ # down-project hidden_states if required
620
+ if self.proj is not None and self.proj_layer_norm is not None:
621
+ hidden_states = self.proj(hidden_states)
622
+ hidden_states = self.proj_layer_norm(hidden_states)
623
+
624
+ hidden_states = self.layers(hidden_states)
625
+
626
+ return hidden_states
627
+
628
+
629
+ class FlaxWav2Vec2AdapterLayer(nn.Module):
630
+ config: Wav2Vec2Config
631
+ dtype: jnp.dtype = jnp.float32
632
+
633
+ def setup(self):
634
+ self.conv = nn.Conv(
635
+ features=2 * self.config.output_hidden_size,
636
+ kernel_size=(self.config.adapter_kernel_size,),
637
+ strides=(self.config.adapter_stride,),
638
+ padding=((1, 1),),
639
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
640
+ dtype=self.dtype,
641
+ )
642
+
643
+ def __call__(self, hidden_states):
644
+ hidden_states = self.conv(hidden_states)
645
+ hidden_states = nn.glu(hidden_states, axis=2)
646
+
647
+ return hidden_states
648
+
649
+
650
+ class FlaxWav2Vec2AdapterLayersCollection(nn.Module):
651
+ config: Wav2Vec2Config
652
+ dtype: jnp.dtype = jnp.float32
653
+
654
+ def setup(self):
655
+ BlockAdapterLayer = remat(FlaxWav2Vec2AdapterLayer) if self.config.gradient_checkpointing else FlaxWav2Vec2AdapterLayer
656
+ self.layers = [
657
+ BlockAdapterLayer(self.config, name=str(i), dtype=self.dtype)
658
+ for i in range(self.config.num_adapter_layers)
659
+ ]
660
+
661
+ def __call__(self, hidden_states):
662
+ for conv_layer in self.layers:
663
+ hidden_states = conv_layer(hidden_states)
664
+
665
+ return hidden_states
666
+
667
+
668
+ class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
669
+ """
670
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
671
+ models.
672
+ """
673
+
674
+ config_class = Wav2Vec2Config
675
+ base_model_prefix: str = "wav2vec2"
676
+ main_input_name = "input_values"
677
+ module_class: nn.Module = None
678
+
679
+ def __init__(
680
+ self,
681
+ config: Wav2Vec2Config,
682
+ input_shape: Tuple = (1, 1024),
683
+ seed: int = 0,
684
+ dtype: jnp.dtype = jnp.float32,
685
+ _do_init: bool = True,
686
+ **kwargs,
687
+ ):
688
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
689
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
690
+
691
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
692
+ # init input tensors
693
+ input_values = jnp.zeros(input_shape, dtype="i4")
694
+ attention_mask = jnp.ones_like(input_values)
695
+ params_rng, dropout_rng = jax.random.split(rng, 2)
696
+ rngs = {"params": params_rng, "dropout": dropout_rng}
697
+
698
+ return self.module.init(rngs, input_values, attention_mask, return_dict=False)["params"]
699
+
700
+ def __call__(
701
+ self,
702
+ input_values,
703
+ attention_mask=None,
704
+ mask_time_indices=None,
705
+ extract_features=None,
706
+ params: dict = None,
707
+ dropout_rng: jax.random.PRNGKey = None,
708
+ train: bool = False,
709
+ output_attentions: Optional[bool] = None,
710
+ output_hidden_states: Optional[bool] = None,
711
+ output_features: Optional[bool] = None,
712
+ freeze_feature_encoder: bool = False,
713
+ return_dict: Optional[bool] = None,
714
+ ):
715
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
716
+ output_hidden_states = (
717
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
718
+ )
719
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
720
+
721
+ if attention_mask is None:
722
+ batch_size, sequence_length = input_values.shape
723
+ attention_mask = jnp.ones((batch_size, sequence_length))
724
+
725
+ if extract_features is not None:
726
+ extract_features = jnp.array(extract_features, dtype="f4")
727
+
728
+ # Handle any PRNG if needed
729
+ rngs = {}
730
+ if dropout_rng is not None:
731
+ rngs["dropout"] = dropout_rng
732
+
733
+ inputs = {"params": params or self.params}
734
+
735
+ return self.module.apply(
736
+ inputs,
737
+ jnp.array(input_values, dtype="f4"),
738
+ jnp.array(attention_mask, dtype="i4"),
739
+ mask_time_indices,
740
+ extract_features,
741
+ not train,
742
+ output_attentions,
743
+ output_hidden_states,
744
+ output_features,
745
+ freeze_feature_encoder,
746
+ return_dict,
747
+ rngs=rngs,
748
+ )
749
+
750
+ def _get_feat_extract_output_lengths(
751
+ self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
752
+ ):
753
+ return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter)
754
+
755
+ def _get_feature_vector_attention_mask(
756
+ self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
757
+ ):
758
+ return self.module._get_feature_vector_attention_mask(feature_vector_length, attention_mask, add_adapter=add_adapter)
759
+
760
+
761
+ class FlaxWav2Vec2Module(nn.Module):
762
+ config: Wav2Vec2Config
763
+ dtype: jnp.dtype = jnp.float32
764
+
765
+ def setup(self):
766
+ self.feature_extractor = FlaxWav2Vec2FeatureEncoder(self.config, dtype=self.dtype)
767
+ self.feature_projection = FlaxWav2Vec2FeatureProjection(self.config, dtype=self.dtype)
768
+ self.masked_spec_embed = self.param(
769
+ "masked_spec_embed", jax.nn.initializers.uniform(), (self.config.hidden_size,)
770
+ )
771
+
772
+ if self.config.do_stable_layer_norm:
773
+ self.encoder = FlaxWav2Vec2StableLayerNormEncoder(self.config, dtype=self.dtype)
774
+ else:
775
+ raise NotImplementedError("``config.do_stable_layer_norm is False`` is currently not supported.")
776
+
777
+ self.adapter = FlaxWav2Vec2Adapter(self.config, dtype=self.dtype) if self.config.add_adapter else None
778
+
779
+ def __call__(
780
+ self,
781
+ input_values,
782
+ attention_mask=None,
783
+ mask_time_indices=None,
784
+ extract_features=None,
785
+ deterministic=True,
786
+ output_attentions=None,
787
+ output_hidden_states=None,
788
+ output_features=False,
789
+ freeze_feature_encoder=False,
790
+ return_dict=None,
791
+ ):
792
+
793
+ # forward pass through the feature extractor if features not specified
794
+ if extract_features is None:
795
+ extract_features = self.feature_extractor(input_values, freeze_feature_encoder=freeze_feature_encoder)
796
+
797
+ if output_features:
798
+ return extract_features
799
+
800
+ # make sure that no loss is computed on padded inputs
801
+ if attention_mask is not None:
802
+ # compute reduced attention_mask corresponding to feature vectors
803
+ attention_mask = self._get_feature_vector_attention_mask(
804
+ extract_features.shape[1], attention_mask, add_adapter=False
805
+ )
806
+
807
+ hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic)
808
+ if mask_time_indices is not None: # apply SpecAugment along time axis with given indices
809
+ hidden_states = jnp.where(
810
+ jnp.broadcast_to(mask_time_indices[:, :, None], hidden_states.shape),
811
+ jnp.broadcast_to(self.masked_spec_embed[None, None, :], hidden_states.shape),
812
+ hidden_states,
813
+ )
814
+
815
+ encoder_outputs = self.encoder(
816
+ hidden_states,
817
+ attention_mask=attention_mask,
818
+ deterministic=deterministic,
819
+ output_attentions=output_attentions,
820
+ output_hidden_states=output_hidden_states,
821
+ return_dict=return_dict,
822
+ )
823
+
824
+ hidden_states = encoder_outputs[0]
825
+
826
+ if self.adapter is not None:
827
+ hidden_states = self.adapter(hidden_states)
828
+
829
+ if not return_dict:
830
+ return (hidden_states, extract_features) + encoder_outputs[1:]
831
+
832
+ return FlaxWav2Vec2BaseModelOutput(
833
+ last_hidden_state=hidden_states,
834
+ extract_features=extract_features,
835
+ hidden_states=encoder_outputs.hidden_states,
836
+ attentions=encoder_outputs.attentions,
837
+ )
838
+
839
+ def _get_feat_extract_output_lengths(
840
+ self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
841
+ ):
842
+ """
843
+ Computes the output length of the convolutional layers
844
+ """
845
+
846
+ add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
847
+
848
+ def _conv_out_length(input_length, kernel_size, stride):
849
+ # 1D convolutional layer output length formula taken
850
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
851
+ return (input_length - kernel_size) // stride + 1
852
+
853
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
854
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
855
+
856
+ if add_adapter:
857
+ for _ in range(self.config.num_adapter_layers):
858
+ input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
859
+
860
+ return input_lengths
861
+
862
+ def _get_feature_vector_attention_mask(
863
+ self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
864
+ ):
865
+ # Effectively attention_mask.sum(-1), but not inplace to be able to run
866
+ # on inference mode.
867
+ non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1]
868
+
869
+ output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
870
+
871
+ batch_size = attention_mask.shape[0]
872
+
873
+ attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
874
+ # these two operations makes sure that all values
875
+ # before the output lengths indices are attended to
876
+ attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1)
877
+ attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
878
+ return attention_mask
879
+
880
+
881
+ class FlaxWav2Vec2Model(FlaxWav2Vec2PreTrainedModel):
882
+ module_class = FlaxWav2Vec2Module
883
+
884
+
885
+ class FlaxWav2Vec2ForCTCModule(nn.Module):
886
+ config: Wav2Vec2Config
887
+ dtype: jnp.dtype = jnp.float32
888
+
889
+ def setup(self):
890
+ self.wav2vec2 = FlaxWav2Vec2Module(self.config, dtype=self.dtype)
891
+ self.dropout = nn.Dropout(rate=self.config.final_dropout)
892
+ self.lm_head = nn.Dense(
893
+ self.config.vocab_size,
894
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
895
+ dtype=self.dtype,
896
+ )
897
+
898
+ def __call__(
899
+ self,
900
+ input_values,
901
+ attention_mask=None,
902
+ mask_time_indices=None,
903
+ extract_features=None,
904
+ deterministic=True,
905
+ output_attentions=None,
906
+ output_hidden_states=None,
907
+ output_features=False,
908
+ freeze_feature_encoder=False,
909
+ return_dict=None,
910
+ ):
911
+ outputs = self.wav2vec2(
912
+ input_values,
913
+ attention_mask=attention_mask,
914
+ mask_time_indices=mask_time_indices,
915
+ deterministic=deterministic,
916
+ output_attentions=output_attentions,
917
+ output_hidden_states=output_hidden_states,
918
+ freeze_feature_encoder=freeze_feature_encoder,
919
+ return_dict=return_dict,
920
+ )
921
+
922
+ hidden_states = outputs[0]
923
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
924
+
925
+ logits = self.lm_head(hidden_states)
926
+
927
+ if not return_dict:
928
+ return (logits,) + outputs[2:]
929
+
930
+ return FlaxCausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
931
+
932
+ def _get_feat_extract_output_lengths(
933
+ self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
934
+ ):
935
+ """
936
+ Computes the output length of the convolutional layers
937
+ """
938
+
939
+ add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
940
+
941
+ def _conv_out_length(input_length, kernel_size, stride):
942
+ # 1D convolutional layer output length formula taken
943
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
944
+ return (input_length - kernel_size) // stride + 1
945
+
946
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
947
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
948
+
949
+ if add_adapter:
950
+ for _ in range(self.config.num_adapter_layers):
951
+ input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
952
+
953
+ return input_lengths
954
+
955
+ def _get_feature_vector_attention_mask(
956
+ self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
957
+ ):
958
+ # Effectively attention_mask.sum(-1), but not inplace to be able to run
959
+ # on inference mode.
960
+ non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1]
961
+
962
+ output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
963
+
964
+ batch_size = attention_mask.shape[0]
965
+
966
+ attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
967
+ # these two operations makes sure that all values
968
+ # before the output lengths indices are attended to
969
+ attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1)
970
+ attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
971
+ return attention_mask
972
+
973
+
974
+ class FlaxWav2Vec2ForCTC(FlaxWav2Vec2PreTrainedModel):
975
+ module_class = FlaxWav2Vec2ForCTCModule
preprocessor_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "feature_extractor_type": "Wav2Vec2FeatureExtractor",
4
+ "feature_size": 1,
5
+ "padding_side": "right",
6
+ "padding_value": 0.0,
7
+ "return_attention_mask": true,
8
+ "sampling_rate": 16000
9
+ }
run_flax_speech_recognition_seq2seq.py ADDED
@@ -0,0 +1,1490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Fine-tuning the Flax library models for sequence to sequence speech recognition.
17
+ """
18
+ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
19
+
20
+ import logging
21
+ import math
22
+ import os
23
+ import sys
24
+ import time
25
+ from dataclasses import dataclass, field
26
+ from pathlib import Path
27
+ from typing import Any, Callable, Dict, List, Optional, Union
28
+
29
+ import datasets
30
+ import numpy as np
31
+ from datasets import DatasetDict, load_dataset, load_metric
32
+ from tqdm import tqdm
33
+
34
+ import flax
35
+ import jax
36
+ import jax.numpy as jnp
37
+ import optax
38
+ import transformers
39
+ import wandb as wandb
40
+ from flax import core, jax_utils, struct, traverse_util
41
+ from flax.jax_utils import pad_shard_unpad, unreplicate
42
+ from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
43
+ from huggingface_hub import Repository
44
+ from models import FlaxSpeechEncoderDecoderModel
45
+ from optax._src import linear_algebra
46
+ from transformers import (
47
+ AutoConfig,
48
+ AutoFeatureExtractor,
49
+ AutoProcessor,
50
+ AutoTokenizer,
51
+ HfArgumentParser,
52
+ Seq2SeqTrainingArguments,
53
+ is_tensorboard_available,
54
+ )
55
+ from transformers.file_utils import get_full_repo_name
56
+ from transformers.trainer_utils import get_last_checkpoint
57
+ from transformers.utils import check_min_version
58
+ from transformers.utils.versions import require_version
59
+
60
+
61
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
62
+ check_min_version("4.17.0.dev0")
63
+
64
+ require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
65
+
66
+ logger = logging.getLogger(__name__)
67
+
68
+
69
+ @flax.struct.dataclass
70
+ class ModelArguments:
71
+ """
72
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
73
+ """
74
+
75
+ model_name_or_path: str = field(
76
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
77
+ )
78
+ config_name: Optional[str] = field(
79
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
80
+ )
81
+ tokenizer_name: Optional[str] = field(
82
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
83
+ )
84
+ feature_extractor_name: Optional[str] = field(
85
+ default=None, metadata={"help": "feature extractor name or path if not the same as model_name"}
86
+ )
87
+ cache_dir: Optional[str] = field(
88
+ default=None,
89
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
90
+ )
91
+ use_fast_tokenizer: bool = field(
92
+ default=True,
93
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
94
+ )
95
+ model_revision: str = field(
96
+ default="main",
97
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
98
+ )
99
+ use_auth_token: bool = field(
100
+ default=False,
101
+ metadata={
102
+ "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
103
+ "with private models)."
104
+ },
105
+ )
106
+ freeze_feature_encoder: bool = field(
107
+ default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
108
+ )
109
+ activation_dropout: float = field(
110
+ default=0.1,
111
+ metadata={
112
+ "help": "The hidden activation dropout probability in the embeddings, encoder, and pooler."
113
+ },
114
+ )
115
+ hidden_dropout: float = field(
116
+ default=0.1,
117
+ metadata={
118
+ "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler."
119
+ },
120
+ )
121
+ feat_proj_dropout: float = field(
122
+ default=0.0,
123
+ metadata={
124
+ "help": "The feat proj dropout probability for feature encoder representations."
125
+ },
126
+ )
127
+ mask_time_prob: float = field(
128
+ default=0.1,
129
+ metadata={
130
+ "help": "The spec aug dropout probability for feature encoder representations."
131
+ },
132
+ )
133
+ encoder_add_adapter: bool = field(
134
+ default=True, metadata={"help": "Whether to add an adapter layer between the encoder and decoder."}
135
+ )
136
+
137
+
138
+ @flax.struct.dataclass
139
+ class DataTrainingArguments:
140
+ """
141
+ Arguments pertaining to what data we are going to input our model for training and eval.
142
+ """
143
+
144
+ dataset_name: str = field(
145
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
146
+ )
147
+ dataset_config_name: Optional[str] = field(
148
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
149
+ )
150
+ text_column: Optional[str] = field(
151
+ default=None,
152
+ metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
153
+ )
154
+ dataset_cache_dir: Optional[str] = field(
155
+ default=None, metadata={"help": "Path to cache directory for saving and loading datasets"}
156
+ )
157
+ overwrite_cache: bool = field(
158
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
159
+ )
160
+ preprocessing_num_workers: Optional[int] = field(
161
+ default=None,
162
+ metadata={"help": "The number of processes to use for the preprocessing."},
163
+ )
164
+ max_train_samples: Optional[int] = field(
165
+ default=None,
166
+ metadata={
167
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
168
+ "value if set."
169
+ },
170
+ )
171
+ max_eval_samples: Optional[int] = field(
172
+ default=None,
173
+ metadata={
174
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
175
+ "value if set."
176
+ },
177
+ )
178
+ max_test_samples: Optional[int] = field(
179
+ default=None,
180
+ metadata={
181
+ "help": "For debugging purposes or quicker training, truncate the number of test examples to this "
182
+ "value if set."
183
+ },
184
+ )
185
+ audio_column_name: str = field(
186
+ default="audio",
187
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
188
+ )
189
+ text_column_name: str = field(
190
+ default="text",
191
+ metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
192
+ )
193
+ id_column_name: str = field(
194
+ default="id",
195
+ metadata={"help": "The name of the dataset column containing the id data. Defaults to 'id'"},
196
+ )
197
+ max_duration_in_seconds: float = field(
198
+ default=20.0,
199
+ metadata={
200
+ "help": "Filter audio files in the training set that are longer than `max_duration_in_seconds` seconds"
201
+ },
202
+ )
203
+ min_duration_in_seconds: float = field(
204
+ default=0.0, metadata={"help": "Filter audio files in the training set that are shorter than `min_duration_in_seconds` seconds"}
205
+ )
206
+ max_eval_duration_in_seconds: float = field(
207
+ default=None,
208
+ metadata={
209
+ "help": "Filter audio files in the eval/test set that are longer than `max_duration_in_seconds` seconds"
210
+ },
211
+ )
212
+ max_target_length: Optional[int] = field(
213
+ default=128,
214
+ metadata={
215
+ "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
216
+ "than this will be truncated, sequences shorter will be padded."
217
+ },
218
+ )
219
+ min_target_length: Optional[int] = field(
220
+ default=0,
221
+ metadata={
222
+ "help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
223
+ "than this will be filtered."
224
+ },
225
+ )
226
+ pad_input_to_multiple_of: Optional[int] = field(
227
+ default=24000,
228
+ metadata={
229
+ "help": "If set will pad the input sequence to a multiple of the provided value. "
230
+ "This is important to avoid triggering recompilations on TPU."
231
+ },
232
+ )
233
+ pad_target_to_multiple_of: Optional[int] = field(
234
+ default=None,
235
+ metadata={
236
+ "help": "If set will pad the target sequence to a multiple of the provided value. "
237
+ "This is important to avoid triggering recompilations on TPU. If unspecified, will default to `max_target_length`, "
238
+ " the equivalent of padding the targets to max length."
239
+ },
240
+ )
241
+ preprocessing_only: bool = field(
242
+ default=False,
243
+ metadata={
244
+ "help": "Whether to only do data preprocessing and skip training. "
245
+ "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
246
+ "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
247
+ "so that the cached datasets can consequently be loaded in distributed training"
248
+ },
249
+ )
250
+ train_split_name: str = field(
251
+ default="train",
252
+ metadata={
253
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
254
+ },
255
+ )
256
+ eval_split_name: str = field(
257
+ default="validation",
258
+ metadata={
259
+ "help": "The name of the evaluation data set split to use (via the datasets library). Defaults to 'validation'"
260
+ },
261
+ )
262
+ test_split_name: str = field(
263
+ default="test",
264
+ metadata={"help": "The name of the test data set split to use (via the datasets library). Defaults to 'test'"},
265
+ )
266
+ wandb_project: str = field(
267
+ default="flax-speech-recognition-seq2seq",
268
+ metadata={"help": "The name of the wandb project."},
269
+ )
270
+ wandb_name: str = field(
271
+ default=None,
272
+ metadata={"help": "The name of the wandb run."},
273
+ )
274
+ wandb_job_type: str = field(
275
+ default="Seq2Seq",
276
+ metadata={"help": "The name of the wandb job type."},
277
+ )
278
+ log_first_ids: bool = field(
279
+ default=True,
280
+ metadata={
281
+ "help": "Whether to log the first id's from the dataset. Defaults to `True`. If `False`, will log the first id's returned by the grouped length sampler."
282
+ },
283
+ )
284
+
285
+
286
+ # @flax.struct.dataclass
287
+ @dataclass
288
+ class FlaxSeq2SeqTrainingArguments(Seq2SeqTrainingArguments):
289
+ precision: str = field(
290
+ default="full",
291
+ metadata={
292
+ "help": "Whether to enable mixed-precision training. If true, the optimizer is stored in half-precision (bfloat16) and computations are executed in half-precision"
293
+ "**Note that this only specifies the dtype of the computation and optimizer state. It does not influence the dtype of model parameters.**"
294
+ },
295
+ )
296
+ matmul_precision: str = field(
297
+ default="default",
298
+ metadata={
299
+ "help": "Default floating-point precision of internal computations used in TPU matrix multiplications and convolutions. "
300
+ "This configuration option controls the default precision for JAX operations that take an optional precision argument (e.g. `lax.conv_general_dilated` and `lax.dot`). "
301
+ "This configuration option does not change the behaviours of such calls with explicit precision arguments; "
302
+ "it only changes the behaviors of calls with no such argument provided. "
303
+ "One of `['highest', 'float32', 'high', 'bfloat16_3x', 'default', 'bfloat16', 'fastest', None]`."
304
+ },
305
+ )
306
+ generation_length_penalty: float = field(
307
+ default=1,
308
+ metadata={
309
+ "help": "Exponential penalty to the length. 1.0 (default) means no penalty. Set to values < 1.0 in order to encourage the model"
310
+ "to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer sequences."
311
+ },
312
+ )
313
+ final_generation_max_length: int = field(
314
+ default=None,
315
+ metadata={
316
+ "help": "The `max_length` to use on each evaluation loop when `predict_with_generate=True`. If unspecified, will default "
317
+ "to the `max_length` value of the model configuration."
318
+ },
319
+ )
320
+ final_generation_num_beams: int = field(
321
+ default=None,
322
+ metadata={
323
+ "help": "The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. If unspecified, will default "
324
+ "to the `num_beams` value of the model configuration."
325
+ },
326
+ )
327
+
328
+ def __post_init__(self):
329
+ if self.final_generation_max_length is None:
330
+ self.final_generation_max_length = self.generation_max_length
331
+ if self.final_generation_num_beams is None:
332
+ self.final_generation_num_beams = self.generation_num_beams
333
+
334
+
335
+ def to_fp32(t):
336
+ return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
337
+
338
+
339
+ def to_bf16(t):
340
+ return jax.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t)
341
+
342
+
343
+ class MixedPrecisionTrainState(struct.PyTreeNode):
344
+ """Train state for use with a single Optax optimizer.
345
+ Adapted from flax train_state https://github.com/google/flax/blob/main/flax/training/train_state.py
346
+
347
+ Synopsis::
348
+
349
+ state = TrainState.create(
350
+ apply_fn=model.apply,
351
+ params=variables['params'],
352
+ tx=tx)
353
+ grad_fn = jax.grad(make_loss_fn(state.apply_fn))
354
+ for batch in data:
355
+ grads = grad_fn(state.params, batch)
356
+ state = state.apply_gradients(grads=grads)
357
+
358
+ Args:
359
+ step: Counter starts at 0 and is incremented by every call to
360
+ `.apply_gradients()`.
361
+ apply_fn: Usually set to `model.apply()`. Kept in this dataclass for
362
+ convenience to have a shorter params list for the `train_step()` function
363
+ in your training loop.
364
+ params: The parameters to be updated by `tx` and used by `apply_fn`.
365
+ tx: An Optax gradient transformation.
366
+ opt_state: The state for `tx`.
367
+ dropout_rng: PRNG key for stochastic operations.
368
+ bf16: Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training.
369
+ """
370
+
371
+ step: int
372
+ apply_fn: Callable = struct.field(pytree_node=False)
373
+ params: core.FrozenDict[str, Any]
374
+ tx: optax.GradientTransformation = struct.field(pytree_node=False)
375
+ opt_state: optax.OptState
376
+ dropout_rng: jnp.ndarray
377
+ max_grad_norm: Optional[float] = 1.0
378
+
379
+ def apply_gradients(self, *, grads, to_dtype, **kwargs):
380
+ """Updates `step`, `params`, `opt_state` and `**kwargs` in return value.
381
+
382
+ Note that internally this function calls `.tx.update()` followed by a call
383
+ to `optax.apply_updates()` to update `params` and `opt_state`.
384
+
385
+ Args:
386
+ grads: Gradients that have the same pytree structure as `.params`.
387
+ **kwargs: Additional dataclass attributes that should be `.replace()`-ed.
388
+
389
+ Returns:
390
+ An updated instance of `self` with `step` incremented by one, `params`
391
+ and `opt_state` updated by applying `grads`, and additional attributes
392
+ replaced as specified by `kwargs`.
393
+ """
394
+
395
+ # clip gradients by global l2 norm
396
+ casted_max_grad_norm = to_dtype(self.max_grad_norm)
397
+ g_norm = linear_algebra.global_norm(grads)
398
+ g_norm = jnp.maximum(casted_max_grad_norm, g_norm)
399
+ grads = jax.tree_map(lambda t: (t / g_norm) * casted_max_grad_norm, grads)
400
+
401
+ # perform update step in fp32 and subsequently downcast optimizer states if mixed precision training
402
+ # grads and opt_state in bf16 (need to upcast), params in fp32 (leave as is)
403
+ updates, new_opt_state = self.tx.update(to_fp32(grads), to_fp32(self.opt_state), self.params)
404
+
405
+ new_params = optax.apply_updates(self.params, updates)
406
+ return self.replace(
407
+ step=self.step + 1,
408
+ params=new_params,
409
+ opt_state=to_dtype(new_opt_state),
410
+ **kwargs,
411
+ )
412
+
413
+ @classmethod
414
+ def create(cls, *, apply_fn, params, tx, to_dtype, **kwargs):
415
+ """Creates a new instance with `step=0` and initialized `opt_state`."""
416
+ # downcast optimizer state to bf16 if mixed-precision training
417
+ opt_state = tx.init(to_dtype(params)) if tx is not None else None
418
+ return cls(
419
+ step=0,
420
+ apply_fn=apply_fn,
421
+ params=params,
422
+ tx=tx,
423
+ opt_state=opt_state,
424
+ **kwargs,
425
+ )
426
+
427
+ def replicate(self):
428
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
429
+
430
+
431
+ def pad_to_max_length(data, tokenizer):
432
+ # Get lengths of each row of data
433
+ lens = np.array([len(i) for i in data])
434
+
435
+ # Mask of valid places in each row
436
+ mask = np.arange(lens.max()) < lens[:, None]
437
+
438
+ # Setup output array and put elements from data into masked positions
439
+ out = np.ones_like(mask, dtype=data.dtype) * tokenizer.pad_token_id
440
+ out[mask] = np.concatenate(data)
441
+ return out
442
+
443
+
444
+ def shift_tokens_right(label_ids: np.array, decoder_start_token_id: int) -> np.ndarray:
445
+ """
446
+ Shift label ids one token to the right.
447
+ """
448
+ shifted_label_ids = np.zeros_like(label_ids)
449
+ shifted_label_ids[:, 1:] = label_ids[:, :-1]
450
+ shifted_label_ids[:, 0] = decoder_start_token_id
451
+
452
+ return shifted_label_ids
453
+
454
+
455
+ @flax.struct.dataclass
456
+ class FlaxDataCollatorSpeechSeq2SeqWithPadding:
457
+ """
458
+ Data collator that will dynamically pad the inputs received.
459
+ Args:
460
+ processor ([`Wav2Vec2Processor`])
461
+ The processor used for proccessing the data.
462
+ decoder_start_token_id (:obj: `int`)
463
+ The begin-of-sentence of the decoder.
464
+ input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
465
+ Select a strategy to pad the returned input sequences (according to the model's padding side and padding index)
466
+ among:
467
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
468
+ sequence if provided).
469
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
470
+ maximum acceptable input length for the model if that argument is not provided.
471
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
472
+ different lengths).
473
+ target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
474
+ Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
475
+ See above for details.
476
+ max_input_length (:obj:`float`, `optional`):
477
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
478
+ max_target_length (:obj:`int`, `optional`):
479
+ Maximum length of the ``labels`` of the returned list and optionally padding length (see above).
480
+ pad_input_to_multiple_of (:obj:`int`, `optional`):
481
+ If set will pad the input sequence to a multiple of the provided value.
482
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
483
+ 7.5 (Volta).
484
+ pad_target_to_multiple_of (:obj:`int`, `optional`):
485
+ If set will pad the target sequence to a multiple of the provided value.
486
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
487
+ 7.5 (Volta).
488
+ """
489
+
490
+ processor: Any
491
+ decoder_start_token_id: int
492
+ input_padding: Union[bool, str] = "longest"
493
+ target_padding: Union[bool, str] = "max_length"
494
+ max_input_length: Optional[float] = None
495
+ max_target_length: Optional[int] = None
496
+ pad_input_to_multiple_of: Optional[int] = None
497
+ pad_target_to_multiple_of: Optional[int] = None
498
+
499
+ def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
500
+ # split inputs and labels since they have to be of different lengths and need
501
+ # different padding methods
502
+ input_features = [{"input_values": feature["input_values"]} for feature in features]
503
+ input_ids = [feature["input_id"] for feature in features]
504
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
505
+
506
+ # reformat list to dict and set to pytorch format
507
+ batch = self.processor.feature_extractor.pad(
508
+ input_features,
509
+ max_length=self.max_input_length,
510
+ padding=self.input_padding,
511
+ pad_to_multiple_of=self.pad_input_to_multiple_of,
512
+ return_tensors="np",
513
+ )
514
+
515
+ labels_batch = self.processor.tokenizer.pad(
516
+ label_features,
517
+ max_length=self.max_target_length,
518
+ padding=self.target_padding,
519
+ pad_to_multiple_of=self.pad_target_to_multiple_of,
520
+ return_tensors="np",
521
+ )
522
+
523
+ # if bos token is appended in previous tokenization step,
524
+ # cut bos token here as it's append later anyways
525
+ labels = labels_batch["input_ids"]
526
+ if (labels[:, 0] == self.decoder_start_token_id).all().item():
527
+ labels = labels[:, 1:]
528
+ labels_batch.attention_mask = labels_batch.attention_mask[:, 1:]
529
+
530
+ decoder_input_ids = shift_tokens_right(labels, self.decoder_start_token_id)
531
+
532
+ # replace padding with -100 to ignore correctly when computing the loss
533
+ labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1))
534
+ labels = labels.filled(fill_value=-100)
535
+
536
+ batch["inputs"] = batch.pop("input_values")
537
+ batch["input_ids"] = input_ids
538
+ batch["labels"] = labels
539
+ batch["decoder_input_ids"] = decoder_input_ids
540
+
541
+ return batch
542
+
543
+
544
+ def get_grouped_indices(
545
+ dataset, batch_size: int, rng: Optional[List[int]] = None, mega_batch_mult: Optional[int] = None
546
+ ) -> np.array:
547
+ """
548
+ Adapted from the `get_length_grouped_indices` function in the PyTorch Trainer utils file (https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_pt_utils.py#L486)
549
+ Function that returns a list of indices in which each slice of `batch_size` consecutive indices correspond to elements of similar
550
+ lengths. To do this, the indices are:
551
+
552
+ - randomly permuted (if a JAX rng is specified)
553
+ - grouped in mega-batches of size `mega_batch_mult * batch_size`
554
+ - sorted by length in each mega-batch
555
+
556
+ The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
557
+ maximum length placed first, so that an OOM happens sooner rather than later.
558
+ """
559
+ lengths = dataset["input_length"]
560
+
561
+ # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
562
+ if mega_batch_mult is None:
563
+ mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
564
+ # Just in case, for tiny datasets
565
+ if mega_batch_mult == 0:
566
+ mega_batch_mult = 1
567
+
568
+ # We need to use JAX for the random permutation as the PRNG key will be set based on the seed outside of the sampler.
569
+ num_samples = len(lengths)
570
+ indices = jax.random.permutation(rng, np.arange(num_samples)) if rng is not None else np.arange(num_samples)
571
+ indices = np.asarray(indices)
572
+
573
+ megabatch_size = mega_batch_mult * batch_size
574
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
575
+ megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches]
576
+
577
+ # The rest is to get the biggest batch first.
578
+ # Since each megabatch is sorted by descending length, the longest element is the first
579
+ megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
580
+ max_idx = np.argmax(megabatch_maximums).item()
581
+ # Switch to put the longest batch in first position
582
+ # (note that this is different to the PT grouped sampler in which we only put the longest element in the first position, and not its batch)
583
+ megabatches[0], megabatches[max_idx] = megabatches[max_idx], megabatches[0]
584
+
585
+ megabatches = np.array([i for megabatch in megabatches for i in megabatch])
586
+
587
+ return megabatches
588
+
589
+
590
+ def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last_batch=True) -> np.ndarray:
591
+ """Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
592
+ the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned."""
593
+ num_samples = len(samples_idx)
594
+ if drop_last_batch:
595
+ samples_to_remove = num_samples % batch_size
596
+ if samples_to_remove != 0:
597
+ samples_idx = samples_idx[:-samples_to_remove]
598
+ sections_split = num_samples // batch_size
599
+ samples_idx = samples_idx.reshape((sections_split, batch_size))
600
+ else:
601
+ sections_split = math.ceil(num_samples / batch_size)
602
+ samples_idx = np.array_split(samples_idx, sections_split)
603
+ return samples_idx
604
+
605
+
606
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
607
+ summary_writer.scalar("train_time", train_time, step)
608
+
609
+ train_metrics = get_metrics(train_metrics)
610
+ for key, vals in train_metrics.items():
611
+ tag = f"train_{key}"
612
+ for i, val in enumerate(vals):
613
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
614
+
615
+
616
+ def write_eval_metric(summary_writer, eval_metrics, step, pred_str=None):
617
+ for metric_name, value in eval_metrics.items():
618
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
619
+
620
+ if pred_str is not None:
621
+ # write output actual predictions for debugging
622
+ summary_writer.text("eval_predictions", "\n".join(pred_str), step)
623
+
624
+
625
+ def write_wandb_log(metrics, step, prefix=None):
626
+ if jax.process_index() == 0:
627
+ log_metrics = {}
628
+ for k, v in metrics.items():
629
+ if "layer" in k:
630
+ log_metrics[f"{k}/"] = v
631
+ elif prefix is not None:
632
+ log_metrics[f"{prefix}/{k}"] = v
633
+ else:
634
+ log_metrics[k] = v
635
+ wandb.log(log_metrics, step)
636
+
637
+
638
+ def write_wandb_pred(pred_str, label_str, eval_ids, step, prefix="eval", top_ids=None, final_step=True):
639
+ if jax.process_index() == 0:
640
+ top_ids = top_ids if top_ids else eval_ids
641
+ num_beams = len(pred_str)
642
+ # convert str data to a wandb compatible format
643
+ str_data = []
644
+ for id in top_ids:
645
+ if id in eval_ids:
646
+ idx = eval_ids.index(id)
647
+ str_data.append([eval_ids[idx], label_str[idx]] + [pred_str[beam][idx] for beam in range(num_beams)])
648
+ columns = ["id", "label_str"] + [f"beam_{i + 1}" for i in range(num_beams)]
649
+ wandb.log(
650
+ {f"{prefix}/step_{int(step / 1000)}k": wandb.Table(columns=columns, data=str_data[:50])},
651
+ step,
652
+ )
653
+ if final_step:
654
+ str_data = np.array(str_data)
655
+ wandb.log(
656
+ {f"{prefix}/step_{int(step / 1000)}k_all": wandb.Table(columns=columns, data=str_data[:200000])},
657
+ step,
658
+ )
659
+ str_data = str_data[str_data[:, 1] != str_data[:, 2]]
660
+ wandb.log(
661
+ {f"{prefix}/step_{int(step / 1000)}k_incorrect": wandb.Table(columns=columns, data=str_data[:200000])},
662
+ step,
663
+ )
664
+
665
+
666
+ def create_learning_rate_fn(
667
+ num_train_steps: int, num_warmup_steps: int, learning_rate: float
668
+ ) -> Callable[[int], jnp.array]:
669
+ """Returns a linear warmup, linear_decay learning rate function."""
670
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
671
+ decay_fn = optax.linear_schedule(
672
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
673
+ )
674
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
675
+ return schedule_fn
676
+
677
+
678
+ def main():
679
+ # 1. Parse input arguments
680
+ # See all possible arguments in src/transformers/training_args.py
681
+ # or by passing the --help flag to this script.
682
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
683
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FlaxSeq2SeqTrainingArguments))
684
+
685
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
686
+ # If we pass only one argument to the script and it's the path to a json file,
687
+ # let's parse it to get our arguments.
688
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
689
+ else:
690
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
691
+
692
+ # 2. Setup logging
693
+ # Make one log on every process with the configuration for debugging.
694
+ logging.basicConfig(
695
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
696
+ datefmt="%m/%d/%Y %H:%M:%S",
697
+ handlers=[logging.StreamHandler(sys.stdout)],
698
+ )
699
+ # Set the verbosity to info of the Transformers logger.
700
+ # We only want one process per machine to log things on the screen.
701
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
702
+ if jax.process_index() == 0:
703
+ datasets.utils.logging.set_verbosity_warning()
704
+ transformers.utils.logging.set_verbosity_info()
705
+ else:
706
+ datasets.utils.logging.set_verbosity_error()
707
+ transformers.utils.logging.set_verbosity_error()
708
+
709
+ # Set up wandb run
710
+ if jax.process_index() == 0:
711
+ wandb.init(project=data_args.wandb_project, name=data_args.wandb_name, job_type=data_args.wandb_job_type)
712
+
713
+ logger.info("Training/evaluation parameters %s", training_args)
714
+
715
+ # Set the default TPU matmul precision and display the number of devices
716
+ jax.config.update("jax_default_matmul_precision", training_args.matmul_precision)
717
+ logger.info(f"JAX devices: {jax.device_count()}, matmul precision: {training_args.matmul_precision}")
718
+
719
+ # TODO: 3. Detecting last checkpoint and eventually continue from last checkpoint
720
+ last_checkpoint = None
721
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
722
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
723
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
724
+ raise ValueError(
725
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
726
+ "Use --overwrite_output_dir to overcome."
727
+ )
728
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
729
+ logger.info(
730
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
731
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
732
+ )
733
+
734
+ # 4. Load dataset
735
+ raw_datasets = DatasetDict()
736
+
737
+ if training_args.do_train:
738
+ raw_datasets["train"] = load_dataset(
739
+ data_args.dataset_name,
740
+ data_args.dataset_config_name,
741
+ split=data_args.train_split_name,
742
+ cache_dir=data_args.dataset_cache_dir,
743
+ use_auth_token=True if model_args.use_auth_token else None,
744
+ )
745
+
746
+ if training_args.do_eval:
747
+ raw_datasets["eval"] = load_dataset(
748
+ data_args.dataset_name,
749
+ data_args.dataset_config_name,
750
+ split=data_args.eval_split_name,
751
+ cache_dir=data_args.dataset_cache_dir,
752
+ use_auth_token=True if model_args.use_auth_token else None,
753
+ )
754
+
755
+ if training_args.do_predict:
756
+ test_split = data_args.test_split_name.split("+")
757
+ for split in test_split:
758
+ raw_datasets[split] = load_dataset(
759
+ data_args.dataset_name,
760
+ data_args.dataset_config_name,
761
+ split=split,
762
+ cache_dir=data_args.dataset_cache_dir,
763
+ use_auth_token=True if model_args.use_auth_token else None,
764
+ )
765
+
766
+ if not training_args.do_train and not training_args.do_eval and not training_args.do_predict:
767
+ raise ValueError(
768
+ "Cannot not train, not do evaluation and not do prediction. At least one of "
769
+ "training, evaluation or prediction has to be done."
770
+ )
771
+
772
+ # if not training, there is no need to run multiple epochs
773
+ if not training_args.do_train:
774
+ training_args.num_train_epochs = 1
775
+
776
+ if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
777
+ raise ValueError(
778
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
779
+ "Make sure to set `--audio_column_name` to the correct audio column - one of "
780
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
781
+ )
782
+
783
+ if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names:
784
+ raise ValueError(
785
+ f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
786
+ "Make sure to set `--text_column_name` to the correct text column - one of "
787
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
788
+ )
789
+
790
+ if data_args.log_first_ids and data_args.id_column_name not in next(iter(raw_datasets.values())).column_names:
791
+ raise ValueError(
792
+ f"--id_column_name {data_args.id_column_name} not found in dataset '{data_args.dataset_name}'. "
793
+ "Make sure to set `--id_column_name` to the correct id column - one of "
794
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
795
+ )
796
+
797
+ # 5. Load pretrained model, tokenizer, and feature extractor
798
+ #
799
+ # Distributed training:
800
+ # The .from_pretrained methods guarantee that only one local process can concurrently
801
+ config = AutoConfig.from_pretrained(
802
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
803
+ cache_dir=model_args.cache_dir,
804
+ revision=model_args.model_revision,
805
+ use_auth_token=True if model_args.use_auth_token else None,
806
+ )
807
+
808
+ # update config according to training and model args
809
+ config.encoder.update(
810
+ {
811
+ "gradient_checkpointing": training_args.gradient_checkpointing,
812
+ "hidden_dropout": model_args.hidden_dropout,
813
+ "activation_dropout": model_args.activation_dropout,
814
+ "feat_proj_dropout": model_args.feat_proj_dropout,
815
+ "mask_time_prob": model_args.mask_time_prob,
816
+ "add_adapter": model_args.encoder_add_adapter,
817
+ }
818
+ )
819
+ config.decoder.update(
820
+ {
821
+ "gradient_checkpointing": training_args.gradient_checkpointing,
822
+ "dropout": model_args.hidden_dropout,
823
+ "activation_dropout": model_args.activation_dropout,
824
+ }
825
+ )
826
+
827
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
828
+ model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
829
+ cache_dir=model_args.cache_dir,
830
+ revision=model_args.model_revision,
831
+ use_auth_token=True if model_args.use_auth_token else None,
832
+ )
833
+ tokenizer = AutoTokenizer.from_pretrained(
834
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
835
+ cache_dir=model_args.cache_dir,
836
+ use_fast=model_args.use_fast_tokenizer,
837
+ revision=model_args.model_revision,
838
+ use_auth_token=True if model_args.use_auth_token else None,
839
+ )
840
+
841
+ if training_args.precision == "full_mixed":
842
+ dtype = jnp.bfloat16
843
+ training_args.mixed_precision = True
844
+ elif training_args.precision == "half_mixed":
845
+ dtype = jnp.bfloat16
846
+ training_args.mixed_precision = False
847
+ else:
848
+ dtype = jnp.float32
849
+ training_args.mixed_precision = False
850
+
851
+ model = FlaxSpeechEncoderDecoderModel.from_pretrained(
852
+ model_args.model_name_or_path,
853
+ config=config,
854
+ dtype=dtype,
855
+ cache_dir=model_args.cache_dir,
856
+ revision=model_args.model_revision,
857
+ use_auth_token=True if model_args.use_auth_token else None,
858
+ )
859
+
860
+ if model.config.decoder_start_token_id is None:
861
+ raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
862
+
863
+ # 6. Resample speech dataset ALWAYS
864
+ raw_datasets = raw_datasets.cast_column(data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate))
865
+
866
+ # 7. Preprocessing the datasets.
867
+ # We need to read the audio files as arrays and tokenize the targets.
868
+ max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
869
+ min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate)
870
+ max_eval_input_length = int(data_args.max_eval_duration_in_seconds * feature_extractor.sampling_rate) if data_args.max_eval_duration_in_seconds else None
871
+ max_target_length = data_args.max_target_length
872
+ min_target_length = data_args.min_target_length
873
+ pad_input_to_multiple_of = data_args.pad_input_to_multiple_of
874
+ pad_target_to_multiple_of = data_args.pad_target_to_multiple_of
875
+ audio_column_name = data_args.audio_column_name
876
+ num_workers = data_args.preprocessing_num_workers
877
+ text_column_name = data_args.text_column_name
878
+ id_column_name = data_args.id_column_name
879
+ model_input_name = feature_extractor.model_input_names[0]
880
+ log_first_ids = data_args.log_first_ids
881
+
882
+ if training_args.do_train and data_args.max_train_samples is not None:
883
+ raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
884
+
885
+ if training_args.do_eval and data_args.max_eval_samples is not None:
886
+ raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
887
+
888
+ if training_args.do_predict and data_args.max_test_samples is not None:
889
+ for split in test_split:
890
+ raw_datasets[split] = raw_datasets[split].select(range(data_args.max_eval_samples))
891
+
892
+
893
+ def prepare_dataset(batch):
894
+ # Pre-process audio
895
+ sample = batch[audio_column_name]
896
+
897
+ # normalise audio (mean, std) to (0, 1)
898
+ inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
899
+ # process audio length
900
+ batch[model_input_name] = inputs.input_values[0]
901
+ batch["input_length"] = len(batch["input_values"])
902
+ batch["input_id"] = batch[id_column_name] if log_first_ids else None
903
+
904
+ input_str = batch[text_column_name]
905
+ # Finally, we tokenize the processed text
906
+ batch["labels"] = tokenizer(input_str).input_ids
907
+ batch["labels_length"] = len(batch["labels"])
908
+ return batch
909
+
910
+ vectorized_datasets = raw_datasets.map(
911
+ prepare_dataset,
912
+ remove_columns=next(iter(raw_datasets.values())).column_names,
913
+ num_proc=num_workers,
914
+ desc="preprocess train dataset",
915
+ )
916
+
917
+ # filter training data with inputs longer than max_input_length
918
+ def is_audio_in_length_range(length):
919
+ return min_input_length < length < max_input_length
920
+
921
+ if training_args.do_train:
922
+ vectorized_datasets["train"] = vectorized_datasets["train"].filter(
923
+ is_audio_in_length_range,
924
+ num_proc=num_workers,
925
+ input_columns=["input_length"],
926
+ )
927
+
928
+ if max_eval_input_length is not None:
929
+ # filter training data with inputs longer than max_input_length
930
+ def is_eval_audio_in_length_range(length):
931
+ return min_input_length < length < max_eval_input_length
932
+
933
+ if training_args.do_eval:
934
+ vectorized_datasets["eval"] = vectorized_datasets["eval"].filter(
935
+ is_eval_audio_in_length_range,
936
+ num_proc=num_workers,
937
+ input_columns=["input_length"],
938
+ )
939
+
940
+ if training_args.do_test:
941
+ for split in test_split:
942
+ vectorized_datasets[split] = vectorized_datasets[split].filter(
943
+ is_eval_audio_in_length_range,
944
+ num_proc=num_workers,
945
+ input_columns=["input_length"],
946
+ )
947
+
948
+ # filter data with targets shorter than min_target_length or longer than max_target_length
949
+ def is_labels_in_length_range(length):
950
+ return min_target_length < length < max_target_length
951
+
952
+ if training_args.do_train:
953
+ vectorized_datasets["train"] = vectorized_datasets["train"].filter(
954
+ is_labels_in_length_range,
955
+ num_proc=num_workers,
956
+ input_columns=["labels_length"],
957
+ )
958
+
959
+ # for large datasets it is advised to run the preprocessing on a
960
+ # single machine first with `args.preprocessing_only` since there will mostly likely
961
+ # be a timeout when running the script in distributed mode.
962
+ # In a second step `args.preprocessing_only` can then be set to `False` to load the
963
+ # cached dataset
964
+ if data_args.preprocessing_only:
965
+ cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
966
+ logger.info(f"Data preprocessing finished. Files cached at {cache}.")
967
+ return
968
+
969
+ # 8. Load Metrics
970
+ wer_metric = load_metric("wer")
971
+ cer_metric = load_metric("cer")
972
+
973
+ def compute_metrics(pred_ids: List[List[int]], label_ids: List[List[int]]):
974
+ label_ids = (
975
+ pad_to_max_length(np.array(label_ids, dtype="object"), tokenizer)
976
+ if pad_target_to_multiple_of
977
+ else label_ids
978
+ )
979
+
980
+ padded_ids = np.where(np.asarray(label_ids) == -100, tokenizer.pad_token_id, np.asarray(label_ids))
981
+ # we do not want to group tokens when computing the metrics
982
+ label_str = tokenizer.batch_decode(padded_ids, skip_special_tokens=True)
983
+
984
+ pred_ids = np.array(pred_ids)
985
+ num_beams = pred_ids.shape[1]
986
+ # decode on a beam-by-beam basis
987
+ pred_str = [
988
+ tokenizer.batch_decode(pred_ids[:, beam, :], skip_special_tokens=True)
989
+ for beam in reversed(range(num_beams))
990
+ ]
991
+ # compute word/character error rate for top beam
992
+ wer = wer_metric.compute(predictions=pred_str[0], references=label_str)
993
+ cer = cer_metric.compute(predictions=pred_str[0], references=label_str)
994
+
995
+ return {"wer": wer, "cer": cer}, pred_str, label_str
996
+
997
+ # 9. Save feature extractor, tokenizer and config
998
+ feature_extractor.save_pretrained(training_args.output_dir)
999
+ tokenizer.save_pretrained(training_args.output_dir)
1000
+ config.save_pretrained(training_args.output_dir)
1001
+
1002
+ processor = AutoProcessor.from_pretrained(training_args.output_dir)
1003
+
1004
+ data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding(
1005
+ processor=processor,
1006
+ decoder_start_token_id=model.config.decoder_start_token_id,
1007
+ input_padding="longest",
1008
+ target_padding="longest",
1009
+ max_target_length=max_target_length,
1010
+ pad_input_to_multiple_of=pad_input_to_multiple_of,
1011
+ pad_target_to_multiple_of=pad_target_to_multiple_of if pad_target_to_multiple_of else max_target_length,
1012
+ )
1013
+
1014
+ # Enable tensorboard only on the master node
1015
+ has_tensorboard = is_tensorboard_available()
1016
+ if has_tensorboard and jax.process_index() == 0:
1017
+ try:
1018
+ from flax.metrics.tensorboard import SummaryWriter
1019
+
1020
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
1021
+ except ImportError as ie:
1022
+ has_tensorboard = False
1023
+ logger.warning(
1024
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
1025
+ )
1026
+ else:
1027
+ logger.warning(
1028
+ "Unable to display metrics through TensorBoard because the package is not installed: "
1029
+ "Please run `pip install tensorboard` to enable."
1030
+ )
1031
+
1032
+ # 10. Handle the repository creation
1033
+ if training_args.push_to_hub:
1034
+ with open(os.path.join(training_args.output_dir, ".gitattributes"), "r+") as f:
1035
+ git_lfs_extensions = f.read()
1036
+ if "*.wandb" not in git_lfs_extensions:
1037
+ f.write("*.wandb filter=lfs diff=lfs merge=lfs -text")
1038
+ if training_args.hub_model_id is None:
1039
+ repo_name = get_full_repo_name(
1040
+ Path(training_args.output_dir).absolute().name, token=training_args.hub_token
1041
+ )
1042
+ else:
1043
+ repo_name = training_args.hub_model_id
1044
+ repo = Repository(training_args.output_dir, clone_from=repo_name)
1045
+
1046
+ # 11. Initialize our training
1047
+ rng = jax.random.PRNGKey(training_args.seed)
1048
+ rng, dropout_rng = jax.random.split(rng)
1049
+
1050
+ # Store some constants
1051
+ max_steps = int(training_args.max_steps)
1052
+ gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
1053
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
1054
+ batch_size_per_update = train_batch_size * gradient_accumulation_steps
1055
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
1056
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
1057
+ to_dtype = to_bf16 if training_args.mixed_precision else to_fp32
1058
+
1059
+ if training_args.do_train:
1060
+ num_train_samples = len(vectorized_datasets["train"])
1061
+ steps_per_epoch = num_train_samples // batch_size_per_update
1062
+ if max_steps > 0:
1063
+ num_epochs = -(training_args.max_steps // -steps_per_epoch)
1064
+ total_train_steps = max_steps
1065
+ else:
1066
+ num_epochs = int(training_args.num_train_epochs)
1067
+ total_train_steps = steps_per_epoch * num_epochs
1068
+
1069
+ # Create learning rate schedule
1070
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
1071
+ total_train_steps,
1072
+ training_args.warmup_steps,
1073
+ training_args.learning_rate,
1074
+ )
1075
+
1076
+ # We use Optax's "masking" functionality to not apply weight decay
1077
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
1078
+ # mask boolean with the same structure as the parameters.
1079
+ # The mask is True for parameters that should be decayed.
1080
+ # Note that this mask is specifically adapted for FlaxWav2Vec2 and FlaxBart.
1081
+ # For FlaxT5, one should correct the layer norm parameter naming
1082
+ # accordingly - see `run_t5_mlm_flax.py` e.g.
1083
+ def decay_mask_fn(params):
1084
+ flat_params = traverse_util.flatten_dict(params)
1085
+ layer_norm_params = [
1086
+ (name, "scale")
1087
+ for name in ["layer_norm", "self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
1088
+ ]
1089
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
1090
+ return traverse_util.unflatten_dict(flat_mask)
1091
+
1092
+ if training_args.adafactor:
1093
+ # Create Adafactor optimizer
1094
+ optim = optax.adafactor(
1095
+ learning_rate=linear_decay_lr_schedule_fn,
1096
+ dtype_momentum=jnp.bfloat16 if training_args.mixed_precision else jnp.float32,
1097
+ weight_decay_rate=training_args.weight_decay,
1098
+ weight_decay_mask=decay_mask_fn,
1099
+ )
1100
+ else:
1101
+ # Create AdamW optimizer
1102
+ optim = optax.adamw(
1103
+ learning_rate=linear_decay_lr_schedule_fn,
1104
+ b1=training_args.adam_beta1,
1105
+ b2=training_args.adam_beta2,
1106
+ eps=training_args.adam_epsilon,
1107
+ weight_decay=training_args.weight_decay,
1108
+ mask=decay_mask_fn,
1109
+ )
1110
+ else:
1111
+ num_epochs = 0
1112
+ total_train_steps = 0
1113
+ num_train_samples = 0
1114
+ optim = None
1115
+
1116
+ # Setup train state
1117
+ state = MixedPrecisionTrainState.create(
1118
+ apply_fn=model.__call__,
1119
+ params=model.params,
1120
+ tx=optim,
1121
+ to_dtype=to_dtype,
1122
+ dropout_rng=dropout_rng,
1123
+ max_grad_norm=training_args.max_grad_norm,
1124
+ )
1125
+
1126
+ # Cross entropy loss
1127
+ def loss_fn(logits, labels):
1128
+ vocab_size = logits.shape[-1]
1129
+ # optax onehot always returns a float32 device array, need to downcast if performing mixed precision training
1130
+ onehot_targets = to_dtype(onehot(labels, vocab_size))
1131
+ loss = optax.softmax_cross_entropy(logits, onehot_targets)
1132
+ # ignore padded tokens from loss, i.e. where labels are not set to -100
1133
+ padding = labels >= 0
1134
+ loss = loss * padding
1135
+ loss = loss.sum()
1136
+ num_labels = padding.sum()
1137
+ return loss, num_labels
1138
+
1139
+ # Define gradient update step fn
1140
+ def train_step(state, batch):
1141
+ # only one single rng per grad step, with or without accumulation, as the graph should be identical over one effective training batch
1142
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
1143
+
1144
+ def compute_loss(params, minibatch):
1145
+ labels = minibatch.pop("labels")
1146
+ logits = state.apply_fn(
1147
+ **minibatch,
1148
+ params=params,
1149
+ dropout_rng=dropout_rng,
1150
+ freeze_feature_encoder=model_args.freeze_feature_encoder,
1151
+ train=True,
1152
+ )[0]
1153
+ loss, num_labels = loss_fn(logits, labels)
1154
+ return loss, num_labels
1155
+
1156
+ grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
1157
+
1158
+ if gradient_accumulation_steps == 1:
1159
+ (loss, num_labels), grad = grad_fn(to_dtype(state.params), batch)
1160
+
1161
+ # Custom gradient accumulation
1162
+ else:
1163
+ # add a first dimension over gradient_accumulation_steps for minibatch slices
1164
+ batch = jax.tree_map(
1165
+ lambda x: x.reshape(
1166
+ gradient_accumulation_steps, training_args.per_device_train_batch_size, *x.shape[1::]
1167
+ ),
1168
+ batch,
1169
+ )
1170
+
1171
+ def accum_minibatch_step(accum_grad, minibatch):
1172
+ # compute loss, num labels and grad over minibatch and accumulate
1173
+ (loss, num_labels), grad = grad_fn(to_dtype(state.params), minibatch)
1174
+ return jax.tree_map(jnp.add, accum_grad, grad), (loss, num_labels)
1175
+
1176
+ # create an initial state for accumulating losses, num labels and gradients
1177
+ init_grad = jax.tree_map(jnp.zeros_like, to_dtype(state.params))
1178
+ # loop accum minibatch step over the number of gradient accumulation steps
1179
+ grad, (loss, num_labels) = jax.lax.scan(accum_minibatch_step, init_grad, batch)
1180
+
1181
+ grad = jax.lax.psum(grad, "batch")
1182
+ loss = jax.lax.psum(loss.sum(), "batch")
1183
+ total_samples = jax.lax.psum(num_labels.sum(), "batch")
1184
+ grad = jax.tree_map(lambda g: g / total_samples, grad)
1185
+ loss = jax.tree_map(lambda l: l / total_samples, loss)
1186
+
1187
+ # update state
1188
+ new_state = state.apply_gradients(
1189
+ grads=grad,
1190
+ dropout_rng=new_dropout_rng,
1191
+ to_dtype=to_dtype,
1192
+ )
1193
+
1194
+ # compute gradient norms over all layers, total encoder, total decoder and global for detailed monitoring
1195
+ layer_grad_norm = jax.tree_map(jnp.linalg.norm, grad)
1196
+ logs = {
1197
+ "layer_grad_norm": layer_grad_norm,
1198
+ "encoder_grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm["encoder"])),
1199
+ "decoder_grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm["decoder"])),
1200
+ }
1201
+ logs["grad_norm"] = jnp.linalg.norm([logs["encoder_grad_norm"], logs["decoder_grad_norm"]])
1202
+
1203
+ # compute parameter norms over all layers, total encoder, total decoder and global for detailed monitoring
1204
+ layer_param_norm = jax.tree_map(jnp.linalg.norm, new_state.params)
1205
+ logs["layer_param_norm"] = layer_param_norm
1206
+ logs["encoder_param_norm"] = jnp.linalg.norm(jax.tree_util.tree_leaves(layer_param_norm["encoder"]))
1207
+ logs["decoder_param_norm"] = jnp.linalg.norm(jax.tree_util.tree_leaves(layer_param_norm["decoder"]))
1208
+ logs["param_norm"] = jnp.linalg.norm([logs["encoder_param_norm"], logs["decoder_param_norm"]])
1209
+
1210
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
1211
+ metrics.update(logs)
1212
+
1213
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1214
+ # metrics = to_fp32(metrics)
1215
+
1216
+ return new_state, metrics
1217
+
1218
+ # Define eval fn
1219
+ def eval_step(params, batch):
1220
+ labels = batch.pop("labels")
1221
+ logits = model(**batch, params=params, train=False)[0]
1222
+ loss, num_labels = loss_fn(logits, labels)
1223
+
1224
+ total_samples = jax.lax.psum(num_labels, "batch")
1225
+ loss = jax.lax.psum(loss, "batch")
1226
+ loss = jax.tree_map(lambda l: l / total_samples, loss)
1227
+
1228
+ # summarize metrics
1229
+ metrics = {"loss": loss}
1230
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1231
+ # metrics = to_fp32(metrics)
1232
+ return metrics
1233
+
1234
+ # Define generation function
1235
+ gen_kwargs = {
1236
+ "max_length": training_args.generation_max_length,
1237
+ "num_beams": training_args.generation_num_beams,
1238
+ "length_penalty": training_args.generation_length_penalty,
1239
+ }
1240
+ final_gen_kwargs = {
1241
+ "max_length": training_args.final_generation_max_length,
1242
+ "num_beams": training_args.final_generation_num_beams,
1243
+ "length_penalty": training_args.generation_length_penalty,
1244
+ }
1245
+
1246
+ def generate_step(params, batch):
1247
+ model.params = params
1248
+ output_ids = model.generate(batch["inputs"], **gen_kwargs)
1249
+ return output_ids.sequences
1250
+
1251
+ def final_generate_step(params, batch):
1252
+ model.params = params
1253
+ output_ids = model.generate(batch["inputs"], **final_gen_kwargs)
1254
+ return output_ids.sequences
1255
+
1256
+ # Create parallel version of the train and eval step
1257
+ if training_args.do_train:
1258
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
1259
+
1260
+ if training_args.do_eval or training_args.do_predict:
1261
+ p_eval_step = jax.pmap(eval_step, "batch")
1262
+
1263
+ if training_args.predict_with_generate:
1264
+ p_generate_step = jax.pmap(generate_step, "batch")
1265
+ p_final_generate_step = jax.pmap(final_generate_step, "batch")
1266
+
1267
+ def run_evaluation(step, final_step=False):
1268
+ if training_args.do_eval:
1269
+ # ======================== Evaluating ==============================
1270
+ eval_metrics = []
1271
+ eval_preds = []
1272
+ eval_ids = []
1273
+ eval_labels = []
1274
+
1275
+ # Generate eval set by sequentially sampling indices from the eval dataset and grouping by length
1276
+ eval_samples_idx = get_grouped_indices(vectorized_datasets["eval"], eval_batch_size)
1277
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last_batch=False)
1278
+
1279
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
1280
+ samples = [vectorized_datasets["eval"][int(idx)] for idx in batch_idx]
1281
+ batch = data_collator(samples)
1282
+ eval_ids.extend(batch.pop("input_ids"))
1283
+ labels = batch["labels"]
1284
+
1285
+ metrics = pad_shard_unpad(p_eval_step, static_return=True)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1286
+ eval_metrics.append(metrics)
1287
+
1288
+ # generation
1289
+ if training_args.predict_with_generate:
1290
+ if not final_step:
1291
+ generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1292
+ eval_preds.extend(
1293
+ jax.device_get(
1294
+ generated_ids.reshape(-1, gen_kwargs["num_beams"], gen_kwargs["max_length"])
1295
+ )
1296
+ )
1297
+ else:
1298
+ generated_ids = pad_shard_unpad(p_final_generate_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1299
+ eval_preds.extend(
1300
+ jax.device_get(
1301
+ generated_ids.reshape(
1302
+ -1, final_gen_kwargs["num_beams"], final_gen_kwargs["max_length"]
1303
+ )
1304
+ )
1305
+ )
1306
+ eval_labels.extend(labels)
1307
+
1308
+ # normalize eval metrics
1309
+ eval_metrics = get_metrics(eval_metrics)
1310
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
1311
+ eval_metrics = to_fp32(eval_metrics)
1312
+
1313
+ # compute error rate metric and get predicted string (for debugging)
1314
+ error_rate_desc = ""
1315
+ pred_str = []
1316
+ label_str = []
1317
+ if training_args.predict_with_generate:
1318
+ error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels)
1319
+ eval_metrics.update(error_rate_metric)
1320
+ error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()])
1321
+
1322
+ # Print metrics and update progress bar
1323
+ desc = f"Step... ({step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})"
1324
+ epochs.write(desc)
1325
+ epochs.desc = desc
1326
+
1327
+ # Save metrics
1328
+ write_wandb_log(eval_metrics, step, prefix="eval")
1329
+ write_wandb_pred(
1330
+ pred_str,
1331
+ label_str,
1332
+ eval_ids,
1333
+ step,
1334
+ top_ids=vectorized_datasets["eval"]["input_id"] if data_args.log_first_ids else None,
1335
+ final_step=final_step,
1336
+ )
1337
+ # if has_tensorboard and jax.process_index() == 0:
1338
+ # write_eval_metric(summary_writer, eval_metrics, step, pred_str=pred_str)
1339
+
1340
+ def save_checkpoint(step):
1341
+ # save and push checkpoint to the hub
1342
+ if jax.process_index() == 0:
1343
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
1344
+ model.save_pretrained(training_args.output_dir, params=params)
1345
+ tokenizer.save_pretrained(training_args.output_dir)
1346
+ if training_args.push_to_hub:
1347
+ repo.push_to_hub(commit_message=f"{wandb.run.id}: saving weights and logs of step {int(step / 1000)}k", blocking=False)
1348
+
1349
+ # Replicate the train state on each device
1350
+ state = state.replicate()
1351
+
1352
+ logger.info("***** Running training *****")
1353
+ logger.info(f" Num examples = {num_train_samples}")
1354
+ logger.info(f" Num Epochs = {num_epochs}")
1355
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
1356
+ logger.info(f" Num gradient accumulation steps = {gradient_accumulation_steps}")
1357
+ logger.info(f" Total train batch size (w. parallel & distributed) = {batch_size_per_update}")
1358
+ logger.info(f" Total optimization steps = {total_train_steps}")
1359
+ logger.info(f" Gradient checkpointing: {config.encoder.gradient_checkpointing}")
1360
+ logger.info(f" Use scan: {config.encoder.use_scan}")
1361
+ logger.info(f" Fuse matmuls: {config.encoder.fuse_matmuls}")
1362
+
1363
+ train_time = cur_step = 0
1364
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
1365
+ for epoch in epochs:
1366
+ if training_args.do_train:
1367
+ # ======================== Training ================================
1368
+ train_start = time.time()
1369
+
1370
+ # Create sampling rng
1371
+ rng, input_rng = jax.random.split(rng)
1372
+
1373
+ # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length
1374
+ train_samples_idx = get_grouped_indices(vectorized_datasets["train"], batch_size_per_update, input_rng)
1375
+ train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update, drop_last_batch=True)
1376
+
1377
+ # Gather the indices for creating the batch and do a training step
1378
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1), 1):
1379
+ samples = [vectorized_datasets["train"][int(idx)] for idx in batch_idx]
1380
+ batch = data_collator(samples)
1381
+ batch.pop("input_ids")
1382
+ batch = shard(batch.data)
1383
+ state, train_metric = p_train_step(state, batch)
1384
+
1385
+ cur_step = epoch * (num_train_samples // batch_size_per_update) + step
1386
+
1387
+ if cur_step % training_args.logging_steps == 0:
1388
+ # Save metrics
1389
+ train_metric = unreplicate(train_metric)
1390
+ train_time += time.time() - train_start
1391
+ # need to upcast all device arrays to fp32 for wandb logging (jnp.bfloat16 not supported) -> do this here OR in train_step
1392
+ write_wandb_log(to_fp32(train_metric), cur_step, prefix="train")
1393
+ # we won't log to tensorboard for now (it is fiddly logging param and grad norms on a layer-by-layer basis)
1394
+ # if has_tensorboard and jax.process_index() == 0:
1395
+ # write_train_metric(summary_writer, train_metrics, train_time, cur_step)
1396
+
1397
+ epochs.write(
1398
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']}, Gradient Norm: {train_metric['grad_norm']})"
1399
+ )
1400
+
1401
+ if cur_step % total_train_steps == 0:
1402
+ break
1403
+
1404
+ if training_args.eval_steps and cur_step % training_args.eval_steps == 0:
1405
+ # run beam search at each eval step
1406
+ run_evaluation(cur_step, final_step=False)
1407
+
1408
+ if cur_step % training_args.save_steps == 0:
1409
+ save_checkpoint(cur_step)
1410
+
1411
+ if training_args.eval_steps == 0 and (epoch + 1) != num_epochs:
1412
+ # run evaluation at the end of the epoch if eval steps are not specified
1413
+ run_evaluation(cur_step, final_step=False)
1414
+ save_checkpoint(cur_step)
1415
+
1416
+ if training_args.do_train:
1417
+ save_checkpoint(cur_step)
1418
+
1419
+ cur_step = max_steps if max_steps > 0 else cur_step # set step to max steps so that eval happens in alignment with training
1420
+
1421
+ if training_args.do_eval:
1422
+ run_evaluation(cur_step, final_step=True)
1423
+
1424
+ # TODO: collapse 'do_predict' into the run_evaluation function
1425
+ if training_args.do_predict:
1426
+ # ======================== Prediction ==============================
1427
+ for split in test_split:
1428
+ pred_metrics = []
1429
+ pred_generations = []
1430
+ pred_ids = []
1431
+ pred_labels = []
1432
+
1433
+ # Generate eval set by sequentially sampling indices from the eval dataset and grouping by length
1434
+ pred_samples_idx = get_grouped_indices(vectorized_datasets[split], eval_batch_size)
1435
+ pred_batch_idx = generate_batch_splits(pred_samples_idx, eval_batch_size, drop_last_batch=False)
1436
+
1437
+ for i, batch_idx in enumerate(tqdm(pred_batch_idx, desc=f"Predicting {split}...", position=2)):
1438
+ samples = [vectorized_datasets[split][int(idx)] for idx in batch_idx]
1439
+ batch = data_collator(samples)
1440
+ pred_ids.extend(batch.pop("input_ids"))
1441
+ labels = batch["labels"]
1442
+
1443
+ metrics = pad_shard_unpad(p_eval_step, static_return=True)(state.params, batch.data,
1444
+ min_device_batch=per_device_eval_batch_size)
1445
+ pred_metrics.append(metrics)
1446
+
1447
+ # generation
1448
+ if training_args.predict_with_generate:
1449
+ generated_ids = pad_shard_unpad(p_final_generate_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1450
+ pred_generations.extend(
1451
+ jax.device_get(
1452
+ generated_ids.reshape(-1, final_gen_kwargs["num_beams"], final_gen_kwargs["max_length"])
1453
+ )
1454
+ )
1455
+ pred_labels.extend(labels)
1456
+
1457
+ # normalize eval metrics
1458
+ pred_metrics = get_metrics(pred_metrics)
1459
+ pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
1460
+ pred_metrics = to_fp32(pred_metrics)
1461
+
1462
+ # compute error rate metric and get predicted string (for debugging)
1463
+ error_rate_desc = ""
1464
+ pred_str = []
1465
+ label_str = []
1466
+ if training_args.predict_with_generate:
1467
+ error_rate_metric, pred_str, label_str = compute_metrics(pred_generations, pred_labels)
1468
+ pred_metrics.update(error_rate_metric)
1469
+ error_rate_desc = " ".join([f"{split} {key}: {value} |" for key, value in error_rate_metric.items()])
1470
+
1471
+ # Print metrics and update progress bar
1472
+ desc = f"Step... ({cur_step}/{total_train_steps} | {split} Loss: {pred_metrics['loss']} | {error_rate_desc})"
1473
+ epochs.write(desc)
1474
+ epochs.desc = desc
1475
+
1476
+ # Save metrics
1477
+ write_wandb_log(pred_metrics, cur_step, prefix=split)
1478
+ write_wandb_pred(
1479
+ pred_str,
1480
+ label_str,
1481
+ pred_ids,
1482
+ cur_step,
1483
+ prefix=split,
1484
+ top_ids=vectorized_datasets[split]["input_id"] if data_args.log_first_ids else None,
1485
+ final_step=True,
1486
+ )
1487
+
1488
+
1489
+ if __name__ == "__main__":
1490
+ main()
run_spgispeech.sh ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ python run_flax_speech_recognition_seq2seq.py \
3
+ --dataset_name="esc-benchmark/esc-datasets" \
4
+ --model_name_or_path="esc-benchmark/wav2vec2-aed-pretrained" \
5
+ --dataset_config_name="spgispeech" \
6
+ --output_dir="./" \
7
+ --wandb_name="wav2vec2-aed-spgispeech" \
8
+ --wandb_project="wav2vec2-aed" \
9
+ --per_device_train_batch_size="8" \
10
+ --per_device_eval_batch_size="2" \
11
+ --learning_rate="1e-4" \
12
+ --warmup_steps="500" \
13
+ --logging_steps="25" \
14
+ --max_steps="50001" \
15
+ --eval_steps="10000" \
16
+ --save_steps="10000" \
17
+ --generation_max_length="40" \
18
+ --generation_num_beams="1" \
19
+ --final_generation_max_length="225" \
20
+ --final_generation_num_beams="14" \
21
+ --generation_length_penalty="1.6" \
22
+ --overwrite_output_dir \
23
+ --gradient_checkpointing \
24
+ --freeze_feature_encoder \
25
+ --predict_with_generate \
26
+ --do_eval \
27
+ --do_train \
28
+ --do_predict \
29
+ --push_to_hub \
30
+ --use_auth_token
special_tokens_map.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "cls_token": "<s>",
4
+ "eos_token": "</s>",
5
+ "mask_token": {
6
+ "content": "<mask>",
7
+ "lstrip": true,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "pad_token": "<pad>",
13
+ "sep_token": "</s>",
14
+ "unk_token": "<unk>"
15
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": "<s>",
4
+ "cls_token": "<s>",
5
+ "eos_token": "</s>",
6
+ "errors": "replace",
7
+ "mask_token": "<mask>",
8
+ "model_max_length": 1024,
9
+ "name_or_path": "sanchit-gandhi/flax-wav2vec2-2-bart-large-scan",
10
+ "pad_token": "<pad>",
11
+ "sep_token": "</s>",
12
+ "special_tokens_map_file": null,
13
+ "tokenizer_class": "BartTokenizer",
14
+ "trim_offsets": true,
15
+ "unk_token": "<unk>"
16
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff