sanchit-gandhi HF staff commited on
Commit
c11b15c
1 Parent(s): 6e1ee26
config.json ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "sanchit-gandhi/flax-wav2vec2-ctc-cv9-baseline-50k",
3
+ "activation_dropout": 0.1,
4
+ "adapter_kernel_size": 3,
5
+ "adapter_stride": 2,
6
+ "add_adapter": false,
7
+ "apply_spec_augment": true,
8
+ "architectures": [
9
+ "Wav2Vec2ForCTC"
10
+ ],
11
+ "attention_dropout": 0.1,
12
+ "bos_token_id": 1,
13
+ "classifier_proj_size": 256,
14
+ "codevector_dim": 768,
15
+ "contrastive_logits_temperature": 0.1,
16
+ "conv_bias": true,
17
+ "conv_dim": [
18
+ 512,
19
+ 512,
20
+ 512,
21
+ 512,
22
+ 512,
23
+ 512,
24
+ 512
25
+ ],
26
+ "conv_kernel": [
27
+ 10,
28
+ 3,
29
+ 3,
30
+ 3,
31
+ 3,
32
+ 2,
33
+ 2
34
+ ],
35
+ "conv_stride": [
36
+ 5,
37
+ 2,
38
+ 2,
39
+ 2,
40
+ 2,
41
+ 2,
42
+ 2
43
+ ],
44
+ "ctc_loss_reduction": "sum",
45
+ "ctc_zero_infinity": false,
46
+ "diversity_loss_weight": 0.1,
47
+ "do_stable_layer_norm": true,
48
+ "eos_token_id": 2,
49
+ "feat_extract_activation": "gelu",
50
+ "feat_extract_dropout": 0.0,
51
+ "feat_extract_norm": "layer",
52
+ "feat_proj_dropout": 0.0,
53
+ "feat_quantizer_dropout": 0.0,
54
+ "final_dropout": 0.0,
55
+ "fuse_matmuls": false,
56
+ "gradient_checkpointing": true,
57
+ "hidden_act": "gelu",
58
+ "hidden_dropout": 0.1,
59
+ "hidden_dropout_prob": 0.1,
60
+ "hidden_size": 1024,
61
+ "initializer_range": 0.02,
62
+ "intermediate_size": 4096,
63
+ "layer_norm_eps": 1e-05,
64
+ "layerdrop": 0.0,
65
+ "mask_feature_length": 10,
66
+ "mask_feature_min_masks": 0,
67
+ "mask_feature_prob": 0.0,
68
+ "mask_time_length": 10,
69
+ "mask_time_min_masks": 2,
70
+ "mask_time_prob": 0.1,
71
+ "model_type": "wav2vec2",
72
+ "num_adapter_layers": 3,
73
+ "num_attention_heads": 16,
74
+ "num_codevector_groups": 2,
75
+ "num_codevectors_per_group": 320,
76
+ "num_conv_pos_embedding_groups": 16,
77
+ "num_conv_pos_embeddings": 128,
78
+ "num_feat_extract_layers": 7,
79
+ "num_hidden_layers": 24,
80
+ "num_negatives": 100,
81
+ "output_hidden_size": 1024,
82
+ "pad_token_id": 0,
83
+ "proj_codevector_dim": 768,
84
+ "tdnn_dilation": [
85
+ 1,
86
+ 2,
87
+ 3,
88
+ 1,
89
+ 1
90
+ ],
91
+ "tdnn_dim": [
92
+ 512,
93
+ 512,
94
+ 512,
95
+ 512,
96
+ 1500
97
+ ],
98
+ "tdnn_kernel": [
99
+ 5,
100
+ 3,
101
+ 3,
102
+ 1,
103
+ 1
104
+ ],
105
+ "transformers_version": "4.18.0.dev0",
106
+ "use_scan": false,
107
+ "use_weighted_layer_sum": false,
108
+ "vocab_size": 32,
109
+ "xvector_output_dim": 512
110
+ }
flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0512bc1dd68b4c718e2d78943bc07ed0ecd05734e48eb2cd434b970dc477417c
3
+ size 1261901472
models/__init__.py ADDED
File without changes
models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (173 Bytes). View file
 
models/__pycache__/configuration_wav2vec2.cpython-38.pyc ADDED
Binary file (16.8 kB). View file
 
models/__pycache__/modeling_flax_wav2vec2.cpython-38.pyc ADDED
Binary file (30.7 kB). View file
 
