sanchit-gandhi HF staff commited on
Commit
24717ce
1 Parent(s): 6e767da

2gs2wia3: saving weights and logs of step 20k

Browse files
.gitattributes CHANGED
@@ -30,4 +30,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
30
  *.zip filter=lfs diff=lfs merge=lfs -text
31
  *.zst filter=lfs diff=lfs merge=lfs -text
32
  *tfevents* filter=lfs diff=lfs merge=lfs -text
33
- *.wandb filter=lfs diff=lfs merge=lfs -text
 
30
  *.zip filter=lfs diff=lfs merge=lfs -text
31
  *.zst filter=lfs diff=lfs merge=lfs -text
32
  *tfevents* filter=lfs diff=lfs merge=lfs -text
33
+ *.wandb filter=lfs diff=lfs merge=lfs -text
flax_model.msgpack CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6e1022cf390485f6c7fa8c82872726ddb949b08a6ae53bbb8cdc1469cd8073c1
3
  size 2353616717
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e272f049f32168ca65769e2732f2f3b5ad190de2e141847ce077760ba2a76314
3
  size 2353616717
models/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- from models.configuration_bart import BartConfig
2
- from models.configuration_wav2vec2 import Wav2Vec2Config
3
- from models.configuration_speech_encoder_decoder import SpeechEncoderDecoderConfig
4
- from models.modeling_flax_wav2vec2 import FlaxWav2Vec2Model, FlaxWav2Vec2Module, FlaxWav2Vec2ForCTC, FlaxWav2Vec2ForCTCModule
5
- from models.modeling_flax_bart import FlaxBartForCausalLM, FlaxBartForCausalLMModule
6
- from models.modeling_flax_speech_encoder_decoder import FlaxSpeechEncoderDecoderModel
 
 
 
 
 
 
 
models/__pycache__/__init__.cpython-38.pyc DELETED
Binary file (744 Bytes)
 
models/__pycache__/configuration_bart.cpython-38.pyc DELETED
Binary file (7.04 kB)
 
models/__pycache__/configuration_speech_encoder_decoder.cpython-38.pyc DELETED
Binary file (4.62 kB)
 
models/__pycache__/configuration_wav2vec2.cpython-38.pyc DELETED
Binary file (16.8 kB)
 
models/__pycache__/modeling_flax_bart.cpython-38.pyc DELETED
Binary file (21.1 kB)
 
models/__pycache__/modeling_flax_speech_encoder_decoder.cpython-38.pyc DELETED
Binary file (39.4 kB)
 
models/__pycache__/modeling_flax_wav2vec2.cpython-38.pyc DELETED
Binary file (30.6 kB)
 