models/configuration_bart.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ BART model configuration"""
16
+ import warnings
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+ BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
25
+ "facebook/bart-large": "https://huggingface.co/facebook/bart-large/resolve/main/config.json",
26
+ # See all BART models at https://huggingface.co/models?filter=bart
27
+ }
28
+
29
+
30
+ class BartConfig(PretrainedConfig):
31
+ r"""
32
+ This is the configuration class to store the configuration of a [`BartModel`]. It is used to instantiate a BART
33
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
34
+ defaults will yield a similar configuration to that of the BART
35
+ [facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture.
36
+
37
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
38
+ documentation from [`PretrainedConfig`] for more information.
39
+
40
+
41
+ Args:
42
+ vocab_size (`int`, *optional*, defaults to 50265):
43
+ Vocabulary size of the BART model. Defines the number of different tokens that can be represented by the
44
+ `inputs_ids` passed when calling [`BartModel`] or [`TFBartModel`].
45
+ d_model (`int`, *optional*, defaults to 1024):
46
+ Dimensionality of the layers and the pooler layer.
47
+ encoder_layers (`int`, *optional*, defaults to 12):
48
+ Number of encoder layers.
49
+ decoder_layers (`int`, *optional*, defaults to 12):
50
+ Number of decoder layers.
51
+ encoder_attention_heads (`int`, *optional*, defaults to 16):
52
+ Number of attention heads for each attention layer in the Transformer encoder.
53
+ decoder_attention_heads (`int`, *optional*, defaults to 16):
54
+ Number of attention heads for each attention layer in the Transformer decoder.
55
+ decoder_ffn_dim (`int`, *optional*, defaults to 4096):
56
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
57
+ encoder_ffn_dim (`int`, *optional*, defaults to 4096):
58
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
59
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
60
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
61
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
62
+ dropout (`float`, *optional*, defaults to 0.1):
63
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
64
+ attention_dropout (`float`, *optional*, defaults to 0.0):
65
+ The dropout ratio for the attention probabilities.
66
+ activation_dropout (`float`, *optional*, defaults to 0.0):
67
+ The dropout ratio for activations inside the fully connected layer.
68
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
69
+ The dropout ratio for classifier.
70
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
71
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
72
+ just in case (e.g., 512 or 1024 or 2048).
73
+ init_std (`float`, *optional*, defaults to 0.02):
74
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
75
+ encoder_layerdrop: (`float`, *optional*, defaults to 0.0):
76
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
77
+ for more details.
78
+ decoder_layerdrop: (`float`, *optional*, defaults to 0.0):
79
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
80
+ for more details.
81
+ scale_embedding (`bool`, *optional*, defaults to `False`):
82
+ Scale embeddings by diving by sqrt(d_model).
83
+ use_cache (`bool`, *optional*, defaults to `True`):
84
+ Whether or not the model should return the last key/values attentions (not used by all models).
85
+ num_labels: (`int`, *optional*, defaults to 3):
86
+ The number of labels to use in [`BartForSequenceClassification`].
87
+ forced_eos_token_id (`int`, *optional*, defaults to 2):
88
+ The id of the token to force as the last generated token when `max_length` is reached. Usually set to
89
+ `eos_token_id`.
90
+ use_scan (`bool`, *optional*, defaults to `False`):
91
+ Whether or not to use nn.scan in the Flax Bart attention layers.
92
+
93
+ Example:
94
+
95
+ ```python
96
+ >>> from transformers import BartModel, BartConfig
97
+
98
+ >>> # Initializing a BART facebook/bart-large style configuration
99
+ >>> configuration = BartConfig()
100
+
101
+ >>> # Initializing a model from the facebook/bart-large style configuration
102
+ >>> model = BartModel(configuration)
103
+
104
+ >>> # Accessing the model configuration
105
+ >>> configuration = model.config
106
+ ```"""
107
+ model_type = "bart"
108
+ keys_to_ignore_at_inference = ["past_key_values"]
109
+ attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
110
+
111
+ def __init__(
112
+ self,
113
+ vocab_size=50265,
114
+ max_position_embeddings=1024,
115
+ encoder_layers=12,
116
+ encoder_ffn_dim=4096,
117
+ encoder_attention_heads=16,
118
+ decoder_layers=12,
119
+ decoder_ffn_dim=4096,
120
+ decoder_attention_heads=16,
121
+ encoder_layerdrop=0.0,
122
+ decoder_layerdrop=0.0,
123
+ activation_function="gelu",
124
+ d_model=1024,
125
+ dropout=0.1,
126
+ attention_dropout=0.0,
127
+ activation_dropout=0.0,
128
+ init_std=0.02,
129
+ classifier_dropout=0.0,
130
+ scale_embedding=False,
131
+ use_cache=True,
132
+ use_scan=False,
133
+ fuse_matmuls=False,
134
+ num_labels=3,
135
+ pad_token_id=1,
136
+ bos_token_id=0,
137
+ eos_token_id=2,
138
+ is_encoder_decoder=True,
139
+ decoder_start_token_id=2,
140
+ forced_eos_token_id=2,
141
+ **kwargs
142
+ ):
143
+ self.vocab_size = vocab_size
144
+ self.max_position_embeddings = max_position_embeddings
145
+ self.d_model = d_model
146
+ self.encoder_ffn_dim = encoder_ffn_dim
147
+ self.encoder_layers = encoder_layers
148
+ self.encoder_attention_heads = encoder_attention_heads
149
+ self.decoder_ffn_dim = decoder_ffn_dim
150
+ self.decoder_layers = decoder_layers
151
+ self.decoder_attention_heads = decoder_attention_heads
152
+ self.dropout = dropout
153
+ self.attention_dropout = attention_dropout
154
+ self.activation_dropout = activation_dropout
155
+ self.activation_function = activation_function
156
+ self.init_std = init_std
157
+ self.encoder_layerdrop = encoder_layerdrop
158
+ self.decoder_layerdrop = decoder_layerdrop
159
+ self.classifier_dropout = classifier_dropout
160
+ self.use_cache = use_cache
161
+ self.use_scan = use_scan
162
+ self.fuse_matmuls = fuse_matmuls
163
+ self.num_hidden_layers = encoder_layers
164
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
165
+
166
+ super().__init__(
167
+ num_labels=num_labels,
168
+ pad_token_id=pad_token_id,
169
+ bos_token_id=bos_token_id,
170
+ eos_token_id=eos_token_id,
171
+ is_encoder_decoder=is_encoder_decoder,
172
+ decoder_start_token_id=decoder_start_token_id,
173
+ forced_eos_token_id=forced_eos_token_id,
174
+ **kwargs,
175
+ )
176
+
177
+ # ensure backward compatibility for BART CNN models
178
+ if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
179
+ self.forced_bos_token_id = self.bos_token_id
180
+ warnings.warn(
181
+ f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
182
+ "The config can simply be saved and uploaded again to be fixed."
183
+ )
models/configuration_speech_encoder_decoder.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import copy
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.utils import logging
21
+ from models.configuration_wav2vec2 import Wav2Vec2Config
22
+ from models.configuration_bart import BartConfig
23
+ from transformers import AutoConfig
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ class SpeechEncoderDecoderConfig(PretrainedConfig):
30
+ r"""
31
+ [`SpeechEncoderDecoderConfig`] is the configuration class to store the configuration of a
32
+ [`SpeechEncoderDecoderModel`]. It is used to instantiate an Encoder Decoder model according to the specified
33
+ arguments, defining the encoder and decoder configs.
34
+
35
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
+ documentation from [`PretrainedConfig`] for more information.
37
+
38
+ Args:
39
+ kwargs (*optional*):
40
+ Dictionary of keyword arguments. Notably:
41
+
42
+ - **encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
43
+ the encoder config.
44
+ - **decoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
45
+ the decoder config.
46
+
47
+ Examples:
48
+
49
+ ```python
50
+ >>> from transformers import BertConfig, Wav2Vec2Config, SpeechEncoderDecoderConfig, SpeechEncoderDecoderModel
51
+
52
+ >>> # Initializing a Wav2Vec2 & BERT style configuration
53
+ >>> config_encoder = Wav2Vec2Config()
54
+ >>> config_decoder = BertConfig()
55
+
56
+ >>> config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)
57
+
58
+ >>> # Initializing a Wav2Vec2Bert model from a Wav2Vec2 & bert-base-uncased style configurations
59
+ >>> model = SpeechEncoderDecoderModel(config=config)
60
+
61
+ >>> # Accessing the model configuration
62
+ >>> config_encoder = model.config.encoder
63
+ >>> config_decoder = model.config.decoder
64
+ >>> # set decoder config to causal lm
65
+ >>> config_decoder.is_decoder = True
66
+ >>> config_decoder.add_cross_attention = True
67
+
68
+ >>> # Saving the model, including its configuration
69
+ >>> model.save_pretrained("my-model")
70
+
71
+ >>> # loading model and config from pretrained folder
72
+ >>> encoder_decoder_config = SpeechEncoderDecoderConfig.from_pretrained("my-model")
73
+ >>> model = SpeechEncoderDecoderModel.from_pretrained("my-model", config=encoder_decoder_config)
74
+ ```"""
75
+ model_type = "speech-encoder-decoder"
76
+ is_composition = True
77
+
78
+ def __init__(self, **kwargs):
79
+ super().__init__(**kwargs)
80
+ if "encoder" not in kwargs or "decoder" not in kwargs:
81
+ raise ValueError(
82
+ f"A configuraton of type {self.model_type} cannot be instantiated because not both `encoder` and `decoder` sub-configurations are passed, but only {kwargs}"
83
+ )
84
+
85
+ encoder_config = kwargs.pop("encoder")
86
+ decoder_config = kwargs.pop("decoder")
87
+
88
+ # TODO: Load configs from AutoConfig (as done in Transformers 🤗)
89
+ self.encoder = Wav2Vec2Config(**encoder_config)
90
+ self.decoder = BartConfig(**decoder_config)
91
+ self.is_encoder_decoder = True
92
+
93
+ @classmethod
94
+ def from_encoder_decoder_configs(
95
+ cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
96
+ ) -> PretrainedConfig:
97
+ r"""
98
+ Instantiate a [`SpeechEncoderDecoderConfig`] (or a derived class) from a pre-trained encoder model
99
+ configuration and decoder model configuration.
100
+
101
+ Returns:
102
+ [`SpeechEncoderDecoderConfig`]: An instance of a configuration object
103
+ """
104
+ logger.info("Setting `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
105
+ decoder_config.is_decoder = True
106
+ decoder_config.add_cross_attention = True
107
+
108
+ return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)
109
+
110
+ def to_dict(self):
111
+ """
112
+ Serializes this instance to a Python dictionary. Override the default *to_dict()* from *PretrainedConfig*.
113
+
114
+ Returns:
115
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
116
+ """
117
+ output = copy.deepcopy(self.__dict__)
118
+ output["encoder"] = self.encoder.to_dict()
119
+ output["decoder"] = self.decoder.to_dict()
120
+ output["model_type"] = self.__class__.model_type
121
+ return output
models/configuration_wav2vec2.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Wav2Vec2 model configuration"""
16
+
17
+ import functools
18
+ import operator
19
+
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.utils import logging
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
27
+ "facebook/wav2vec2-base-960h": "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/config.json",
28
+ # See all Wav2Vec2 models at https://huggingface.co/models?filter=wav2vec2
29
+ }
30
+
31
+
32
+ class Wav2Vec2Config(PretrainedConfig):
33
+ r"""
34
+ This is the configuration class to store the configuration of a [`Wav2Vec2Model`]. It is used to instantiate an
35
+ Wav2Vec2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
36
+ with the defaults will yield a similar configuration to that of the Wav2Vec2
37
+ [facebook/wav2vec2-base-960h](https://huggingface.co/facebook/wav2vec2-base-960h) architecture.
38
+
39
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
40
+ documentation from [`PretrainedConfig`] for more information.
41
+
42
+
43
+ Args:
44
+ vocab_size (`int`, *optional*, defaults to 32):
45
+ Vocabulary size of the Wav2Vec2 model. Defines the number of different tokens that can be represented by
46
+ the `inputs_ids` passed when calling [`Wav2Vec2Model`] or [`TFWav2Vec2Model`]. Vocabulary size of the
47
+ model. Defines the different tokens that can be represented by the *inputs_ids* passed to the forward
48
+ method of [`Wav2Vec2Model`].
49
+ hidden_size (`int`, *optional*, defaults to 768):
50
+ Dimensionality of the encoder layers and the pooler layer.
51
+ num_hidden_layers (`int`, *optional*, defaults to 12):
52
+ Number of hidden layers in the Transformer encoder.
53
+ num_attention_heads (`int`, *optional*, defaults to 12):
54
+ Number of attention heads for each attention layer in the Transformer encoder.
55
+ intermediate_size (`int`, *optional*, defaults to 3072):
56
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
57
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
58
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
59
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
60
+ hidden_dropout (`float`, *optional*, defaults to 0.1):
61
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
62
+ attention_dropout (`float`, *optional*, defaults to 0.1):
63
+ The dropout ratio for the attention probabilities.
64
+ final_dropout (`float`, *optional*, defaults to 0.1):
65
+ The dropout probability for the final projection layer of [`Wav2Vec2ForCTC`].
66
+ initializer_range (`float`, *optional*, defaults to 0.02):
67
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
68
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
69
+ The epsilon used by the layer normalization layers.
70
+ feat_extract_norm (`str`, *optional*, defaults to `"group"`):
71
+ The norm to be applied to 1D convolutional layers in feature encoder. One of `"group"` for group
72
+ normalization of only the first 1D convolutional layer or `"layer"` for layer normalization of all 1D
73
+ convolutional layers.
74
+ feat_proj_dropout (`float`, *optional*, defaults to 0.0):
75
+ The dropout probability for output of the feature encoder.
76
+ feat_extract_activation (`str, `optional`, defaults to `"gelu"`):
77
+ The non-linear activation function (function or string) in the 1D convolutional layers of the feature
78
+ extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
79
+ feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
80
+ The dropout probabilitiy for quantized feature encoder states.
81
+ conv_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
82
+ A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
83
+ feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
84
+ conv_stride (`Tuple[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
85
+ A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
86
+ of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*.
87
+ conv_kernel (`Tuple[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
88
+ A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
89
+ length of *conv_kernel* defines the number of convolutional layers and has to match the length of
90
+ *conv_dim*.
91
+ conv_bias (`bool`, *optional*, defaults to `False`):
92
+ Whether the 1D convolutional layers have a bias.
93
+ num_conv_pos_embeddings (`int`, *optional*, defaults to 128):
94
+ Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional
95
+ embeddings layer.
96
+ num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
97
+ Number of groups of 1D convolutional positional embeddings layer.
98
+ do_stable_layer_norm (`bool`, *optional*, defaults to `False`):
99
+ Whether to apply *stable* layer norm architecture of the Transformer encoder. `do_stable_layer_norm is
100
+ True` corresponds to applying layer norm before the attention layer, whereas `do_stable_layer_norm is
101
+ False` corresponds to applying layer norm after the attention layer.
102
+ apply_spec_augment (`bool`, *optional*, defaults to `True`):
103
+ Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see
104
+ [SpecAugment: A Simple Data Augmentation Method for Automatic Speech
105
+ Recognition](https://arxiv.org/abs/1904.08779).
106
+ mask_time_prob (`float`, *optional*, defaults to 0.05):
107
+ Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking
108
+ procecure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If
109
+ reasoning from the propability of each feature vector to be chosen as the start of the vector span to be
110
+ masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the
111
+ actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`.
112
+ mask_time_length (`int`, *optional*, defaults to 10):
113
+ Length of vector span along the time axis.
114
+ mask_time_min_masks (`int`, *optional*, defaults to 2),:
115
+ The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step,
116
+ irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length <
117
+ mask_time_min_masks''
118
+ mask_feature_prob (`float`, *optional*, defaults to 0.0):
119
+ Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The
120
+ masking procecure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over
121
+ the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector
122
+ span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap
123
+ may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is
124
+ True`.
125
+ mask_feature_length (`int`, *optional*, defaults to 10):
126
+ Length of vector span along the feature axis.
127
+ mask_feature_min_masks (`int`, *optional*, defaults to 0),:
128
+ The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
129
+ step, irrespectively of `mask_feature_prob`. Only relevant if
130
+ ''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks''
131
+ num_codevectors_per_group (`int`, *optional*, defaults to 320):
132
+ Number of entries in each quantization codebook (group).
133
+ num_codevector_groups (`int`, *optional*, defaults to 2):
134
+ Number of codevector groups for product codevector quantization.
135
+ contrastive_logits_temperature (`float`, *optional*, defaults to 0.1):
136
+ The temperature *kappa* in the contrastive loss.
137
+ feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
138
+ The dropout probabilitiy for the output of the feature encoder that's used by the quantizer.
139
+ num_negatives (`int`, *optional*, defaults to 100):
140
+ Number of negative samples for the contrastive loss.
141
+ codevector_dim (`int`, *optional*, defaults to 256):
142
+ Dimensionality of the quantized feature vectors.
143
+ proj_codevector_dim (`int`, *optional*, defaults to 256):
144
+ Dimensionality of the final projection of both the quantized and the transformer features.
145
+ diversity_loss_weight (`int`, *optional*, defaults to 0.1):
146
+ The weight of the codebook diversity loss component.
147
+ ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`):
148
+ Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
149
+ instance of [`Wav2Vec2ForCTC`].
150
+ ctc_zero_infinity (`bool`, *optional*, defaults to `False`):
151
+ Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
152
+ occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
153
+ of [`Wav2Vec2ForCTC`].
154
+ use_weighted_layer_sum (`bool`, *optional*, defaults to `False`):
155
+ Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
156
+ instance of [`Wav2Vec2ForSequenceClassification`].
157
+ classifier_proj_size (`int`, *optional*, defaults to 256):
158
+ Dimensionality of the projection before token mean-pooling for classification.
159
+ tdnn_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
160
+ A tuple of integers defining the number of output channels of each 1D convolutional layer in the *TDNN*
161
+ module of the *XVector* model. The length of *tdnn_dim* defines the number of *TDNN* layers.
162
+ tdnn_kernel (`Tuple[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
163
+ A tuple of integers defining the kernel size of each 1D convolutional layer in the *TDNN* module of the
164
+ *XVector* model. The length of *tdnn_kernel* has to match the length of *tdnn_dim*.
165
+ tdnn_dilation (`Tuple[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
166
+ A tuple of integers defining the dilation factor of each 1D convolutional layer in *TDNN* module of the
167
+ *XVector* model. The length of *tdnn_dilation* has to match the length of *tdnn_dim*.
168
+ xvector_output_dim (`int`, *optional*, defaults to 512):
169
+ Dimensionality of the *XVector* embedding vectors.
170
+ add_adapter (`bool`, *optional*, defaults to `False`):
171
+ Whether a convolutional network should be stacked on top of the Wav2Vec2 Encoder. Can be very useful for
172
+ warm-starting Wav2Vec2 for SpeechEncoderDecoder models.
173
+ adapter_kernel_size (`int`, *optional*, defaults to 3):
174
+ Kernel size of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
175
+ adapter_stride (`int`, *optional*, defaults to 2):
176
+ Stride of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
177
+ num_adapter_layers (`int`, *optional*, defaults to 3):
178
+ Number of convolutional layers that should be used in the adapter network. Only relevant if `add_adapter is
179
+ True`.
180
+ output_hidden_size (`int`, *optional*):
181
+ Dimensionality of the encoder output layer. If not defined, this defaults to *hidden-size*. Only relevant
182
+ if `add_adapter is True`.
183
+ use_scan (`bool`, *optional*, defaults to `False`):
184
+ Whether or not to use nn.scan in the Flax Wav2Vec2 transformer layers.
185
+
186
+ Example:
187
+
188
+ ```python
189
+ >>> from transformers import Wav2Vec2Model, Wav2Vec2Config
190
+
191
+ >>> # Initializing a Wav2Vec2 facebook/wav2vec2-base-960h style configuration
192
+ >>> configuration = Wav2Vec2Config()
193
+
194
+ >>> # Initializing a model from the facebook/wav2vec2-base-960h style configuration
195
+ >>> model = Wav2Vec2Model(configuration)
196
+
197
+ >>> # Accessing the model configuration
198
+ >>> configuration = model.config
199
+ ```"""
200
+ model_type = "wav2vec2"
201
+
202
+ def __init__(
203
+ self,
204
+ vocab_size=32,
205
+ hidden_size=768,
206
+ num_hidden_layers=12,
207
+ num_attention_heads=12,
208
+ intermediate_size=3072,
209
+ hidden_act="gelu",
210
+ hidden_dropout=0.1,
211
+ activation_dropout=0.1,
212
+ attention_dropout=0.1,
213
+ feat_proj_dropout=0.0,
214
+ feat_quantizer_dropout=0.0,
215
+ final_dropout=0.1,
216
+ layerdrop=0.1,
217
+ initializer_range=0.02,
218
+ layer_norm_eps=1e-5,
219
+ feat_extract_norm="group",
220
+ feat_extract_activation="gelu",
221
+ conv_dim=(512, 512, 512, 512, 512, 512, 512),
222
+ conv_stride=(5, 2, 2, 2, 2, 2, 2),
223
+ conv_kernel=(10, 3, 3, 3, 3, 2, 2),
224
+ conv_bias=False,
225
+ num_conv_pos_embeddings=128,
226
+ num_conv_pos_embedding_groups=16,
227
+ do_stable_layer_norm=False,
228
+ apply_spec_augment=True,
229
+ mask_time_prob=0.05,
230
+ mask_time_length=10,
231
+ mask_time_min_masks=2,
232
+ mask_feature_prob=0.0,
233
+ mask_feature_length=10,
234
+ mask_feature_min_masks=0,
235
+ num_codevectors_per_group=320,
236
+ num_codevector_groups=2,
237
+ contrastive_logits_temperature=0.1,
238
+ num_negatives=100,
239
+ codevector_dim=256,
240
+ proj_codevector_dim=256,
241
+ diversity_loss_weight=0.1,
242
+ ctc_loss_reduction="sum",
243
+ ctc_zero_infinity=False,
244
+ use_weighted_layer_sum=False,
245
+ classifier_proj_size=256,
246
+ tdnn_dim=(512, 512, 512, 512, 1500),
247
+ tdnn_kernel=(5, 3, 3, 1, 1),
248
+ tdnn_dilation=(1, 2, 3, 1, 1),
249
+ xvector_output_dim=512,
250
+ pad_token_id=0,
251
+ bos_token_id=1,
252
+ eos_token_id=2,
253
+ add_adapter=False,
254
+ adapter_kernel_size=3,
255
+ adapter_stride=2,
256
+ num_adapter_layers=3,
257
+ output_hidden_size=None,
258
+ use_scan=False,
259
+ fuse_matmuls=False,
260
+ **kwargs
261
+ ):
262
+ super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
263
+ self.hidden_size = hidden_size
264
+ self.feat_extract_norm = feat_extract_norm
265
+ self.feat_extract_activation = feat_extract_activation
266
+ self.conv_dim = list(conv_dim)
267
+ self.conv_stride = list(conv_stride)
268
+ self.conv_kernel = list(conv_kernel)
269
+ self.conv_bias = conv_bias
270
+ self.num_conv_pos_embeddings = num_conv_pos_embeddings
271
+ self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
272
+ self.num_feat_extract_layers = len(self.conv_dim)
273
+ self.num_hidden_layers = num_hidden_layers
274
+ self.intermediate_size = intermediate_size
275
+ self.hidden_act = hidden_act
276
+ self.num_attention_heads = num_attention_heads
277
+ self.hidden_dropout = hidden_dropout
278
+ self.attention_dropout = attention_dropout
279
+ self.activation_dropout = activation_dropout
280
+ self.feat_proj_dropout = feat_proj_dropout
281
+ self.final_dropout = final_dropout
282
+ self.layerdrop = layerdrop
283
+ self.layer_norm_eps = layer_norm_eps
284
+ self.initializer_range = initializer_range
285
+ self.vocab_size = vocab_size
286
+ self.do_stable_layer_norm = do_stable_layer_norm
287
+ self.use_weighted_layer_sum = use_weighted_layer_sum
288
+ self.use_scan = use_scan
289
+ self.fuse_matmuls = fuse_matmuls
290
+
291
+ if (
292
+ (len(self.conv_stride) != self.num_feat_extract_layers)
293
+ or (len(self.conv_kernel) != self.num_feat_extract_layers)
294
+ or (len(self.conv_dim) != self.num_feat_extract_layers)
295
+ ):
296
+ raise ValueError(
297
+ "Configuration for convolutional layers is incorrect. "
298
+ "It is required that `len(config.conv_dim)` == `len(config.conv_stride)` == `len(config.conv_kernel)`, "
299
+ f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride) "
300
+ f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`."
301
+ )
302
+
303
+ # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
304
+ self.apply_spec_augment = apply_spec_augment
305
+ self.mask_time_prob = mask_time_prob
306
+ self.mask_time_length = mask_time_length
307
+ self.mask_time_min_masks = mask_time_min_masks
308
+ self.mask_feature_prob = mask_feature_prob
309
+ self.mask_feature_length = mask_feature_length
310
+ self.mask_feature_min_masks = mask_feature_min_masks
311
+
312
+ # parameters for pretraining with codevector quantized representations
313
+ self.num_codevectors_per_group = num_codevectors_per_group
314
+ self.num_codevector_groups = num_codevector_groups
315
+ self.contrastive_logits_temperature = contrastive_logits_temperature
316
+ self.feat_quantizer_dropout = feat_quantizer_dropout
317
+ self.num_negatives = num_negatives
318
+ self.codevector_dim = codevector_dim
319
+ self.proj_codevector_dim = proj_codevector_dim
320
+ self.diversity_loss_weight = diversity_loss_weight
321
+
322
+ # ctc loss
323
+ self.ctc_loss_reduction = ctc_loss_reduction
324
+ self.ctc_zero_infinity = ctc_zero_infinity
325
+
326
+ # adapter
327
+ self.add_adapter = add_adapter
328
+ self.adapter_kernel_size = adapter_kernel_size
329
+ self.adapter_stride = adapter_stride
330
+ self.num_adapter_layers = num_adapter_layers
331
+ self.output_hidden_size = output_hidden_size or hidden_size
332
+
333
+ # SequenceClassification-specific parameter. Feel free to ignore for other classes.
334
+ self.classifier_proj_size = classifier_proj_size
335
+
336
+ # XVector-specific parameters. Feel free to ignore for other classes.
337
+ self.tdnn_dim = list(tdnn_dim)
338
+ self.tdnn_kernel = list(tdnn_kernel)
339
+ self.tdnn_dilation = list(tdnn_dilation)
340
+ self.xvector_output_dim = xvector_output_dim
341
+
342
+ @property
343
+ def inputs_to_logits_ratio(self):
344
+ return functools.reduce(operator.mul, self.conv_stride, 1)
models/modeling_flax_bart.py ADDED
@@ -0,0 +1,815 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Flax Bart model."""
16
+
17
+ import math
18
+ import random
19
+ from functools import partial
20
+ from typing import Optional, Tuple
21
+
22
+ import numpy as np
23
+
24
+ import flax.linen as nn
25
+ import jax
26
+ import jax.numpy as jnp
27
+ from flax.core.frozen_dict import FrozenDict, unfreeze
28
+ from flax.linen import combine_masks, make_causal_mask
29
+ from flax.linen import partitioning as nn_partitioning
30
+ from flax.linen.attention import dot_product_attention_weights
31
+ from jax import lax
32
+ from jax.random import PRNGKey
33
+
34
+ from transformers.modeling_flax_outputs import (
35
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
36
+ FlaxCausalLMOutputWithCrossAttentions,
37
+ )
38
+ from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
39
+
40
+ from models.configuration_bart import BartConfig
41
+
42
+
43
+ scan_with_axes = nn_partitioning.scan_with_axes
44
+ remat = nn_partitioning.remat
45
+
46
+
47
+ def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
48
+ """
49
+ Shift input ids one token to the right.
50
+ """
51
+ shifted_input_ids = np.zeros_like(input_ids)
52
+ shifted_input_ids[:, 1:] = input_ids[:, :-1]
53
+ shifted_input_ids[:, 0] = decoder_start_token_id
54
+
55
+ shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
56
+ return shifted_input_ids
57
+
58
+
59
+ class FlaxBartAttention(nn.Module):
60
+ config: BartConfig
61
+ embed_dim: int
62
+ num_heads: int
63
+ dropout: float = 0.0
64
+ causal: bool = False
65
+ bias: bool = True
66
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
67
+
68
+ def setup(self) -> None:
69
+ self.head_dim = self.embed_dim // self.num_heads
70
+ if self.head_dim * self.num_heads != self.embed_dim:
71
+ raise ValueError(
72
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
73
+ f" and `num_heads`: {self.num_heads})."
74
+ )
75
+
76
+ dense = partial(
77
+ nn.Dense,
78
+ self.embed_dim,
79
+ use_bias=self.bias,
80
+ dtype=self.dtype,
81
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
82
+ )
83
+
84
+ self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
85
+
86
+ self.fused_proj = nn.Dense(
87
+ self.embed_dim * 3,
88
+ use_bias=self.bias,
89
+ dtype=self.dtype,
90
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
91
+ )
92
+
93
+ self.fused_key_value = nn.Dense(
94
+ self.embed_dim * 2,
95
+ use_bias=self.bias,
96
+ dtype=self.dtype,
97
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
98
+ )
99
+
100
+ self.out_proj = dense()
101
+
102
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
103
+
104
+ if self.causal:
105
+ self.causal_mask = make_causal_mask(
106
+ jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
107
+ )
108
+
109
+ def _split_heads(self, hidden_states):
110
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
111
+
112
+ def _merge_heads(self, hidden_states):
113
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
114
+
115
+ @nn.compact
116
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
117
+ """
118
+ This function takes projected key, value states from a single input token and concatenates the states to cached
119
+ states from previous steps. This function is slighly adapted from the official Flax repository:
120
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
121
+ """
122
+ # detect if we're initializing by absence of existing cache data.
123
+ is_initialized = self.has_variable("cache", "cached_key")
124
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
125
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
126
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
127
+
128
+ if is_initialized:
129
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
130
+ # update key, value caches with our new 1d spatial slices
131
+ cur_index = cache_index.value
132
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
133
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
134
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
135
+ cached_key.value = key
136
+ cached_value.value = value
137
+ num_updated_cache_vectors = query.shape[1]
138
+ cache_index.value = cache_index.value + num_updated_cache_vectors
139
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
140
+ pad_mask = jnp.broadcast_to(
141
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
142
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
143
+ )
144
+ attention_mask = combine_masks(pad_mask, attention_mask)
145
+ return key, value, attention_mask
146
+
147
+ def __call__(
148
+ self,
149
+ hidden_states: jnp.ndarray,
150
+ key_value_states: Optional[jnp.ndarray] = None,
151
+ attention_mask: Optional[jnp.ndarray] = None,
152
+ init_cache: bool = False,
153
+ deterministic: bool = True,
154
+ ) -> Tuple[jnp.ndarray]:
155
+ """Input shape: Batch x Time x Channel"""
156
+
157
+ # if key_value_states are provided this layer is used as a cross-attention layer
158
+ # for the decoder
159
+ is_cross_attention = key_value_states is not None
160
+ batch_size = hidden_states.shape[0]
161
+
162
+ if self.config.fuse_matmuls:
163
+ # get key, value proj
164
+ if is_cross_attention:
165
+ # get query proj
166
+ query_states = self.q_proj(hidden_states)
167
+ # cross_attentions
168
+ attention_states = self.fused_key_value(key_value_states)
169
+ key_states, value_states = jnp.split(attention_states, 2, axis=-1)
170
+ else:
171
+ attention_states = self.fused_proj(hidden_states)
172
+ query_states, key_states, value_states = jnp.split(attention_states, 3, axis=-1)
173
+
174
+ else:
175
+ # get query proj
176
+ query_states = self.q_proj(hidden_states)
177
+ # get key, value proj
178
+ if is_cross_attention:
179
+ # cross_attentions
180
+ key_states = self.k_proj(key_value_states)
181
+ value_states = self.v_proj(key_value_states)
182
+ else:
183
+ # self_attention
184
+ key_states = self.k_proj(hidden_states)
185
+ value_states = self.v_proj(hidden_states)
186
+
187
+ query_states = self._split_heads(query_states)
188
+ key_states = self._split_heads(key_states)
189
+ value_states = self._split_heads(value_states)
190
+
191
+ # handle cache prepare causal attention mask
192
+ if self.causal:
193
+ query_length, key_length = query_states.shape[1], key_states.shape[1]
194
+ if self.has_variable("cache", "cached_key"):
195
+ mask_shift = self.variables["cache"]["cache_index"]
196
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
197
+ causal_mask = lax.dynamic_slice(
198
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
199
+ )
200
+ else:
201
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
202
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
203
+
204
+ # combine masks if needed
205
+ if attention_mask is not None and self.causal:
206
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
207
+ attention_mask = combine_masks(attention_mask, causal_mask)
208
+ elif self.causal:
209
+ attention_mask = causal_mask
210
+ elif attention_mask is not None:
211
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
212
+
213
+ # During fast autoregressive decoding, we feed one position at a time,
214
+ # and cache the keys and values step by step.
215
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
216
+ key_states, value_states, attention_mask = self._concatenate_to_cache(
217
+ key_states, value_states, query_states, attention_mask
218
+ )
219
+
220
+ # Convert the boolean attention mask to an attention bias.
221
+ if attention_mask is not None:
222
+ # attention mask in the form of attention bias
223
+ attention_bias = lax.select(
224
+ attention_mask > 0,
225
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
226
+ jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
227
+ )
228
+ else:
229
+ attention_bias = None
230
+
231
+ dropout_rng = None
232
+ if not deterministic and self.dropout > 0.0:
233
+ dropout_rng = self.make_rng("dropout")
234
+
235
+ attn_weights = dot_product_attention_weights(
236
+ query_states,
237
+ key_states,
238
+ bias=attention_bias,
239
+ dropout_rng=dropout_rng,
240
+ dropout_rate=self.dropout,
241
+ broadcast_dropout=True,
242
+ deterministic=deterministic,
243
+ dtype=self.dtype,
244
+ precision=None,
245
+ )
246
+
247
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
248
+ attn_output = self._merge_heads(attn_output)
249
+ attn_output = self.out_proj(attn_output)
250
+
251
+ return attn_output, attn_weights
252
+
253
+
254
+ class FlaxBartDecoderLayer(nn.Module):
255
+ config: BartConfig
256
+ dtype: jnp.dtype = jnp.float32
257
+
258
+ def setup(self) -> None:
259
+ self.embed_dim = self.config.d_model
260
+ self.self_attn = FlaxBartAttention(
261
+ config=self.config,
262
+ embed_dim=self.embed_dim,
263
+ num_heads=self.config.decoder_attention_heads,
264
+ dropout=self.config.attention_dropout,
265
+ causal=True,
266
+ dtype=self.dtype,
267
+ )
268
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
269
+ self.activation_fn = ACT2FN[self.config.activation_function]
270
+ self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
271
+
272
+ self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
273
+ self.encoder_attn = FlaxBartAttention(
274
+ config=self.config,
275
+ embed_dim=self.embed_dim,
276
+ num_heads=self.config.decoder_attention_heads,
277
+ dropout=self.config.attention_dropout,
278
+ dtype=self.dtype,
279
+ )
280
+ self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
281
+ self.fc1 = nn.Dense(
282
+ self.config.encoder_ffn_dim,
283
+ dtype=self.dtype,
284
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
285
+ )
286
+ self.fc2 = nn.Dense(
287
+ self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
288
+ )
289
+ self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
290
+
291
+ def __call__(
292
+ self,
293
+ hidden_states: jnp.ndarray,
294
+ attention_mask: jnp.ndarray,
295
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
296
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
297
+ init_cache: bool = False,
298
+ output_attentions: bool = True,
299
+ deterministic: bool = True,
300
+ ) -> Tuple[jnp.ndarray]:
301
+
302
+ if self.config.use_scan:
303
+ hidden_states = hidden_states[0]
304
+
305
+ residual = hidden_states
306
+
307
+ # Self Attention
308
+ hidden_states, self_attn_weights = self.self_attn(
309
+ hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache
310
+ )
311
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
312
+ hidden_states = residual + hidden_states
313
+ hidden_states = self.self_attn_layer_norm(hidden_states)
314
+
315
+ # Cross-Attention Block
316
+ cross_attn_weights = None
317
+ if encoder_hidden_states is not None:
318
+ residual = hidden_states
319
+
320
+ hidden_states, cross_attn_weights = self.encoder_attn(
321
+ hidden_states=hidden_states,
322
+ key_value_states=encoder_hidden_states,
323
+ attention_mask=encoder_attention_mask,
324
+ )
325
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
326
+ hidden_states = residual + hidden_states
327
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
328
+
329
+ # Fully Connected
330
+ residual = hidden_states
331
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
332
+ hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
333
+ hidden_states = self.fc2(hidden_states)
334
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
335
+ hidden_states = residual + hidden_states
336
+ hidden_states = self.final_layer_norm(hidden_states)
337
+
338
+ outputs = (hidden_states,)
339
+
340
+ if output_attentions:
341
+ outputs += (self_attn_weights, cross_attn_weights)
342
+
343
+ if self.config.use_scan:
344
+ outputs = (outputs, None)
345
+
346
+ return outputs
347
+
348
+
349
+ class FlaxBartDecoderLayerCollection(nn.Module):
350
+ config: BartConfig
351
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
352
+
353
+ @nn.compact
354
+ def __call__(
355
+ self,
356
+ hidden_states,
357
+ attention_mask,
358
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
359
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
360
+ deterministic: bool = True,
361
+ init_cache: bool = False,
362
+ output_attentions: bool = False,
363
+ output_hidden_states: bool = False,
364
+ return_dict: bool = True,
365
+ ):
366
+ # decoder layers
367
+ all_hidden_states = () if output_hidden_states else None
368
+ all_self_attns = () if output_attentions else None
369
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
370
+
371
+ num_decoder_layers = self.config.decoder_layers
372
+ BlockDecoderLayer = (
373
+ remat(
374
+ FlaxBartDecoderLayer,
375
+ static_argnums=(4, 5, 6),
376
+ prevent_cse=not self.config.use_scan,
377
+ )
378
+ if self.config.gradient_checkpointing
379
+ else FlaxBartDecoderLayer
380
+ )
381
+
382
+ if self.config.use_scan:
383
+ # since all decoder layers are the same, we use nn.scan directly
384
+ assert not output_attentions, "cannot use `scan` with `output_attentions` set to `True`"
385
+ assert not output_hidden_states, "cannot use `scan` with `output_hidden_states` set to `True`"
386
+ hidden_states = (hidden_states,)
387
+
388
+ # TODO: add layerdrop in checkpointed scan (note: default value for layerdrop in config is zero)
389
+ hidden_states, _ = scan_with_axes(
390
+ BlockDecoderLayer,
391
+ variable_axes={"params": 0, "cache": 0},
392
+ split_rngs={"params": True, "dropout": True},
393
+ in_axes=(nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast),
394
+ length=num_decoder_layers,
395
+ )(self.config, dtype=self.dtype, name="FlaxBartDecoderLayers")(
396
+ hidden_states,
397
+ attention_mask,
398
+ encoder_hidden_states,
399
+ encoder_attention_mask,
400
+ init_cache,
401
+ output_attentions,
402
+ deterministic,
403
+ )
404
+ hidden_states = hidden_states[0]
405
+
406
+ else:
407
+ for layer in range(num_decoder_layers):
408
+ if output_hidden_states:
409
+ all_hidden_states += (hidden_states,)
410
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
411
+ dropout_probability = random.uniform(0, 1)
412
+ if not deterministic and (dropout_probability < self.config.decoder_layerdrop):
413
+ layer_outputs = (None, None, None)
414
+ else:
415
+ layer_outputs = BlockDecoderLayer(self.config, dtype=self.dtype, name=str(layer),)(
416
+ hidden_states,
417
+ attention_mask,
418
+ encoder_hidden_states,
419
+ encoder_attention_mask,
420
+ init_cache,
421
+ output_attentions,
422
+ deterministic,
423
+ )
424
+
425
+ hidden_states = layer_outputs[0]
426
+ if output_attentions:
427
+ all_self_attns += (layer_outputs[1],)
428
+
429
+ if encoder_hidden_states is not None:
430
+ all_cross_attentions += (layer_outputs[2],)
431
+
432
+ # add hidden states from the last decoder layer
433
+ if output_hidden_states:
434
+ all_hidden_states += (hidden_states,)
435
+
436
+ outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions]
437
+
438
+ if not return_dict:
439
+ return tuple(v for v in outputs if v is not None)
440
+
441
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
442
+ last_hidden_state=hidden_states,
443
+ hidden_states=all_hidden_states,
444
+ attentions=all_self_attns,
445
+ cross_attentions=all_cross_attentions,
446
+ )
447
+
448
+
449
+ class FlaxBartDecoder(nn.Module):
450
+ config: BartConfig
451
+ embed_tokens: nn.Embed
452
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
453
+
454
+ def setup(self):
455
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
456
+
457
+ embed_dim = self.config.d_model
458
+ self.padding_idx = self.config.pad_token_id
459
+ self.max_target_positions = self.config.max_position_embeddings
460
+ self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
461
+
462
+ # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
463
+ # and adjust num_embeddings appropriately. Other models don't have this hack
464
+ self.offset = 2
465
+ self.embed_positions = nn.Embed(
466
+ self.config.max_position_embeddings + self.offset,
467
+ embed_dim,
468
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
469
+ )
470
+
471
+ self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
472
+ self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
473
+
474
+ def __call__(
475
+ self,
476
+ input_ids,
477
+ attention_mask,
478
+ position_ids,
479
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
480
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
481
+ init_cache: bool = False,
482
+ output_attentions: bool = False,
483
+ output_hidden_states: bool = False,
484
+ return_dict: bool = True,
485
+ deterministic: bool = True,
486
+ ):
487
+ input_shape = input_ids.shape
488
+ input_ids = input_ids.reshape(-1, input_shape[-1])
489
+
490
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
491
+
492
+ # embed positions
493
+ positions = self.embed_positions(position_ids + self.offset)
494
+
495
+ hidden_states = inputs_embeds + positions
496
+ hidden_states = self.layernorm_embedding(hidden_states)
497
+
498
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
499
+
500
+ outputs = self.layers(
501
+ hidden_states,
502
+ attention_mask,
503
+ encoder_hidden_states,
504
+ encoder_attention_mask,
505
+ deterministic=deterministic,
506
+ init_cache=init_cache,
507
+ output_attentions=output_attentions,
508
+ output_hidden_states=output_hidden_states,
509
+ return_dict=return_dict,
510
+ )
511
+
512
+ if not return_dict:
513
+ return outputs
514
+
515
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
516
+ last_hidden_state=outputs.last_hidden_state,
517
+ hidden_states=outputs.hidden_states,
518
+ attentions=outputs.attentions,
519
+ cross_attentions=outputs.cross_attentions,
520
+ )
521
+
522
+
523
+ class FlaxBartDecoderPreTrainedModel(FlaxPreTrainedModel):
524
+ config_class = BartConfig
525
+ base_model_prefix: str = "model"
526
+ module_class: nn.Module = None
527
+
528
+ def __init__(
529
+ self,
530
+ config: BartConfig,
531
+ input_shape: Tuple[int] = (1, 1),
532
+ seed: int = 0,
533
+ dtype: jnp.dtype = jnp.float32,
534
+ **kwargs
535
+ ):
536
+ config.is_decoder = True
537
+ config.is_encoder_decoder = False
538
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
539
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
540
+
541
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
542
+ # init input tensors
543
+ input_ids = jnp.zeros(input_shape, dtype="i4")
544
+ attention_mask = jnp.ones_like(input_ids)
545
+
546
+ batch_size, sequence_length = input_ids.shape
547
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
548
+
549
+ params_rng, dropout_rng = jax.random.split(rng)
550
+ rngs = {"params": params_rng, "dropout": dropout_rng}
551
+ encoder_hidden_states = jnp.zeros(input_shape + (self.config.d_model,))
552
+ encoder_attention_mask = attention_mask
553
+ module_init_outputs = self.module.init(
554
+ rngs,
555
+ input_ids,
556
+ attention_mask,
557
+ position_ids,
558
+ encoder_hidden_states,
559
+ encoder_attention_mask,
560
+ return_dict=False,
561
+ )
562
+ return module_init_outputs["params"]
563
+
564
+ def init_cache(self, batch_size, max_length):
565
+ r"""
566
+ Args:
567
+ batch_size (`int`):
568
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
569
+ max_length (`int`):
570
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
571
+ cache.
572
+ """
573
+ # init input variables to retrieve cache
574
+ input_ids = jnp.ones((batch_size, max_length), dtype="i4")
575
+ attention_mask = jnp.ones_like(input_ids, dtype="i4")
576
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
577
+
578
+ init_variables = self.module.init(
579
+ jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
580
+ )
581
+ return unfreeze(init_variables["cache"])
582
+
583
+ def __call__(
584
+ self,
585
+ input_ids: jnp.ndarray,
586
+ attention_mask: Optional[jnp.ndarray] = None,
587
+ position_ids: Optional[jnp.ndarray] = None,
588
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
589
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
590
+ output_attentions: Optional[bool] = None,
591
+ output_hidden_states: Optional[bool] = None,
592
+ return_dict: Optional[bool] = None,
593
+ train: bool = False,
594
+ params: dict = None,
595
+ past_key_values: dict = None,
596
+ dropout_rng: PRNGKey = None,
597
+ ):
598
+ """
599
+ Args:
600
+ input_ids (`jnp.ndarray` of shape `(target_batch_size, target_sequence_length)`):
601
+ Indices of decoder input sequence tokens in the vocabulary.
602
+
603
+ Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
604
+ [`PreTrainedTokenizer.__call__`] for details.
605
+
606
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
607
+
608
+ For translation and summarization training, `decoder_input_ids` should be provided. If no
609
+ `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
610
+ for denoising pre-training following the paper.
611
+ attention_mask (`jnp.ndarray` of shape `(target_batch_size, target_sequence_length)`, *optional*):
612
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
613
+ be used by default.
614
+
615
+ If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the
616
+ paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
617
+ position_ids (`numpy.ndarray` of shape `(target_batch_size, sequence_length)`, *optional*):
618
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
619
+ range `[0, config.max_position_embeddings - 1]`.
620
+ encoder_hidden_states (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
621
+ A sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
622
+ encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
623
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
624
+
625
+ - 1 for tokens that are **not masked**,
626
+ - 0 for tokens that are **masked**.
627
+
628
+ [What are attention masks?](../glossary#attention-mask)
629
+ output_attentions (`bool`, *optional*):
630
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
631
+ tensors for more detail.
632
+ output_hidden_states (`bool`, *optional*):
633
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
634
+ more detail.
635
+ return_dict (`bool`, *optional*):
636
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
637
+ past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
638
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
639
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
640
+ """
641
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
642
+ output_hidden_states = (
643
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
644
+ )
645
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
646
+
647
+ if encoder_hidden_states is not None and encoder_attention_mask is None:
648
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
649
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
650
+
651
+ # prepare decoder inputs
652
+ if attention_mask is None:
653
+ attention_mask = jnp.ones_like(input_ids)
654
+ if position_ids is None:
655
+ batch_size, sequence_length = input_ids.shape
656
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
657
+
658
+ # Handle any PRNG if needed
659
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
660
+
661
+ inputs = {"params": params or self.params}
662
+
663
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
664
+ # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
665
+ # changed by FlaxBartAttention module
666
+ if past_key_values:
667
+ inputs["cache"] = past_key_values
668
+ mutable = ["cache"]
669
+ else:
670
+ mutable = False
671
+
672
+ outputs = self.module.apply(
673
+ inputs,
674
+ input_ids=jnp.array(input_ids, dtype="i4"),
675
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
676
+ position_ids=jnp.array(position_ids, dtype="i4"),
677
+ encoder_hidden_states=encoder_hidden_states,
678
+ encoder_attention_mask=encoder_attention_mask,
679
+ output_attentions=output_attentions,
680
+ output_hidden_states=output_hidden_states,
681
+ return_dict=return_dict,
682
+ deterministic=not train,
683
+ rngs=rngs,
684
+ mutable=mutable,
685
+ )
686
+
687
+ # add updated cache to model output
688
+ if past_key_values is not None and return_dict:
689
+ outputs, past_key_values = outputs
690
+ outputs["past_key_values"] = unfreeze(past_key_values["cache"])
691
+ return outputs
692
+ elif past_key_values is not None and not return_dict:
693
+ outputs, past_key_values = outputs
694
+ outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
695
+
696
+ return outputs
697
+
698
+
699
+ class FlaxBartDecoderWrapper(nn.Module):
700
+ """
701
+ This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
702
+ used in combination with the [`EncoderDecoderModel`] framework.
703
+ """
704
+
705
+ config: BartConfig
706
+ dtype: jnp.dtype = jnp.float32
707
+
708
+ def setup(self):
709
+ embed_dim = self.config.d_model
710
+ embed_tokens = nn.Embed(
711
+ self.config.vocab_size,
712
+ embed_dim,
713
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
714
+ )
715
+ self.decoder = FlaxBartDecoder(config=self.config, embed_tokens=embed_tokens, dtype=self.dtype)
716
+
717
+ def __call__(self, *args, **kwargs):
718
+ return self.decoder(*args, **kwargs)
719
+
720
+
721
+ class FlaxBartForCausalLMModule(nn.Module):
722
+ """Bart Decoder Module with a language modeling head on top (linear layer with weights tied to the input embeddings)
723
+ e.g. for autoregressive tasks.
724
+ """
725
+
726
+ config: BartConfig
727
+ dtype: jnp.dtype = jnp.float32
728
+
729
+ def setup(self):
730
+ self.model = FlaxBartDecoderWrapper(config=self.config, dtype=self.dtype)
731
+ self.lm_head = nn.Dense(
732
+ self.config.vocab_size,
733
+ use_bias=False,
734
+ dtype=self.dtype,
735
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
736
+ )
737
+
738
+ def __call__(
739
+ self,
740
+ input_ids,
741
+ attention_mask,
742
+ position_ids,
743
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
744
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
745
+ init_cache: bool = False,
746
+ output_attentions: bool = False,
747
+ output_hidden_states: bool = False,
748
+ return_dict: bool = True,
749
+ deterministic: bool = True,
750
+ ):
751
+
752
+ outputs = self.model(
753
+ input_ids,
754
+ attention_mask,
755
+ position_ids,
756
+ encoder_hidden_states,
757
+ encoder_attention_mask,
758
+ deterministic=deterministic,
759
+ init_cache=init_cache,
760
+ output_attentions=output_attentions,
761
+ output_hidden_states=output_hidden_states,
762
+ return_dict=return_dict,
763
+ )
764
+
765
+ hidden_states = outputs[0]
766
+
767
+ if self.config.tie_word_embeddings:
768
+ shared_embedding = self.model.variables["params"]["decoder"]["embed_tokens"]["embedding"]
769
+ lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
770
+ else:
771
+ lm_logits = self.lm_head(hidden_states)
772
+
773
+ if not return_dict:
774
+ return (lm_logits,) + outputs[1:]
775
+
776
+ return FlaxCausalLMOutputWithCrossAttentions(
777
+ logits=lm_logits,
778
+ hidden_states=outputs.hidden_states,
779
+ attentions=outputs.attentions,
780
+ cross_attentions=outputs.cross_attentions,
781
+ )
782
+
783
+
784
+ class FlaxBartForCausalLM(FlaxBartDecoderPreTrainedModel):
785
+ """Bart Decoder Model with a language modeling head on top (linear layer with weights tied to the input embeddings)
786
+ e.g. for autoregressive tasks.
787
+ """
788
+
789
+ module_class = FlaxBartForCausalLMModule
790
+
791
+ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
792
+ # initializing the cache
793
+ batch_size, seq_length = input_ids.shape
794
+
795
+ past_key_values = self.init_cache(batch_size, max_length)
796
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
797
+ # But since the decoder uses a causal mask, those positions are masked anyway.
798
+ # Thus, we can create a single static attention_mask here, which is more efficient for compilation
799
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
800
+ if attention_mask is not None:
801
+ position_ids = attention_mask.cumsum(axis=-1) - 1
802
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
803
+ else:
804
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
805
+
806
+ return {
807
+ "past_key_values": past_key_values,
808
+ "attention_mask": extended_attention_mask,
809
+ "position_ids": position_ids,
810
+ }
811
+
812
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
813
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
814
+ model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
815
+ return model_kwargs
models/modeling_flax_speech_encoder_decoder.py ADDED
@@ -0,0 +1,1234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Classes to support Flax Speech-Encoder-Decoder architectures"""
16
+
17
+ import os
18
+ from functools import partial
19
+ from typing import Optional, Tuple, Union, Dict
20
+
21
+ import flax
22
+ import flax.linen as nn
23
+ import jax
24
+ import jax.numpy as jnp
25
+ from flax.core.frozen_dict import FrozenDict, unfreeze
26
+ from jax import lax
27
+ from jax.random import PRNGKey
28
+ import numpy as np
29
+
30
+ from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput
31
+ from transformers.modeling_flax_utils import FlaxPreTrainedModel
32
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, ModelOutput
33
+ from transformers.generation_flax_utils import FlaxLogitsProcessorList
34
+ from models.modeling_flax_wav2vec2 import FlaxWav2Vec2Model, FlaxWav2Vec2Module
35
+ from models.modeling_flax_bart import FlaxBartForCausalLM, FlaxBartForCausalLMModule
36
+ from models.configuration_bart import BartConfig
37
+ from models.configuration_wav2vec2 import Wav2Vec2Config
38
+ from models.configuration_speech_encoder_decoder import SpeechEncoderDecoderConfig
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+ _CONFIG_FOR_DOC = "SpeechEncoderDecoderConfig"
43
+
44
+ SPEECH_ENCODER_DECODER_START_DOCSTRING = r"""
45
+ This class can be used to initialize a speech-sequence-to-text-sequence model with any pretrained speech
46
+ autoencoding model as the encoder and any pretrained text autoregressive model as the decoder. The encoder is
47
+ loaded via [`~AutoModel.from_pretrained`] function and the decoder is loaded via
48
+ [`~AutoModelForCausalLM.from_pretrained`] function. Cross-attention layers are automatically added to the decoder
49
+ and should be fine-tuned on a downstream generative task, like summarization.
50
+
51
+ The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation
52
+ tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation
53
+ Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi
54
+ Zhou, Wei Li, Peter J. Liu.
55
+
56
+ Additionally, in [Large-Scale Self- and Semi-Supervised Learning for Speech
57
+ Translation](https://arxiv.org/abs/2104.06678) it is shown how leveraging large pretrained speech models for speech
58
+ translation yields a significant performance improvement.
59
+
60
+ After such an Speech-Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other
61
+ models (see the examples for more information).
62
+
63
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
64
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
65
+ etc.)
66
+
67
+ This model is also a Flax Linen
68
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
69
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
70
+
71
+ Parameters:
72
+ config ([`SpeechEncoderDecoderConfig`]): Model configuration class with all the parameters of the model.
73
+ Initializing with a config file does not load the weights associated with the model, only the
74
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
75
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
76
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
77
+ `jax.numpy.bfloat16` (on TPUs).
78
+
79
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
80
+ specified all the computation will be performed with the given `dtype`.
81
+
82
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
83
+ parameters.**
84
+
85
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
86
+ [`~FlaxPreTrainedModel.to_bf16`].
87
+ """
88
+
89
+ SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
90
+ Args:
91
+ inputs (`jnp.ndarray` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*):
92
+ Float values of input raw speech waveform or speech features. Values can be obtained by loading a *.flac*
93
+ or *.wav* audio file into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile
94
+ library (*pip install soundfile*). To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or
95
+ [`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type
96
+ *torch.FloatTensor*.
97
+ attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
98
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
99
+
100
+ - 1 for tokens that are **not masked**,
101
+ - 0 for tokens that are **masked**.
102
+
103
+ [What are attention masks?](../glossary#attention-mask)
104
+ decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
105
+ Indices of decoder input sequence tokens in the vocabulary.
106
+
107
+ Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
108
+ [`PreTrainedTokenizer.__call__`] for details.
109
+
110
+ [What are input IDs?](../glossary#input-ids)
111
+
112
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
113
+ `past_key_values`).
114
+
115
+ For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
116
+ created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
117
+ and prepending them with the `decoder_start_token_id`.
118
+ decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
119
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
120
+ be used by default.
121
+ decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
122
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
123
+ range `[0, config.decoder.max_position_embeddings - 1]`.
124
+ output_hidden_states (`bool`, *optional*):
125
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
126
+ more detail.
127
+ return_dict (`bool`, *optional*):
128
+ If set to `True`, the model will return a [`~utils.FlaxSeq2SeqLMOutput`] instead of a plain tuple.
129
+ """
130
+
131
+ SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING = r"""
132
+ Args:
133
+ inputs (`jnp.ndarray` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*):
134
+ Float values of input raw speech waveform or speech features. Values can be obtained by loading a *.flac*
135
+ or *.wav* audio file into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile
136
+ library (*pip install soundfile*). To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or
137
+ [`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type
138
+ *torch.FloatTensor*.
139
+ attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
140
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
141
+
142
+ - 1 for tokens that are **not masked**,
143
+ - 0 for tokens that are **masked**.
144
+
145
+ [What are attention masks?](../glossary#attention-mask)
146
+ output_attentions (`bool`, *optional*):
147
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
148
+ tensors for more detail.
149
+ output_hidden_states (`bool`, *optional*):
150
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
151
+ more detail.
152
+ return_dict (`bool`, *optional*):
153
+ If set to `True`, the model will return a [`~utils.FlaxBaseModelOutput`] instead of a plain tuple.
154
+ """
155
+
156
+ SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r"""
157
+ Args:
158
+ decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
159
+ Indices of decoder input sequence tokens in the vocabulary.
160
+
161
+ Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
162
+ [`PreTrainedTokenizer.__call__`] for details.
163
+
164
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
165
+
166
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
167
+ `past_key_values`).
168
+
169
+ For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
170
+ created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
171
+ and prepending them with the `decoder_start_token_id`.
172
+ encoder_outputs (`tuple(tuple(jnp.ndarray)`):
173
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
174
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
175
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
176
+ encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
177
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
178
+
179
+ - 1 for tokens that are **not masked**,
180
+ - 0 for tokens that are **masked**.
181
+
182
+ [What are attention masks?](../glossary#attention-mask)
183
+ decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
184
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
185
+ be used by default.
186
+ decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
187
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
188
+ range `[0, config.decoder.max_position_embeddings - 1]`.
189
+ past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
190
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
191
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
192
+ output_attentions (`bool`, *optional*):
193
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
194
+ tensors for more detail.
195
+ output_hidden_states (`bool`, *optional*):
196
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
197
+ more detail.
198
+ return_dict (`bool`, *optional*):
199
+ If set to `True`, the model will return a [`~utils.FlaxCausalLMOutputWithCrossAttentions`] instead of a
200
+ plain tuple.
201
+ """
202
+
203
+ @flax.struct.dataclass
204
+ class FlaxBeamSearchOutput(ModelOutput):
205
+ """
206
+ Flax Base class for outputs of decoder-only generation models using greedy search.
207
+
208
+
209
+ Args:
210
+ sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
211
+ The generated sequences.
212
+ scores (`jnp.ndarray` of shape `(batch_size,)`):
213
+ The scores (log probabilites) of the generated sequences.
214
+ """
215
+
216
+ sequences: jnp.ndarray = None
217
+ scores: jnp.ndarray = None
218
+
219
+
220
+ @flax.struct.dataclass
221
+ class BeamSearchState:
222
+ cur_len: jnp.ndarray
223
+ running_sequences: jnp.ndarray
224
+ running_scores: jnp.ndarray
225
+ sequences: jnp.ndarray
226
+ scores: jnp.ndarray
227
+ is_sent_finished: jnp.ndarray
228
+ model_kwargs: Dict[str, jnp.ndarray]
229
+
230
+
231
+
232
+
233
+ class FlaxSpeechEncoderDecoderModule(nn.Module):
234
+ config: SpeechEncoderDecoderConfig
235
+ dtype: jnp.dtype = jnp.float32
236
+
237
+ def setup(self):
238
+ encoder_config = self.config.encoder
239
+ decoder_config = self.config.decoder
240
+
241
+ # TODO: configure FlaxAutoModel mappings (required when trialling different encoder-decoder combinations)
242
+ encoder_module = FlaxWav2Vec2Module
243
+ decoder_module = FlaxBartForCausalLMModule
244
+
245
+ self.encoder = encoder_module(encoder_config, dtype=self.dtype)
246
+ self.decoder = decoder_module(decoder_config, dtype=self.dtype)
247
+
248
+ # encoder outputs might need to be projected to different dimension for decoder
249
+ if (
250
+ self.encoder.config.hidden_size != self.decoder.config.hidden_size
251
+ and self.decoder.config.cross_attention_hidden_size is None
252
+ ):
253
+ self.enc_to_dec_proj = nn.Dense(
254
+ self.decoder.config.hidden_size,
255
+ kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range),
256
+ dtype=self.dtype,
257
+ )
258
+ else:
259
+ self.enc_to_dec_proj = None
260
+
261
+ def _get_feat_extract_output_lengths(
262
+ self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
263
+ ):
264
+ """
265
+ Computes the output length of the convolutional layers
266
+ """
267
+
268
+ add_adapter = self.config.encoder.add_adapter if add_adapter is None else add_adapter
269
+
270
+ def _conv_out_length(input_length, kernel_size, stride):
271
+ # 1D convolutional layer output length formula taken
272
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
273
+ return (input_length - kernel_size) // stride + 1
274
+
275
+ for kernel_size, stride in zip(self.config.encoder.conv_kernel, self.config.encoder.conv_stride):
276
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
277
+
278
+ if add_adapter:
279
+ for _ in range(self.config.encoder.num_adapter_layers):
280
+ input_lengths = _conv_out_length(input_lengths, 1, self.config.encoder.adapter_stride)
281
+
282
+ return input_lengths
283
+
284
+ def _get_encoder_module(self):
285
+ return self.encoder
286
+
287
+ def _get_projection_module(self):
288
+ return self.enc_to_dec_proj
289
+
290
+ def _get_decoder_module(self):
291
+ return self.decoder
292
+
293
+ def __call__(
294
+ self,
295
+ inputs,
296
+ attention_mask,
297
+ decoder_input_ids,
298
+ decoder_attention_mask,
299
+ decoder_position_ids,
300
+ encoder_outputs=None,
301
+ extract_features=None,
302
+ output_attentions: bool = False,
303
+ output_hidden_states: bool = False,
304
+ output_features: bool = False,
305
+ return_dict: bool = True,
306
+ deterministic: bool = True,
307
+ freeze_feature_encoder: bool = False,
308
+ ):
309
+ if encoder_outputs is None:
310
+ encoder_outputs = self.encoder(
311
+ inputs,
312
+ attention_mask=attention_mask,
313
+ extract_features=extract_features,
314
+ output_attentions=output_attentions,
315
+ output_hidden_states=output_hidden_states,
316
+ output_features=output_features,
317
+ return_dict=return_dict,
318
+ deterministic=deterministic,
319
+ freeze_feature_encoder=freeze_feature_encoder,
320
+ )
321
+
322
+ if output_features:
323
+ return encoder_outputs
324
+
325
+ encoder_hidden_states = encoder_outputs[0]
326
+
327
+ # optionally project encoder_hidden_states
328
+ if self.enc_to_dec_proj is not None:
329
+ encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
330
+
331
+ # compute correct encoder attention mask
332
+ if attention_mask is not None:
333
+ encoder_attention_mask = self.encoder._get_feature_vector_attention_mask(
334
+ encoder_hidden_states.shape[1], attention_mask
335
+ )
336
+ else:
337
+ encoder_attention_mask = None
338
+
339
+ # flax script modeling_flax_wav2vec2.py
340
+ decoder_outputs = self.decoder(
341
+ input_ids=decoder_input_ids,
342
+ attention_mask=decoder_attention_mask,
343
+ position_ids=decoder_position_ids,
344
+ encoder_hidden_states=encoder_hidden_states,
345
+ encoder_attention_mask=encoder_attention_mask,
346
+ output_attentions=output_attentions,
347
+ output_hidden_states=output_hidden_states,
348
+ return_dict=return_dict,
349
+ deterministic=deterministic,
350
+ )
351
+
352
+ if not return_dict:
353
+ return decoder_outputs + encoder_outputs
354
+
355
+ return FlaxSeq2SeqLMOutput(
356
+ logits=decoder_outputs.logits,
357
+ decoder_hidden_states=decoder_outputs.hidden_states,
358
+ decoder_attentions=decoder_outputs.attentions,
359
+ cross_attentions=decoder_outputs.cross_attentions,
360
+ encoder_last_hidden_state=encoder_hidden_states,
361
+ encoder_hidden_states=encoder_outputs.hidden_states,
362
+ encoder_attentions=encoder_outputs.attentions,
363
+ )
364
+
365
+
366
+ @add_start_docstrings(SPEECH_ENCODER_DECODER_START_DOCSTRING)
367
+ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
368
+ r"""
369
+ [`FlaxSpeechEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture
370
+ with the module (flax.nn.Module) of one of the base model classes of the library as encoder module and another one
371
+ as decoder module when created with the :meth*~transformers.FlaxAutoModel.from_pretrained* class method for the
372
+ encoder and :meth*~transformers.FlaxAutoModelForCausalLM.from_pretrained* class method for the decoder.
373
+ """
374
+
375
+ config_class = SpeechEncoderDecoderConfig
376
+ base_model_prefix: str = "speech_encoder_decoder"
377
+ module_class = FlaxSpeechEncoderDecoderModule
378
+
379
+ def __init__(
380
+ self,
381
+ config: SpeechEncoderDecoderConfig,
382
+ input_shape: Optional[Tuple] = None,
383
+ seed: int = 0,
384
+ dtype: jnp.dtype = jnp.float32,
385
+ **kwargs
386
+ ):
387
+ if config.decoder.cross_attention_hidden_size is not None:
388
+ # Raise ValueError or option to project enc to dec hidden_size (eg EncAdapterLayer)
389
+ if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
390
+ raise ValueError(
391
+ "If `cross_attention_hidden_size` is specified in the decoder's configuration, "
392
+ "it has to be equal to the encoder's `hidden_size`. "
393
+ f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
394
+ f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
395
+ )
396
+
397
+ # make sure input & output embeddings are not tied
398
+ config.tie_word_embeddings = False
399
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
400
+
401
+ if input_shape is None:
402
+ # speech encoders almost always downsample the sequence length dimension
403
+ encoder_input_length = 1024
404
+ decoder_input_length = module._get_feat_extract_output_lengths(encoder_input_length)
405
+ input_shape = ((1, encoder_input_length), (1, decoder_input_length))
406
+
407
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
408
+
409
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
410
+ encoder_input_shape, decoder_input_shape = input_shape
411
+
412
+ # init input DeviceArrays
413
+ inputs = jnp.zeros(encoder_input_shape, dtype="f4")
414
+ attention_mask = jnp.ones_like(inputs, dtype="i4")
415
+ decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4")
416
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
417
+
418
+ batch_size, sequence_length = inputs.shape
419
+
420
+ decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape
421
+ if not decoder_batch_size == batch_size:
422
+ raise ValueError(
423
+ f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder and {decoder_batch_size} for decoder."
424
+ )
425
+ decoder_position_ids = jnp.broadcast_to(
426
+ jnp.arange(decoder_sequence_length)[None, :], (decoder_batch_size, decoder_sequence_length)
427
+ )
428
+
429
+ params_rng, dropout_rng = jax.random.split(rng)
430
+ rngs = {"params": params_rng, "dropout": dropout_rng}
431
+
432
+ return self.module.init(
433
+ rngs,
434
+ inputs,
435
+ attention_mask,
436
+ decoder_input_ids,
437
+ decoder_attention_mask,
438
+ decoder_position_ids,
439
+ )["params"]
440
+
441
+ def init_cache(self, batch_size, max_length, encoder_outputs):
442
+ r"""
443
+ Args:
444
+ batch_size (`int`):
445
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
446
+ max_length (`int`):
447
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
448
+ cache.
449
+ encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
450
+ `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
451
+ `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
452
+ is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
453
+ cross-attention of the decoder.
454
+ """
455
+ # init input variables to retrieve cache
456
+ decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
457
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
458
+ decoder_position_ids = jnp.broadcast_to(
459
+ jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
460
+ )
461
+
462
+ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
463
+ decoder_module = module._get_decoder_module()
464
+ return decoder_module(
465
+ input_ids=decoder_input_ids,
466
+ attention_mask=decoder_attention_mask,
467
+ position_ids=decoder_position_ids,
468
+ **kwargs,
469
+ )
470
+
471
+ init_variables = self.module.init(
472
+ jax.random.PRNGKey(0),
473
+ decoder_input_ids=decoder_input_ids,
474
+ decoder_attention_mask=decoder_attention_mask,
475
+ decoder_position_ids=decoder_position_ids,
476
+ encoder_hidden_states=encoder_outputs[0],
477
+ init_cache=True,
478
+ method=_decoder_forward, # we only need to call the decoder to init the cache
479
+ )
480
+ return unfreeze(init_variables["cache"])
481
+
482
+ def _get_feat_extract_output_lengths(
483
+ self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
484
+ ):
485
+ return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter)
486
+
487
+ @add_start_docstrings(SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING)
488
+ @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=_CONFIG_FOR_DOC)
489
+ def encode(
490
+ self,
491
+ inputs: jnp.ndarray,
492
+ attention_mask: Optional[jnp.ndarray] = None,
493
+ extract_features: Optional[jnp.ndarray] = None,
494
+ output_attentions: Optional[bool] = None,
495
+ output_hidden_states: Optional[bool] = None,
496
+ output_features: Optional[bool] = None,
497
+ return_dict: Optional[bool] = None,
498
+ train: bool = False,
499
+ freeze_feature_encoder: bool = False,
500
+ params: dict = None,
501
+ dropout_rng: PRNGKey = None,
502
+ ):
503
+ r"""
504
+ Returns:
505
+
506
+ Example:
507
+
508
+ ```python
509
+ >>> from transformers import FlaxSpeechEncoderDecoderModel
510
+
511
+ >>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized
512
+ >>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
513
+ ... "facebook/wav2vec2-large-lv60", "facebook/bart-large"
514
+ ... )
515
+
516
+ >>> inputs = jnp.ones((2, 5000), dtype=jnp.float32)
517
+ >>> encoder_outputs = model.encode(inputs)
518
+ ```"""
519
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
520
+ output_hidden_states = (
521
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
522
+ )
523
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
524
+
525
+ if attention_mask is None:
526
+ attention_mask = jnp.ones_like(inputs, dtype="i4")
527
+
528
+ if extract_features is not None:
529
+ extract_features = jnp.array(extract_features, dtype="f4")
530
+
531
+ # Handle any PRNG if needed
532
+ rngs = {}
533
+ if dropout_rng is not None:
534
+ rngs["dropout"] = dropout_rng
535
+
536
+ def _encoder_forward(module, inputs, attention_mask, **kwargs):
537
+ encode_module = module._get_encoder_module()
538
+ return encode_module(inputs, attention_mask, **kwargs)
539
+
540
+ outputs = self.module.apply(
541
+ {"params": params or self.params},
542
+ inputs=jnp.array(inputs, dtype="f4"),
543
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
544
+ extract_features=extract_features,
545
+ output_attentions=output_attentions,
546
+ output_hidden_states=output_hidden_states,
547
+ output_features=output_features,
548
+ return_dict=return_dict,
549
+ deterministic=not train,
550
+ freeze_feature_encoder=freeze_feature_encoder,
551
+ rngs=rngs,
552
+ method=_encoder_forward,
553
+ )
554
+
555
+ if return_dict and not output_features:
556
+ outputs = FlaxBaseModelOutput(
557
+ last_hidden_state=outputs.last_hidden_state,
558
+ hidden_states=outputs.hidden_states,
559
+ attentions=outputs.attentions,
560
+ )
561
+
562
+ return outputs
563
+
564
+ @add_start_docstrings(SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING)
565
+ @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
566
+ def decode(
567
+ self,
568
+ decoder_input_ids,
569
+ encoder_outputs,
570
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
571
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
572
+ decoder_position_ids: Optional[jnp.ndarray] = None,
573
+ past_key_values: dict = None,
574
+ output_attentions: Optional[bool] = None,
575
+ output_hidden_states: Optional[bool] = None,
576
+ return_dict: Optional[bool] = None,
577
+ train: bool = False,
578
+ params: dict = None,
579
+ dropout_rng: PRNGKey = None,
580
+ ):
581
+ r"""
582
+ Returns:
583
+
584
+ Example:
585
+
586
+ ```python
587
+ >>> from transformers import FlaxSpeechEncoderDecoderModel
588
+ >>> import jax.numpy as jnp
589
+
590
+ >>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized
591
+ >>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
592
+ ... "facebook/wav2vec2-large-lv60", "facebook/bart-large"
593
+ ... )
594
+
595
+ >>> inputs = jnp.ones((2, 5000), dtype=jnp.float32)
596
+ >>> encoder_outputs = model.encode(inputs)
597
+
598
+ >>> decoder_start_token_id = model.config.decoder.bos_token_id
599
+ >>> decoder_input_ids = jnp.ones((inputs.shape[0], 1), dtype="i4") * decoder_start_token_id
600
+
601
+ >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
602
+ >>> logits = outputs.logits
603
+ ```"""
604
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
605
+ output_hidden_states = (
606
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
607
+ )
608
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
609
+
610
+ encoder_hidden_states = encoder_outputs[0]
611
+ if encoder_attention_mask is None:
612
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
613
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
614
+
615
+ batch_size, sequence_length = decoder_input_ids.shape
616
+ if decoder_attention_mask is None:
617
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length))
618
+
619
+ if decoder_position_ids is None:
620
+ if past_key_values is not None:
621
+ raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
622
+
623
+ decoder_position_ids = jnp.broadcast_to(
624
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
625
+ )
626
+
627
+ # Handle any PRNG if needed
628
+ rngs = {}
629
+ if dropout_rng is not None:
630
+ rngs["dropout"] = dropout_rng
631
+
632
+ params = {"params": params or self.params}
633
+
634
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
635
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
636
+ # it can be changed by FlaxBartAttention module
637
+ if past_key_values:
638
+ params["cache"] = past_key_values
639
+ mutable = ["cache"]
640
+ else:
641
+ mutable = False
642
+
643
+ def _decoder_forward(
644
+ module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, encoder_hidden_states, **kwargs
645
+ ):
646
+
647
+ projection_module = module._get_projection_module()
648
+ decoder_module = module._get_decoder_module()
649
+
650
+ # optionally project encoder_hidden_states
651
+ if projection_module is not None:
652
+ encoder_hidden_states = projection_module(encoder_hidden_states)
653
+
654
+ return decoder_module(
655
+ decoder_input_ids,
656
+ decoder_attention_mask,
657
+ decoder_position_ids,
658
+ encoder_hidden_states,
659
+ **kwargs,
660
+ )
661
+
662
+ outputs = self.module.apply(
663
+ params,
664
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
665
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
666
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
667
+ encoder_hidden_states=encoder_hidden_states,
668
+ encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
669
+ output_attentions=output_attentions,
670
+ output_hidden_states=output_hidden_states,
671
+ return_dict=return_dict,
672
+ deterministic=not train,
673
+ rngs=rngs,
674
+ mutable=mutable,
675
+ method=_decoder_forward,
676
+ )
677
+
678
+ # add updated cache to model output
679
+ if past_key_values is not None and return_dict:
680
+ outputs, past = outputs
681
+ outputs["past_key_values"] = unfreeze(past["cache"])
682
+ return outputs
683
+ elif past_key_values is not None and not return_dict:
684
+ outputs, past = outputs
685
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
686
+
687
+ return outputs
688
+
689
+ @add_start_docstrings_to_model_forward(SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING)
690
+ @replace_return_docstrings(output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
691
+ def __call__(
692
+ self,
693
+ inputs: jnp.ndarray,
694
+ attention_mask: Optional[jnp.ndarray] = None,
695
+ extract_features: Optional[jnp.ndarray] = None,
696
+ decoder_input_ids: Optional[jnp.ndarray] = None,
697
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
698
+ decoder_position_ids: Optional[jnp.ndarray] = None,
699
+ output_attentions: Optional[bool] = None,
700
+ output_hidden_states: Optional[bool] = None,
701
+ output_features: Optional[bool] = None,
702
+ return_dict: Optional[bool] = None,
703
+ train: bool = False,
704
+ freeze_feature_encoder: bool = False,
705
+ params: dict = None,
706
+ dropout_rng: PRNGKey = None,
707
+ ):
708
+ r"""
709
+ Returns:
710
+
711
+ Examples:
712
+
713
+ ```python
714
+ >>> from transformers import FlaxSpeechEncoderDecoderModel, BartTokenizer
715
+
716
+ >>> # load a fine-tuned wav2vec2-2-bart model
717
+ >>> model = FlaxSpeechEncoderDecoderModel.from_pretrained("patrickvonplaten/wav2vec2-2-bart-large")
718
+ >>> # load output tokenizer
719
+ >>> tokenizer_output = BartTokenizer.from_pretrained("facebook/bart-large")
720
+
721
+ >>> inputs = jnp.ones((2, 5000), dtype=jnp.float32)
722
+
723
+ >>> # use bart's special bos, pad and eos tokens
724
+ >>> model.config.decoder_start_token_id = model.decoder.config.bos_token_id
725
+ >>> model.config.pad_token_id = model.decoder.config.pad_token_id
726
+ >>> model.config.eos_token_id = model.decoder.config.eos_token_id
727
+
728
+ >>> outputs = model.generate(inputs)
729
+ # Assert something? More interesting input? dtype correct?
730
+ ```
731
+ """
732
+
733
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
734
+ output_hidden_states = (
735
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
736
+ )
737
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
738
+
739
+ # prepare encoder inputs
740
+ if attention_mask is None:
741
+ attention_mask = jnp.ones_like(inputs, dtype="i4")
742
+
743
+ if extract_features is not None:
744
+ inputs = None # we can omit passing the inputs to the model to save memory
745
+ extract_features = jnp.array(extract_features, dtype="f4")
746
+ else:
747
+ inputs = jnp.array(inputs, dtype="f4")
748
+
749
+ # prepare decoder inputs
750
+ if decoder_input_ids is None:
751
+ raise ValueError(
752
+ "`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must be specified as an input argument."
753
+ )
754
+ if decoder_attention_mask is None:
755
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
756
+ if decoder_position_ids is None:
757
+ batch_size, sequence_length = decoder_input_ids.shape
758
+ decoder_position_ids = jnp.broadcast_to(
759
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
760
+ )
761
+
762
+ # Handle any PRNG if needed
763
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
764
+
765
+ return self.module.apply(
766
+ {"params": params or self.params},
767
+ inputs=inputs,
768
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
769
+ extract_features=extract_features,
770
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
771
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
772
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
773
+ output_attentions=output_attentions,
774
+ output_hidden_states=output_hidden_states,
775
+ output_features=output_features,
776
+ return_dict=return_dict,
777
+ deterministic=not train,
778
+ freeze_feature_encoder=freeze_feature_encoder,
779
+ rngs=rngs,
780
+ )
781
+
782
+ def prepare_inputs_for_generation(
783
+ self,
784
+ decoder_input_ids,
785
+ max_length,
786
+ attention_mask: Optional[jnp.DeviceArray] = None,
787
+ decoder_attention_mask: Optional[jnp.DeviceArray] = None,
788
+ encoder_outputs=None,
789
+ **kwargs
790
+ ):
791
+ # initializing the cache
792
+ batch_size, seq_length = decoder_input_ids.shape
793
+
794
+ past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
795
+ # Note that usually one would have to put 0's in the attention_mask for x > input.shape[-1] and x < cache_length.
796
+ # But since the decoder uses a causal mask, those positions are masked anyways.
797
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
798
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
799
+ if decoder_attention_mask is not None:
800
+ decoder_position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
801
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
802
+ else:
803
+ decoder_position_ids = jnp.broadcast_to(
804
+ jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
805
+ )
806
+
807
+ return {
808
+ "past_key_values": past_key_values,
809
+ "encoder_outputs": encoder_outputs,
810
+ "encoder_attention_mask": attention_mask,
811
+ "decoder_attention_mask": extended_attention_mask,
812
+ "decoder_position_ids": decoder_position_ids,
813
+ }
814
+
815
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
816
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
817
+ model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
818
+ return model_kwargs
819
+
820
+ @classmethod
821
+ def from_encoder_decoder_pretrained(
822
+ cls,
823
+ encoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
824
+ decoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
825
+ *model_args,
826
+ **kwargs
827
+ ) -> FlaxPreTrainedModel:
828
+ r"""
829
+ Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model
830
+ checkpoints.
831
+
832
+ Params:
833
+ encoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*):
834
+ Information necessary to initiate the encoder. Can be either:
835
+
836
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
837
+ Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
838
+ user or organization name, like `dbmdz/bert-base-german-cased`.
839
+ - A path to a *directory* containing model weights saved using
840
+ [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
841
+
842
+ decoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*, defaults to `None`):
843
+ Information necessary to initiate the decoder. Can be either:
844
+
845
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
846
+ Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
847
+ user or organization name, like `dbmdz/bert-base-german-cased`.
848
+ - A path to a *directory* containing model weights saved using
849
+ [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
850
+
851
+ model_args (remaining positional arguments, *optional*):
852
+ All remaning positional arguments will be passed to the underlying model's `__init__` method.
853
+
854
+ kwargs (remaining dictionary of keyword arguments, *optional*):
855
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
856
+ `output_attentions=True`).
857
+
858
+ - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter.
859
+ - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
860
+ - To update the parent model configuration, do not use a prefix for each configuration parameter.
861
+
862
+ Behaves differently depending on whether a `config` is provided or automatically loaded.
863
+
864
+ Example:
865
+
866
+ ```python
867
+ >>> from transformers import FlaxSpeechEncoderDecoderModel
868
+
869
+ >>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized
870
+ >>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
871
+ ... "facebook/wav2vec2-large-lv60", "facebook/bart-large"
872
+ ... )
873
+ >>> # saving model after fine-tuning
874
+ >>> model.save_pretrained("./wav2vec2-2-bart-large")
875
+ >>> # load fine-tuned model
876
+ >>> model = FlaxSpeechEncoderDecoderModel.from_pretrained("./wav2vec2-2-bart-large")
877
+ ```"""
878
+
879
+ kwargs_encoder = {
880
+ argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
881
+ }
882
+
883
+ kwargs_decoder = {
884
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
885
+ }
886
+
887
+ # remove encoder, decoder kwargs from kwargs
888
+ for key in kwargs_encoder.keys():
889
+ del kwargs["encoder_" + key]
890
+ for key in kwargs_decoder.keys():
891
+ del kwargs["decoder_" + key]
892
+
893
+ # Load and initialize the encoder and decoder
894
+ # The distinction between encoder and decoder at the model level is made
895
+ # by the value of the flag `is_decoder` that we need to set correctly.
896
+ encoder = kwargs_encoder.pop("model", None)
897
+ if encoder is None:
898
+ if encoder_pretrained_model_name_or_path is None:
899
+ raise ValueError(
900
+ "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
901
+ "to be defined."
902
+ )
903
+
904
+ if "config" not in kwargs_encoder:
905
+ # TODO: AutoConfig .from_pretrained
906
+ encoder_config, kwargs_encoder = Wav2Vec2Config.from_pretrained(
907
+ encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True
908
+ )
909
+ if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
910
+ logger.info(
911
+ f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
912
+ "from a decoder model. Cross-attention and casual mask are disabled."
913
+ )
914
+ encoder_config.is_decoder = False
915
+ encoder_config.add_cross_attention = False
916
+
917
+ kwargs_encoder["config"] = encoder_config
918
+
919
+ # TODO: FlaxAutoModel .from_pretrained
920
+ encoder = FlaxWav2Vec2Model.from_pretrained(
921
+ encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder
922
+ )
923
+
924
+ decoder = kwargs_decoder.pop("model", None)
925
+ if decoder is None:
926
+ if decoder_pretrained_model_name_or_path is None:
927
+ raise ValueError(
928
+ "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
929
+ "to be defined."
930
+ )
931
+
932
+ if "config" not in kwargs_decoder:
933
+ # TODO: AutoConfig .from_pretrained
934
+ decoder_config, kwargs_decoder = BartConfig.from_pretrained(
935
+ decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
936
+ )
937
+ if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
938
+ logger.info(
939
+ f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
940
+ f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
941
+ f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
942
+ "cross attention layers."
943
+ )
944
+ decoder_config.is_decoder = True
945
+ decoder_config.add_cross_attention = True
946
+
947
+ kwargs_decoder["config"] = decoder_config
948
+
949
+ if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
950
+ logger.warning(
951
+ f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
952
+ f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
953
+ "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
954
+ "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
955
+ "`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
956
+ )
957
+
958
+ # TODO: FlaxAutoModelForCausalLM .from_pretrained
959
+ decoder = FlaxBartForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
960
+
961
+ # instantiate config with corresponding kwargs
962
+ dtype = kwargs.pop("dtype", jnp.float32)
963
+ config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
964
+
965
+ # make sure input & output word embeddings are not tied
966
+ config.tie_word_embeddings = False
967
+
968
+ # init model
969
+ model = cls(config, dtype=dtype)
970
+ model.params["encoder"] = encoder.params
971
+ model.params["decoder"] = decoder.params
972
+
973
+ return model
974
+
975
+ def _beam_search(
976
+ self,
977
+ input_ids: None,
978
+ max_length: Optional[int] = None,
979
+ pad_token_id: Optional[int] = None,
980
+ eos_token_id: Optional[int] = None,
981
+ length_penalty: Optional[float] = None,
982
+ early_stopping: Optional[bool] = None,
983
+ logits_processor: Optional[FlaxLogitsProcessorList] = None,
984
+ trace: bool = True,
985
+ params: Optional[Dict[str, jnp.ndarray]] = None,
986
+ model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
987
+ ):
988
+ """
989
+ This beam search function is heavily inspired by Flax's official example:
990
+ https://github.com/google/flax/blob/master/examples/wmt/train.py#L254
991
+ """
992
+
993
+ def flatten_beam_dim(tensor):
994
+ """Flattens the first two dimensions of a non-scalar array."""
995
+ # ignore scalars (e.g. cache index)
996
+ if tensor.ndim == 0 or tensor.ndim == 1:
997
+ return tensor
998
+ elif tensor.ndim == 6:
999
+ return tensor.reshape(tensor.shape[:1] + (tensor.shape[1] * tensor.shape[2],) + tensor.shape[3:])
1000
+ return tensor.reshape((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:])
1001
+
1002
+ def unflatten_beam_dim(tensor, batch_size, num_beams):
1003
+ """Unflattens the first, flat batch*beam dimension of a non-scalar array."""
1004
+ # ignore scalars (e.g. cache index)
1005
+ if tensor.ndim == 0 or tensor.ndim == 1:
1006
+ return tensor
1007
+ if tensor.ndim == 5:
1008
+ return tensor.reshape(tensor.shape[:1] + (batch_size, num_beams) + tensor.shape[2:])
1009
+ return tensor.reshape((batch_size, num_beams) + tensor.shape[1:])
1010
+
1011
+ def gather_beams(nested, beam_indices, batch_size, new_num_beams):
1012
+ """
1013
+ Gathers the beam slices indexed by beam_indices into new beam array.
1014
+ """
1015
+ batch_indices = jnp.reshape(
1016
+ jnp.arange(batch_size * new_num_beams) // new_num_beams, (batch_size, new_num_beams)
1017
+ )
1018
+
1019
+ def gather_fn(tensor):
1020
+ # ignore scalars (e.g. cache index)
1021
+ if tensor.ndim == 0 or tensor.ndim == 1:
1022
+ return tensor
1023
+ if tensor.ndim == 6:
1024
+ return tensor[:, batch_indices, beam_indices]
1025
+ return tensor[batch_indices, beam_indices]
1026
+
1027
+ return jax.tree_map(gather_fn, nested)
1028
+
1029
+ # init values
1030
+ max_length = max_length if max_length is not None else self.config.max_length
1031
+ pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
1032
+ eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
1033
+ length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
1034
+ early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
1035
+
1036
+ batch_size, num_beams, cur_len = input_ids.shape
1037
+
1038
+ eos_token_id = jnp.array(eos_token_id)
1039
+ pad_token_id = jnp.array(pad_token_id)
1040
+ cur_len = jnp.array(cur_len)
1041
+
1042
+ # per batch,beam-item holding current token in loop.
1043
+ sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
1044
+ running_sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
1045
+ running_sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0, 0))
1046
+
1047
+ # per batch,beam-item state bit indicating if sentence has finished.
1048
+ is_sent_finished = jnp.zeros((batch_size, num_beams), dtype=jnp.bool_)
1049
+
1050
+ # per batch,beam-item score, logprobs
1051
+ running_scores = jnp.tile(jnp.array([0.0] + [np.array(-1.0e7)] * (num_beams - 1)), [batch_size, 1])
1052
+ scores = jnp.ones((batch_size, num_beams)) * np.array(-1.0e7)
1053
+
1054
+ # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
1055
+ # and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
1056
+ model = self.decode if self.config.is_encoder_decoder else self
1057
+
1058
+ # flatten beam dim
1059
+ if "encoder_outputs" in model_kwargs:
1060
+ model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim(
1061
+ model_kwargs["encoder_outputs"]["last_hidden_state"]
1062
+ )
1063
+ if "attention_mask" in model_kwargs:
1064
+ model_kwargs["attention_mask"] = flatten_beam_dim(model_kwargs["attention_mask"])
1065
+
1066
+ # initialize model specific kwargs
1067
+ model_kwargs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), max_length, **model_kwargs)
1068
+
1069
+ # initialize state
1070
+ state = BeamSearchState(
1071
+ cur_len=cur_len,
1072
+ running_sequences=running_sequences,
1073
+ running_scores=running_scores,
1074
+ sequences=sequences,
1075
+ scores=scores,
1076
+ is_sent_finished=is_sent_finished,
1077
+ model_kwargs=model_kwargs,
1078
+ )
1079
+
1080
+ def beam_search_cond_fn(state):
1081
+ """beam search state termination condition fn."""
1082
+
1083
+ # 1. is less than max length?
1084
+ not_max_length_yet = state.cur_len < max_length
1085
+
1086
+ # 2. can the new beams still improve?
1087
+ best_running_score = state.running_scores[:, -1:] / (max_length**length_penalty)
1088
+ worst_finished_score = jnp.where(
1089
+ state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7)
1090
+ )
1091
+ improvement_still_possible = jnp.all(worst_finished_score < best_running_score)
1092
+
1093
+ # 3. is there still a beam that has not finished?
1094
+ still_open_beam = ~(jnp.all(state.is_sent_finished) & early_stopping)
1095
+
1096
+ return not_max_length_yet & still_open_beam & improvement_still_possible
1097
+
1098
+ def beam_search_body_fn(state, input_ids_length=1):
1099
+ """beam search state update fn."""
1100
+ # 1. Forward current tokens
1101
+ # Collect the current position slice along length to feed the fast
1102
+ # autoregressive decoder model. Flatten the beam dimension into batch
1103
+ # dimension for feeding into the model.
1104
+ # unflatten beam dimension
1105
+ # Unflatten beam dimension in attention cache arrays
1106
+ input_token = flatten_beam_dim(
1107
+ lax.dynamic_slice(
1108
+ state.running_sequences,
1109
+ (0, 0, state.cur_len - input_ids_length),
1110
+ (batch_size, num_beams, input_ids_length),
1111
+ )
1112
+ )
1113
+ model_outputs = model(input_token, params=params, **state.model_kwargs)
1114
+
1115
+ logits = unflatten_beam_dim(model_outputs.logits[:, -1], batch_size, num_beams)
1116
+ cache = jax.tree_map(
1117
+ lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams), model_outputs.past_key_values
1118
+ )
1119
+
1120
+ # adapt logits for FlaxMarianMTModel
1121
+ logits = self._adapt_logits_for_beam_search(logits)
1122
+
1123
+ # 2. Compute log probs
1124
+ # get log probabilities from logits,
1125
+ # process logits with processors (*e.g.* min_length, ...), and
1126
+ # add new logprobs to existing running logprobs scores.
1127
+ log_probs = jax.nn.log_softmax(logits)
1128
+ log_probs = logits_processor(
1129
+ flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), state.cur_len
1130
+ )
1131
+ log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
1132
+ log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2)
1133
+ vocab_size = log_probs.shape[2]
1134
+ log_probs = log_probs.reshape((batch_size, num_beams * vocab_size))
1135
+
1136
+ # 3. Retrieve top-K
1137
+ # Each item in batch has num_beams * vocab_size candidate sequences.
1138
+ # For each item, get the top 2*k candidates with the highest log-
1139
+ # probabilities. We gather the top 2*K beams here so that even if the best
1140
+ # K sequences reach EOS simultaneously, we have another K sequences
1141
+ # remaining to continue the live beam search.
1142
+ # Gather the top 2*K scores from _all_ beams.
1143
+ # Gather 2*k top beams.
1144
+ # Recover the beam index by floor division.
1145
+ # Recover token id by modulo division and expand Id array for broadcasting.
1146
+ # Update sequences for the 2*K top-k new sequences.
1147
+ beams_to_keep = 2 * num_beams
1148
+ topk_log_probs, topk_indices = lax.top_k(log_probs, k=beams_to_keep)
1149
+ topk_beam_indices = topk_indices // vocab_size
1150
+ topk_running_sequences = gather_beams(
1151
+ state.running_sequences, topk_beam_indices, batch_size, beams_to_keep
1152
+ )
1153
+ topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
1154
+ topk_sequences = lax.dynamic_update_slice(topk_running_sequences, topk_ids, (0, 0, state.cur_len))
1155
+
1156
+ # 4. Check which sequences have ended
1157
+ # Update current sequences:
1158
+ # Did any of these sequences reach an end marker?
1159
+ # To prevent these just finished sequences from being added to the current sequences
1160
+ # set of active beam search sequences, set their log probs to a very large
1161
+ # negative value.
1162
+ did_topk_just_finished = topk_sequences[:, :, state.cur_len] == eos_token_id
1163
+ running_topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(-1.0e7)
1164
+ # 5. Get running sequences scores for next
1165
+ # Determine the top k beam indices (from top 2*k beams) from log probs
1166
+ # and gather top k beams (from top 2*k beams).
1167
+ next_topk_indices = jnp.flip(lax.top_k(running_topk_log_probs, k=num_beams)[1], axis=1)
1168
+ next_running_sequences, next_running_scores = gather_beams(
1169
+ [topk_sequences, running_topk_log_probs], next_topk_indices, batch_size, num_beams
1170
+ )
1171
+
1172
+ # 6. Process topk logits
1173
+ # Further process log probs:
1174
+ # - add length penalty
1175
+ # - make sure no scores can be added anymore if beam is full
1176
+ # - make sure still running sequences cannot be chosen as finalized beam
1177
+ topk_log_probs = topk_log_probs / (state.cur_len**length_penalty)
1178
+ beams_in_batch_are_full = (
1179
+ jnp.broadcast_to(state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape)
1180
+ & early_stopping
1181
+ )
1182
+ add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
1183
+ topk_log_probs += add_penalty * np.array(-1.0e7)
1184
+
1185
+ # 7. Get scores, sequences, is sentence finished for next.
1186
+ # Combine sequences, scores, and flags along the beam dimension and compare
1187
+ # new finished sequence scores to existing finished scores and select the
1188
+ # best from the new set of beams
1189
+ merged_sequences = jnp.concatenate([state.sequences, topk_sequences], axis=1)
1190
+ merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1)
1191
+ merged_is_sent_finished = jnp.concatenate([state.is_sent_finished, did_topk_just_finished], axis=1)
1192
+ topk_merged_indices = jnp.flip(lax.top_k(merged_scores, k=num_beams)[1], axis=1)
1193
+ next_sequences, next_scores, next_is_sent_finished = gather_beams(
1194
+ [merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams
1195
+ )
1196
+
1197
+ # 8. Update model kwargs.
1198
+ # Determine the top k beam indices from the original set of all beams.
1199
+ # With these, gather the top k beam-associated caches.
1200
+ next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams)
1201
+ next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams)
1202
+ model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache)
1203
+ next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
1204
+
1205
+ return BeamSearchState(
1206
+ cur_len=state.cur_len + 1,
1207
+ running_scores=next_running_scores,
1208
+ running_sequences=next_running_sequences,
1209
+ scores=next_scores,
1210
+ sequences=next_sequences,
1211
+ is_sent_finished=next_is_sent_finished,
1212
+ model_kwargs=next_model_kwargs,
1213
+ )
1214
+
1215
+ # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
1216
+ if input_ids.shape[-1] > 1:
1217
+ state = partial(beam_search_body_fn, input_ids_length=input_ids.shape[-1])(state)
1218
+
1219
+ if not trace:
1220
+ state = self._run_loop_in_debug(beam_search_cond_fn, beam_search_body_fn, state)
1221
+ else:
1222
+ state = lax.while_loop(beam_search_cond_fn, beam_search_body_fn, state)
1223
+
1224
+ # Account for the edge-case where there are no finished sequences for a
1225
+ # particular batch item. If so, return running sequences for that batch item.
1226
+ none_finished = jnp.any(state.is_sent_finished, axis=1)
1227
+ sequences = jnp.where(none_finished[:, None, None], state.sequences, state.running_sequences)
1228
+ scores = jnp.where(none_finished[:, None], state.scores, state.running_scores)
1229
+
1230
+ # return all beams for each batch and the best score
1231
+ sequences = sequences[:, :]
1232
+ scores = scores[:, -1]
1233
+
1234
+ return FlaxBeamSearchOutput(sequences=sequences, scores=scores)
models/modeling_flax_wav2vec2.py ADDED
@@ -0,0 +1,975 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Flax Wav2Vec2 model."""
16
+
17
+ from functools import partial
18
+ from typing import Optional, Tuple, Union
19
+
20
+ import flax
21
+ import flax.linen as nn
22
+ import jax
23
+ import jax.numpy as jnp
24
+ from flax.core.frozen_dict import FrozenDict
25
+ from flax.linen import partitioning as nn_partitioning
26
+ from flax.linen.attention import dot_product_attention_weights
27
+ from jax import lax
28
+
29
+ from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
30
+ from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
31
+ from transformers.utils import ModelOutput
32
+
33
+ from models.configuration_wav2vec2 import Wav2Vec2Config
34
+
35
+ scan_with_axes = nn_partitioning.scan_with_axes
36
+ remat = nn_partitioning.remat
37
+
38
+
39
+ @flax.struct.dataclass
40
+ class FlaxWav2Vec2BaseModelOutput(ModelOutput):
41
+ """
42
+ Output type of [`FlaxWav2Vec2BaseModelOutput`], with potential hidden states and attentions.
43
+
44
+ Args:
45
+ last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
46
+ Sequence of hidden-states at the output of the last layer of the model.
47
+ extract_features (`jnp.ndarray` of shape `(batch_size, sequence_length, last_conv_dim)`):
48
+ Sequence of extracted feature vectors of the last convolutional layer of the model with `last_conv_dim`
49
+ being the dimension of the last convolutional layer.
50
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
51
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
52
+ `(batch_size, sequence_length, hidden_size)`.
53
+
54
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
55
+ attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
56
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
57
+ sequence_length)`.
58
+
59
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
60
+ heads.
61
+ """
62
+
63
+ last_hidden_state: jnp.ndarray = None
64
+ extract_features: jnp.ndarray = None
65
+ hidden_states: Optional[Tuple[jnp.ndarray]] = None
66
+ attentions: Optional[Tuple[jnp.ndarray]] = None
67
+
68
+
69
+ WAV_2_VEC_2_START_DOCSTRING = r"""
70
+ Wav2Vec2 was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
71
+ Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
72
+ Auli.
73
+
74
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
75
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
76
+ etc.)
77
+
78
+ This model is also a Flax Linen
79
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
80
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
81
+
82
+ Finally, this model supports inherent JAX features such as:
83
+
84
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
85
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
86
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
87
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
88
+
89
+ Parameters:
90
+ config ([`Wav2Vec2Config`]): Model configuration class with all the parameters of the model.
91
+ Initializing with a config file does not load the weights associated with the model, only the
92
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
93
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
94
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
95
+ `jax.numpy.bfloat16` (on TPUs).
96
+
97
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
98
+ specified all the computation will be performed with the given `dtype`.
99
+
100
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
101
+ parameters.**
102
+
103
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
104
+ [`~FlaxPreTrainedModel.to_bf16`].
105
+ """
106
+
107
+
108
+ WAV_2_VEC_2_INPUTS_DOCSTRING = r"""
109
+ Args:
110
+ input_values (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
111
+ Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file
112
+ into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile library (*pip install
113
+ soundfile*). To prepare the array into *input_values*, the [`Wav2Vec2Processor`] should be used for padding
114
+ and conversion into a tensor of type *jnp.ndarray*. See [`Wav2Vec2Processor.__call__`] for details.
115
+ attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
116
+ Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
117
+ 1]`:
118
+
119
+ - 1 for tokens that are **not masked**,
120
+ - 0 for tokens that are **masked**.
121
+
122
+ [What are attention masks?](../glossary#attention-mask) .. warning:: `attention_mask` should only be passed
123
+ if the corresponding processor has `config.return_attention_mask == True`. For all models whose processor
124
+ has `config.return_attention_mask == False`, such as
125
+ [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base-960h), `attention_mask` should **not** be
126
+ passed to avoid degraded performance when doing batched inference. For such models `input_values` should
127
+ simply be padded with 0 and passed without `attention_mask`. Be aware that these models also yield slightly
128
+ different results depending on whether `input_values` is padded or not.
129
+ mask_time_indices (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
130
+ Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
131
+ masked extracted features in *config.proj_codevector_dim* space.
132
+ output_attentions (`bool`, *optional*):
133
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
134
+ tensors for more detail.
135
+ output_hidden_states (`bool`, *optional*):
136
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
137
+ more detail.
138
+ return_dict (`bool`, *optional*):
139
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
140
+ """
141
+
142
+
143
+ class FlaxWav2Vec2LayerNormConvLayer(nn.Module):
144
+ config: Wav2Vec2Config
145
+ layer_id: int = 0
146
+ dtype: jnp.dtype = jnp.float32
147
+
148
+ def setup(self):
149
+ self.in_conv_dim = self.config.conv_dim[self.layer_id] if self.layer_id > 0 else 1
150
+ self.out_conv_dim = self.config.conv_dim[self.layer_id]
151
+
152
+ self.conv = nn.Conv(
153
+ features=self.config.conv_dim[self.layer_id],
154
+ kernel_size=(self.config.conv_kernel[self.layer_id],),
155
+ strides=(self.config.conv_stride[self.layer_id],),
156
+ use_bias=self.config.conv_bias,
157
+ kernel_init=jax.nn.initializers.he_normal(),
158
+ padding="VALID",
159
+ dtype=self.dtype,
160
+ )
161
+ self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
162
+ self.activation = ACT2FN[self.config.feat_extract_activation]
163
+
164
+ def __call__(self, hidden_states):
165
+ hidden_states = self.conv(hidden_states)
166
+ hidden_states = self.layer_norm(hidden_states)
167
+ hidden_states = self.activation(hidden_states)
168
+ return hidden_states
169
+
170
+
171
+ class FlaxConvWithWeightNorm(nn.Module):
172
+ config: Wav2Vec2Config
173
+ dtype: jnp.dtype = jnp.float32
174
+
175
+ def setup(self):
176
+ self.conv = nn.Conv(
177
+ features=self.config.hidden_size,
178
+ kernel_size=(self.config.num_conv_pos_embeddings,),
179
+ kernel_init=jax.nn.initializers.he_normal(),
180
+ padding="VALID",
181
+ feature_group_count=self.config.num_conv_pos_embedding_groups,
182
+ dtype=self.dtype,
183
+ )
184
+ weight_shape = (
185
+ self.conv.features,
186
+ self.conv.features // self.conv.feature_group_count,
187
+ self.conv.kernel_size[0],
188
+ )
189
+ self.weight_v = self.param("weight_v", jax.nn.initializers.he_normal(), weight_shape)
190
+ self.weight_g = self.param("weight_g", lambda _: jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :])
191
+ self.bias = self.param("bias", jax.nn.initializers.zeros, (self.conv.features,))
192
+ self.prev_padding = self.conv.kernel_size[0] // 2
193
+
194
+ def _get_normed_weights(self):
195
+ weight_v_norm = jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :]
196
+ normed_weight_v = jnp.divide(self.weight_v, weight_v_norm)
197
+ normed_kernel = jnp.multiply(normed_weight_v, self.weight_g)
198
+ return normed_kernel
199
+
200
+ def __call__(self, hidden_states):
201
+ kernel = self._get_normed_weights()
202
+ hidden_states = jnp.pad(hidden_states, ((0, 0), (self.prev_padding, self.prev_padding), (0, 0)))
203
+ hidden_states = self.conv.apply({"params": {"kernel": kernel.T, "bias": self.bias}}, hidden_states)
204
+ return hidden_states
205
+
206
+
207
+ class FlaxWav2Vec2PositionalConvEmbedding(nn.Module):
208
+ config: Wav2Vec2Config
209
+ dtype: jnp.dtype = jnp.float32
210
+
211
+ def setup(self):
212
+ self.conv = FlaxConvWithWeightNorm(self.config, dtype=self.dtype)
213
+ self.activation = ACT2FN[self.config.feat_extract_activation]
214
+ self.num_pad_remove = 1 if self.config.num_conv_pos_embeddings % 2 == 0 else 0
215
+
216
+ def __call__(self, hidden_states):
217
+ hidden_states = hidden_states.transpose((0, 1, 2))
218
+
219
+ hidden_states = self.conv(hidden_states)
220
+
221
+ if self.num_pad_remove > 0:
222
+ hidden_states = hidden_states[:, : -self.num_pad_remove, :]
223
+ hidden_states = self.activation(hidden_states)
224
+
225
+ hidden_states = hidden_states.transpose((0, 1, 2))
226
+ return hidden_states
227
+
228
+
229
+ class FlaxConvLayersCollection(nn.Module):
230
+ config: Wav2Vec2Config
231
+ dtype: jnp.dtype = jnp.float32
232
+
233
+ def setup(self):
234
+ if self.config.feat_extract_norm == "layer":
235
+ # note that we can't use scan on the conv layers as they differ on a layer-by-layer basis
236
+ BlockLayer = remat(FlaxWav2Vec2LayerNormConvLayer) if self.config.gradient_checkpointing else FlaxWav2Vec2LayerNormConvLayer
237
+ self.layers = [
238
+ BlockLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype)
239
+ for i in range(self.config.num_feat_extract_layers)
240
+ ]
241
+ elif self.config.feat_extract_norm == "group":
242
+ raise NotImplementedError("At the moment only ``config.feat_extact_norm == 'layer'`` is supported")
243
+ else:
244
+ raise ValueError(
245
+ f"`config.feat_extract_norm` is {self.config.feat_extract_norm}, but has to be one of ['group', 'layer']"
246
+ )
247
+
248
+ def __call__(self, hidden_states):
249
+ for i, conv_layer in enumerate(self.layers):
250
+ hidden_states = conv_layer(hidden_states)
251
+ return hidden_states
252
+
253
+
254
+ class FlaxWav2Vec2FeatureEncoder(nn.Module):
255
+ """Construct the features from raw audio waveform"""
256
+
257
+ config: Wav2Vec2Config
258
+ dtype: jnp.dtype = jnp.float32
259
+
260
+ def setup(self):
261
+ self.conv_layers = FlaxConvLayersCollection(self.config, dtype=self.dtype)
262
+
263
+ def __call__(self, input_values, freeze_feature_encoder=False):
264
+ hidden_states = input_values[:, :, None]
265
+ hidden_states = self.conv_layers(hidden_states)
266
+ if freeze_feature_encoder:
267
+ hidden_states = jax.lax.stop_gradient(hidden_states)
268
+ return hidden_states
269
+
270
+
271
+ class FlaxWav2Vec2FeatureProjection(nn.Module):
272
+ config: Wav2Vec2Config
273
+ dtype: jnp.dtype = jnp.float32
274
+
275
+ def setup(self):
276
+ self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
277
+ self.projection = nn.Dense(
278
+ self.config.hidden_size,
279
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
280
+ dtype=self.dtype,
281
+ )
282
+ self.dropout = nn.Dropout(rate=self.config.feat_proj_dropout)
283
+
284
+ def __call__(self, hidden_states, deterministic=True):
285
+ norm_hidden_states = self.layer_norm(hidden_states)
286
+ hidden_states = self.projection(norm_hidden_states)
287
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
288
+ return hidden_states, norm_hidden_states
289
+
290
+
291
+ class FlaxWav2Vec2Attention(nn.Module):
292
+ config: Wav2Vec2Config
293
+ embed_dim: int
294
+ num_heads: int
295
+ dropout: float = 0.0
296
+ bias: bool = True
297
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
298
+
299
+ def setup(self) -> None:
300
+ self.head_dim = self.embed_dim // self.num_heads
301
+ if self.head_dim * self.num_heads != self.embed_dim:
302
+ raise ValueError(
303
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
304
+ )
305
+
306
+ dense = partial(
307
+ nn.Dense,
308
+ self.embed_dim,
309
+ use_bias=self.bias,
310
+ dtype=self.dtype,
311
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
312
+ )
313
+
314
+ self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
315
+
316
+ self.fused_proj = nn.Dense(
317
+ self.embed_dim * 3,
318
+ use_bias=self.bias,
319
+ dtype=self.dtype,
320
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
321
+ )
322
+
323
+ self.out_proj = dense()
324
+
325
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
326
+
327
+ def _split_heads(self, hidden_states):
328
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
329
+
330
+ def _merge_heads(self, hidden_states):
331
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
332
+
333
+ def __call__(
334
+ self,
335
+ hidden_states: jnp.ndarray,
336
+ key_value_states: Optional[jnp.ndarray] = None,
337
+ attention_mask: Optional[jnp.ndarray] = None,
338
+ deterministic: bool = True,
339
+ ) -> Tuple[jnp.ndarray]:
340
+ """Input shape: Batch x Time x Channel"""
341
+
342
+ if self.config.fuse_matmuls:
343
+ attention_states = self.fused_proj(hidden_states)
344
+ query_states, key_states, value_states = jnp.split(attention_states, 3, axis=-1)
345
+
346
+ else:
347
+ # get query proj
348
+ query_states = self.q_proj(hidden_states)
349
+
350
+ key_states = self.k_proj(hidden_states)
351
+ value_states = self.v_proj(hidden_states)
352
+
353
+ query_states = self._split_heads(query_states)
354
+ key_states = self._split_heads(key_states)
355
+ value_states = self._split_heads(value_states)
356
+
357
+ if attention_mask is not None:
358
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
359
+
360
+ # Convert the boolean attention mask to an attention bias.
361
+ if attention_mask is not None:
362
+ # attention mask in the form of attention bias
363
+ attention_bias = lax.select(
364
+ attention_mask > 0,
365
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
366
+ jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
367
+ )
368
+ else:
369
+ attention_bias = None
370
+
371
+ dropout_rng = None
372
+ if not deterministic and self.dropout > 0.0:
373
+ dropout_rng = self.make_rng("dropout")
374
+
375
+ attn_weights = dot_product_attention_weights(
376
+ query_states,
377
+ key_states,
378
+ bias=attention_bias,
379
+ dropout_rng=dropout_rng,
380
+ dropout_rate=self.dropout,
381
+ broadcast_dropout=True,
382
+ deterministic=deterministic,
383
+ dtype=self.dtype,
384
+ precision=None,
385
+ )
386
+
387
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
388
+ attn_output = self._merge_heads(attn_output)
389
+ attn_output = self.out_proj(attn_output)
390
+
391
+ return attn_output, attn_weights
392
+
393
+
394
+ class FlaxWav2Vec2FeedForward(nn.Module):
395
+ config: Wav2Vec2Config
396
+ dtype: jnp.dtype = jnp.float32
397
+
398
+ def setup(self):
399
+ self.intermediate_dropout = nn.Dropout(rate=self.config.activation_dropout)
400
+
401
+ self.intermediate_dense = nn.Dense(
402
+ self.config.intermediate_size,
403
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
404
+ dtype=self.dtype,
405
+ )
406
+ if isinstance(self.config.hidden_act, str):
407
+ self.intermediate_act_fn = ACT2FN[self.config.hidden_act]
408
+ else:
409
+ self.intermediate_act_fn = self.config.hidden_act
410
+
411
+ self.output_dense = nn.Dense(
412
+ self.config.hidden_size,
413
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
414
+ dtype=self.dtype,
415
+ )
416
+ self.output_dropout = nn.Dropout(rate=self.config.hidden_dropout)
417
+
418
+ def __call__(self, hidden_states, deterministic=True):
419
+ hidden_states = self.intermediate_dense(hidden_states)
420
+ hidden_states = self.intermediate_act_fn(hidden_states)
421
+ hidden_states = self.intermediate_dropout(hidden_states, deterministic=deterministic)
422
+
423
+ hidden_states = self.output_dense(hidden_states)
424
+ hidden_states = self.output_dropout(hidden_states, deterministic=deterministic)
425
+ return hidden_states
426
+
427
+
428
+ class FlaxWav2Vec2EncoderLayerStableLayerNorm(nn.Module):
429
+ config: Wav2Vec2Config
430
+ dtype: jnp.dtype = jnp.float32
431
+
432
+ def setup(self):
433
+ self.attention = FlaxWav2Vec2Attention(
434
+ config=self.config,
435
+ embed_dim=self.config.hidden_size,
436
+ num_heads=self.config.num_attention_heads,
437
+ dropout=self.config.attention_dropout,
438
+ dtype=self.dtype,
439
+ )
440
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout)
441
+ self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
442
+ self.feed_forward = FlaxWav2Vec2FeedForward(self.config, dtype=self.dtype)
443
+ self.final_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
444
+
445
+ def __call__(self, hidden_states, attention_mask=None, deterministic=True, output_attentions=False):
446
+ if self.config.use_scan:
447
+ hidden_states = hidden_states[0]
448
+ attn_residual = hidden_states
449
+ hidden_states = self.layer_norm(hidden_states)
450
+ hidden_states, attn_weights = self.attention(
451
+ hidden_states, attention_mask=attention_mask, deterministic=deterministic
452
+ )
453
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
454
+ hidden_states = attn_residual + hidden_states
455
+ hidden_states = hidden_states + self.feed_forward(
456
+ self.final_layer_norm(hidden_states), deterministic=deterministic
457
+ )
458
+
459
+ outputs = (hidden_states,)
460
+
461
+ if output_attentions:
462
+ outputs += (attn_weights,)
463
+
464
+ if self.config.use_scan:
465
+ outputs = (outputs, None)
466
+
467
+ return outputs
468
+
469
+
470
+ class FlaxWav2Vec2EncoderLayerStableLayerNormCollection(nn.Module):
471
+ config: Wav2Vec2Config
472
+ dtype: jnp.dtype = jnp.float32
473
+
474
+ @nn.compact
475
+ def __call__(
476
+ self,
477
+ hidden_states,
478
+ attention_mask=None,
479
+ deterministic: bool = True,
480
+ output_attentions: bool = False,
481
+ output_hidden_states: bool = False,
482
+ return_dict: bool = True,
483
+ ):
484
+ all_attentions = () if output_attentions else None
485
+ all_hidden_states = () if output_hidden_states else None
486
+
487
+ num_layers = self.config.num_hidden_layers
488
+ BlockEncoderLayer = (
489
+ remat(
490
+ FlaxWav2Vec2EncoderLayerStableLayerNorm,
491
+ static_argnums=(2, 3),
492
+ prevent_cse=not self.config.use_scan,
493
+ )
494
+ if self.config.gradient_checkpointing
495
+ else FlaxWav2Vec2EncoderLayerStableLayerNorm
496
+ )
497
+
498
+ if self.config.use_scan:
499
+ # since all decoder layers are the same, we use nn.scan directly
500
+ assert not output_attentions, "cannot use `scan` with `output_attentions` set to `True`"
501
+ assert not output_hidden_states, "cannot use `scan` with `output_hidden_states` set to `True`"
502
+ hidden_states = (hidden_states,)
503
+
504
+ hidden_states, _ = scan_with_axes(
505
+ BlockEncoderLayer,
506
+ variable_axes={"params": 0, "cache": 0},
507
+ split_rngs={"params": True, "dropout": True},
508
+ in_axes=(nn.broadcast, nn.broadcast, nn.broadcast),
509
+ length=num_layers,
510
+ )(self.config, dtype=self.dtype, name="FlaxWav2Vec2EncoderLayers",)(
511
+ hidden_states, attention_mask, deterministic, output_attentions
512
+ )
513
+ hidden_states = hidden_states[0]
514
+
515
+ else:
516
+ for layer in range(num_layers):
517
+ if output_hidden_states:
518
+ all_hidden_states += (hidden_states,)
519
+
520
+ layer_outputs = BlockEncoderLayer(
521
+ self.config,
522
+ dtype=self.dtype,
523
+ name=str(layer),
524
+ )(hidden_states, attention_mask, deterministic, output_attentions)
525
+
526
+ hidden_states = layer_outputs[0]
527
+
528
+ if output_attentions:
529
+ all_attentions += (layer_outputs[1],)
530
+
531
+ if output_hidden_states:
532
+ all_hidden_states += (hidden_states,)
533
+
534
+ outputs = (hidden_states, all_hidden_states, all_attentions)
535
+
536
+ if not return_dict:
537
+ return tuple(v for v in outputs if v is not None)
538
+
539
+ return FlaxBaseModelOutput(
540
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
541
+ )
542
+
543
+
544
+ class FlaxWav2Vec2StableLayerNormEncoder(nn.Module):
545
+ config: Wav2Vec2Config
546
+ dtype: jnp.dtype = jnp.float32
547
+
548
+ def setup(self):
549
+ self.pos_conv_embed = FlaxWav2Vec2PositionalConvEmbedding(self.config, dtype=self.dtype)
550
+ self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
551
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout)
552
+ self.layers = FlaxWav2Vec2EncoderLayerStableLayerNormCollection(self.config, dtype=self.dtype)
553
+
554
+ def __call__(
555
+ self,
556
+ hidden_states,
557
+ attention_mask=None,
558
+ deterministic=True,
559
+ output_attentions=False,
560
+ output_hidden_states=False,
561
+ return_dict=True,
562
+ ):
563
+
564
+ if attention_mask is not None:
565
+ # make sure padded tokens are not attended to
566
+ hidden_states = jnp.where(
567
+ jnp.broadcast_to(attention_mask[:, :, None], hidden_states.shape), hidden_states, 0
568
+ )
569
+
570
+ position_embeddings = self.pos_conv_embed(hidden_states)
571
+
572
+ hidden_states = hidden_states + position_embeddings
573
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
574
+
575
+ outputs = self.layers(
576
+ hidden_states,
577
+ attention_mask,
578
+ output_attentions=output_attentions,
579
+ output_hidden_states=output_hidden_states,
580
+ return_dict=return_dict,
581
+ )
582
+
583
+ last_hidden_state = self.layer_norm(outputs[0])
584
+
585
+ # update the last element in `hidden_states` after applying `layernorm` above
586
+ hidden_states = None
587
+ if output_hidden_states:
588
+ hidden_states = outputs[1]
589
+ hidden_states = hidden_states[:-1] + (last_hidden_state,)
590
+
591
+ if not return_dict:
592
+ outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
593
+ return tuple(v for v in outputs if v is not None)
594
+
595
+ return FlaxBaseModelOutput(
596
+ last_hidden_state=last_hidden_state, hidden_states=hidden_states, attentions=outputs.attentions
597
+ )
598
+
599
+
600
+ class FlaxWav2Vec2Adapter(nn.Module):
601
+ config: Wav2Vec2Config
602
+ dtype: jnp.dtype = jnp.float32
603
+
604
+ def setup(self):
605
+ # hidden_states require down-projection if feature dims don't match
606
+ if self.config.output_hidden_size != self.config.hidden_size:
607
+ self.proj = nn.Dense(
608
+ self.config.output_hidden_size,
609
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
610
+ dtype=self.dtype,
611
+ )
612
+ self.proj_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
613
+ else:
614
+ self.proj = self.proj_layer_norm = None
615
+
616
+ self.layers = FlaxWav2Vec2AdapterLayersCollection(self.config, dtype=self.dtype)
617
+
618
+ def __call__(self, hidden_states, deterministic=True):
619
+ # down-project hidden_states if required
620
+ if self.proj is not None and self.proj_layer_norm is not None:
621
+ hidden_states = self.proj(hidden_states)
622
+ hidden_states = self.proj_layer_norm(hidden_states)
623
+
624
+ hidden_states = self.layers(hidden_states)
625
+
626
+ return hidden_states
627
+
628
+
629
+ class FlaxWav2Vec2AdapterLayer(nn.Module):
630
+ config: Wav2Vec2Config
631
+ dtype: jnp.dtype = jnp.float32
632
+
633
+ def setup(self):
634
+ self.conv = nn.Conv(
635
+ features=2 * self.config.output_hidden_size,
636
+ kernel_size=(self.config.adapter_kernel_size,),
637
+ strides=(self.config.adapter_stride,),
638
+ padding=((1, 1),),
639
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
640
+ dtype=self.dtype,
641
+ )
642
+
643
+ def __call__(self, hidden_states):
644
+ hidden_states = self.conv(hidden_states)
645
+ hidden_states = nn.glu(hidden_states, axis=2)
646
+
647
+ return hidden_states
648
+
649
+
650
+ class FlaxWav2Vec2AdapterLayersCollection(nn.Module):
651
+ config: Wav2Vec2Config
652
+ dtype: jnp.dtype = jnp.float32
653
+
654
+ def setup(self):
655
+ BlockAdapterLayer = remat(FlaxWav2Vec2AdapterLayer) if self.config.gradient_checkpointing else FlaxWav2Vec2AdapterLayer
656
+ self.layers = [
657
+ BlockAdapterLayer(self.config, name=str(i), dtype=self.dtype)
658
+ for i in range(self.config.num_adapter_layers)
659
+ ]
660
+
661
+ def __call__(self, hidden_states):
662
+ for conv_layer in self.layers:
663
+ hidden_states = conv_layer(hidden_states)
664
+
665
+ return hidden_states
666
+
667
+
668
+ class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
669
+ """
670
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
671
+ models.
672
+ """
673
+
674
+ config_class = Wav2Vec2Config
675
+ base_model_prefix: str = "wav2vec2"
676
+ main_input_name = "input_values"
677
+ module_class: nn.Module = None
678
+
679
+ def __init__(
680
+ self,
681
+ config: Wav2Vec2Config,
682
+ input_shape: Tuple = (1, 1024),
683
+ seed: int = 0,
684
+ dtype: jnp.dtype = jnp.float32,
685
+ _do_init: bool = True,
686
+ **kwargs,
687
+ ):
688
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
689
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
690
+
691
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
692
+ # init input tensors
693
+ input_values = jnp.zeros(input_shape, dtype="i4")
694
+ attention_mask = jnp.ones_like(input_values)
695
+ params_rng, dropout_rng = jax.random.split(rng, 2)
696
+ rngs = {"params": params_rng, "dropout": dropout_rng}
697
+
698
+ return self.module.init(rngs, input_values, attention_mask, return_dict=False)["params"]
699
+
700
+ def __call__(
701
+ self,
702
+ input_values,
703
+ attention_mask=None,
704
+ mask_time_indices=None,
705
+ extract_features=None,
706
+ params: dict = None,
707
+ dropout_rng: jax.random.PRNGKey = None,
708
+ train: bool = False,
709
+ output_attentions: Optional[bool] = None,
710
+ output_hidden_states: Optional[bool] = None,
711
+ output_features: Optional[bool] = None,
712
+ freeze_feature_encoder: bool = False,
713
+ return_dict: Optional[bool] = None,
714
+ ):
715
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
716
+ output_hidden_states = (
717
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
718
+ )
719
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
720
+
721
+ if attention_mask is None:
722
+ batch_size, sequence_length = input_values.shape
723
+ attention_mask = jnp.ones((batch_size, sequence_length))
724
+
725
+ if extract_features is not None:
726
+ extract_features = jnp.array(extract_features, dtype="f4")
727
+
728
+ # Handle any PRNG if needed
729
+ rngs = {}
730
+ if dropout_rng is not None:
731
+ rngs["dropout"] = dropout_rng
732
+
733
+ inputs = {"params": params or self.params}
734
+
735
+ return self.module.apply(
736
+ inputs,
737
+ jnp.array(input_values, dtype="f4"),
738
+ jnp.array(attention_mask, dtype="i4"),
739
+ mask_time_indices,
740
+ extract_features,
741
+ not train,
742
+ output_attentions,
743
+ output_hidden_states,
744
+ output_features,
745
+ freeze_feature_encoder,
746
+ return_dict,
747
+ rngs=rngs,
748
+ )
749
+
750
+ def _get_feat_extract_output_lengths(
751
+ self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
752
+ ):
753
+ return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter)
754
+
755
+ def _get_feature_vector_attention_mask(
756
+ self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
757
+ ):
758
+ return self.module._get_feature_vector_attention_mask(feature_vector_length, attention_mask, add_adapter=add_adapter)
759
+
760
+
761
+ class FlaxWav2Vec2Module(nn.Module):
762
+ config: Wav2Vec2Config
763
+ dtype: jnp.dtype = jnp.float32
764
+
765
+ def setup(self):
766
+ self.feature_extractor = FlaxWav2Vec2FeatureEncoder(self.config, dtype=self.dtype)
767
+ self.feature_projection = FlaxWav2Vec2FeatureProjection(self.config, dtype=self.dtype)
768
+ self.masked_spec_embed = self.param(
769
+ "masked_spec_embed", jax.nn.initializers.uniform(), (self.config.hidden_size,)
770
+ )
771
+
772
+ if self.config.do_stable_layer_norm:
773
+ self.encoder = FlaxWav2Vec2StableLayerNormEncoder(self.config, dtype=self.dtype)
774
+ else:
775
+ raise NotImplementedError("``config.do_stable_layer_norm is False`` is currently not supported.")
776
+
777
+ self.adapter = FlaxWav2Vec2Adapter(self.config, dtype=self.dtype) if self.config.add_adapter else None
778
+
779
+ def __call__(
780
+ self,
781
+ input_values,
782
+ attention_mask=None,
783
+ mask_time_indices=None,
784
+ extract_features=None,
785
+ deterministic=True,
786
+ output_attentions=None,
787
+ output_hidden_states=None,
788
+ output_features=False,
789
+ freeze_feature_encoder=False,
790
+ return_dict=None,
791
+ ):
792
+
793
+ # forward pass through the feature extractor if features not specified
794
+ if extract_features is None:
795
+ extract_features = self.feature_extractor(input_values, freeze_feature_encoder=freeze_feature_encoder)
796
+
797
+ if output_features:
798
+ return extract_features
799
+
800
+ # make sure that no loss is computed on padded inputs
801
+ if attention_mask is not None:
802
+ # compute reduced attention_mask corresponding to feature vectors
803
+ attention_mask = self._get_feature_vector_attention_mask(
804
+ extract_features.shape[1], attention_mask, add_adapter=False
805
+ )
806
+
807
+ hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic)
808
+ if mask_time_indices is not None: # apply SpecAugment along time axis with given indices
809
+ hidden_states = jnp.where(
810
+ jnp.broadcast_to(mask_time_indices[:, :, None], hidden_states.shape),
811
+ jnp.broadcast_to(self.masked_spec_embed[None, None, :], hidden_states.shape),
812
+ hidden_states,
813
+ )
814
+
815
+ encoder_outputs = self.encoder(
816
+ hidden_states,
817
+ attention_mask=attention_mask,
818
+ deterministic=deterministic,
819
+ output_attentions=output_attentions,
820
+ output_hidden_states=output_hidden_states,
821
+ return_dict=return_dict,
822
+ )
823
+
824
+ hidden_states = encoder_outputs[0]
825
+
826
+ if self.adapter is not None:
827
+ hidden_states = self.adapter(hidden_states)
828
+
829
+ if not return_dict:
830
+ return (hidden_states, extract_features) + encoder_outputs[1:]
831
+
832
+ return FlaxWav2Vec2BaseModelOutput(
833
+ last_hidden_state=hidden_states,
834
+ extract_features=extract_features,
835
+ hidden_states=encoder_outputs.hidden_states,
836
+ attentions=encoder_outputs.attentions,
837
+ )
838
+
839
+ def _get_feat_extract_output_lengths(
840
+ self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
841
+ ):
842
+ """
843
+ Computes the output length of the convolutional layers
844
+ """
845
+
846
+ add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
847
+
848
+ def _conv_out_length(input_length, kernel_size, stride):
849
+ # 1D convolutional layer output length formula taken
850
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
851
+ return (input_length - kernel_size) // stride + 1
852
+
853
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
854
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
855
+
856
+ if add_adapter:
857
+ for _ in range(self.config.num_adapter_layers):
858
+ input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
859
+
860
+ return input_lengths
861
+
862
+ def _get_feature_vector_attention_mask(
863
+ self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
864
+ ):
865
+ # Effectively attention_mask.sum(-1), but not inplace to be able to run
866
+ # on inference mode.
867
+ non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1]
868
+
869
+ output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
870
+
871
+ batch_size = attention_mask.shape[0]
872
+
873
+ attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
874
+ # these two operations makes sure that all values
875
+ # before the output lengths indices are attended to
876
+ attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1)
877
+ attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
878
+ return attention_mask
879
+
880
+
881
+ class FlaxWav2Vec2Model(FlaxWav2Vec2PreTrainedModel):
882
+ module_class = FlaxWav2Vec2Module
883
+
884
+
885
+ class FlaxWav2Vec2ForCTCModule(nn.Module):
886
+ config: Wav2Vec2Config
887
+ dtype: jnp.dtype = jnp.float32
888
+
889
+ def setup(self):
890
+ self.wav2vec2 = FlaxWav2Vec2Module(self.config, dtype=self.dtype)
891
+ self.dropout = nn.Dropout(rate=self.config.final_dropout)
892
+ self.lm_head = nn.Dense(
893
+ self.config.vocab_size,
894
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
895
+ dtype=self.dtype,
896
+ )
897
+
898
+ def __call__(
899
+ self,
900
+ input_values,
901
+ attention_mask=None,
902
+ mask_time_indices=None,
903
+ extract_features=None,
904
+ deterministic=True,
905
+ output_attentions=None,
906
+ output_hidden_states=None,
907
+ output_features=False,
908
+ freeze_feature_encoder=False,
909
+ return_dict=None,
910
+ ):
911
+ outputs = self.wav2vec2(
912
+ input_values,
913
+ attention_mask=attention_mask,
914
+ mask_time_indices=mask_time_indices,
915
+ deterministic=deterministic,
916
+ output_attentions=output_attentions,
917
+ output_hidden_states=output_hidden_states,
918
+ freeze_feature_encoder=freeze_feature_encoder,
919
+ return_dict=return_dict,
920
+ )
921
+
922
+ hidden_states = outputs[0]
923
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
924
+
925
+ logits = self.lm_head(hidden_states)
926
+
927
+ if not return_dict:
928
+ return (logits,) + outputs[2:]
929
+
930
+ return FlaxCausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
931
+
932
+ def _get_feat_extract_output_lengths(
933
+ self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
934
+ ):
935
+ """
936
+ Computes the output length of the convolutional layers
937
+ """
938
+
939
+ add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
940
+
941
+ def _conv_out_length(input_length, kernel_size, stride):
942
+ # 1D convolutional layer output length formula taken
943
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
944
+ return (input_length - kernel_size) // stride + 1
945
+
946
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
947
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
948
+
949
+ if add_adapter:
950
+ for _ in range(self.config.num_adapter_layers):
951
+ input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
952
+
953
+ return input_lengths
954
+
955
+ def _get_feature_vector_attention_mask(
956
+ self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
957
+ ):
958
+ # Effectively attention_mask.sum(-1), but not inplace to be able to run
959
+ # on inference mode.
960
+ non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1]
961
+
962
+ output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
963
+
964
+ batch_size = attention_mask.shape[0]
965
+
966
+ attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
967
+ # these two operations makes sure that all values
968
+ # before the output lengths indices are attended to
969
+ attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1)
970
+ attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
971
+ return attention_mask
972
+
973
+
974
+ class FlaxWav2Vec2ForCTC(FlaxWav2Vec2PreTrainedModel):
975
+ module_class = FlaxWav2Vec2ForCTCModule
preprocessor_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "feature_extractor_type": "Wav2Vec2FeatureExtractor",
4
+ "feature_size": 1,
5
+ "padding_side": "right",
6
+ "padding_value": 0.0,
7
+ "processor_class": "Wav2Vec2Processor",
8
+ "return_attention_mask": true,
9
+ "sampling_rate": 16000
10
+ }
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>"}
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>", "pad_token": "<pad>", "do_lower_case": false, "word_delimiter_token": "|", "replace_word_delimiter_char": " ", "special_tokens_map_file": null, "name_or_path": "patrickvonplaten/wav2vec2_ctc_cv9_tokenizer", "tokenizer_class": "Wav2Vec2CTCTokenizer"}
vocab.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"<pad>": 0, "<s>": 1, "</s>": 2, "<unk>": 3, "?": 4, ".": 5, "n": 6, "e": 7, "s": 8, "h": 9, "'": 10, "l": 11, "u": 12, "!": 13, "i": 14, "g": 15, "o": 16, "c": 17, "t": 18, "b": 19, "\"": 20, "k": 21, "w": 22, "-": 23, "y": 24, "f": 26, "q": 27, "d": 28, "r": 29, "z": 30, "j": 31, "x": 32, "v": 33, ",": 34, "a": 35, "p": 36, "m": 37, "|": 25}