models/configuration_bart.py DELETED
@@ -1,183 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """ BART model configuration"""
16
- import warnings
17
-
18
- from transformers.configuration_utils import PretrainedConfig
19
- from transformers.utils import logging
20
-
21
-
22
- logger = logging.get_logger(__name__)
23
-
24
- BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
25
- "facebook/bart-large": "https://huggingface.co/facebook/bart-large/resolve/main/config.json",
26
- # See all BART models at https://huggingface.co/models?filter=bart
27
- }
28
-
29
-
30
- class BartConfig(PretrainedConfig):
31
- r"""
32
- This is the configuration class to store the configuration of a [`BartModel`]. It is used to instantiate a BART
33
- model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
34
- defaults will yield a similar configuration to that of the BART
35
- [facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture.
36
-
37
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
38
- documentation from [`PretrainedConfig`] for more information.
39
-
40
-
41
- Args:
42
- vocab_size (`int`, *optional*, defaults to 50265):
43
- Vocabulary size of the BART model. Defines the number of different tokens that can be represented by the
44
- `inputs_ids` passed when calling [`BartModel`] or [`TFBartModel`].
45
- d_model (`int`, *optional*, defaults to 1024):
46
- Dimensionality of the layers and the pooler layer.
47
- encoder_layers (`int`, *optional*, defaults to 12):
48
- Number of encoder layers.
49
- decoder_layers (`int`, *optional*, defaults to 12):
50
- Number of decoder layers.
51
- encoder_attention_heads (`int`, *optional*, defaults to 16):
52
- Number of attention heads for each attention layer in the Transformer encoder.
53
- decoder_attention_heads (`int`, *optional*, defaults to 16):
54
- Number of attention heads for each attention layer in the Transformer decoder.
55
- decoder_ffn_dim (`int`, *optional*, defaults to 4096):
56
- Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
57
- encoder_ffn_dim (`int`, *optional*, defaults to 4096):
58
- Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
59
- activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
60
- The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
61
- `"relu"`, `"silu"` and `"gelu_new"` are supported.
62
- dropout (`float`, *optional*, defaults to 0.1):
63
- The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
64
- attention_dropout (`float`, *optional*, defaults to 0.0):
65
- The dropout ratio for the attention probabilities.
66
- activation_dropout (`float`, *optional*, defaults to 0.0):
67
- The dropout ratio for activations inside the fully connected layer.
68
- classifier_dropout (`float`, *optional*, defaults to 0.0):
69
- The dropout ratio for classifier.
70
- max_position_embeddings (`int`, *optional*, defaults to 1024):
71
- The maximum sequence length that this model might ever be used with. Typically set this to something large
72
- just in case (e.g., 512 or 1024 or 2048).
73
- init_std (`float`, *optional*, defaults to 0.02):
74
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
75
- encoder_layerdrop: (`float`, *optional*, defaults to 0.0):
76
- The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
77
- for more details.
78
- decoder_layerdrop: (`float`, *optional*, defaults to 0.0):
79
- The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
80
- for more details.
81
- scale_embedding (`bool`, *optional*, defaults to `False`):
82
- Scale embeddings by diving by sqrt(d_model).
83
- use_cache (`bool`, *optional*, defaults to `True`):
84
- Whether or not the model should return the last key/values attentions (not used by all models).
85
- num_labels: (`int`, *optional*, defaults to 3):
86
- The number of labels to use in [`BartForSequenceClassification`].
87
- forced_eos_token_id (`int`, *optional*, defaults to 2):
88
- The id of the token to force as the last generated token when `max_length` is reached. Usually set to
89
- `eos_token_id`.
90
- use_scan (`bool`, *optional*, defaults to `False`):
91
- Whether or not to use nn.scan in the Flax Bart attention layers.
92
-
93
- Example:
94
-
95
- ```python
96
- >>> from transformers import BartModel, BartConfig
97
-
98
- >>> # Initializing a BART facebook/bart-large style configuration
99
- >>> configuration = BartConfig()
100
-
101
- >>> # Initializing a model from the facebook/bart-large style configuration
102
- >>> model = BartModel(configuration)
103
-
104
- >>> # Accessing the model configuration
105
- >>> configuration = model.config
106
- ```"""
107
- model_type = "bart"
108
- keys_to_ignore_at_inference = ["past_key_values"]
109
- attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
110
-
111
- def __init__(
112
- self,
113
- vocab_size=50265,
114
- max_position_embeddings=1024,
115
- encoder_layers=12,
116
- encoder_ffn_dim=4096,
117
- encoder_attention_heads=16,
118
- decoder_layers=12,
119
- decoder_ffn_dim=4096,
120
- decoder_attention_heads=16,
121
- encoder_layerdrop=0.0,
122
- decoder_layerdrop=0.0,
123
- activation_function="gelu",
124
- d_model=1024,
125
- dropout=0.1,
126
- attention_dropout=0.0,
127
- activation_dropout=0.0,
128
- init_std=0.02,
129
- classifier_dropout=0.0,
130
- scale_embedding=False,
131
- use_cache=True,
132
- use_scan=False,
133
- fuse_matmuls=False,
134
- num_labels=3,
135
- pad_token_id=1,
136
- bos_token_id=0,
137
- eos_token_id=2,
138
- is_encoder_decoder=True,
139
- decoder_start_token_id=2,
140
- forced_eos_token_id=2,
141
- **kwargs
142
- ):
143
- self.vocab_size = vocab_size
144
- self.max_position_embeddings = max_position_embeddings
145
- self.d_model = d_model
146
- self.encoder_ffn_dim = encoder_ffn_dim
147
- self.encoder_layers = encoder_layers
148
- self.encoder_attention_heads = encoder_attention_heads
149
- self.decoder_ffn_dim = decoder_ffn_dim
150
- self.decoder_layers = decoder_layers
151
- self.decoder_attention_heads = decoder_attention_heads
152
- self.dropout = dropout
153
- self.attention_dropout = attention_dropout
154
- self.activation_dropout = activation_dropout
155
- self.activation_function = activation_function
156
- self.init_std = init_std
157
- self.encoder_layerdrop = encoder_layerdrop
158
- self.decoder_layerdrop = decoder_layerdrop
159
- self.classifier_dropout = classifier_dropout
160
- self.use_cache = use_cache
161
- self.use_scan = use_scan
162
- self.fuse_matmuls = fuse_matmuls
163
- self.num_hidden_layers = encoder_layers
164
- self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
165
-
166
- super().__init__(
167
- num_labels=num_labels,
168
- pad_token_id=pad_token_id,
169
- bos_token_id=bos_token_id,
170
- eos_token_id=eos_token_id,
171
- is_encoder_decoder=is_encoder_decoder,
172
- decoder_start_token_id=decoder_start_token_id,
173
- forced_eos_token_id=forced_eos_token_id,
174
- **kwargs,
175
- )
176
-
177
- # ensure backward compatibility for BART CNN models
178
- if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
179
- self.forced_bos_token_id = self.bos_token_id
180
- warnings.warn(
181
- f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
182
- "The config can simply be saved and uploaded again to be fixed."
183
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/configuration_speech_encoder_decoder.py DELETED
@@ -1,121 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2021 The HuggingFace Inc. team.
3
- # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
-
17
- import copy
18
-
19
- from transformers.configuration_utils import PretrainedConfig
20
- from transformers.utils import logging
21
- from models.configuration_wav2vec2 import Wav2Vec2Config
22
- from models.configuration_bart import BartConfig
23
- from transformers import AutoConfig
24
-
25
-
26
- logger = logging.get_logger(__name__)
27
-
28
-
29
- class SpeechEncoderDecoderConfig(PretrainedConfig):
30
- r"""
31
- [`SpeechEncoderDecoderConfig`] is the configuration class to store the configuration of a
32
- [`SpeechEncoderDecoderModel`]. It is used to instantiate an Encoder Decoder model according to the specified
33
- arguments, defining the encoder and decoder configs.
34
-
35
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
- documentation from [`PretrainedConfig`] for more information.
37
-
38
- Args:
39
- kwargs (*optional*):
40
- Dictionary of keyword arguments. Notably:
41
-
42
- - **encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
43
- the encoder config.
44
- - **decoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
45
- the decoder config.
46
-
47
- Examples:
48
-
49
- ```python
50
- >>> from transformers import BertConfig, Wav2Vec2Config, SpeechEncoderDecoderConfig, SpeechEncoderDecoderModel
51
-
52
- >>> # Initializing a Wav2Vec2 & BERT style configuration
53
- >>> config_encoder = Wav2Vec2Config()
54
- >>> config_decoder = BertConfig()
55
-
56
- >>> config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)
57
-
58
- >>> # Initializing a Wav2Vec2Bert model from a Wav2Vec2 & bert-base-uncased style configurations
59
- >>> model = SpeechEncoderDecoderModel(config=config)
60
-
61
- >>> # Accessing the model configuration
62
- >>> config_encoder = model.config.encoder
63
- >>> config_decoder = model.config.decoder
64
- >>> # set decoder config to causal lm
65
- >>> config_decoder.is_decoder = True
66
- >>> config_decoder.add_cross_attention = True
67
-
68
- >>> # Saving the model, including its configuration
69
- >>> model.save_pretrained("my-model")
70
-
71
- >>> # loading model and config from pretrained folder
72
- >>> encoder_decoder_config = SpeechEncoderDecoderConfig.from_pretrained("my-model")
73
- >>> model = SpeechEncoderDecoderModel.from_pretrained("my-model", config=encoder_decoder_config)
74
- ```"""
75
- model_type = "speech-encoder-decoder"
76
- is_composition = True
77
-
78
- def __init__(self, **kwargs):
79
- super().__init__(**kwargs)
80
- if "encoder" not in kwargs or "decoder" not in kwargs:
81
- raise ValueError(
82
- f"A configuraton of type {self.model_type} cannot be instantiated because not both `encoder` and `decoder` sub-configurations are passed, but only {kwargs}"
83
- )
84
-
85
- encoder_config = kwargs.pop("encoder")
86
- decoder_config = kwargs.pop("decoder")
87
-
88
- # TODO: Load configs from AutoConfig (as done in Transformers 🤗)
89
- self.encoder = Wav2Vec2Config(**encoder_config)
90
- self.decoder = BartConfig(**decoder_config)
91
- self.is_encoder_decoder = True
92
-
93
- @classmethod
94
- def from_encoder_decoder_configs(
95
- cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
96
- ) -> PretrainedConfig:
97
- r"""
98
- Instantiate a [`SpeechEncoderDecoderConfig`] (or a derived class) from a pre-trained encoder model
99
- configuration and decoder model configuration.
100
-
101
- Returns:
102
- [`SpeechEncoderDecoderConfig`]: An instance of a configuration object
103
- """
104
- logger.info("Setting `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
105
- decoder_config.is_decoder = True
106
- decoder_config.add_cross_attention = True
107
-
108
- return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)
109
-
110
- def to_dict(self):
111
- """
112
- Serializes this instance to a Python dictionary. Override the default *to_dict()* from *PretrainedConfig*.
113
-
114
- Returns:
115
- `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
116
- """
117
- output = copy.deepcopy(self.__dict__)
118
- output["encoder"] = self.encoder.to_dict()
119
- output["decoder"] = self.decoder.to_dict()
120
- output["model_type"] = self.__class__.model_type
121
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/configuration_wav2vec2.py DELETED
@@ -1,344 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """ Wav2Vec2 model configuration"""
16
-
17
- import functools
18
- import operator
19
-
20
- from transformers.configuration_utils import PretrainedConfig
21
- from transformers.utils import logging
22
-
23
-
24
- logger = logging.get_logger(__name__)
25
-
26
- WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
27
- "facebook/wav2vec2-base-960h": "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/config.json",
28
- # See all Wav2Vec2 models at https://huggingface.co/models?filter=wav2vec2
29
- }
30
-
31
-
32
- class Wav2Vec2Config(PretrainedConfig):
33
- r"""
34
- This is the configuration class to store the configuration of a [`Wav2Vec2Model`]. It is used to instantiate an
35
- Wav2Vec2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
36
- with the defaults will yield a similar configuration to that of the Wav2Vec2
37
- [facebook/wav2vec2-base-960h](https://huggingface.co/facebook/wav2vec2-base-960h) architecture.
38
-
39
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
40
- documentation from [`PretrainedConfig`] for more information.
41
-
42
-
43
- Args:
44
- vocab_size (`int`, *optional*, defaults to 32):
45
- Vocabulary size of the Wav2Vec2 model. Defines the number of different tokens that can be represented by
46
- the `inputs_ids` passed when calling [`Wav2Vec2Model`] or [`TFWav2Vec2Model`]. Vocabulary size of the
47
- model. Defines the different tokens that can be represented by the *inputs_ids* passed to the forward
48
- method of [`Wav2Vec2Model`].
49
- hidden_size (`int`, *optional*, defaults to 768):
50
- Dimensionality of the encoder layers and the pooler layer.
51
- num_hidden_layers (`int`, *optional*, defaults to 12):
52
- Number of hidden layers in the Transformer encoder.
53
- num_attention_heads (`int`, *optional*, defaults to 12):
54
- Number of attention heads for each attention layer in the Transformer encoder.
55
- intermediate_size (`int`, *optional*, defaults to 3072):
56
- Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
57
- hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
58
- The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
59
- `"relu"`, `"selu"` and `"gelu_new"` are supported.
60
- hidden_dropout (`float`, *optional*, defaults to 0.1):
61
- The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
62
- attention_dropout (`float`, *optional*, defaults to 0.1):
63
- The dropout ratio for the attention probabilities.
64
- final_dropout (`float`, *optional*, defaults to 0.1):
65
- The dropout probability for the final projection layer of [`Wav2Vec2ForCTC`].
66
- initializer_range (`float`, *optional*, defaults to 0.02):
67
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
68
- layer_norm_eps (`float`, *optional*, defaults to 1e-12):
69
- The epsilon used by the layer normalization layers.
70
- feat_extract_norm (`str`, *optional*, defaults to `"group"`):
71
- The norm to be applied to 1D convolutional layers in feature encoder. One of `"group"` for group
72
- normalization of only the first 1D convolutional layer or `"layer"` for layer normalization of all 1D
73
- convolutional layers.
74
- feat_proj_dropout (`float`, *optional*, defaults to 0.0):
75
- The dropout probability for output of the feature encoder.
76
- feat_extract_activation (`str, `optional`, defaults to `"gelu"`):
77
- The non-linear activation function (function or string) in the 1D convolutional layers of the feature
78
- extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
79
- feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
80
- The dropout probabilitiy for quantized feature encoder states.
81
- conv_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
82
- A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
83
- feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
84
- conv_stride (`Tuple[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
85
- A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
86
- of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*.
87
- conv_kernel (`Tuple[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
88
- A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
89
- length of *conv_kernel* defines the number of convolutional layers and has to match the length of
90
- *conv_dim*.
91
- conv_bias (`bool`, *optional*, defaults to `False`):
92
- Whether the 1D convolutional layers have a bias.
93
- num_conv_pos_embeddings (`int`, *optional*, defaults to 128):
94
- Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional
95
- embeddings layer.
96
- num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
97
- Number of groups of 1D convolutional positional embeddings layer.
98
- do_stable_layer_norm (`bool`, *optional*, defaults to `False`):
99
- Whether to apply *stable* layer norm architecture of the Transformer encoder. `do_stable_layer_norm is
100
- True` corresponds to applying layer norm before the attention layer, whereas `do_stable_layer_norm is
101
- False` corresponds to applying layer norm after the attention layer.
102
- apply_spec_augment (`bool`, *optional*, defaults to `True`):
103
- Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see
104
- [SpecAugment: A Simple Data Augmentation Method for Automatic Speech
105
- Recognition](https://arxiv.org/abs/1904.08779).
106
- mask_time_prob (`float`, *optional*, defaults to 0.05):
107
- Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking
108
- procecure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If
109
- reasoning from the propability of each feature vector to be chosen as the start of the vector span to be
110
- masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the
111
- actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`.
112
- mask_time_length (`int`, *optional*, defaults to 10):
113
- Length of vector span along the time axis.
114
- mask_time_min_masks (`int`, *optional*, defaults to 2),:
115
- The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step,
116
- irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length <
117
- mask_time_min_masks''
118
- mask_feature_prob (`float`, *optional*, defaults to 0.0):
119
- Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The
120
- masking procecure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over
121
- the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector
122
- span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap
123
- may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is
124
- True`.
125
- mask_feature_length (`int`, *optional*, defaults to 10):
126
- Length of vector span along the feature axis.
127
- mask_feature_min_masks (`int`, *optional*, defaults to 0),:
128
- The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
129
- step, irrespectively of `mask_feature_prob`. Only relevant if
130
- ''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks''
131
- num_codevectors_per_group (`int`, *optional*, defaults to 320):
132
- Number of entries in each quantization codebook (group).
133
- num_codevector_groups (`int`, *optional*, defaults to 2):
134
- Number of codevector groups for product codevector quantization.
135
- contrastive_logits_temperature (`float`, *optional*, defaults to 0.1):
136
- The temperature *kappa* in the contrastive loss.
137
- feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
138
- The dropout probabilitiy for the output of the feature encoder that's used by the quantizer.
139
- num_negatives (`int`, *optional*, defaults to 100):
140
- Number of negative samples for the contrastive loss.
141
- codevector_dim (`int`, *optional*, defaults to 256):
142
- Dimensionality of the quantized feature vectors.
143
- proj_codevector_dim (`int`, *optional*, defaults to 256):
144
- Dimensionality of the final projection of both the quantized and the transformer features.
145
- diversity_loss_weight (`int`, *optional*, defaults to 0.1):
146
- The weight of the codebook diversity loss component.
147
- ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`):
148
- Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
149
- instance of [`Wav2Vec2ForCTC`].
150
- ctc_zero_infinity (`bool`, *optional*, defaults to `False`):
151
- Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
152
- occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
153
- of [`Wav2Vec2ForCTC`].
154
- use_weighted_layer_sum (`bool`, *optional*, defaults to `False`):
155
- Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
156
- instance of [`Wav2Vec2ForSequenceClassification`].
157
- classifier_proj_size (`int`, *optional*, defaults to 256):
158
- Dimensionality of the projection before token mean-pooling for classification.
159
- tdnn_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
160
- A tuple of integers defining the number of output channels of each 1D convolutional layer in the *TDNN*
161
- module of the *XVector* model. The length of *tdnn_dim* defines the number of *TDNN* layers.
162
- tdnn_kernel (`Tuple[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
163
- A tuple of integers defining the kernel size of each 1D convolutional layer in the *TDNN* module of the
164
- *XVector* model. The length of *tdnn_kernel* has to match the length of *tdnn_dim*.
165
- tdnn_dilation (`Tuple[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
166
- A tuple of integers defining the dilation factor of each 1D convolutional layer in *TDNN* module of the
167
- *XVector* model. The length of *tdnn_dilation* has to match the length of *tdnn_dim*.
168
- xvector_output_dim (`int`, *optional*, defaults to 512):
169
- Dimensionality of the *XVector* embedding vectors.
170
- add_adapter (`bool`, *optional*, defaults to `False`):
171
- Whether a convolutional network should be stacked on top of the Wav2Vec2 Encoder. Can be very useful for
172
- warm-starting Wav2Vec2 for SpeechEncoderDecoder models.
173
- adapter_kernel_size (`int`, *optional*, defaults to 3):
174
- Kernel size of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
175
- adapter_stride (`int`, *optional*, defaults to 2):
176
- Stride of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
177
- num_adapter_layers (`int`, *optional*, defaults to 3):
178
- Number of convolutional layers that should be used in the adapter network. Only relevant if `add_adapter is
179
- True`.
180
- output_hidden_size (`int`, *optional*):
181
- Dimensionality of the encoder output layer. If not defined, this defaults to *hidden-size*. Only relevant
182
- if `add_adapter is True`.
183
- use_scan (`bool`, *optional*, defaults to `False`):
184
- Whether or not to use nn.scan in the Flax Wav2Vec2 transformer layers.
185
-
186
- Example:
187
-
188
- ```python
189
- >>> from transformers import Wav2Vec2Model, Wav2Vec2Config
190
-
191
- >>> # Initializing a Wav2Vec2 facebook/wav2vec2-base-960h style configuration
192
- >>> configuration = Wav2Vec2Config()
193
-
194
- >>> # Initializing a model from the facebook/wav2vec2-base-960h style configuration
195
- >>> model = Wav2Vec2Model(configuration)
196
-
197
- >>> # Accessing the model configuration
198
- >>> configuration = model.config
199
- ```"""
200
- model_type = "wav2vec2"
201
-
202
- def __init__(
203
- self,
204
- vocab_size=32,
205
- hidden_size=768,
206
- num_hidden_layers=12,
207
- num_attention_heads=12,
208
- intermediate_size=3072,
209
- hidden_act="gelu",
210
- hidden_dropout=0.1,
211
- activation_dropout=0.1,
212
- attention_dropout=0.1,
213
- feat_proj_dropout=0.0,
214
- feat_quantizer_dropout=0.0,
215
- final_dropout=0.1,
216
- layerdrop=0.1,
217
- initializer_range=0.02,
218
- layer_norm_eps=1e-5,
219
- feat_extract_norm="group",
220
- feat_extract_activation="gelu",
221
- conv_dim=(512, 512, 512, 512, 512, 512, 512),
222
- conv_stride=(5, 2, 2, 2, 2, 2, 2),
223
- conv_kernel=(10, 3, 3, 3, 3, 2, 2),
224
- conv_bias=False,
225
- num_conv_pos_embeddings=128,
226
- num_conv_pos_embedding_groups=16,
227
- do_stable_layer_norm=False,
228
- apply_spec_augment=True,
229
- mask_time_prob=0.05,
230
- mask_time_length=10,
231
- mask_time_min_masks=2,
232
- mask_feature_prob=0.0,
233
- mask_feature_length=10,
234
- mask_feature_min_masks=0,
235
- num_codevectors_per_group=320,
236
- num_codevector_groups=2,
237
- contrastive_logits_temperature=0.1,
238
- num_negatives=100,
239
- codevector_dim=256,
240
- proj_codevector_dim=256,
241
- diversity_loss_weight=0.1,
242
- ctc_loss_reduction="sum",
243
- ctc_zero_infinity=False,
244
- use_weighted_layer_sum=False,
245
- classifier_proj_size=256,
246
- tdnn_dim=(512, 512, 512, 512, 1500),
247
- tdnn_kernel=(5, 3, 3, 1, 1),
248
- tdnn_dilation=(1, 2, 3, 1, 1),
249
- xvector_output_dim=512,
250
- pad_token_id=0,
251
- bos_token_id=1,
252
- eos_token_id=2,
253
- add_adapter=False,
254
- adapter_kernel_size=3,
255
- adapter_stride=2,
256
- num_adapter_layers=3,
257
- output_hidden_size=None,
258
- use_scan=False,
259
- fuse_matmuls=False,
260
- **kwargs
261
- ):
262
- super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
263
- self.hidden_size = hidden_size
264
- self.feat_extract_norm = feat_extract_norm
265
- self.feat_extract_activation = feat_extract_activation
266
- self.conv_dim = list(conv_dim)
267
- self.conv_stride = list(conv_stride)
268
- self.conv_kernel = list(conv_kernel)
269
- self.conv_bias = conv_bias
270
- self.num_conv_pos_embeddings = num_conv_pos_embeddings
271
- self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
272
- self.num_feat_extract_layers = len(self.conv_dim)
273
- self.num_hidden_layers = num_hidden_layers
274
- self.intermediate_size = intermediate_size
275
- self.hidden_act = hidden_act
276
- self.num_attention_heads = num_attention_heads
277
- self.hidden_dropout = hidden_dropout
278
- self.attention_dropout = attention_dropout
279
- self.activation_dropout = activation_dropout
280
- self.feat_proj_dropout = feat_proj_dropout
281
- self.final_dropout = final_dropout
282
- self.layerdrop = layerdrop
283
- self.layer_norm_eps = layer_norm_eps
284
- self.initializer_range = initializer_range
285
- self.vocab_size = vocab_size
286
- self.do_stable_layer_norm = do_stable_layer_norm
287
- self.use_weighted_layer_sum = use_weighted_layer_sum
288
- self.use_scan = use_scan
289
- self.fuse_matmuls = fuse_matmuls
290
-
291
- if (
292
- (len(self.conv_stride) != self.num_feat_extract_layers)
293
- or (len(self.conv_kernel) != self.num_feat_extract_layers)
294
- or (len(self.conv_dim) != self.num_feat_extract_layers)
295
- ):
296
- raise ValueError(
297
- "Configuration for convolutional layers is incorrect. "
298
- "It is required that `len(config.conv_dim)` == `len(config.conv_stride)` == `len(config.conv_kernel)`, "
299
- f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride) "
300
- f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`."
301
- )
302
-
303
- # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
304
- self.apply_spec_augment = apply_spec_augment
305
- self.mask_time_prob = mask_time_prob
306
- self.mask_time_length = mask_time_length
307
- self.mask_time_min_masks = mask_time_min_masks
308
- self.mask_feature_prob = mask_feature_prob
309
- self.mask_feature_length = mask_feature_length
310
- self.mask_feature_min_masks = mask_feature_min_masks
311
-
312
- # parameters for pretraining with codevector quantized representations
313
- self.num_codevectors_per_group = num_codevectors_per_group
314
- self.num_codevector_groups = num_codevector_groups
315
- self.contrastive_logits_temperature = contrastive_logits_temperature
316
- self.feat_quantizer_dropout = feat_quantizer_dropout
317
- self.num_negatives = num_negatives
318
- self.codevector_dim = codevector_dim
319
- self.proj_codevector_dim = proj_codevector_dim
320
- self.diversity_loss_weight = diversity_loss_weight
321
-
322
- # ctc loss
323
- self.ctc_loss_reduction = ctc_loss_reduction
324
- self.ctc_zero_infinity = ctc_zero_infinity
325
-
326
- # adapter
327
- self.add_adapter = add_adapter
328
- self.adapter_kernel_size = adapter_kernel_size
329
- self.adapter_stride = adapter_stride
330
- self.num_adapter_layers = num_adapter_layers
331
- self.output_hidden_size = output_hidden_size or hidden_size
332
-
333
- # SequenceClassification-specific parameter. Feel free to ignore for other classes.
334
- self.classifier_proj_size = classifier_proj_size
335
-
336
- # XVector-specific parameters. Feel free to ignore for other classes.
337
- self.tdnn_dim = list(tdnn_dim)
338
- self.tdnn_kernel = list(tdnn_kernel)
339
- self.tdnn_dilation = list(tdnn_dilation)
340
- self.xvector_output_dim = xvector_output_dim
341
-
342
- @property
343
- def inputs_to_logits_ratio(self):
344
- return functools.reduce(operator.mul, self.conv_stride, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/modeling_flax_bart.py DELETED
@@ -1,816 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2021 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """ Flax Bart model."""
16
-
17
- import math
18
- import random
19
- from functools import partial
20
- from typing import Optional, Tuple
21
-
22
- import numpy as np
23
-
24
- import flax.linen as nn
25
- import jax
26
- import jax.numpy as jnp
27
- from flax.core.frozen_dict import FrozenDict, unfreeze
28
- from flax.linen import combine_masks, make_causal_mask
29
- from flax.linen import partitioning as nn_partitioning
30
- from flax.linen.attention import dot_product_attention_weights
31
- from jax import lax
32
- from jax.random import PRNGKey
33
-
34
- from transformers.modeling_flax_outputs import (
35
- FlaxBaseModelOutputWithPastAndCrossAttentions,
36
- FlaxCausalLMOutputWithCrossAttentions,
37
- )
38
- from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
39
-
40
- from models import BartConfig
41
-
42
-
43
- scan_with_axes = nn_partitioning.scan_with_axes
44
- remat = nn_partitioning.remat
45
-
46
-
47
- def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
48
- """
49
- Shift input ids one token to the right.
50
- """
51
- shifted_input_ids = np.zeros_like(input_ids)
52
- shifted_input_ids[:, 1:] = input_ids[:, :-1]
53
- shifted_input_ids[:, 0] = decoder_start_token_id
54
-
55
- shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
56
- return shifted_input_ids
57
-
58
-
59
- class FlaxBartAttention(nn.Module):
60
- config: BartConfig
61
- embed_dim: int
62
- num_heads: int
63
- dropout: float = 0.0
64
- causal: bool = False
65
- bias: bool = True
66
- dtype: jnp.dtype = jnp.float32 # the dtype of the computation
67
-
68
- def setup(self) -> None:
69
- self.head_dim = self.embed_dim // self.num_heads
70
- if self.head_dim * self.num_heads != self.embed_dim:
71
- raise ValueError(
72
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
73
- f" and `num_heads`: {self.num_heads})."
74
- )
75
-
76
- dense = partial(
77
- nn.Dense,
78
- self.embed_dim,
79
- use_bias=self.bias,
80
- dtype=self.dtype,
81
- kernel_init=jax.nn.initializers.normal(self.config.init_std),
82
- )
83
-
84
- self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
85
-
86
- self.fused_proj = nn.Dense(
87
- self.embed_dim * 3,
88
- use_bias=self.bias,
89
- dtype=self.dtype,
90
- kernel_init=jax.nn.initializers.normal(self.config.init_std),
91
- )
92
-
93
- self.fused_key_value = nn.Dense(
94
- self.embed_dim * 2,
95
- use_bias=self.bias,
96
- dtype=self.dtype,
97
- kernel_init=jax.nn.initializers.normal(self.config.init_std),
98
- )
99
-
100
- self.out_proj = dense()
101
-
102
- self.dropout_layer = nn.Dropout(rate=self.dropout)
103
-
104
- if self.causal:
105
- self.causal_mask = make_causal_mask(
106
- jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
107
- )
108
-
109
- def _split_heads(self, hidden_states):
110
- return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
111
-
112
- def _merge_heads(self, hidden_states):
113
- return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
114
-
115
- @nn.compact
116
- def _concatenate_to_cache(self, key, value, query, attention_mask):
117
- """
118
- This function takes projected key, value states from a single input token and concatenates the states to cached
119
- states from previous steps. This function is slighly adapted from the official Flax repository:
120
- https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
121
- """
122
- # detect if we're initializing by absence of existing cache data.
123
- is_initialized = self.has_variable("cache", "cached_key")
124
- cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
125
- cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
126
- cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
127
-
128
- if is_initialized:
129
- *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
130
- # update key, value caches with our new 1d spatial slices
131
- cur_index = cache_index.value
132
- indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
133
- key = lax.dynamic_update_slice(cached_key.value, key, indices)
134
- value = lax.dynamic_update_slice(cached_value.value, value, indices)
135
- cached_key.value = key
136
- cached_value.value = value
137
- num_updated_cache_vectors = query.shape[1]
138
- cache_index.value = cache_index.value + num_updated_cache_vectors
139
- # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
140
- pad_mask = jnp.broadcast_to(
141
- jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
142
- tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
143
- )
144
- attention_mask = combine_masks(pad_mask, attention_mask)
145
- return key, value, attention_mask
146
-
147
- def __call__(
148
- self,
149
- hidden_states: jnp.ndarray,
150
- key_value_states: Optional[jnp.ndarray] = None,
151
- attention_mask: Optional[jnp.ndarray] = None,
152
- init_cache: bool = False,
153
- deterministic: bool = True,
154
- ) -> Tuple[jnp.ndarray]:
155
- """Input shape: Batch x Time x Channel"""
156
-
157
- # if key_value_states are provided this layer is used as a cross-attention layer
158
- # for the decoder
159
- is_cross_attention = key_value_states is not None
160
- batch_size = hidden_states.shape[0]
161
-
162
- if self.config.fuse_matmuls:
163
- # get key, value proj
164
- if is_cross_attention:
165
- # get query proj
166
- query_states = self.q_proj(hidden_states)
167
- # cross_attentions
168
- attention_states = self.fused_key_value(key_value_states)
169
- key_states, value_states = jnp.split(attention_states, 2, axis=-1)
170
- else:
171
- attention_states = self.fused_proj(hidden_states)
172
- query_states, key_states, value_states = jnp.split(attention_states, 3, axis=-1)
173
-
174
- else:
175
- # get query proj
176
- query_states = self.q_proj(hidden_states)
177
- # get key, value proj
178
- if is_cross_attention:
179
- # cross_attentions
180
- key_states = self.k_proj(key_value_states)
181
- value_states = self.v_proj(key_value_states)
182
- else:
183
- # self_attention
184
- key_states = self.k_proj(hidden_states)
185
- value_states = self.v_proj(hidden_states)
186
-
187
- query_states = self._split_heads(query_states)
188
- key_states = self._split_heads(key_states)
189
- value_states = self._split_heads(value_states)
190
-
191
- # handle cache prepare causal attention mask
192
- if self.causal:
193
- query_length, key_length = query_states.shape[1], key_states.shape[1]
194
- if self.has_variable("cache", "cached_key"):
195
- mask_shift = self.variables["cache"]["cache_index"]
196
- max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
197
- causal_mask = lax.dynamic_slice(
198
- self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
199
- )
200
- else:
201
- causal_mask = self.causal_mask[:, :, :query_length, :key_length]
202
- causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
203
-
204
- # combine masks if needed
205
- if attention_mask is not None and self.causal:
206
- attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
207
- attention_mask = combine_masks(attention_mask, causal_mask)
208
- elif self.causal:
209
- attention_mask = causal_mask
210
- elif attention_mask is not None:
211
- attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
212
-
213
- # During fast autoregressive decoding, we feed one position at a time,
214
- # and cache the keys and values step by step.
215
- if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
216
- key_states, value_states, attention_mask = self._concatenate_to_cache(
217
- key_states, value_states, query_states, attention_mask
218
- )
219
-
220
- # Convert the boolean attention mask to an attention bias.
221
- if attention_mask is not None:
222
- # attention mask in the form of attention bias
223
- attention_bias = lax.select(
224
- attention_mask > 0,
225
- jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
226
- jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
227
- )
228
- else:
229
- attention_bias = None
230
-
231
- dropout_rng = None
232
- if not deterministic and self.dropout > 0.0:
233
- dropout_rng = self.make_rng("dropout")
234
-
235
- attn_weights = dot_product_attention_weights(
236
- query_states,
237
- key_states,
238
- bias=attention_bias,
239
- dropout_rng=dropout_rng,
240
- dropout_rate=self.dropout,
241
- broadcast_dropout=True,
242
- deterministic=deterministic,
243
- dtype=self.dtype,
244
- precision=None,
245
- )
246
-
247
- attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
248
- attn_output = self._merge_heads(attn_output)
249
- attn_output = self.out_proj(attn_output)
250
-
251
- return attn_output, attn_weights
252
-
253
-
254
- class FlaxBartDecoderLayer(nn.Module):
255
- config: BartConfig
256
- dtype: jnp.dtype = jnp.float32
257
-
258
- def setup(self) -> None:
259
- self.embed_dim = self.config.d_model
260
- self.self_attn = FlaxBartAttention(
261
- config=self.config,
262
- embed_dim=self.embed_dim,
263
- num_heads=self.config.decoder_attention_heads,
264
- dropout=self.config.attention_dropout,
265
- causal=True,
266
- dtype=self.dtype,
267
- )
268
- self.dropout_layer = nn.Dropout(rate=self.config.dropout)
269
- self.activation_fn = ACT2FN[self.config.activation_function]
270
- self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
271
-
272
- self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
273
- self.encoder_attn = FlaxBartAttention(
274
- config=self.config,
275
- embed_dim=self.embed_dim,
276
- num_heads=self.config.decoder_attention_heads,
277
- dropout=self.config.attention_dropout,
278
- dtype=self.dtype,
279
- )
280
- self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
281
- self.fc1 = nn.Dense(
282
- self.config.encoder_ffn_dim,
283
- dtype=self.dtype,
284
- kernel_init=jax.nn.initializers.normal(self.config.init_std),
285
- )
286
- self.fc2 = nn.Dense(
287
- self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
288
- )
289
- self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
290
-
291
- def __call__(
292
- self,
293
- hidden_states: jnp.ndarray,
294
- attention_mask: jnp.ndarray,
295
- encoder_hidden_states: Optional[jnp.ndarray] = None,
296
- encoder_attention_mask: Optional[jnp.ndarray] = None,
297
- init_cache: bool = False,
298
- output_attentions: bool = True,
299
- deterministic: bool = True,
300
- ) -> Tuple[jnp.ndarray]:
301
-
302
- if self.config.use_scan:
303
- hidden_states = hidden_states[0]
304
-
305
- residual = hidden_states
306
-
307
- # Self Attention
308
- hidden_states, self_attn_weights = self.self_attn(
309
- hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache
310
- )
311
- hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
312
- hidden_states = residual + hidden_states
313
- hidden_states = self.self_attn_layer_norm(hidden_states)
314
-
315
- # Cross-Attention Block
316
- cross_attn_weights = None
317
- if encoder_hidden_states is not None:
318
- residual = hidden_states
319
-
320
- hidden_states, cross_attn_weights = self.encoder_attn(
321
- hidden_states=hidden_states,
322
- key_value_states=encoder_hidden_states,
323
- attention_mask=encoder_attention_mask,
324
- )
325
- hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
326
- hidden_states = residual + hidden_states
327
- hidden_states = self.encoder_attn_layer_norm(hidden_states)
328
-
329
- # Fully Connected
330
- residual = hidden_states
331
- hidden_states = self.activation_fn(self.fc1(hidden_states))
332
- hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
333
- hidden_states = self.fc2(hidden_states)
334
- hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
335
- hidden_states = residual + hidden_states
336
- hidden_states = self.final_layer_norm(hidden_states)
337
-
338
- outputs = (hidden_states,)
339
-
340
- if output_attentions:
341
- outputs += (self_attn_weights, cross_attn_weights)
342
-
343
- if self.config.use_scan:
344
- outputs = (outputs, None)
345
-
346
- return outputs
347
-
348
-
349
- class FlaxBartDecoderLayerCollection(nn.Module):
350
- config: BartConfig
351
- dtype: jnp.dtype = jnp.float32 # the dtype of the computation
352
-
353
- @nn.compact
354
- def __call__(
355
- self,
356
- hidden_states,
357
- attention_mask,
358
- encoder_hidden_states: Optional[jnp.ndarray] = None,
359
- encoder_attention_mask: Optional[jnp.ndarray] = None,
360
- deterministic: bool = True,
361
- init_cache: bool = False,
362
- output_attentions: bool = False,
363
- output_hidden_states: bool = False,
364
- return_dict: bool = True,
365
- ):
366
- # decoder layers
367
- all_hidden_states = () if output_hidden_states else None
368
- all_self_attns = () if output_attentions else None
369
- all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
370
-
371
- num_decoder_layers = self.config.decoder_layers
372
- BlockDecoderLayer = (
373
- remat(
374
- FlaxBartDecoderLayer,
375
- static_argnums=(4, 5, 6),
376
- prevent_cse=not self.config.use_scan,
377
- )
378
- if self.config.gradient_checkpointing
379
- else FlaxBartDecoderLayer
380
- )
381
-
382
- if self.config.use_scan:
383
- # since all decoder layers are the same, we use nn.scan directly
384
- assert not output_attentions, "cannot use `scan` with `output_attentions` set to `True`"
385
- assert not output_hidden_states, "cannot use `scan` with `output_hidden_states` set to `True`"
386
- hidden_states = (hidden_states,)
387
-
388
- # TODO: add layerdrop in checkpointed scan (note: default value for layerdrop in config is zero)
389
- hidden_states, _ = scan_with_axes(
390
- BlockDecoderLayer,
391
- variable_axes={"params": 0, "cache": 0},
392
- split_rngs={"params": True, "dropout": True},
393
- in_axes=(nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast),
394
- length=num_decoder_layers,
395
- )(self.config, dtype=self.dtype, name="FlaxBartDecoderLayers")(
396
- hidden_states,
397
- attention_mask,
398
- encoder_hidden_states,
399
- encoder_attention_mask,
400
- init_cache,
401
- output_attentions,
402
- deterministic,
403
- )
404
- hidden_states = hidden_states[0]
405
-
406
- else:
407
- for layer in range(num_decoder_layers):
408
- if output_hidden_states:
409
- all_hidden_states += (hidden_states,)
410
- # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
411
- dropout_probability = random.uniform(0, 1)
412
- if not deterministic and (dropout_probability < self.config.decoder_layerdrop):
413
- layer_outputs = (None, None, None)
414
- else:
415
- layer_outputs = BlockDecoderLayer(self.config, dtype=self.dtype, name=str(layer),)(
416
- hidden_states,
417
- attention_mask,
418
- encoder_hidden_states,
419
- encoder_attention_mask,
420
- init_cache,
421
- output_attentions,
422
- deterministic,
423
- )
424
-
425
- hidden_states = layer_outputs[0]
426
- if output_attentions:
427
- all_self_attns += (layer_outputs[1],)
428
-
429
- if encoder_hidden_states is not None:
430
- all_cross_attentions += (layer_outputs[2],)
431
-
432
- # add hidden states from the last decoder layer
433
- if output_hidden_states:
434
- all_hidden_states += (hidden_states,)
435
-
436
- outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions]
437
-
438
- if not return_dict:
439
- return tuple(v for v in outputs if v is not None)
440
-
441
- return FlaxBaseModelOutputWithPastAndCrossAttentions(
442
- last_hidden_state=hidden_states,
443
- hidden_states=all_hidden_states,
444
- attentions=all_self_attns,
445
- cross_attentions=all_cross_attentions,
446
- )
447
-
448
-
449
- class FlaxBartDecoder(nn.Module):
450
- config: BartConfig
451
- embed_tokens: nn.Embed
452
- dtype: jnp.dtype = jnp.float32 # the dtype of the computation
453
-
454
- def setup(self):
455
- self.dropout_layer = nn.Dropout(rate=self.config.dropout)
456
-
457
- embed_dim = self.config.d_model
458
- self.padding_idx = self.config.pad_token_id
459
- self.max_target_positions = self.config.max_position_embeddings
460
- self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
461
-
462
- # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
463
- # and adjust num_embeddings appropriately. Other models don't have this hack
464
- self.offset = 2
465
- self.embed_positions = nn.Embed(
466
- self.config.max_position_embeddings + self.offset,
467
- embed_dim,
468
- embedding_init=jax.nn.initializers.normal(self.config.init_std),
469
- )
470
-
471
- self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
472
- self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
473
-
474
- def __call__(
475
- self,
476
- input_ids,
477
- attention_mask,
478
- position_ids,
479
- encoder_hidden_states: Optional[jnp.ndarray] = None,
480
- encoder_attention_mask: Optional[jnp.ndarray] = None,
481
- init_cache: bool = False,
482
- output_attentions: bool = False,
483
- output_hidden_states: bool = False,
484
- return_dict: bool = True,
485
- deterministic: bool = True,
486
- ):
487
- input_shape = input_ids.shape
488
- input_ids = input_ids.reshape(-1, input_shape[-1])
489
-
490
- inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
491
-
492
- # embed positions
493
- positions = self.embed_positions(position_ids + self.offset)
494
-
495
- hidden_states = inputs_embeds + positions
496
- hidden_states = self.layernorm_embedding(hidden_states)
497
-
498
- hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
499
-
500
- outputs = self.layers(
501
- hidden_states,
502
- attention_mask,
503
- encoder_hidden_states,
504
- encoder_attention_mask,
505
- deterministic=deterministic,
506
- init_cache=init_cache,
507
- output_attentions=output_attentions,
508
- output_hidden_states=output_hidden_states,
509
- return_dict=return_dict,
510
- )
511
-
512
- if not return_dict:
513
- return outputs
514
-
515
- return FlaxBaseModelOutputWithPastAndCrossAttentions(
516
- last_hidden_state=outputs.last_hidden_state,
517
- hidden_states=outputs.hidden_states,
518
- attentions=outputs.attentions,
519
- cross_attentions=outputs.cross_attentions,
520
- )
521
-
522
-
523
- class FlaxBartDecoderPreTrainedModel(FlaxPreTrainedModel):
524
- config_class = BartConfig
525
- base_model_prefix: str = "model"
526
- module_class: nn.Module = None
527
-
528
- def __init__(
529
- self,
530
- config: BartConfig,
531
- input_shape: Tuple[int] = (1, 1),
532
- seed: int = 0,
533
- dtype: jnp.dtype = jnp.float32,
534
- _do_init: bool = True,
535
- **kwargs
536
- ):
537
- config.is_decoder = True
538
- config.is_encoder_decoder = False
539
- module = self.module_class(config=config, dtype=dtype, **kwargs)
540
- super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
541
-
542
- def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
543
- # init input tensors
544
- input_ids = jnp.zeros(input_shape, dtype="i4")
545
- attention_mask = jnp.ones_like(input_ids)
546
-
547
- batch_size, sequence_length = input_ids.shape
548
- position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
549
-
550
- params_rng, dropout_rng = jax.random.split(rng)
551
- rngs = {"params": params_rng, "dropout": dropout_rng}
552
- encoder_hidden_states = jnp.zeros(input_shape + (self.config.d_model,))
553
- encoder_attention_mask = attention_mask
554
- module_init_outputs = self.module.init(
555
- rngs,
556
- input_ids,
557
- attention_mask,
558
- position_ids,
559
- encoder_hidden_states,
560
- encoder_attention_mask,
561
- return_dict=False,
562
- )
563
- return module_init_outputs["params"]
564
-
565
- def init_cache(self, batch_size, max_length):
566
- r"""
567
- Args:
568
- batch_size (`int`):
569
- batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
570
- max_length (`int`):
571
- maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
572
- cache.
573
- """
574
- # init input variables to retrieve cache
575
- input_ids = jnp.ones((batch_size, max_length), dtype="i4")
576
- attention_mask = jnp.ones_like(input_ids, dtype="i4")
577
- position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
578
-
579
- init_variables = self.module.init(
580
- jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
581
- )
582
- return unfreeze(init_variables["cache"])
583
-
584
- def __call__(
585
- self,
586
- input_ids: jnp.ndarray,
587
- attention_mask: Optional[jnp.ndarray] = None,
588
- position_ids: Optional[jnp.ndarray] = None,
589
- encoder_hidden_states: Optional[jnp.ndarray] = None,
590
- encoder_attention_mask: Optional[jnp.ndarray] = None,
591
- output_attentions: Optional[bool] = None,
592
- output_hidden_states: Optional[bool] = None,
593
- return_dict: Optional[bool] = None,
594
- train: bool = False,
595
- params: dict = None,
596
- past_key_values: dict = None,
597
- dropout_rng: PRNGKey = None,
598
- ):
599
- """
600
- Args:
601
- input_ids (`jnp.ndarray` of shape `(target_batch_size, target_sequence_length)`):
602
- Indices of decoder input sequence tokens in the vocabulary.
603
-
604
- Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
605
- [`PreTrainedTokenizer.__call__`] for details.
606
-
607
- [What are decoder input IDs?](../glossary#decoder-input-ids)
608
-
609
- For translation and summarization training, `decoder_input_ids` should be provided. If no
610
- `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
611
- for denoising pre-training following the paper.
612
- attention_mask (`jnp.ndarray` of shape `(target_batch_size, target_sequence_length)`, *optional*):
613
- Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
614
- be used by default.
615
-
616
- If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the
617
- paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
618
- position_ids (`numpy.ndarray` of shape `(target_batch_size, sequence_length)`, *optional*):
619
- Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
620
- range `[0, config.max_position_embeddings - 1]`.
621
- encoder_hidden_states (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
622
- A sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
623
- encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
624
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
625
-
626
- - 1 for tokens that are **not masked**,
627
- - 0 for tokens that are **masked**.
628
-
629
- [What are attention masks?](../glossary#attention-mask)
630
- output_attentions (`bool`, *optional*):
631
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
632
- tensors for more detail.
633
- output_hidden_states (`bool`, *optional*):
634
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
635
- more detail.
636
- return_dict (`bool`, *optional*):
637
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
638
- past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
639
- Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
640
- auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
641
- """
642
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
643
- output_hidden_states = (
644
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
645
- )
646
- return_dict = return_dict if return_dict is not None else self.config.return_dict
647
-
648
- if encoder_hidden_states is not None and encoder_attention_mask is None:
649
- batch_size, sequence_length = encoder_hidden_states.shape[:2]
650
- encoder_attention_mask = jnp.ones((batch_size, sequence_length))
651
-
652
- # prepare decoder inputs
653
- if attention_mask is None:
654
- attention_mask = jnp.ones_like(input_ids)
655
- if position_ids is None:
656
- batch_size, sequence_length = input_ids.shape
657
- position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
658
-
659
- # Handle any PRNG if needed
660
- rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
661
-
662
- inputs = {"params": params or self.params}
663
-
664
- # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
665
- # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
666
- # changed by FlaxBartAttention module
667
- if past_key_values:
668
- inputs["cache"] = past_key_values
669
- mutable = ["cache"]
670
- else:
671
- mutable = False
672
-
673
- outputs = self.module.apply(
674
- inputs,
675
- input_ids=jnp.array(input_ids, dtype="i4"),
676
- attention_mask=jnp.array(attention_mask, dtype="i4"),
677
- position_ids=jnp.array(position_ids, dtype="i4"),
678
- encoder_hidden_states=encoder_hidden_states,
679
- encoder_attention_mask=encoder_attention_mask,
680
- output_attentions=output_attentions,
681
- output_hidden_states=output_hidden_states,
682
- return_dict=return_dict,
683
- deterministic=not train,
684
- rngs=rngs,
685
- mutable=mutable,
686
- )
687
-
688
- # add updated cache to model output
689
- if past_key_values is not None and return_dict:
690
- outputs, past_key_values = outputs
691
- outputs["past_key_values"] = unfreeze(past_key_values["cache"])
692
- return outputs
693
- elif past_key_values is not None and not return_dict:
694
- outputs, past_key_values = outputs
695
- outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
696
-
697
- return outputs
698
-
699
-
700
- class FlaxBartDecoderWrapper(nn.Module):
701
- """
702
- This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
703
- used in combination with the [`EncoderDecoderModel`] framework.
704
- """
705
-
706
- config: BartConfig
707
- dtype: jnp.dtype = jnp.float32
708
-
709
- def setup(self):
710
- embed_dim = self.config.d_model
711
- embed_tokens = nn.Embed(
712
- self.config.vocab_size,
713
- embed_dim,
714
- embedding_init=jax.nn.initializers.normal(self.config.init_std),
715
- )
716
- self.decoder = FlaxBartDecoder(config=self.config, embed_tokens=embed_tokens, dtype=self.dtype)
717
-
718
- def __call__(self, *args, **kwargs):
719
- return self.decoder(*args, **kwargs)
720
-
721
-
722
- class FlaxBartForCausalLMModule(nn.Module):
723
- """Bart Decoder Module with a language modeling head on top (linear layer with weights tied to the input embeddings)
724
- e.g. for autoregressive tasks.
725
- """
726
-
727
- config: BartConfig
728
- dtype: jnp.dtype = jnp.float32
729
-
730
- def setup(self):
731
- self.model = FlaxBartDecoderWrapper(config=self.config, dtype=self.dtype)
732
- self.lm_head = nn.Dense(
733
- self.config.vocab_size,
734
- use_bias=False,
735
- dtype=self.dtype,
736
- kernel_init=jax.nn.initializers.normal(self.config.init_std),
737
- )
738
-
739
- def __call__(
740
- self,
741
- input_ids,
742
- attention_mask,
743
- position_ids,
744
- encoder_hidden_states: Optional[jnp.ndarray] = None,
745
- encoder_attention_mask: Optional[jnp.ndarray] = None,
746
- init_cache: bool = False,
747
- output_attentions: bool = False,
748
- output_hidden_states: bool = False,
749
- return_dict: bool = True,
750
- deterministic: bool = True,
751
- ):
752
-
753
- outputs = self.model(
754
- input_ids,
755
- attention_mask,
756
- position_ids,
757
- encoder_hidden_states,
758
- encoder_attention_mask,
759
- deterministic=deterministic,
760
- init_cache=init_cache,
761
- output_attentions=output_attentions,
762
- output_hidden_states=output_hidden_states,
763
- return_dict=return_dict,
764
- )
765
-
766
- hidden_states = outputs[0]
767
-
768
- if self.config.tie_word_embeddings:
769
- shared_embedding = self.model.variables["params"]["decoder"]["embed_tokens"]["embedding"]
770
- lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
771
- else:
772
- lm_logits = self.lm_head(hidden_states)
773
-
774
- if not return_dict:
775
- return (lm_logits,) + outputs[1:]
776
-
777
- return FlaxCausalLMOutputWithCrossAttentions(
778
- logits=lm_logits,
779
- hidden_states=outputs.hidden_states,
780
- attentions=outputs.attentions,
781
- cross_attentions=outputs.cross_attentions,
782
- )
783
-
784
-
785
- class FlaxBartForCausalLM(FlaxBartDecoderPreTrainedModel):
786
- """Bart Decoder Model with a language modeling head on top (linear layer with weights tied to the input embeddings)
787
- e.g. for autoregressive tasks.
788
- """
789
-
790
- module_class = FlaxBartForCausalLMModule
791
-
792
- def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
793
- # initializing the cache
794
- batch_size, seq_length = input_ids.shape
795
-
796
- past_key_values = self.init_cache(batch_size, max_length)
797
- # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
798
- # But since the decoder uses a causal mask, those positions are masked anyway.
799
- # Thus, we can create a single static attention_mask here, which is more efficient for compilation
800
- extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
801
- if attention_mask is not None:
802
- position_ids = attention_mask.cumsum(axis=-1) - 1
803
- extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
804
- else:
805
- position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
806
-
807
- return {
808
- "past_key_values": past_key_values,
809
- "attention_mask": extended_attention_mask,
810
- "position_ids": position_ids,
811
- }
812
-
813
- def update_inputs_for_generation(self, model_outputs, model_kwargs):
814
- model_kwargs["past_key_values"] = model_outputs.past_key_values
815
- model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
816
- return model_kwargs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/modeling_flax_speech_encoder_decoder.py DELETED
@@ -1,1245 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2022 The HuggingFace Inc. team.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """ Classes to support Flax Speech-Encoder-Decoder architectures"""
16
-
17
- import os
18
- from functools import partial
19
- from typing import Optional, Tuple, Union, Dict
20
-
21
- import flax
22
- import flax.linen as nn
23
- import jax
24
- import jax.numpy as jnp
25
- from flax.core.frozen_dict import FrozenDict, unfreeze
26
- from jax import lax
27
- from jax.random import PRNGKey
28
- import numpy as np
29
-
30
- from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput
31
- from transformers.modeling_flax_utils import FlaxPreTrainedModel
32
- from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, ModelOutput
33
- from transformers.generation_flax_utils import FlaxLogitsProcessorList
34
- from models import (
35
- FlaxWav2Vec2Model,
36
- FlaxWav2Vec2Module,
37
- FlaxBartForCausalLM,
38
- FlaxBartForCausalLMModule,
39
- BartConfig,
40
- Wav2Vec2Config,
41
- SpeechEncoderDecoderConfig,
42
- )
43
-
44
- logger = logging.get_logger(__name__)
45
-
46
- _CONFIG_FOR_DOC = "SpeechEncoderDecoderConfig"
47
-
48
- SPEECH_ENCODER_DECODER_START_DOCSTRING = r"""
49
- This class can be used to initialize a speech-sequence-to-text-sequence model with any pretrained speech
50
- autoencoding model as the encoder and any pretrained text autoregressive model as the decoder. The encoder is
51
- loaded via [`~AutoModel.from_pretrained`] function and the decoder is loaded via
52
- [`~AutoModelForCausalLM.from_pretrained`] function. Cross-attention layers are automatically added to the decoder
53
- and should be fine-tuned on a downstream generative task, like summarization.
54
-
55
- The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation
56
- tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation
57
- Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi
58
- Zhou, Wei Li, Peter J. Liu.
59
-
60
- Additionally, in [Large-Scale Self- and Semi-Supervised Learning for Speech
61
- Translation](https://arxiv.org/abs/2104.06678) it is shown how leveraging large pretrained speech models for speech
62
- translation yields a significant performance improvement.
63
-
64
- After such an Speech-Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other
65
- models (see the examples for more information).
66
-
67
- This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
68
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
69
- etc.)
70
-
71
- This model is also a Flax Linen
72
- [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
73
- regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
74
-
75
- Parameters:
76
- config ([`SpeechEncoderDecoderConfig`]): Model configuration class with all the parameters of the model.
77
- Initializing with a config file does not load the weights associated with the model, only the
78
- configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
79
- dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
80
- The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
81
- `jax.numpy.bfloat16` (on TPUs).
82
-
83
- This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
84
- specified all the computation will be performed with the given `dtype`.
85
-
86
- **Note that this only specifies the dtype of the computation and does not influence the dtype of model
87
- parameters.**
88
-
89
- If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
90
- [`~FlaxPreTrainedModel.to_bf16`].
91
- """
92
-
93
- SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
94
- Args:
95
- inputs (`jnp.ndarray` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*):
96
- Float values of input raw speech waveform or speech features. Values can be obtained by loading a *.flac*
97
- or *.wav* audio file into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile
98
- library (*pip install soundfile*). To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or
99
- [`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type
100
- *torch.FloatTensor*.
101
- attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
102
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
103
-
104
- - 1 for tokens that are **not masked**,
105
- - 0 for tokens that are **masked**.
106
-
107
- [What are attention masks?](../glossary#attention-mask)
108
- decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
109
- Indices of decoder input sequence tokens in the vocabulary.
110
-
111
- Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
112
- [`PreTrainedTokenizer.__call__`] for details.
113
-
114
- [What are input IDs?](../glossary#input-ids)
115
-
116
- If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
117
- `past_key_values`).
118
-
119
- For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
120
- created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
121
- and prepending them with the `decoder_start_token_id`.
122
- decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
123
- Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
124
- be used by default.
125
- decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
126
- Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
127
- range `[0, config.decoder.max_position_embeddings - 1]`.
128
- output_hidden_states (`bool`, *optional*):
129
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
130
- more detail.
131
- return_dict (`bool`, *optional*):
132
- If set to `True`, the model will return a [`~utils.FlaxSeq2SeqLMOutput`] instead of a plain tuple.
133
- """
134
-
135
- SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING = r"""
136
- Args:
137
- inputs (`jnp.ndarray` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*):
138
- Float values of input raw speech waveform or speech features. Values can be obtained by loading a *.flac*
139
- or *.wav* audio file into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile
140
- library (*pip install soundfile*). To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or
141
- [`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type
142
- *torch.FloatTensor*.
143
- attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
144
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
145
-
146
- - 1 for tokens that are **not masked**,
147
- - 0 for tokens that are **masked**.
148
-
149
- [What are attention masks?](../glossary#attention-mask)
150
- output_attentions (`bool`, *optional*):
151
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
152
- tensors for more detail.
153
- output_hidden_states (`bool`, *optional*):
154
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
155
- more detail.
156
- return_dict (`bool`, *optional*):
157
- If set to `True`, the model will return a [`~utils.FlaxBaseModelOutput`] instead of a plain tuple.
158
- """
159
-
160
- SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r"""
161
- Args:
162
- decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
163
- Indices of decoder input sequence tokens in the vocabulary.
164
-
165
- Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
166
- [`PreTrainedTokenizer.__call__`] for details.
167
-
168
- [What are decoder input IDs?](../glossary#decoder-input-ids)
169
-
170
- If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
171
- `past_key_values`).
172
-
173
- For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
174
- created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
175
- and prepending them with the `decoder_start_token_id`.
176
- encoder_outputs (`tuple(tuple(jnp.ndarray)`):
177
- Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
178
- `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
179
- hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
180
- encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
181
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
182
-
183
- - 1 for tokens that are **not masked**,
184
- - 0 for tokens that are **masked**.
185
-
186
- [What are attention masks?](../glossary#attention-mask)
187
- decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
188
- Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
189
- be used by default.
190
- decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
191
- Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
192
- range `[0, config.decoder.max_position_embeddings - 1]`.
193
- past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
194
- Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
195
- auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
196
- output_attentions (`bool`, *optional*):
197
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
198
- tensors for more detail.
199
- output_hidden_states (`bool`, *optional*):
200
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
201
- more detail.
202
- return_dict (`bool`, *optional*):
203
- If set to `True`, the model will return a [`~utils.FlaxCausalLMOutputWithCrossAttentions`] instead of a
204
- plain tuple.
205
- """
206
-
207
- @flax.struct.dataclass
208
- class FlaxBeamSearchOutput(ModelOutput):
209
- """
210
- Flax Base class for outputs of decoder-only generation models using greedy search.
211
-
212
-
213
- Args:
214
- sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
215
- The generated sequences.
216
- scores (`jnp.ndarray` of shape `(batch_size,)`):
217
- The scores (log probabilites) of the generated sequences.
218
- """
219
-
220
- sequences: jnp.ndarray = None
221
- scores: jnp.ndarray = None
222
-
223
-
224
- @flax.struct.dataclass
225
- class BeamSearchState:
226
- cur_len: jnp.ndarray
227
- running_sequences: jnp.ndarray
228
- running_scores: jnp.ndarray
229
- sequences: jnp.ndarray
230
- scores: jnp.ndarray
231
- is_sent_finished: jnp.ndarray
232
- model_kwargs: Dict[str, jnp.ndarray]
233
-
234
-
235
-
236
-
237
- class FlaxSpeechEncoderDecoderModule(nn.Module):
238
- config: SpeechEncoderDecoderConfig
239
- dtype: jnp.dtype = jnp.float32
240
-
241
- def setup(self):
242
- encoder_config = self.config.encoder
243
- decoder_config = self.config.decoder
244
-
245
- # TODO: configure FlaxAutoModel mappings (required when trialling different encoder-decoder combinations)
246
- encoder_module = FlaxWav2Vec2Module
247
- decoder_module = FlaxBartForCausalLMModule
248
-
249
- self.encoder = encoder_module(encoder_config, dtype=self.dtype)
250
- self.decoder = decoder_module(decoder_config, dtype=self.dtype)
251
-
252
- # encoder outputs might need to be projected to different dimension for decoder
253
- if (
254
- self.encoder.config.hidden_size != self.decoder.config.hidden_size
255
- and self.decoder.config.cross_attention_hidden_size is None
256
- ):
257
- self.enc_to_dec_proj = nn.Dense(
258
- self.decoder.config.hidden_size,
259
- kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range),
260
- dtype=self.dtype,
261
- )
262
- else:
263
- self.enc_to_dec_proj = None
264
-
265
- def _get_feat_extract_output_lengths(
266
- self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
267
- ):
268
- """
269
- Computes the output length of the convolutional layers
270
- """
271
-
272
- add_adapter = self.config.encoder.add_adapter if add_adapter is None else add_adapter
273
-
274
- def _conv_out_length(input_length, kernel_size, stride):
275
- # 1D convolutional layer output length formula taken
276
- # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
277
- return (input_length - kernel_size) // stride + 1
278
-
279
- for kernel_size, stride in zip(self.config.encoder.conv_kernel, self.config.encoder.conv_stride):
280
- input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
281
-
282
- if add_adapter:
283
- for _ in range(self.config.encoder.num_adapter_layers):
284
- input_lengths = _conv_out_length(input_lengths, 1, self.config.encoder.adapter_stride)
285
-
286
- return input_lengths
287
-
288
- def _get_encoder_module(self):
289
- return self.encoder
290
-
291
- def _get_projection_module(self):
292
- return self.enc_to_dec_proj
293
-
294
- def _get_decoder_module(self):
295
- return self.decoder
296
-
297
- def __call__(
298
- self,
299
- inputs,
300
- attention_mask,
301
- decoder_input_ids,
302
- decoder_attention_mask,
303
- decoder_position_ids,
304
- encoder_outputs=None,
305
- extract_features=None,
306
- output_attentions: bool = False,
307
- output_hidden_states: bool = False,
308
- output_features: bool = False,
309
- return_dict: bool = True,
310
- deterministic: bool = True,
311
- freeze_feature_encoder: bool = False,
312
- ):
313
- if encoder_outputs is None:
314
- encoder_outputs = self.encoder(
315
- inputs,
316
- attention_mask=attention_mask,
317
- extract_features=extract_features,
318
- output_attentions=output_attentions,
319
- output_hidden_states=output_hidden_states,
320
- output_features=output_features,
321
- return_dict=return_dict,
322
- deterministic=deterministic,
323
- freeze_feature_encoder=freeze_feature_encoder,
324
- )
325
-
326
- if output_features:
327
- return encoder_outputs
328
-
329
- encoder_hidden_states = encoder_outputs[0]
330
-
331
- # optionally project encoder_hidden_states
332
- if self.enc_to_dec_proj is not None:
333
- encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
334
-
335
- # compute correct encoder attention mask
336
- if attention_mask is not None:
337
- encoder_attention_mask = self.encoder._get_feature_vector_attention_mask(
338
- encoder_hidden_states.shape[1], attention_mask
339
- )
340
- else:
341
- encoder_attention_mask = None
342
-
343
- # flax script modeling_flax_wav2vec2.py
344
- decoder_outputs = self.decoder(
345
- input_ids=decoder_input_ids,
346
- attention_mask=decoder_attention_mask,
347
- position_ids=decoder_position_ids,
348
- encoder_hidden_states=encoder_hidden_states,
349
- encoder_attention_mask=encoder_attention_mask,
350
- output_attentions=output_attentions,
351
- output_hidden_states=output_hidden_states,
352
- return_dict=return_dict,
353
- deterministic=deterministic,
354
- )
355
-
356
- if not return_dict:
357
- return decoder_outputs + encoder_outputs
358
-
359
- return FlaxSeq2SeqLMOutput(
360
- logits=decoder_outputs.logits,
361
- decoder_hidden_states=decoder_outputs.hidden_states,
362
- decoder_attentions=decoder_outputs.attentions,
363
- cross_attentions=decoder_outputs.cross_attentions,
364
- encoder_last_hidden_state=encoder_hidden_states,
365
- encoder_hidden_states=encoder_outputs.hidden_states,
366
- encoder_attentions=encoder_outputs.attentions,
367
- )
368
-
369
-
370
- @add_start_docstrings(SPEECH_ENCODER_DECODER_START_DOCSTRING)
371
- class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
372
- r"""
373
- [`FlaxSpeechEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture
374
- with the module (flax.nn.Module) of one of the base model classes of the library as encoder module and another one
375
- as decoder module when created with the :meth*~transformers.FlaxAutoModel.from_pretrained* class method for the
376
- encoder and :meth*~transformers.FlaxAutoModelForCausalLM.from_pretrained* class method for the decoder.
377
- """
378
-
379
- config_class = SpeechEncoderDecoderConfig
380
- base_model_prefix: str = "speech_encoder_decoder"
381
- module_class = FlaxSpeechEncoderDecoderModule
382
-
383
- def __init__(
384
- self,
385
- config: SpeechEncoderDecoderConfig,
386
- input_shape: Optional[Tuple] = None,
387
- seed: int = 0,
388
- dtype: jnp.dtype = jnp.float32,
389
- _do_init: bool = True,
390
- **kwargs
391
- ):
392
-
393
- if not _do_init:
394
- raise ValueError(
395
- "`FlaxSpeechEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`."
396
- )
397
-
398
- if config.decoder.cross_attention_hidden_size is not None:
399
- # Raise ValueError or option to project enc to dec hidden_size (eg EncAdapterLayer)
400
- if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
401
- raise ValueError(
402
- "If `cross_attention_hidden_size` is specified in the decoder's configuration, "
403
- "it has to be equal to the encoder's `hidden_size`. "
404
- f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
405
- f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
406
- )
407
-
408
- # make sure input & output embeddings are not tied
409
- config.tie_word_embeddings = False
410
- module = self.module_class(config=config, dtype=dtype, **kwargs)
411
-
412
- if input_shape is None:
413
- # speech encoders almost always downsample the sequence length dimension
414
- encoder_input_length = 1024
415
- decoder_input_length = module._get_feat_extract_output_lengths(encoder_input_length)
416
- input_shape = ((1, encoder_input_length), (1, decoder_input_length))
417
-
418
- super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
419
-
420
- def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
421
- encoder_input_shape, decoder_input_shape = input_shape
422
-
423
- # init input DeviceArrays
424
- inputs = jnp.zeros(encoder_input_shape, dtype="f4")
425
- attention_mask = jnp.ones_like(inputs, dtype="i4")
426
- decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4")
427
- decoder_attention_mask = jnp.ones_like(decoder_input_ids)
428
-
429
- batch_size, sequence_length = inputs.shape
430
-
431
- decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape
432
- if not decoder_batch_size == batch_size:
433
- raise ValueError(
434
- f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder and {decoder_batch_size} for decoder."
435
- )
436
- decoder_position_ids = jnp.broadcast_to(
437
- jnp.arange(decoder_sequence_length)[None, :], (decoder_batch_size, decoder_sequence_length)
438
- )
439
-
440
- params_rng, dropout_rng = jax.random.split(rng)
441
- rngs = {"params": params_rng, "dropout": dropout_rng}
442
-
443
- return self.module.init(
444
- rngs,
445
- inputs,
446
- attention_mask,
447
- decoder_input_ids,
448
- decoder_attention_mask,
449
- decoder_position_ids,
450
- )["params"]
451
-
452
- def init_cache(self, batch_size, max_length, encoder_outputs):
453
- r"""
454
- Args:
455
- batch_size (`int`):
456
- batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
457
- max_length (`int`):
458
- maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
459
- cache.
460
- encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
461
- `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
462
- `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
463
- is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
464
- cross-attention of the decoder.
465
- """
466
- # init input variables to retrieve cache
467
- decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
468
- decoder_attention_mask = jnp.ones_like(decoder_input_ids)
469
- decoder_position_ids = jnp.broadcast_to(
470
- jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
471
- )
472
-
473
- def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
474
- decoder_module = module._get_decoder_module()
475
- return decoder_module(
476
- input_ids=decoder_input_ids,
477
- attention_mask=decoder_attention_mask,
478
- position_ids=decoder_position_ids,
479
- **kwargs,
480
- )
481
-
482
- init_variables = self.module.init(
483
- jax.random.PRNGKey(0),
484
- decoder_input_ids=decoder_input_ids,
485
- decoder_attention_mask=decoder_attention_mask,
486
- decoder_position_ids=decoder_position_ids,
487
- encoder_hidden_states=encoder_outputs[0],
488
- init_cache=True,
489
- method=_decoder_forward, # we only need to call the decoder to init the cache
490
- )
491
- return unfreeze(init_variables["cache"])
492
-
493
- def _get_feat_extract_output_lengths(
494
- self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
495
- ):
496
- return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter)
497
-
498
- @add_start_docstrings(SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING)
499
- @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=_CONFIG_FOR_DOC)
500
- def encode(
501
- self,
502
- inputs: jnp.ndarray,
503
- attention_mask: Optional[jnp.ndarray] = None,
504
- extract_features: Optional[jnp.ndarray] = None,
505
- output_attentions: Optional[bool] = None,
506
- output_hidden_states: Optional[bool] = None,
507
- output_features: Optional[bool] = None,
508
- return_dict: Optional[bool] = None,
509
- train: bool = False,
510
- freeze_feature_encoder: bool = False,
511
- params: dict = None,
512
- dropout_rng: PRNGKey = None,
513
- ):
514
- r"""
515
- Returns:
516
-
517
- Example:
518
-
519
- ```python
520
- >>> from transformers import FlaxSpeechEncoderDecoderModel
521
-
522
- >>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized
523
- >>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
524
- ... "facebook/wav2vec2-large-lv60", "facebook/bart-large"
525
- ... )
526
-
527
- >>> inputs = jnp.ones((2, 5000), dtype=jnp.float32)
528
- >>> encoder_outputs = model.encode(inputs)
529
- ```"""
530
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
531
- output_hidden_states = (
532
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
533
- )
534
- return_dict = return_dict if return_dict is not None else self.config.return_dict
535
-
536
- if attention_mask is None:
537
- attention_mask = jnp.ones_like(inputs, dtype="i4")
538
-
539
- if extract_features is not None:
540
- extract_features = jnp.array(extract_features, dtype="f4")
541
-
542
- # Handle any PRNG if needed
543
- rngs = {}
544
- if dropout_rng is not None:
545
- rngs["dropout"] = dropout_rng
546
-
547
- def _encoder_forward(module, inputs, attention_mask, **kwargs):
548
- encode_module = module._get_encoder_module()
549
- return encode_module(inputs, attention_mask, **kwargs)
550
-
551
- outputs = self.module.apply(
552
- {"params": params or self.params},
553
- inputs=jnp.array(inputs, dtype="f4"),
554
- attention_mask=jnp.array(attention_mask, dtype="i4"),
555
- extract_features=extract_features,
556
- output_attentions=output_attentions,
557
- output_hidden_states=output_hidden_states,
558
- output_features=output_features,
559
- return_dict=return_dict,
560
- deterministic=not train,
561
- freeze_feature_encoder=freeze_feature_encoder,
562
- rngs=rngs,
563
- method=_encoder_forward,
564
- )
565
-
566
- if return_dict and not output_features:
567
- outputs = FlaxBaseModelOutput(
568
- last_hidden_state=outputs.last_hidden_state,
569
- hidden_states=outputs.hidden_states,
570
- attentions=outputs.attentions,
571
- )
572
-
573
- return outputs
574
-
575
- @add_start_docstrings(SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING)
576
- @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
577
- def decode(
578
- self,
579
- decoder_input_ids,
580
- encoder_outputs,
581
- encoder_attention_mask: Optional[jnp.ndarray] = None,
582
- decoder_attention_mask: Optional[jnp.ndarray] = None,
583
- decoder_position_ids: Optional[jnp.ndarray] = None,
584
- past_key_values: dict = None,
585
- output_attentions: Optional[bool] = None,
586
- output_hidden_states: Optional[bool] = None,
587
- return_dict: Optional[bool] = None,
588
- train: bool = False,
589
- params: dict = None,
590
- dropout_rng: PRNGKey = None,
591
- ):
592
- r"""
593
- Returns:
594
-
595
- Example:
596
-
597
- ```python
598
- >>> from transformers import FlaxSpeechEncoderDecoderModel
599
- >>> import jax.numpy as jnp
600
-
601
- >>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized
602
- >>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
603
- ... "facebook/wav2vec2-large-lv60", "facebook/bart-large"
604
- ... )
605
-
606
- >>> inputs = jnp.ones((2, 5000), dtype=jnp.float32)
607
- >>> encoder_outputs = model.encode(inputs)
608
-
609
- >>> decoder_start_token_id = model.config.decoder.bos_token_id
610
- >>> decoder_input_ids = jnp.ones((inputs.shape[0], 1), dtype="i4") * decoder_start_token_id
611
-
612
- >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
613
- >>> logits = outputs.logits
614
- ```"""
615
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
616
- output_hidden_states = (
617
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
618
- )
619
- return_dict = return_dict if return_dict is not None else self.config.return_dict
620
-
621
- encoder_hidden_states = encoder_outputs[0]
622
- if encoder_attention_mask is None:
623
- batch_size, sequence_length = encoder_hidden_states.shape[:2]
624
- encoder_attention_mask = jnp.ones((batch_size, sequence_length))
625
-
626
- batch_size, sequence_length = decoder_input_ids.shape
627
- if decoder_attention_mask is None:
628
- decoder_attention_mask = jnp.ones((batch_size, sequence_length))
629
-
630
- if decoder_position_ids is None:
631
- if past_key_values is not None:
632
- raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
633
-
634
- decoder_position_ids = jnp.broadcast_to(
635
- jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
636
- )
637
-
638
- # Handle any PRNG if needed
639
- rngs = {}
640
- if dropout_rng is not None:
641
- rngs["dropout"] = dropout_rng
642
-
643
- params = {"params": params or self.params}
644
-
645
- # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
646
- # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
647
- # it can be changed by FlaxBartAttention module
648
- if past_key_values:
649
- params["cache"] = past_key_values
650
- mutable = ["cache"]
651
- else:
652
- mutable = False
653
-
654
- def _decoder_forward(
655
- module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, encoder_hidden_states, **kwargs
656
- ):
657
-
658
- projection_module = module._get_projection_module()
659
- decoder_module = module._get_decoder_module()
660
-
661
- # optionally project encoder_hidden_states
662
- if projection_module is not None:
663
- encoder_hidden_states = projection_module(encoder_hidden_states)
664
-
665
- return decoder_module(
666
- decoder_input_ids,
667
- decoder_attention_mask,
668
- decoder_position_ids,
669
- encoder_hidden_states,
670
- **kwargs,
671
- )
672
-
673
- outputs = self.module.apply(
674
- params,
675
- decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
676
- decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
677
- decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
678
- encoder_hidden_states=encoder_hidden_states,
679
- encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
680
- output_attentions=output_attentions,
681
- output_hidden_states=output_hidden_states,
682
- return_dict=return_dict,
683
- deterministic=not train,
684
- rngs=rngs,
685
- mutable=mutable,
686
- method=_decoder_forward,
687
- )
688
-
689
- # add updated cache to model output
690
- if past_key_values is not None and return_dict:
691
- outputs, past = outputs
692
- outputs["past_key_values"] = unfreeze(past["cache"])
693
- return outputs
694
- elif past_key_values is not None and not return_dict:
695
- outputs, past = outputs
696
- outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
697
-
698
- return outputs
699
-
700
- @add_start_docstrings_to_model_forward(SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING)
701
- @replace_return_docstrings(output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
702
- def __call__(
703
- self,
704
- inputs: jnp.ndarray,
705
- attention_mask: Optional[jnp.ndarray] = None,
706
- extract_features: Optional[jnp.ndarray] = None,
707
- decoder_input_ids: Optional[jnp.ndarray] = None,
708
- decoder_attention_mask: Optional[jnp.ndarray] = None,
709
- decoder_position_ids: Optional[jnp.ndarray] = None,
710
- output_attentions: Optional[bool] = None,
711
- output_hidden_states: Optional[bool] = None,
712
- output_features: Optional[bool] = None,
713
- return_dict: Optional[bool] = None,
714
- train: bool = False,
715
- freeze_feature_encoder: bool = False,
716
- params: dict = None,
717
- dropout_rng: PRNGKey = None,
718
- ):
719
- r"""
720
- Returns:
721
-
722
- Examples:
723
-
724
- ```python
725
- >>> from transformers import FlaxSpeechEncoderDecoderModel, BartTokenizer
726
-
727
- >>> # load a fine-tuned wav2vec2-2-bart model
728
- >>> model = FlaxSpeechEncoderDecoderModel.from_pretrained("patrickvonplaten/wav2vec2-2-bart-large")
729
- >>> # load output tokenizer
730
- >>> tokenizer_output = BartTokenizer.from_pretrained("facebook/bart-large")
731
-
732
- >>> inputs = jnp.ones((2, 5000), dtype=jnp.float32)
733
-
734
- >>> # use bart's special bos, pad and eos tokens
735
- >>> model.config.decoder_start_token_id = model.decoder.config.bos_token_id
736
- >>> model.config.pad_token_id = model.decoder.config.pad_token_id
737
- >>> model.config.eos_token_id = model.decoder.config.eos_token_id
738
-
739
- >>> outputs = model.generate(inputs)
740
- # Assert something? More interesting input? dtype correct?
741
- ```
742
- """
743
-
744
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
745
- output_hidden_states = (
746
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
747
- )
748
- return_dict = return_dict if return_dict is not None else self.config.return_dict
749
-
750
- # prepare encoder inputs
751
- if attention_mask is None:
752
- attention_mask = jnp.ones_like(inputs, dtype="i4")
753
-
754
- if extract_features is not None:
755
- inputs = None # we can omit passing the inputs to the model to save memory
756
- extract_features = jnp.array(extract_features, dtype="f4")
757
- else:
758
- inputs = jnp.array(inputs, dtype="f4")
759
-
760
- # prepare decoder inputs
761
- if decoder_input_ids is None:
762
- raise ValueError(
763
- "`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must be specified as an input argument."
764
- )
765
- if decoder_attention_mask is None:
766
- decoder_attention_mask = jnp.ones_like(decoder_input_ids)
767
- if decoder_position_ids is None:
768
- batch_size, sequence_length = decoder_input_ids.shape
769
- decoder_position_ids = jnp.broadcast_to(
770
- jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
771
- )
772
-
773
- # Handle any PRNG if needed
774
- rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
775
-
776
- return self.module.apply(
777
- {"params": params or self.params},
778
- inputs=inputs,
779
- attention_mask=jnp.array(attention_mask, dtype="i4"),
780
- extract_features=extract_features,
781
- decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
782
- decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
783
- decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
784
- output_attentions=output_attentions,
785
- output_hidden_states=output_hidden_states,
786
- output_features=output_features,
787
- return_dict=return_dict,
788
- deterministic=not train,
789
- freeze_feature_encoder=freeze_feature_encoder,
790
- rngs=rngs,
791
- )
792
-
793
- def prepare_inputs_for_generation(
794
- self,
795
- decoder_input_ids,
796
- max_length,
797
- attention_mask: Optional[jnp.DeviceArray] = None,
798
- decoder_attention_mask: Optional[jnp.DeviceArray] = None,
799
- encoder_outputs=None,
800
- **kwargs
801
- ):
802
- # initializing the cache
803
- batch_size, seq_length = decoder_input_ids.shape
804
-
805
- past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
806
- # Note that usually one would have to put 0's in the attention_mask for x > input.shape[-1] and x < cache_length.
807
- # But since the decoder uses a causal mask, those positions are masked anyways.
808
- # Thus we can create a single static attention_mask here, which is more efficient for compilation
809
- extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
810
- if decoder_attention_mask is not None:
811
- decoder_position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
812
- extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
813
- else:
814
- decoder_position_ids = jnp.broadcast_to(
815
- jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
816
- )
817
-
818
- return {
819
- "past_key_values": past_key_values,
820
- "encoder_outputs": encoder_outputs,
821
- "encoder_attention_mask": attention_mask,
822
- "decoder_attention_mask": extended_attention_mask,
823
- "decoder_position_ids": decoder_position_ids,
824
- }
825
-
826
- def update_inputs_for_generation(self, model_outputs, model_kwargs):
827
- model_kwargs["past_key_values"] = model_outputs.past_key_values
828
- model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
829
- return model_kwargs
830
-
831
- @classmethod
832
- def from_encoder_decoder_pretrained(
833
- cls,
834
- encoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
835
- decoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
836
- *model_args,
837
- **kwargs
838
- ) -> FlaxPreTrainedModel:
839
- r"""
840
- Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model
841
- checkpoints.
842
-
843
- Params:
844
- encoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*):
845
- Information necessary to initiate the encoder. Can be either:
846
-
847
- - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
848
- Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
849
- user or organization name, like `dbmdz/bert-base-german-cased`.
850
- - A path to a *directory* containing model weights saved using
851
- [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
852
-
853
- decoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*, defaults to `None`):
854
- Information necessary to initiate the decoder. Can be either:
855
-
856
- - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
857
- Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
858
- user or organization name, like `dbmdz/bert-base-german-cased`.
859
- - A path to a *directory* containing model weights saved using
860
- [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
861
-
862
- model_args (remaining positional arguments, *optional*):
863
- All remaning positional arguments will be passed to the underlying model's `__init__` method.
864
-
865
- kwargs (remaining dictionary of keyword arguments, *optional*):
866
- Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
867
- `output_attentions=True`).
868
-
869
- - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter.
870
- - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
871
- - To update the parent model configuration, do not use a prefix for each configuration parameter.
872
-
873
- Behaves differently depending on whether a `config` is provided or automatically loaded.
874
-
875
- Example:
876
-
877
- ```python
878
- >>> from transformers import FlaxSpeechEncoderDecoderModel
879
-
880
- >>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized
881
- >>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
882
- ... "facebook/wav2vec2-large-lv60", "facebook/bart-large"
883
- ... )
884
- >>> # saving model after fine-tuning
885
- >>> model.save_pretrained("./wav2vec2-2-bart-large")
886
- >>> # load fine-tuned model
887
- >>> model = FlaxSpeechEncoderDecoderModel.from_pretrained("./wav2vec2-2-bart-large")
888
- ```"""
889
-
890
- kwargs_encoder = {
891
- argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
892
- }
893
-
894
- kwargs_decoder = {
895
- argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
896
- }
897
-
898
- # remove encoder, decoder kwargs from kwargs
899
- for key in kwargs_encoder.keys():
900
- del kwargs["encoder_" + key]
901
- for key in kwargs_decoder.keys():
902
- del kwargs["decoder_" + key]
903
-
904
- # Load and initialize the encoder and decoder
905
- # The distinction between encoder and decoder at the model level is made
906
- # by the value of the flag `is_decoder` that we need to set correctly.
907
- encoder = kwargs_encoder.pop("model", None)
908
- if encoder is None:
909
- if encoder_pretrained_model_name_or_path is None:
910
- raise ValueError(
911
- "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
912
- "to be defined."
913
- )
914
-
915
- if "config" not in kwargs_encoder:
916
- # TODO: AutoConfig .from_pretrained
917
- encoder_config, kwargs_encoder = Wav2Vec2Config.from_pretrained(
918
- encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True
919
- )
920
- if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
921
- logger.info(
922
- f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
923
- "from a decoder model. Cross-attention and casual mask are disabled."
924
- )
925
- encoder_config.is_decoder = False
926
- encoder_config.add_cross_attention = False
927
-
928
- kwargs_encoder["config"] = encoder_config
929
-
930
- # TODO: FlaxAutoModel .from_pretrained
931
- encoder = FlaxWav2Vec2Model.from_pretrained(
932
- encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder
933
- )
934
-
935
- decoder = kwargs_decoder.pop("model", None)
936
- if decoder is None:
937
- if decoder_pretrained_model_name_or_path is None:
938
- raise ValueError(
939
- "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
940
- "to be defined."
941
- )
942
-
943
- if "config" not in kwargs_decoder:
944
- # TODO: AutoConfig .from_pretrained
945
- decoder_config, kwargs_decoder = BartConfig.from_pretrained(
946
- decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
947
- )
948
- if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
949
- logger.info(
950
- f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
951
- f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
952
- f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
953
- "cross attention layers."
954
- )
955
- decoder_config.is_decoder = True
956
- decoder_config.add_cross_attention = True
957
-
958
- kwargs_decoder["config"] = decoder_config
959
-
960
- if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
961
- logger.warning(
962
- f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
963
- f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
964
- "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
965
- "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
966
- "`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
967
- )
968
-
969
- # TODO: FlaxAutoModelForCausalLM .from_pretrained
970
- decoder = FlaxBartForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
971
-
972
- # instantiate config with corresponding kwargs
973
- dtype = kwargs.pop("dtype", jnp.float32)
974
- config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
975
-
976
- # make sure input & output word embeddings are not tied
977
- config.tie_word_embeddings = False
978
-
979
- # init model
980
- model = cls(config, dtype=dtype)
981
- model.params["encoder"] = encoder.params
982
- model.params["decoder"] = decoder.params
983
-
984
- return model
985
-
986
- def _beam_search(
987
- self,
988
- input_ids: None,
989
- max_length: Optional[int] = None,
990
- pad_token_id: Optional[int] = None,
991
- eos_token_id: Optional[int] = None,
992
- length_penalty: Optional[float] = None,
993
- early_stopping: Optional[bool] = None,
994
- logits_processor: Optional[FlaxLogitsProcessorList] = None,
995
- trace: bool = True,
996
- params: Optional[Dict[str, jnp.ndarray]] = None,
997
- model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
998
- ):
999
- """
1000
- This beam search function is heavily inspired by Flax's official example:
1001
- https://github.com/google/flax/blob/master/examples/wmt/train.py#L254
1002
- """
1003
-
1004
- def flatten_beam_dim(tensor):
1005
- """Flattens the first two dimensions of a non-scalar array."""
1006
- # ignore scalars (e.g. cache index)
1007
- if tensor.ndim == 0 or tensor.ndim == 1:
1008
- return tensor
1009
- elif tensor.ndim == 6:
1010
- return tensor.reshape(tensor.shape[:1] + (tensor.shape[1] * tensor.shape[2],) + tensor.shape[3:])
1011
- return tensor.reshape((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:])
1012
-
1013
- def unflatten_beam_dim(tensor, batch_size, num_beams):
1014
- """Unflattens the first, flat batch*beam dimension of a non-scalar array."""
1015
- # ignore scalars (e.g. cache index)
1016
- if tensor.ndim == 0 or tensor.ndim == 1:
1017
- return tensor
1018
- if tensor.ndim == 5:
1019
- return tensor.reshape(tensor.shape[:1] + (batch_size, num_beams) + tensor.shape[2:])
1020
- return tensor.reshape((batch_size, num_beams) + tensor.shape[1:])
1021
-
1022
- def gather_beams(nested, beam_indices, batch_size, new_num_beams):
1023
- """
1024
- Gathers the beam slices indexed by beam_indices into new beam array.
1025
- """
1026
- batch_indices = jnp.reshape(
1027
- jnp.arange(batch_size * new_num_beams) // new_num_beams, (batch_size, new_num_beams)
1028
- )
1029
-
1030
- def gather_fn(tensor):
1031
- # ignore scalars (e.g. cache index)
1032
- if tensor.ndim == 0 or tensor.ndim == 1:
1033
- return tensor
1034
- if tensor.ndim == 6:
1035
- return tensor[:, batch_indices, beam_indices]
1036
- return tensor[batch_indices, beam_indices]
1037
-
1038
- return jax.tree_map(gather_fn, nested)
1039
-
1040
- # init values
1041
- max_length = max_length if max_length is not None else self.config.max_length
1042
- pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
1043
- eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
1044
- length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
1045
- early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
1046
-
1047
- batch_size, num_beams, cur_len = input_ids.shape
1048
-
1049
- eos_token_id = jnp.array(eos_token_id)
1050
- pad_token_id = jnp.array(pad_token_id)
1051
- cur_len = jnp.array(cur_len)
1052
-
1053
- # per batch,beam-item holding current token in loop.
1054
- sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
1055
- running_sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
1056
- running_sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0, 0))
1057
-
1058
- # per batch,beam-item state bit indicating if sentence has finished.
1059
- is_sent_finished = jnp.zeros((batch_size, num_beams), dtype=jnp.bool_)
1060
-
1061
- # per batch,beam-item score, logprobs
1062
- running_scores = jnp.tile(jnp.array([0.0] + [np.array(-1.0e7)] * (num_beams - 1)), [batch_size, 1])
1063
- scores = jnp.ones((batch_size, num_beams)) * np.array(-1.0e7)
1064
-
1065
- # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
1066
- # and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
1067
- model = self.decode if self.config.is_encoder_decoder else self
1068
-
1069
- # flatten beam dim
1070
- if "encoder_outputs" in model_kwargs:
1071
- model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim(
1072
- model_kwargs["encoder_outputs"]["last_hidden_state"]
1073
- )
1074
- if "attention_mask" in model_kwargs:
1075
- model_kwargs["attention_mask"] = flatten_beam_dim(model_kwargs["attention_mask"])
1076
-
1077
- # initialize model specific kwargs
1078
- model_kwargs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), max_length, **model_kwargs)
1079
-
1080
- # initialize state
1081
- state = BeamSearchState(
1082
- cur_len=cur_len,
1083
- running_sequences=running_sequences,
1084
- running_scores=running_scores,
1085
- sequences=sequences,
1086
- scores=scores,
1087
- is_sent_finished=is_sent_finished,
1088
- model_kwargs=model_kwargs,
1089
- )
1090
-
1091
- def beam_search_cond_fn(state):
1092
- """beam search state termination condition fn."""
1093
-
1094
- # 1. is less than max length?
1095
- not_max_length_yet = state.cur_len < max_length
1096
-
1097
- # 2. can the new beams still improve?
1098
- best_running_score = state.running_scores[:, -1:] / (max_length**length_penalty)
1099
- worst_finished_score = jnp.where(
1100
- state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7)
1101
- )
1102
- improvement_still_possible = jnp.all(worst_finished_score < best_running_score)
1103
-
1104
- # 3. is there still a beam that has not finished?
1105
- still_open_beam = ~(jnp.all(state.is_sent_finished) & early_stopping)
1106
-
1107
- return not_max_length_yet & still_open_beam & improvement_still_possible
1108
-
1109
- def beam_search_body_fn(state, input_ids_length=1):
1110
- """beam search state update fn."""
1111
- # 1. Forward current tokens
1112
- # Collect the current position slice along length to feed the fast
1113
- # autoregressive decoder model. Flatten the beam dimension into batch
1114
- # dimension for feeding into the model.
1115
- # unflatten beam dimension
1116
- # Unflatten beam dimension in attention cache arrays
1117
- input_token = flatten_beam_dim(
1118
- lax.dynamic_slice(
1119
- state.running_sequences,
1120
- (0, 0, state.cur_len - input_ids_length),
1121
- (batch_size, num_beams, input_ids_length),
1122
- )
1123
- )
1124
- model_outputs = model(input_token, params=params, **state.model_kwargs)
1125
-
1126
- logits = unflatten_beam_dim(model_outputs.logits[:, -1], batch_size, num_beams)
1127
- cache = jax.tree_map(
1128
- lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams), model_outputs.past_key_values
1129
- )
1130
-
1131
- # adapt logits for FlaxMarianMTModel
1132
- logits = self._adapt_logits_for_beam_search(logits)
1133
-
1134
- # 2. Compute log probs
1135
- # get log probabilities from logits,
1136
- # process logits with processors (*e.g.* min_length, ...), and
1137
- # add new logprobs to existing running logprobs scores.
1138
- log_probs = jax.nn.log_softmax(logits)
1139
- log_probs = logits_processor(
1140
- flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), state.cur_len
1141
- )
1142
- log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
1143
- log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2)
1144
- vocab_size = log_probs.shape[2]
1145
- log_probs = log_probs.reshape((batch_size, num_beams * vocab_size))
1146
-
1147
- # 3. Retrieve top-K
1148
- # Each item in batch has num_beams * vocab_size candidate sequences.
1149
- # For each item, get the top 2*k candidates with the highest log-
1150
- # probabilities. We gather the top 2*K beams here so that even if the best
1151
- # K sequences reach EOS simultaneously, we have another K sequences
1152
- # remaining to continue the live beam search.
1153
- # Gather the top 2*K scores from _all_ beams.
1154
- # Gather 2*k top beams.
1155
- # Recover the beam index by floor division.
1156
- # Recover token id by modulo division and expand Id array for broadcasting.
1157
- # Update sequences for the 2*K top-k new sequences.
1158
- beams_to_keep = 2 * num_beams
1159
- topk_log_probs, topk_indices = lax.top_k(log_probs, k=beams_to_keep)
1160
- topk_beam_indices = topk_indices // vocab_size
1161
- topk_running_sequences = gather_beams(
1162
- state.running_sequences, topk_beam_indices, batch_size, beams_to_keep
1163
- )
1164
- topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
1165
- topk_sequences = lax.dynamic_update_slice(topk_running_sequences, topk_ids, (0, 0, state.cur_len))
1166
-
1167
- # 4. Check which sequences have ended
1168
- # Update current sequences:
1169
- # Did any of these sequences reach an end marker?
1170
- # To prevent these just finished sequences from being added to the current sequences
1171
- # set of active beam search sequences, set their log probs to a very large
1172
- # negative value.
1173
- did_topk_just_finished = topk_sequences[:, :, state.cur_len] == eos_token_id
1174
- running_topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(-1.0e7)
1175
- # 5. Get running sequences scores for next
1176
- # Determine the top k beam indices (from top 2*k beams) from log probs
1177
- # and gather top k beams (from top 2*k beams).
1178
- next_topk_indices = jnp.flip(lax.top_k(running_topk_log_probs, k=num_beams)[1], axis=1)
1179
- next_running_sequences, next_running_scores = gather_beams(
1180
- [topk_sequences, running_topk_log_probs], next_topk_indices, batch_size, num_beams
1181
- )
1182
-
1183
- # 6. Process topk logits
1184
- # Further process log probs:
1185
- # - add length penalty
1186
- # - make sure no scores can be added anymore if beam is full
1187
- # - make sure still running sequences cannot be chosen as finalized beam
1188
- topk_log_probs = topk_log_probs / (state.cur_len**length_penalty)
1189
- beams_in_batch_are_full = (
1190
- jnp.broadcast_to(state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape)
1191
- & early_stopping
1192
- )
1193
- add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
1194
- topk_log_probs += add_penalty * np.array(-1.0e7)
1195
-
1196
- # 7. Get scores, sequences, is sentence finished for next.
1197
- # Combine sequences, scores, and flags along the beam dimension and compare
1198
- # new finished sequence scores to existing finished scores and select the
1199
- # best from the new set of beams
1200
- merged_sequences = jnp.concatenate([state.sequences, topk_sequences], axis=1)
1201
- merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1)
1202
- merged_is_sent_finished = jnp.concatenate([state.is_sent_finished, did_topk_just_finished], axis=1)
1203
- topk_merged_indices = jnp.flip(lax.top_k(merged_scores, k=num_beams)[1], axis=1)
1204
- next_sequences, next_scores, next_is_sent_finished = gather_beams(
1205
- [merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams
1206
- )
1207
-
1208
- # 8. Update model kwargs.
1209
- # Determine the top k beam indices from the original set of all beams.
1210
- # With these, gather the top k beam-associated caches.
1211
- next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams)
1212
- next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams)
1213
- model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache)
1214
- next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
1215
-
1216
- return BeamSearchState(
1217
- cur_len=state.cur_len + 1,
1218
- running_scores=next_running_scores,
1219
- running_sequences=next_running_sequences,
1220
- scores=next_scores,
1221
- sequences=next_sequences,
1222
- is_sent_finished=next_is_sent_finished,
1223
- model_kwargs=next_model_kwargs,
1224
- )
1225
-
1226
- # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
1227
- if input_ids.shape[-1] > 1:
1228
- state = partial(beam_search_body_fn, input_ids_length=input_ids.shape[-1])(state)
1229
-
1230
- if not trace:
1231
- state = self._run_loop_in_debug(beam_search_cond_fn, beam_search_body_fn, state)
1232
- else:
1233
- state = lax.while_loop(beam_search_cond_fn, beam_search_body_fn, state)
1234
-
1235
- # Account for the edge-case where there are no finished sequences for a
1236
- # particular batch item. If so, return running sequences for that batch item.
1237
- none_finished = jnp.any(state.is_sent_finished, axis=1)
1238
- sequences = jnp.where(none_finished[:, None, None], state.sequences, state.running_sequences)
1239
- scores = jnp.where(none_finished[:, None], state.scores, state.running_scores)
1240
-
1241
- # return all beams for each batch and the best score
1242
- sequences = sequences[:, :]
1243
- scores = scores[:, -1]
1244
-
1245
- return FlaxBeamSearchOutput(sequences=sequences, scores=scores)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/modeling_flax_wav2vec2.py DELETED
@@ -1,975 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """ Flax Wav2Vec2 model."""
16
-
17
- from functools import partial
18
- from typing import Optional, Tuple, Union
19
-
20
- import flax
21
- import flax.linen as nn
22
- import jax
23
- import jax.numpy as jnp
24
- from flax.core.frozen_dict import FrozenDict
25
- from flax.linen import partitioning as nn_partitioning
26
- from flax.linen.attention import dot_product_attention_weights
27
- from jax import lax
28
-
29
- from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
30
- from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
31
- from transformers.utils import ModelOutput
32
-
33
- from models import Wav2Vec2Config
34
-
35
- scan_with_axes = nn_partitioning.scan_with_axes
36
- remat = nn_partitioning.remat
37
-
38
-
39
- @flax.struct.dataclass
40
- class FlaxWav2Vec2BaseModelOutput(ModelOutput):
41
- """
42
- Output type of [`FlaxWav2Vec2BaseModelOutput`], with potential hidden states and attentions.
43
-
44
- Args:
45
- last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
46
- Sequence of hidden-states at the output of the last layer of the model.
47
- extract_features (`jnp.ndarray` of shape `(batch_size, sequence_length, last_conv_dim)`):
48
- Sequence of extracted feature vectors of the last convolutional layer of the model with `last_conv_dim`
49
- being the dimension of the last convolutional layer.
50
- hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
51
- Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
52
- `(batch_size, sequence_length, hidden_size)`.
53
-
54
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
55
- attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
56
- Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
57
- sequence_length)`.
58
-
59
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
60
- heads.
61
- """
62
-
63
- last_hidden_state: jnp.ndarray = None
64
- extract_features: jnp.ndarray = None
65
- hidden_states: Optional[Tuple[jnp.ndarray]] = None
66
- attentions: Optional[Tuple[jnp.ndarray]] = None
67
-
68
-
69
- WAV_2_VEC_2_START_DOCSTRING = r"""
70
- Wav2Vec2 was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
71
- Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
72
- Auli.
73
-
74
- This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
75
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
76
- etc.)
77
-
78
- This model is also a Flax Linen
79
- [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
80
- regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
81
-
82
- Finally, this model supports inherent JAX features such as:
83
-
84
- - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
85
- - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
86
- - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
87
- - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
88
-
89
- Parameters:
90
- config ([`Wav2Vec2Config`]): Model configuration class with all the parameters of the model.
91
- Initializing with a config file does not load the weights associated with the model, only the
92
- configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
93
- dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
94
- The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
95
- `jax.numpy.bfloat16` (on TPUs).
96
-
97
- This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
98
- specified all the computation will be performed with the given `dtype`.
99
-
100
- **Note that this only specifies the dtype of the computation and does not influence the dtype of model
101
- parameters.**
102
-
103
- If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
104
- [`~FlaxPreTrainedModel.to_bf16`].
105
- """
106
-
107
-
108
- WAV_2_VEC_2_INPUTS_DOCSTRING = r"""
109
- Args:
110
- input_values (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
111
- Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file
112
- into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile library (*pip install
113
- soundfile*). To prepare the array into *input_values*, the [`Wav2Vec2Processor`] should be used for padding
114
- and conversion into a tensor of type *jnp.ndarray*. See [`Wav2Vec2Processor.__call__`] for details.
115
- attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
116
- Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
117
- 1]`:
118
-
119
- - 1 for tokens that are **not masked**,
120
- - 0 for tokens that are **masked**.
121
-
122
- [What are attention masks?](../glossary#attention-mask) .. warning:: `attention_mask` should only be passed
123
- if the corresponding processor has `config.return_attention_mask == True`. For all models whose processor
124
- has `config.return_attention_mask == False`, such as
125
- [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base-960h), `attention_mask` should **not** be
126
- passed to avoid degraded performance when doing batched inference. For such models `input_values` should
127
- simply be padded with 0 and passed without `attention_mask`. Be aware that these models also yield slightly
128
- different results depending on whether `input_values` is padded or not.
129
- mask_time_indices (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
130
- Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
131
- masked extracted features in *config.proj_codevector_dim* space.
132
- output_attentions (`bool`, *optional*):
133
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
134
- tensors for more detail.
135
- output_hidden_states (`bool`, *optional*):
136
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
137
- more detail.
138
- return_dict (`bool`, *optional*):
139
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
140
- """
141
-
142
-
143
- class FlaxWav2Vec2LayerNormConvLayer(nn.Module):
144
- config: Wav2Vec2Config
145
- layer_id: int = 0
146
- dtype: jnp.dtype = jnp.float32
147
-
148
- def setup(self):
149
- self.in_conv_dim = self.config.conv_dim[self.layer_id] if self.layer_id > 0 else 1
150
- self.out_conv_dim = self.config.conv_dim[self.layer_id]
151
-
152
- self.conv = nn.Conv(
153
- features=self.config.conv_dim[self.layer_id],
154
- kernel_size=(self.config.conv_kernel[self.layer_id],),
155
- strides=(self.config.conv_stride[self.layer_id],),
156
- use_bias=self.config.conv_bias,
157
- kernel_init=jax.nn.initializers.he_normal(),
158
- padding="VALID",
159
- dtype=self.dtype,
160
- )
161
- self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
162
- self.activation = ACT2FN[self.config.feat_extract_activation]
163
-
164
- def __call__(self, hidden_states):
165
- hidden_states = self.conv(hidden_states)
166
- hidden_states = self.layer_norm(hidden_states)
167
- hidden_states = self.activation(hidden_states)
168
- return hidden_states
169
-
170
-
171
- class FlaxConvWithWeightNorm(nn.Module):
172
- config: Wav2Vec2Config
173
- dtype: jnp.dtype = jnp.float32
174
-
175
- def setup(self):
176
- self.conv = nn.Conv(
177
- features=self.config.hidden_size,
178
- kernel_size=(self.config.num_conv_pos_embeddings,),
179
- kernel_init=jax.nn.initializers.he_normal(),
180
- padding="VALID",
181
- feature_group_count=self.config.num_conv_pos_embedding_groups,
182
- dtype=self.dtype,
183
- )
184
- weight_shape = (
185
- self.conv.features,
186
- self.conv.features // self.conv.feature_group_count,
187
- self.conv.kernel_size[0],
188
- )
189
- self.weight_v = self.param("weight_v", jax.nn.initializers.he_normal(), weight_shape)
190
- self.weight_g = self.param("weight_g", lambda _: jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :])
191
- self.bias = self.param("bias", jax.nn.initializers.zeros, (self.conv.features,))
192
- self.prev_padding = self.conv.kernel_size[0] // 2
193
-
194
- def _get_normed_weights(self):
195
- weight_v_norm = jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :]
196
- normed_weight_v = jnp.divide(self.weight_v, weight_v_norm)
197
- normed_kernel = jnp.multiply(normed_weight_v, self.weight_g)
198
- return normed_kernel
199
-
200
- def __call__(self, hidden_states):
201
- kernel = self._get_normed_weights()
202
- hidden_states = jnp.pad(hidden_states, ((0, 0), (self.prev_padding, self.prev_padding), (0, 0)))
203
- hidden_states = self.conv.apply({"params": {"kernel": kernel.T, "bias": self.bias}}, hidden_states)
204
- return hidden_states
205
-
206
-
207
- class FlaxWav2Vec2PositionalConvEmbedding(nn.Module):
208
- config: Wav2Vec2Config
209
- dtype: jnp.dtype = jnp.float32
210
-
211
- def setup(self):
212
- self.conv = FlaxConvWithWeightNorm(self.config, dtype=self.dtype)
213
- self.activation = ACT2FN[self.config.feat_extract_activation]
214
- self.num_pad_remove = 1 if self.config.num_conv_pos_embeddings % 2 == 0 else 0
215
-
216
- def __call__(self, hidden_states):
217
- hidden_states = hidden_states.transpose((0, 1, 2))
218
-
219
- hidden_states = self.conv(hidden_states)
220
-
221
- if self.num_pad_remove > 0:
222
- hidden_states = hidden_states[:, : -self.num_pad_remove, :]
223
- hidden_states = self.activation(hidden_states)
224
-
225
- hidden_states = hidden_states.transpose((0, 1, 2))
226
- return hidden_states
227
-
228
-
229
- class FlaxConvLayersCollection(nn.Module):
230
- config: Wav2Vec2Config
231
- dtype: jnp.dtype = jnp.float32
232
-
233
- def setup(self):
234
- if self.config.feat_extract_norm == "layer":
235
- # note that we can't use scan on the conv layers as they differ on a layer-by-layer basis
236
- BlockLayer = remat(FlaxWav2Vec2LayerNormConvLayer) if self.config.gradient_checkpointing else FlaxWav2Vec2LayerNormConvLayer
237
- self.layers = [
238
- BlockLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype)
239
- for i in range(self.config.num_feat_extract_layers)
240
- ]
241
- elif self.config.feat_extract_norm == "group":
242
- raise NotImplementedError("At the moment only ``config.feat_extact_norm == 'layer'`` is supported")
243
- else:
244
- raise ValueError(
245
- f"`config.feat_extract_norm` is {self.config.feat_extract_norm}, but has to be one of ['group', 'layer']"
246
- )
247
-
248
- def __call__(self, hidden_states):
249
- for i, conv_layer in enumerate(self.layers):
250
- hidden_states = conv_layer(hidden_states)
251
- return hidden_states
252
-
253
-
254
- class FlaxWav2Vec2FeatureEncoder(nn.Module):
255
- """Construct the features from raw audio waveform"""
256
-
257
- config: Wav2Vec2Config
258
- dtype: jnp.dtype = jnp.float32
259
-
260
- def setup(self):
261
- self.conv_layers = FlaxConvLayersCollection(self.config, dtype=self.dtype)
262
-
263
- def __call__(self, input_values, freeze_feature_encoder=False):
264
- hidden_states = input_values[:, :, None]
265
- hidden_states = self.conv_layers(hidden_states)
266
- if freeze_feature_encoder:
267
- hidden_states = jax.lax.stop_gradient(hidden_states)
268
- return hidden_states
269
-
270
-
271
- class FlaxWav2Vec2FeatureProjection(nn.Module):
272
- config: Wav2Vec2Config
273
- dtype: jnp.dtype = jnp.float32
274
-
275
- def setup(self):
276
- self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
277
- self.projection = nn.Dense(
278
- self.config.hidden_size,
279
- kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
280
- dtype=self.dtype,
281
- )
282
- self.dropout = nn.Dropout(rate=self.config.feat_proj_dropout)
283
-
284
- def __call__(self, hidden_states, deterministic=True):
285
- norm_hidden_states = self.layer_norm(hidden_states)
286
- hidden_states = self.projection(norm_hidden_states)
287
- hidden_states = self.dropout(hidden_states, deterministic=deterministic)
288
- return hidden_states, norm_hidden_states
289
-
290
-
291
- class FlaxWav2Vec2Attention(nn.Module):
292
- config: Wav2Vec2Config
293
- embed_dim: int
294
- num_heads: int
295
- dropout: float = 0.0
296
- bias: bool = True
297
- dtype: jnp.dtype = jnp.float32 # the dtype of the computation
298
-
299
- def setup(self) -> None:
300
- self.head_dim = self.embed_dim // self.num_heads
301
- if self.head_dim * self.num_heads != self.embed_dim:
302
- raise ValueError(
303
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
304
- )
305
-
306
- dense = partial(
307
- nn.Dense,
308
- self.embed_dim,
309
- use_bias=self.bias,
310
- dtype=self.dtype,
311
- kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
312
- )
313
-
314
- self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
315
-
316
- self.fused_proj = nn.Dense(
317
- self.embed_dim * 3,
318
- use_bias=self.bias,
319
- dtype=self.dtype,
320
- kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
321
- )
322
-
323
- self.out_proj = dense()
324
-
325
- self.dropout_layer = nn.Dropout(rate=self.dropout)
326
-
327
- def _split_heads(self, hidden_states):
328
- return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
329
-
330
- def _merge_heads(self, hidden_states):
331
- return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
332
-
333
- def __call__(
334
- self,
335
- hidden_states: jnp.ndarray,
336
- key_value_states: Optional[jnp.ndarray] = None,
337
- attention_mask: Optional[jnp.ndarray] = None,
338
- deterministic: bool = True,
339
- ) -> Tuple[jnp.ndarray]:
340
- """Input shape: Batch x Time x Channel"""
341
-
342
- if self.config.fuse_matmuls:
343
- attention_states = self.fused_proj(hidden_states)
344
- query_states, key_states, value_states = jnp.split(attention_states, 3, axis=-1)
345
-
346
- else:
347
- # get query proj
348
- query_states = self.q_proj(hidden_states)
349
-
350
- key_states = self.k_proj(hidden_states)
351
- value_states = self.v_proj(hidden_states)
352
-
353
- query_states = self._split_heads(query_states)
354
- key_states = self._split_heads(key_states)
355
- value_states = self._split_heads(value_states)
356
-
357
- if attention_mask is not None:
358
- attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
359
-
360
- # Convert the boolean attention mask to an attention bias.
361
- if attention_mask is not None:
362
- # attention mask in the form of attention bias
363
- attention_bias = lax.select(
364
- attention_mask > 0,
365
- jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
366
- jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
367
- )
368
- else:
369
- attention_bias = None
370
-
371
- dropout_rng = None
372
- if not deterministic and self.dropout > 0.0:
373
- dropout_rng = self.make_rng("dropout")
374
-
375
- attn_weights = dot_product_attention_weights(
376
- query_states,
377
- key_states,
378
- bias=attention_bias,
379
- dropout_rng=dropout_rng,
380
- dropout_rate=self.dropout,
381
- broadcast_dropout=True,
382
- deterministic=deterministic,
383
- dtype=self.dtype,
384
- precision=None,
385
- )
386
-
387
- attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
388
- attn_output = self._merge_heads(attn_output)
389
- attn_output = self.out_proj(attn_output)
390
-
391
- return attn_output, attn_weights
392
-
393
-
394
- class FlaxWav2Vec2FeedForward(nn.Module):
395
- config: Wav2Vec2Config
396
- dtype: jnp.dtype = jnp.float32
397
-
398
- def setup(self):
399
- self.intermediate_dropout = nn.Dropout(rate=self.config.activation_dropout)
400
-
401
- self.intermediate_dense = nn.Dense(
402
- self.config.intermediate_size,
403
- kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
404
- dtype=self.dtype,
405
- )
406
- if isinstance(self.config.hidden_act, str):
407
- self.intermediate_act_fn = ACT2FN[self.config.hidden_act]
408
- else:
409
- self.intermediate_act_fn = self.config.hidden_act
410
-
411
- self.output_dense = nn.Dense(
412
- self.config.hidden_size,
413
- kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
414
- dtype=self.dtype,
415
- )
416
- self.output_dropout = nn.Dropout(rate=self.config.hidden_dropout)
417
-
418
- def __call__(self, hidden_states, deterministic=True):
419
- hidden_states = self.intermediate_dense(hidden_states)
420
- hidden_states = self.intermediate_act_fn(hidden_states)
421
- hidden_states = self.intermediate_dropout(hidden_states, deterministic=deterministic)
422
-
423
- hidden_states = self.output_dense(hidden_states)
424
- hidden_states = self.output_dropout(hidden_states, deterministic=deterministic)
425
- return hidden_states
426
-
427
-
428
- class FlaxWav2Vec2EncoderLayerStableLayerNorm(nn.Module):
429
- config: Wav2Vec2Config
430
- dtype: jnp.dtype = jnp.float32
431
-
432
- def setup(self):
433
- self.attention = FlaxWav2Vec2Attention(
434
- config=self.config,
435
- embed_dim=self.config.hidden_size,
436
- num_heads=self.config.num_attention_heads,
437
- dropout=self.config.attention_dropout,
438
- dtype=self.dtype,
439
- )
440
- self.dropout = nn.Dropout(rate=self.config.hidden_dropout)
441
- self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
442
- self.feed_forward = FlaxWav2Vec2FeedForward(self.config, dtype=self.dtype)
443
- self.final_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
444
-
445
- def __call__(self, hidden_states, attention_mask=None, deterministic=True, output_attentions=False):
446
- if self.config.use_scan:
447
- hidden_states = hidden_states[0]
448
- attn_residual = hidden_states
449
- hidden_states = self.layer_norm(hidden_states)
450
- hidden_states, attn_weights = self.attention(
451
- hidden_states, attention_mask=attention_mask, deterministic=deterministic
452
- )
453
- hidden_states = self.dropout(hidden_states, deterministic=deterministic)
454
- hidden_states = attn_residual + hidden_states
455
- hidden_states = hidden_states + self.feed_forward(
456
- self.final_layer_norm(hidden_states), deterministic=deterministic
457
- )
458
-
459
- outputs = (hidden_states,)
460
-
461
- if output_attentions:
462
- outputs += (attn_weights,)
463
-
464
- if self.config.use_scan:
465
- outputs = (outputs, None)
466
-
467
- return outputs
468
-
469
-
470
- class FlaxWav2Vec2EncoderLayerStableLayerNormCollection(nn.Module):
471
- config: Wav2Vec2Config
472
- dtype: jnp.dtype = jnp.float32
473
-
474
- @nn.compact
475
- def __call__(
476
- self,
477
- hidden_states,
478
- attention_mask=None,
479
- deterministic: bool = True,
480
- output_attentions: bool = False,
481
- output_hidden_states: bool = False,
482
- return_dict: bool = True,
483
- ):
484
- all_attentions = () if output_attentions else None
485
- all_hidden_states = () if output_hidden_states else None
486
-
487
- num_layers = self.config.num_hidden_layers
488
- BlockEncoderLayer = (
489
- remat(
490
- FlaxWav2Vec2EncoderLayerStableLayerNorm,
491
- static_argnums=(2, 3),
492
- prevent_cse=not self.config.use_scan,
493
- )
494
- if self.config.gradient_checkpointing
495
- else FlaxWav2Vec2EncoderLayerStableLayerNorm
496
- )
497
-
498
- if self.config.use_scan:
499
- # since all decoder layers are the same, we use nn.scan directly
500
- assert not output_attentions, "cannot use `scan` with `output_attentions` set to `True`"
501
- assert not output_hidden_states, "cannot use `scan` with `output_hidden_states` set to `True`"
502
- hidden_states = (hidden_states,)
503
-
504
- hidden_states, _ = scan_with_axes(
505
- BlockEncoderLayer,
506
- variable_axes={"params": 0, "cache": 0},
507
- split_rngs={"params": True, "dropout": True},
508
- in_axes=(nn.broadcast, nn.broadcast, nn.broadcast),
509
- length=num_layers,
510
- )(self.config, dtype=self.dtype, name="FlaxWav2Vec2EncoderLayers",)(
511
- hidden_states, attention_mask, deterministic, output_attentions
512
- )
513
- hidden_states = hidden_states[0]
514
-
515
- else:
516
- for layer in range(num_layers):
517
- if output_hidden_states:
518
- all_hidden_states += (hidden_states,)
519
-
520
- layer_outputs = BlockEncoderLayer(
521
- self.config,
522
- dtype=self.dtype,
523
- name=str(layer),
524
- )(hidden_states, attention_mask, deterministic, output_attentions)
525
-
526
- hidden_states = layer_outputs[0]
527
-
528
- if output_attentions:
529
- all_attentions += (layer_outputs[1],)
530
-
531
- if output_hidden_states:
532
- all_hidden_states += (hidden_states,)
533
-
534
- outputs = (hidden_states, all_hidden_states, all_attentions)
535
-
536
- if not return_dict:
537
- return tuple(v for v in outputs if v is not None)
538
-
539
- return FlaxBaseModelOutput(
540
- last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
541
- )
542
-
543
-
544
- class FlaxWav2Vec2StableLayerNormEncoder(nn.Module):
545
- config: Wav2Vec2Config
546
- dtype: jnp.dtype = jnp.float32
547
-
548
- def setup(self):
549
- self.pos_conv_embed = FlaxWav2Vec2PositionalConvEmbedding(self.config, dtype=self.dtype)
550
- self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
551
- self.dropout = nn.Dropout(rate=self.config.hidden_dropout)
552
- self.layers = FlaxWav2Vec2EncoderLayerStableLayerNormCollection(self.config, dtype=self.dtype)
553
-
554
- def __call__(
555
- self,
556
- hidden_states,
557
- attention_mask=None,
558
- deterministic=True,
559
- output_attentions=False,
560
- output_hidden_states=False,
561
- return_dict=True,
562
- ):
563
-
564
- if attention_mask is not None:
565
- # make sure padded tokens are not attended to
566
- hidden_states = jnp.where(
567
- jnp.broadcast_to(attention_mask[:, :, None], hidden_states.shape), hidden_states, 0
568
- )
569
-
570
- position_embeddings = self.pos_conv_embed(hidden_states)
571
-
572
- hidden_states = hidden_states + position_embeddings
573
- hidden_states = self.dropout(hidden_states, deterministic=deterministic)
574
-
575
- outputs = self.layers(
576
- hidden_states,
577
- attention_mask,
578
- output_attentions=output_attentions,
579
- output_hidden_states=output_hidden_states,
580
- return_dict=return_dict,
581
- )
582
-
583
- last_hidden_state = self.layer_norm(outputs[0])
584
-
585
- # update the last element in `hidden_states` after applying `layernorm` above
586
- hidden_states = None
587
- if output_hidden_states:
588
- hidden_states = outputs[1]
589
- hidden_states = hidden_states[:-1] + (last_hidden_state,)
590
-
591
- if not return_dict:
592
- outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
593
- return tuple(v for v in outputs if v is not None)
594
-
595
- return FlaxBaseModelOutput(
596
- last_hidden_state=last_hidden_state, hidden_states=hidden_states, attentions=outputs.attentions
597
- )
598
-
599
-
600
- class FlaxWav2Vec2Adapter(nn.Module):
601
- config: Wav2Vec2Config
602
- dtype: jnp.dtype = jnp.float32
603
-
604
- def setup(self):
605
- # hidden_states require down-projection if feature dims don't match
606
- if self.config.output_hidden_size != self.config.hidden_size:
607
- self.proj = nn.Dense(
608
- self.config.output_hidden_size,
609
- kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
610
- dtype=self.dtype,
611
- )
612
- self.proj_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
613
- else:
614
- self.proj = self.proj_layer_norm = None
615
-
616
- self.layers = FlaxWav2Vec2AdapterLayersCollection(self.config, dtype=self.dtype)
617
-
618
- def __call__(self, hidden_states, deterministic=True):
619
- # down-project hidden_states if required
620
- if self.proj is not None and self.proj_layer_norm is not None:
621
- hidden_states = self.proj(hidden_states)
622
- hidden_states = self.proj_layer_norm(hidden_states)
623
-
624
- hidden_states = self.layers(hidden_states)
625
-
626
- return hidden_states
627
-
628
-
629
- class FlaxWav2Vec2AdapterLayer(nn.Module):
630
- config: Wav2Vec2Config
631
- dtype: jnp.dtype = jnp.float32
632
-
633
- def setup(self):
634
- self.conv = nn.Conv(
635
- features=2 * self.config.output_hidden_size,
636
- kernel_size=(self.config.adapter_kernel_size,),
637
- strides=(self.config.adapter_stride,),
638
- padding=((1, 1),),
639
- kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
640
- dtype=self.dtype,
641
- )
642
-
643
- def __call__(self, hidden_states):
644
- hidden_states = self.conv(hidden_states)
645
- hidden_states = nn.glu(hidden_states, axis=2)
646
-
647
- return hidden_states
648
-
649
-
650
- class FlaxWav2Vec2AdapterLayersCollection(nn.Module):
651
- config: Wav2Vec2Config
652
- dtype: jnp.dtype = jnp.float32
653
-
654
- def setup(self):
655
- BlockAdapterLayer = remat(FlaxWav2Vec2AdapterLayer) if self.config.gradient_checkpointing else FlaxWav2Vec2AdapterLayer
656
- self.layers = [
657
- BlockAdapterLayer(self.config, name=str(i), dtype=self.dtype)
658
- for i in range(self.config.num_adapter_layers)
659
- ]
660
-
661
- def __call__(self, hidden_states):
662
- for conv_layer in self.layers:
663
- hidden_states = conv_layer(hidden_states)
664
-
665
- return hidden_states
666
-
667
-
668
- class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
669
- """
670
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
671
- models.
672
- """
673
-
674
- config_class = Wav2Vec2Config
675
- base_model_prefix: str = "wav2vec2"
676
- main_input_name = "input_values"
677
- module_class: nn.Module = None
678
-
679
- def __init__(
680
- self,
681
- config: Wav2Vec2Config,
682
- input_shape: Tuple = (1, 1024),
683
- seed: int = 0,
684
- dtype: jnp.dtype = jnp.float32,
685
- _do_init: bool = True,
686
- **kwargs,
687
- ):
688
- module = self.module_class(config=config, dtype=dtype, **kwargs)
689
- super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
690
-
691
- def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
692
- # init input tensors
693
- input_values = jnp.zeros(input_shape, dtype="i4")
694
- attention_mask = jnp.ones_like(input_values)
695
- params_rng, dropout_rng = jax.random.split(rng, 2)
696
- rngs = {"params": params_rng, "dropout": dropout_rng}
697
-
698
- return self.module.init(rngs, input_values, attention_mask, return_dict=False)["params"]
699
-
700
- def __call__(
701
- self,
702
- input_values,
703
- attention_mask=None,
704
- mask_time_indices=None,
705
- extract_features=None,
706
- params: dict = None,
707
- dropout_rng: jax.random.PRNGKey = None,
708
- train: bool = False,
709
- output_attentions: Optional[bool] = None,
710
- output_hidden_states: Optional[bool] = None,
711
- output_features: Optional[bool] = None,
712
- freeze_feature_encoder: bool = False,
713
- return_dict: Optional[bool] = None,
714
- ):
715
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
716
- output_hidden_states = (
717
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
718
- )
719
- return_dict = return_dict if return_dict is not None else self.config.return_dict
720
-
721
- if attention_mask is None:
722
- batch_size, sequence_length = input_values.shape
723
- attention_mask = jnp.ones((batch_size, sequence_length))
724
-
725
- if extract_features is not None:
726
- extract_features = jnp.array(extract_features, dtype="f4")
727
-
728
- # Handle any PRNG if needed
729
- rngs = {}
730
- if dropout_rng is not None:
731
- rngs["dropout"] = dropout_rng
732
-
733
- inputs = {"params": params or self.params}
734
-
735
- return self.module.apply(
736
- inputs,
737
- jnp.array(input_values, dtype="f4"),
738
- jnp.array(attention_mask, dtype="i4"),
739
- mask_time_indices,
740
- extract_features,
741
- not train,
742
- output_attentions,
743
- output_hidden_states,
744
- output_features,
745
- freeze_feature_encoder,
746
- return_dict,
747
- rngs=rngs,
748
- )
749
-
750
- def _get_feat_extract_output_lengths(
751
- self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
752
- ):
753
- return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter)
754
-
755
- def _get_feature_vector_attention_mask(
756
- self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
757
- ):
758
- return self.module._get_feature_vector_attention_mask(feature_vector_length, attention_mask, add_adapter=add_adapter)
759
-
760
-
761
- class FlaxWav2Vec2Module(nn.Module):
762
- config: Wav2Vec2Config
763
- dtype: jnp.dtype = jnp.float32
764
-
765
- def setup(self):
766
- self.feature_extractor = FlaxWav2Vec2FeatureEncoder(self.config, dtype=self.dtype)
767
- self.feature_projection = FlaxWav2Vec2FeatureProjection(self.config, dtype=self.dtype)
768
- self.masked_spec_embed = self.param(
769
- "masked_spec_embed", jax.nn.initializers.uniform(), (self.config.hidden_size,)
770
- )
771
-
772
- if self.config.do_stable_layer_norm:
773
- self.encoder = FlaxWav2Vec2StableLayerNormEncoder(self.config, dtype=self.dtype)
774
- else:
775
- raise NotImplementedError("``config.do_stable_layer_norm is False`` is currently not supported.")
776
-
777
- self.adapter = FlaxWav2Vec2Adapter(self.config, dtype=self.dtype) if self.config.add_adapter else None
778
-
779
- def __call__(
780
- self,
781
- input_values,
782
- attention_mask=None,
783
- mask_time_indices=None,
784
- extract_features=None,
785
- deterministic=True,
786
- output_attentions=None,
787
- output_hidden_states=None,
788
- output_features=False,
789
- freeze_feature_encoder=False,
790
- return_dict=None,
791
- ):
792
-
793
- # forward pass through the feature extractor if features not specified
794
- if extract_features is None:
795
- extract_features = self.feature_extractor(input_values, freeze_feature_encoder=freeze_feature_encoder)
796
-
797
- if output_features:
798
- return extract_features
799
-
800
- # make sure that no loss is computed on padded inputs
801
- if attention_mask is not None:
802
- # compute reduced attention_mask corresponding to feature vectors
803
- attention_mask = self._get_feature_vector_attention_mask(
804
- extract_features.shape[1], attention_mask, add_adapter=False
805
- )
806
-
807
- hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic)
808
- if mask_time_indices is not None: # apply SpecAugment along time axis with given indices
809
- hidden_states = jnp.where(
810
- jnp.broadcast_to(mask_time_indices[:, :, None], hidden_states.shape),
811
- jnp.broadcast_to(self.masked_spec_embed[None, None, :], hidden_states.shape),
812
- hidden_states,
813
- )
814
-
815
- encoder_outputs = self.encoder(
816
- hidden_states,
817
- attention_mask=attention_mask,
818
- deterministic=deterministic,
819
- output_attentions=output_attentions,
820
- output_hidden_states=output_hidden_states,
821
- return_dict=return_dict,
822
- )
823
-
824
- hidden_states = encoder_outputs[0]
825
-
826
- if self.adapter is not None:
827
- hidden_states = self.adapter(hidden_states)
828
-
829
- if not return_dict:
830
- return (hidden_states, extract_features) + encoder_outputs[1:]
831
-
832
- return FlaxWav2Vec2BaseModelOutput(
833
- last_hidden_state=hidden_states,
834
- extract_features=extract_features,
835
- hidden_states=encoder_outputs.hidden_states,
836
- attentions=encoder_outputs.attentions,
837
- )
838
-
839
- def _get_feat_extract_output_lengths(
840
- self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
841
- ):
842
- """
843
- Computes the output length of the convolutional layers
844
- """
845
-
846
- add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
847
-
848
- def _conv_out_length(input_length, kernel_size, stride):
849
- # 1D convolutional layer output length formula taken
850
- # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
851
- return (input_length - kernel_size) // stride + 1
852
-
853
- for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
854
- input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
855
-
856
- if add_adapter:
857
- for _ in range(self.config.num_adapter_layers):
858
- input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
859
-
860
- return input_lengths
861
-
862
- def _get_feature_vector_attention_mask(
863
- self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
864
- ):
865
- # Effectively attention_mask.sum(-1), but not inplace to be able to run
866
- # on inference mode.
867
- non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1]
868
-
869
- output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
870
-
871
- batch_size = attention_mask.shape[0]
872
-
873
- attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
874
- # these two operations makes sure that all values
875
- # before the output lengths indices are attended to
876
- attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1)
877
- attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
878
- return attention_mask
879
-
880
-
881
- class FlaxWav2Vec2Model(FlaxWav2Vec2PreTrainedModel):
882
- module_class = FlaxWav2Vec2Module
883
-
884
-
885
- class FlaxWav2Vec2ForCTCModule(nn.Module):
886
- config: Wav2Vec2Config
887
- dtype: jnp.dtype = jnp.float32
888
-
889
- def setup(self):
890
- self.wav2vec2 = FlaxWav2Vec2Module(self.config, dtype=self.dtype)
891
- self.dropout = nn.Dropout(rate=self.config.final_dropout)
892
- self.lm_head = nn.Dense(
893
- self.config.vocab_size,
894
- kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
895
- dtype=self.dtype,
896
- )
897
-
898
- def __call__(
899
- self,
900
- input_values,
901
- attention_mask=None,
902
- mask_time_indices=None,
903
- extract_features=None,
904
- deterministic=True,
905
- output_attentions=None,
906
- output_hidden_states=None,
907
- output_features=False,
908
- freeze_feature_encoder=False,
909
- return_dict=None,
910
- ):
911
- outputs = self.wav2vec2(
912
- input_values,
913
- attention_mask=attention_mask,
914
- mask_time_indices=mask_time_indices,
915
- deterministic=deterministic,
916
- output_attentions=output_attentions,
917
- output_hidden_states=output_hidden_states,
918
- freeze_feature_encoder=freeze_feature_encoder,
919
- return_dict=return_dict,
920
- )
921
-
922
- hidden_states = outputs[0]
923
- hidden_states = self.dropout(hidden_states, deterministic=deterministic)
924
-
925
- logits = self.lm_head(hidden_states)
926
-
927
- if not return_dict:
928
- return (logits,) + outputs[2:]
929
-
930
- return FlaxCausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
931
-
932
- def _get_feat_extract_output_lengths(
933
- self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
934
- ):
935
- """
936
- Computes the output length of the convolutional layers
937
- """
938
-
939
- add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
940
-
941
- def _conv_out_length(input_length, kernel_size, stride):
942
- # 1D convolutional layer output length formula taken
943
- # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
944
- return (input_length - kernel_size) // stride + 1
945
-
946
- for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
947
- input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
948
-
949
- if add_adapter:
950
- for _ in range(self.config.num_adapter_layers):
951
- input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
952
-
953
- return input_lengths
954
-
955
- def _get_feature_vector_attention_mask(
956
- self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
957
- ):
958
- # Effectively attention_mask.sum(-1), but not inplace to be able to run
959
- # on inference mode.
960
- non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1]
961
-
962
- output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
963
-
964
- batch_size = attention_mask.shape[0]
965
-
966
- attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
967
- # these two operations makes sure that all values
968
- # before the output lengths indices are attended to
969
- attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1)
970
- attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
971
- return attention_mask
972
-
973
-
974
- class FlaxWav2Vec2ForCTC(FlaxWav2Vec2PreTrainedModel):
975
- module_class = FlaxWav2Vec2ForCTCModule