Commit
•
24717ce
1
Parent(s):
6e767da
2gs2wia3: saving weights and logs of step 20k
Browse files- .gitattributes +1 -1
- flax_model.msgpack +1 -1
- models/__init__.py +0 -6
- models/__pycache__/__init__.cpython-38.pyc +0 -0
- models/__pycache__/configuration_bart.cpython-38.pyc +0 -0
- models/__pycache__/configuration_speech_encoder_decoder.cpython-38.pyc +0 -0
- models/__pycache__/configuration_wav2vec2.cpython-38.pyc +0 -0
- models/__pycache__/modeling_flax_bart.cpython-38.pyc +0 -0
- models/__pycache__/modeling_flax_speech_encoder_decoder.cpython-38.pyc +0 -0
- models/__pycache__/modeling_flax_wav2vec2.cpython-38.pyc +0 -0
- models/configuration_bart.py +0 -183
- models/configuration_speech_encoder_decoder.py +0 -121
- models/configuration_wav2vec2.py +0 -344
- models/modeling_flax_bart.py +0 -816
- models/modeling_flax_speech_encoder_decoder.py +0 -1245
- models/modeling_flax_wav2vec2.py +0 -975
.gitattributes
CHANGED
@@ -30,4 +30,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
30 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
31 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
32 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.wandb filter=lfs diff=lfs merge=lfs -text
|
|
|
30 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
31 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
32 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.wandb filter=lfs diff=lfs merge=lfs -text
|
flax_model.msgpack
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 2353616717
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e272f049f32168ca65769e2732f2f3b5ad190de2e141847ce077760ba2a76314
|
3 |
size 2353616717
|
models/__init__.py
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
from models.configuration_bart import BartConfig
|
2 |
-
from models.configuration_wav2vec2 import Wav2Vec2Config
|
3 |
-
from models.configuration_speech_encoder_decoder import SpeechEncoderDecoderConfig
|
4 |
-
from models.modeling_flax_wav2vec2 import FlaxWav2Vec2Model, FlaxWav2Vec2Module, FlaxWav2Vec2ForCTC, FlaxWav2Vec2ForCTCModule
|
5 |
-
from models.modeling_flax_bart import FlaxBartForCausalLM, FlaxBartForCausalLMModule
|
6 |
-
from models.modeling_flax_speech_encoder_decoder import FlaxSpeechEncoderDecoderModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/__pycache__/__init__.cpython-38.pyc
DELETED
Binary file (744 Bytes)
|
|
models/__pycache__/configuration_bart.cpython-38.pyc
DELETED
Binary file (7.04 kB)
|
|
models/__pycache__/configuration_speech_encoder_decoder.cpython-38.pyc
DELETED
Binary file (4.62 kB)
|
|
models/__pycache__/configuration_wav2vec2.cpython-38.pyc
DELETED
Binary file (16.8 kB)
|
|
models/__pycache__/modeling_flax_bart.cpython-38.pyc
DELETED
Binary file (21.1 kB)
|
|
models/__pycache__/modeling_flax_speech_encoder_decoder.cpython-38.pyc
DELETED
Binary file (39.4 kB)
|
|
models/__pycache__/modeling_flax_wav2vec2.cpython-38.pyc
DELETED
Binary file (30.6 kB)
|
|
models/configuration_bart.py
DELETED
@@ -1,183 +0,0 @@
|
|
1 |
-
# coding=utf-8
|
2 |
-
# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
|
3 |
-
#
|
4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
# you may not use this file except in compliance with the License.
|
6 |
-
# You may obtain a copy of the License at
|
7 |
-
#
|
8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
#
|
10 |
-
# Unless required by applicable law or agreed to in writing, software
|
11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
# See the License for the specific language governing permissions and
|
14 |
-
# limitations under the License.
|
15 |
-
""" BART model configuration"""
|
16 |
-
import warnings
|
17 |
-
|
18 |
-
from transformers.configuration_utils import PretrainedConfig
|
19 |
-
from transformers.utils import logging
|
20 |
-
|
21 |
-
|
22 |
-
logger = logging.get_logger(__name__)
|
23 |
-
|
24 |
-
BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
25 |
-
"facebook/bart-large": "https://huggingface.co/facebook/bart-large/resolve/main/config.json",
|
26 |
-
# See all BART models at https://huggingface.co/models?filter=bart
|
27 |
-
}
|
28 |
-
|
29 |
-
|
30 |
-
class BartConfig(PretrainedConfig):
|
31 |
-
r"""
|
32 |
-
This is the configuration class to store the configuration of a [`BartModel`]. It is used to instantiate a BART
|
33 |
-
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
34 |
-
defaults will yield a similar configuration to that of the BART
|
35 |
-
[facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture.
|
36 |
-
|
37 |
-
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
38 |
-
documentation from [`PretrainedConfig`] for more information.
|
39 |
-
|
40 |
-
|
41 |
-
Args:
|
42 |
-
vocab_size (`int`, *optional*, defaults to 50265):
|
43 |
-
Vocabulary size of the BART model. Defines the number of different tokens that can be represented by the
|
44 |
-
`inputs_ids` passed when calling [`BartModel`] or [`TFBartModel`].
|
45 |
-
d_model (`int`, *optional*, defaults to 1024):
|
46 |
-
Dimensionality of the layers and the pooler layer.
|
47 |
-
encoder_layers (`int`, *optional*, defaults to 12):
|
48 |
-
Number of encoder layers.
|
49 |
-
decoder_layers (`int`, *optional*, defaults to 12):
|
50 |
-
Number of decoder layers.
|
51 |
-
encoder_attention_heads (`int`, *optional*, defaults to 16):
|
52 |
-
Number of attention heads for each attention layer in the Transformer encoder.
|
53 |
-
decoder_attention_heads (`int`, *optional*, defaults to 16):
|
54 |
-
Number of attention heads for each attention layer in the Transformer decoder.
|
55 |
-
decoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
56 |
-
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
57 |
-
encoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
58 |
-
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
59 |
-
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
|
60 |
-
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
61 |
-
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
62 |
-
dropout (`float`, *optional*, defaults to 0.1):
|
63 |
-
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
64 |
-
attention_dropout (`float`, *optional*, defaults to 0.0):
|
65 |
-
The dropout ratio for the attention probabilities.
|
66 |
-
activation_dropout (`float`, *optional*, defaults to 0.0):
|
67 |
-
The dropout ratio for activations inside the fully connected layer.
|
68 |
-
classifier_dropout (`float`, *optional*, defaults to 0.0):
|
69 |
-
The dropout ratio for classifier.
|
70 |
-
max_position_embeddings (`int`, *optional*, defaults to 1024):
|
71 |
-
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
72 |
-
just in case (e.g., 512 or 1024 or 2048).
|
73 |
-
init_std (`float`, *optional*, defaults to 0.02):
|
74 |
-
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
75 |
-
encoder_layerdrop: (`float`, *optional*, defaults to 0.0):
|
76 |
-
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
77 |
-
for more details.
|
78 |
-
decoder_layerdrop: (`float`, *optional*, defaults to 0.0):
|
79 |
-
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
80 |
-
for more details.
|
81 |
-
scale_embedding (`bool`, *optional*, defaults to `False`):
|
82 |
-
Scale embeddings by diving by sqrt(d_model).
|
83 |
-
use_cache (`bool`, *optional*, defaults to `True`):
|
84 |
-
Whether or not the model should return the last key/values attentions (not used by all models).
|
85 |
-
num_labels: (`int`, *optional*, defaults to 3):
|
86 |
-
The number of labels to use in [`BartForSequenceClassification`].
|
87 |
-
forced_eos_token_id (`int`, *optional*, defaults to 2):
|
88 |
-
The id of the token to force as the last generated token when `max_length` is reached. Usually set to
|
89 |
-
`eos_token_id`.
|
90 |
-
use_scan (`bool`, *optional*, defaults to `False`):
|
91 |
-
Whether or not to use nn.scan in the Flax Bart attention layers.
|
92 |
-
|
93 |
-
Example:
|
94 |
-
|
95 |
-
```python
|
96 |
-
>>> from transformers import BartModel, BartConfig
|
97 |
-
|
98 |
-
>>> # Initializing a BART facebook/bart-large style configuration
|
99 |
-
>>> configuration = BartConfig()
|
100 |
-
|
101 |
-
>>> # Initializing a model from the facebook/bart-large style configuration
|
102 |
-
>>> model = BartModel(configuration)
|
103 |
-
|
104 |
-
>>> # Accessing the model configuration
|
105 |
-
>>> configuration = model.config
|
106 |
-
```"""
|
107 |
-
model_type = "bart"
|
108 |
-
keys_to_ignore_at_inference = ["past_key_values"]
|
109 |
-
attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
|
110 |
-
|
111 |
-
def __init__(
|
112 |
-
self,
|
113 |
-
vocab_size=50265,
|
114 |
-
max_position_embeddings=1024,
|
115 |
-
encoder_layers=12,
|
116 |
-
encoder_ffn_dim=4096,
|
117 |
-
encoder_attention_heads=16,
|
118 |
-
decoder_layers=12,
|
119 |
-
decoder_ffn_dim=4096,
|
120 |
-
decoder_attention_heads=16,
|
121 |
-
encoder_layerdrop=0.0,
|
122 |
-
decoder_layerdrop=0.0,
|
123 |
-
activation_function="gelu",
|
124 |
-
d_model=1024,
|
125 |
-
dropout=0.1,
|
126 |
-
attention_dropout=0.0,
|
127 |
-
activation_dropout=0.0,
|
128 |
-
init_std=0.02,
|
129 |
-
classifier_dropout=0.0,
|
130 |
-
scale_embedding=False,
|
131 |
-
use_cache=True,
|
132 |
-
use_scan=False,
|
133 |
-
fuse_matmuls=False,
|
134 |
-
num_labels=3,
|
135 |
-
pad_token_id=1,
|
136 |
-
bos_token_id=0,
|
137 |
-
eos_token_id=2,
|
138 |
-
is_encoder_decoder=True,
|
139 |
-
decoder_start_token_id=2,
|
140 |
-
forced_eos_token_id=2,
|
141 |
-
**kwargs
|
142 |
-
):
|
143 |
-
self.vocab_size = vocab_size
|
144 |
-
self.max_position_embeddings = max_position_embeddings
|
145 |
-
self.d_model = d_model
|
146 |
-
self.encoder_ffn_dim = encoder_ffn_dim
|
147 |
-
self.encoder_layers = encoder_layers
|
148 |
-
self.encoder_attention_heads = encoder_attention_heads
|
149 |
-
self.decoder_ffn_dim = decoder_ffn_dim
|
150 |
-
self.decoder_layers = decoder_layers
|
151 |
-
self.decoder_attention_heads = decoder_attention_heads
|
152 |
-
self.dropout = dropout
|
153 |
-
self.attention_dropout = attention_dropout
|
154 |
-
self.activation_dropout = activation_dropout
|
155 |
-
self.activation_function = activation_function
|
156 |
-
self.init_std = init_std
|
157 |
-
self.encoder_layerdrop = encoder_layerdrop
|
158 |
-
self.decoder_layerdrop = decoder_layerdrop
|
159 |
-
self.classifier_dropout = classifier_dropout
|
160 |
-
self.use_cache = use_cache
|
161 |
-
self.use_scan = use_scan
|
162 |
-
self.fuse_matmuls = fuse_matmuls
|
163 |
-
self.num_hidden_layers = encoder_layers
|
164 |
-
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
165 |
-
|
166 |
-
super().__init__(
|
167 |
-
num_labels=num_labels,
|
168 |
-
pad_token_id=pad_token_id,
|
169 |
-
bos_token_id=bos_token_id,
|
170 |
-
eos_token_id=eos_token_id,
|
171 |
-
is_encoder_decoder=is_encoder_decoder,
|
172 |
-
decoder_start_token_id=decoder_start_token_id,
|
173 |
-
forced_eos_token_id=forced_eos_token_id,
|
174 |
-
**kwargs,
|
175 |
-
)
|
176 |
-
|
177 |
-
# ensure backward compatibility for BART CNN models
|
178 |
-
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
|
179 |
-
self.forced_bos_token_id = self.bos_token_id
|
180 |
-
warnings.warn(
|
181 |
-
f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
|
182 |
-
"The config can simply be saved and uploaded again to be fixed."
|
183 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/configuration_speech_encoder_decoder.py
DELETED
@@ -1,121 +0,0 @@
|
|
1 |
-
# coding=utf-8
|
2 |
-
# Copyright 2021 The HuggingFace Inc. team.
|
3 |
-
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
-
#
|
5 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
-
# you may not use this file except in compliance with the License.
|
7 |
-
# You may obtain a copy of the License at
|
8 |
-
#
|
9 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
-
#
|
11 |
-
# Unless required by applicable law or agreed to in writing, software
|
12 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
-
# See the License for the specific language governing permissions and
|
15 |
-
# limitations under the License.
|
16 |
-
|
17 |
-
import copy
|
18 |
-
|
19 |
-
from transformers.configuration_utils import PretrainedConfig
|
20 |
-
from transformers.utils import logging
|
21 |
-
from models.configuration_wav2vec2 import Wav2Vec2Config
|
22 |
-
from models.configuration_bart import BartConfig
|
23 |
-
from transformers import AutoConfig
|
24 |
-
|
25 |
-
|
26 |
-
logger = logging.get_logger(__name__)
|
27 |
-
|
28 |
-
|
29 |
-
class SpeechEncoderDecoderConfig(PretrainedConfig):
|
30 |
-
r"""
|
31 |
-
[`SpeechEncoderDecoderConfig`] is the configuration class to store the configuration of a
|
32 |
-
[`SpeechEncoderDecoderModel`]. It is used to instantiate an Encoder Decoder model according to the specified
|
33 |
-
arguments, defining the encoder and decoder configs.
|
34 |
-
|
35 |
-
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
36 |
-
documentation from [`PretrainedConfig`] for more information.
|
37 |
-
|
38 |
-
Args:
|
39 |
-
kwargs (*optional*):
|
40 |
-
Dictionary of keyword arguments. Notably:
|
41 |
-
|
42 |
-
- **encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
|
43 |
-
the encoder config.
|
44 |
-
- **decoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
|
45 |
-
the decoder config.
|
46 |
-
|
47 |
-
Examples:
|
48 |
-
|
49 |
-
```python
|
50 |
-
>>> from transformers import BertConfig, Wav2Vec2Config, SpeechEncoderDecoderConfig, SpeechEncoderDecoderModel
|
51 |
-
|
52 |
-
>>> # Initializing a Wav2Vec2 & BERT style configuration
|
53 |
-
>>> config_encoder = Wav2Vec2Config()
|
54 |
-
>>> config_decoder = BertConfig()
|
55 |
-
|
56 |
-
>>> config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)
|
57 |
-
|
58 |
-
>>> # Initializing a Wav2Vec2Bert model from a Wav2Vec2 & bert-base-uncased style configurations
|
59 |
-
>>> model = SpeechEncoderDecoderModel(config=config)
|
60 |
-
|
61 |
-
>>> # Accessing the model configuration
|
62 |
-
>>> config_encoder = model.config.encoder
|
63 |
-
>>> config_decoder = model.config.decoder
|
64 |
-
>>> # set decoder config to causal lm
|
65 |
-
>>> config_decoder.is_decoder = True
|
66 |
-
>>> config_decoder.add_cross_attention = True
|
67 |
-
|
68 |
-
>>> # Saving the model, including its configuration
|
69 |
-
>>> model.save_pretrained("my-model")
|
70 |
-
|
71 |
-
>>> # loading model and config from pretrained folder
|
72 |
-
>>> encoder_decoder_config = SpeechEncoderDecoderConfig.from_pretrained("my-model")
|
73 |
-
>>> model = SpeechEncoderDecoderModel.from_pretrained("my-model", config=encoder_decoder_config)
|
74 |
-
```"""
|
75 |
-
model_type = "speech-encoder-decoder"
|
76 |
-
is_composition = True
|
77 |
-
|
78 |
-
def __init__(self, **kwargs):
|
79 |
-
super().__init__(**kwargs)
|
80 |
-
if "encoder" not in kwargs or "decoder" not in kwargs:
|
81 |
-
raise ValueError(
|
82 |
-
f"A configuraton of type {self.model_type} cannot be instantiated because not both `encoder` and `decoder` sub-configurations are passed, but only {kwargs}"
|
83 |
-
)
|
84 |
-
|
85 |
-
encoder_config = kwargs.pop("encoder")
|
86 |
-
decoder_config = kwargs.pop("decoder")
|
87 |
-
|
88 |
-
# TODO: Load configs from AutoConfig (as done in Transformers 🤗)
|
89 |
-
self.encoder = Wav2Vec2Config(**encoder_config)
|
90 |
-
self.decoder = BartConfig(**decoder_config)
|
91 |
-
self.is_encoder_decoder = True
|
92 |
-
|
93 |
-
@classmethod
|
94 |
-
def from_encoder_decoder_configs(
|
95 |
-
cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
|
96 |
-
) -> PretrainedConfig:
|
97 |
-
r"""
|
98 |
-
Instantiate a [`SpeechEncoderDecoderConfig`] (or a derived class) from a pre-trained encoder model
|
99 |
-
configuration and decoder model configuration.
|
100 |
-
|
101 |
-
Returns:
|
102 |
-
[`SpeechEncoderDecoderConfig`]: An instance of a configuration object
|
103 |
-
"""
|
104 |
-
logger.info("Setting `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
|
105 |
-
decoder_config.is_decoder = True
|
106 |
-
decoder_config.add_cross_attention = True
|
107 |
-
|
108 |
-
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)
|
109 |
-
|
110 |
-
def to_dict(self):
|
111 |
-
"""
|
112 |
-
Serializes this instance to a Python dictionary. Override the default *to_dict()* from *PretrainedConfig*.
|
113 |
-
|
114 |
-
Returns:
|
115 |
-
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
116 |
-
"""
|
117 |
-
output = copy.deepcopy(self.__dict__)
|
118 |
-
output["encoder"] = self.encoder.to_dict()
|
119 |
-
output["decoder"] = self.decoder.to_dict()
|
120 |
-
output["model_type"] = self.__class__.model_type
|
121 |
-
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/configuration_wav2vec2.py
DELETED
@@ -1,344 +0,0 @@
|
|
1 |
-
# coding=utf-8
|
2 |
-
# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
|
3 |
-
#
|
4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
# you may not use this file except in compliance with the License.
|
6 |
-
# You may obtain a copy of the License at
|
7 |
-
#
|
8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
#
|
10 |
-
# Unless required by applicable law or agreed to in writing, software
|
11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
# See the License for the specific language governing permissions and
|
14 |
-
# limitations under the License.
|
15 |
-
""" Wav2Vec2 model configuration"""
|
16 |
-
|
17 |
-
import functools
|
18 |
-
import operator
|
19 |
-
|
20 |
-
from transformers.configuration_utils import PretrainedConfig
|
21 |
-
from transformers.utils import logging
|
22 |
-
|
23 |
-
|
24 |
-
logger = logging.get_logger(__name__)
|
25 |
-
|
26 |
-
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
27 |
-
"facebook/wav2vec2-base-960h": "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/config.json",
|
28 |
-
# See all Wav2Vec2 models at https://huggingface.co/models?filter=wav2vec2
|
29 |
-
}
|
30 |
-
|
31 |
-
|
32 |
-
class Wav2Vec2Config(PretrainedConfig):
|
33 |
-
r"""
|
34 |
-
This is the configuration class to store the configuration of a [`Wav2Vec2Model`]. It is used to instantiate an
|
35 |
-
Wav2Vec2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
36 |
-
with the defaults will yield a similar configuration to that of the Wav2Vec2
|
37 |
-
[facebook/wav2vec2-base-960h](https://huggingface.co/facebook/wav2vec2-base-960h) architecture.
|
38 |
-
|
39 |
-
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
40 |
-
documentation from [`PretrainedConfig`] for more information.
|
41 |
-
|
42 |
-
|
43 |
-
Args:
|
44 |
-
vocab_size (`int`, *optional*, defaults to 32):
|
45 |
-
Vocabulary size of the Wav2Vec2 model. Defines the number of different tokens that can be represented by
|
46 |
-
the `inputs_ids` passed when calling [`Wav2Vec2Model`] or [`TFWav2Vec2Model`]. Vocabulary size of the
|
47 |
-
model. Defines the different tokens that can be represented by the *inputs_ids* passed to the forward
|
48 |
-
method of [`Wav2Vec2Model`].
|
49 |
-
hidden_size (`int`, *optional*, defaults to 768):
|
50 |
-
Dimensionality of the encoder layers and the pooler layer.
|
51 |
-
num_hidden_layers (`int`, *optional*, defaults to 12):
|
52 |
-
Number of hidden layers in the Transformer encoder.
|
53 |
-
num_attention_heads (`int`, *optional*, defaults to 12):
|
54 |
-
Number of attention heads for each attention layer in the Transformer encoder.
|
55 |
-
intermediate_size (`int`, *optional*, defaults to 3072):
|
56 |
-
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
57 |
-
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
58 |
-
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
59 |
-
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
60 |
-
hidden_dropout (`float`, *optional*, defaults to 0.1):
|
61 |
-
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
62 |
-
attention_dropout (`float`, *optional*, defaults to 0.1):
|
63 |
-
The dropout ratio for the attention probabilities.
|
64 |
-
final_dropout (`float`, *optional*, defaults to 0.1):
|
65 |
-
The dropout probability for the final projection layer of [`Wav2Vec2ForCTC`].
|
66 |
-
initializer_range (`float`, *optional*, defaults to 0.02):
|
67 |
-
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
68 |
-
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
69 |
-
The epsilon used by the layer normalization layers.
|
70 |
-
feat_extract_norm (`str`, *optional*, defaults to `"group"`):
|
71 |
-
The norm to be applied to 1D convolutional layers in feature encoder. One of `"group"` for group
|
72 |
-
normalization of only the first 1D convolutional layer or `"layer"` for layer normalization of all 1D
|
73 |
-
convolutional layers.
|
74 |
-
feat_proj_dropout (`float`, *optional*, defaults to 0.0):
|
75 |
-
The dropout probability for output of the feature encoder.
|
76 |
-
feat_extract_activation (`str, `optional`, defaults to `"gelu"`):
|
77 |
-
The non-linear activation function (function or string) in the 1D convolutional layers of the feature
|
78 |
-
extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
|
79 |
-
feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
|
80 |
-
The dropout probabilitiy for quantized feature encoder states.
|
81 |
-
conv_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
|
82 |
-
A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
|
83 |
-
feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
|
84 |
-
conv_stride (`Tuple[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
|
85 |
-
A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
|
86 |
-
of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*.
|
87 |
-
conv_kernel (`Tuple[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
|
88 |
-
A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
|
89 |
-
length of *conv_kernel* defines the number of convolutional layers and has to match the length of
|
90 |
-
*conv_dim*.
|
91 |
-
conv_bias (`bool`, *optional*, defaults to `False`):
|
92 |
-
Whether the 1D convolutional layers have a bias.
|
93 |
-
num_conv_pos_embeddings (`int`, *optional*, defaults to 128):
|
94 |
-
Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional
|
95 |
-
embeddings layer.
|
96 |
-
num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
|
97 |
-
Number of groups of 1D convolutional positional embeddings layer.
|
98 |
-
do_stable_layer_norm (`bool`, *optional*, defaults to `False`):
|
99 |
-
Whether to apply *stable* layer norm architecture of the Transformer encoder. `do_stable_layer_norm is
|
100 |
-
True` corresponds to applying layer norm before the attention layer, whereas `do_stable_layer_norm is
|
101 |
-
False` corresponds to applying layer norm after the attention layer.
|
102 |
-
apply_spec_augment (`bool`, *optional*, defaults to `True`):
|
103 |
-
Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see
|
104 |
-
[SpecAugment: A Simple Data Augmentation Method for Automatic Speech
|
105 |
-
Recognition](https://arxiv.org/abs/1904.08779).
|
106 |
-
mask_time_prob (`float`, *optional*, defaults to 0.05):
|
107 |
-
Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking
|
108 |
-
procecure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If
|
109 |
-
reasoning from the propability of each feature vector to be chosen as the start of the vector span to be
|
110 |
-
masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the
|
111 |
-
actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`.
|
112 |
-
mask_time_length (`int`, *optional*, defaults to 10):
|
113 |
-
Length of vector span along the time axis.
|
114 |
-
mask_time_min_masks (`int`, *optional*, defaults to 2),:
|
115 |
-
The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step,
|
116 |
-
irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length <
|
117 |
-
mask_time_min_masks''
|
118 |
-
mask_feature_prob (`float`, *optional*, defaults to 0.0):
|
119 |
-
Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The
|
120 |
-
masking procecure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over
|
121 |
-
the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector
|
122 |
-
span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap
|
123 |
-
may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is
|
124 |
-
True`.
|
125 |
-
mask_feature_length (`int`, *optional*, defaults to 10):
|
126 |
-
Length of vector span along the feature axis.
|
127 |
-
mask_feature_min_masks (`int`, *optional*, defaults to 0),:
|
128 |
-
The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
|
129 |
-
step, irrespectively of `mask_feature_prob`. Only relevant if
|
130 |
-
''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks''
|
131 |
-
num_codevectors_per_group (`int`, *optional*, defaults to 320):
|
132 |
-
Number of entries in each quantization codebook (group).
|
133 |
-
num_codevector_groups (`int`, *optional*, defaults to 2):
|
134 |
-
Number of codevector groups for product codevector quantization.
|
135 |
-
contrastive_logits_temperature (`float`, *optional*, defaults to 0.1):
|
136 |
-
The temperature *kappa* in the contrastive loss.
|
137 |
-
feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
|
138 |
-
The dropout probabilitiy for the output of the feature encoder that's used by the quantizer.
|
139 |
-
num_negatives (`int`, *optional*, defaults to 100):
|
140 |
-
Number of negative samples for the contrastive loss.
|
141 |
-
codevector_dim (`int`, *optional*, defaults to 256):
|
142 |
-
Dimensionality of the quantized feature vectors.
|
143 |
-
proj_codevector_dim (`int`, *optional*, defaults to 256):
|
144 |
-
Dimensionality of the final projection of both the quantized and the transformer features.
|
145 |
-
diversity_loss_weight (`int`, *optional*, defaults to 0.1):
|
146 |
-
The weight of the codebook diversity loss component.
|
147 |
-
ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`):
|
148 |
-
Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
|
149 |
-
instance of [`Wav2Vec2ForCTC`].
|
150 |
-
ctc_zero_infinity (`bool`, *optional*, defaults to `False`):
|
151 |
-
Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
|
152 |
-
occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
|
153 |
-
of [`Wav2Vec2ForCTC`].
|
154 |
-
use_weighted_layer_sum (`bool`, *optional*, defaults to `False`):
|
155 |
-
Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
|
156 |
-
instance of [`Wav2Vec2ForSequenceClassification`].
|
157 |
-
classifier_proj_size (`int`, *optional*, defaults to 256):
|
158 |
-
Dimensionality of the projection before token mean-pooling for classification.
|
159 |
-
tdnn_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
|
160 |
-
A tuple of integers defining the number of output channels of each 1D convolutional layer in the *TDNN*
|
161 |
-
module of the *XVector* model. The length of *tdnn_dim* defines the number of *TDNN* layers.
|
162 |
-
tdnn_kernel (`Tuple[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
|
163 |
-
A tuple of integers defining the kernel size of each 1D convolutional layer in the *TDNN* module of the
|
164 |
-
*XVector* model. The length of *tdnn_kernel* has to match the length of *tdnn_dim*.
|
165 |
-
tdnn_dilation (`Tuple[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
|
166 |
-
A tuple of integers defining the dilation factor of each 1D convolutional layer in *TDNN* module of the
|
167 |
-
*XVector* model. The length of *tdnn_dilation* has to match the length of *tdnn_dim*.
|
168 |
-
xvector_output_dim (`int`, *optional*, defaults to 512):
|
169 |
-
Dimensionality of the *XVector* embedding vectors.
|
170 |
-
add_adapter (`bool`, *optional*, defaults to `False`):
|
171 |
-
Whether a convolutional network should be stacked on top of the Wav2Vec2 Encoder. Can be very useful for
|
172 |
-
warm-starting Wav2Vec2 for SpeechEncoderDecoder models.
|
173 |
-
adapter_kernel_size (`int`, *optional*, defaults to 3):
|
174 |
-
Kernel size of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
|
175 |
-
adapter_stride (`int`, *optional*, defaults to 2):
|
176 |
-
Stride of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
|
177 |
-
num_adapter_layers (`int`, *optional*, defaults to 3):
|
178 |
-
Number of convolutional layers that should be used in the adapter network. Only relevant if `add_adapter is
|
179 |
-
True`.
|
180 |
-
output_hidden_size (`int`, *optional*):
|
181 |
-
Dimensionality of the encoder output layer. If not defined, this defaults to *hidden-size*. Only relevant
|
182 |
-
if `add_adapter is True`.
|
183 |
-
use_scan (`bool`, *optional*, defaults to `False`):
|
184 |
-
Whether or not to use nn.scan in the Flax Wav2Vec2 transformer layers.
|
185 |
-
|
186 |
-
Example:
|
187 |
-
|
188 |
-
```python
|
189 |
-
>>> from transformers import Wav2Vec2Model, Wav2Vec2Config
|
190 |
-
|
191 |
-
>>> # Initializing a Wav2Vec2 facebook/wav2vec2-base-960h style configuration
|
192 |
-
>>> configuration = Wav2Vec2Config()
|
193 |
-
|
194 |
-
>>> # Initializing a model from the facebook/wav2vec2-base-960h style configuration
|
195 |
-
>>> model = Wav2Vec2Model(configuration)
|
196 |
-
|
197 |
-
>>> # Accessing the model configuration
|
198 |
-
>>> configuration = model.config
|
199 |
-
```"""
|
200 |
-
model_type = "wav2vec2"
|
201 |
-
|
202 |
-
def __init__(
|
203 |
-
self,
|
204 |
-
vocab_size=32,
|
205 |
-
hidden_size=768,
|
206 |
-
num_hidden_layers=12,
|
207 |
-
num_attention_heads=12,
|
208 |
-
intermediate_size=3072,
|
209 |
-
hidden_act="gelu",
|
210 |
-
hidden_dropout=0.1,
|
211 |
-
activation_dropout=0.1,
|
212 |
-
attention_dropout=0.1,
|
213 |
-
feat_proj_dropout=0.0,
|
214 |
-
feat_quantizer_dropout=0.0,
|
215 |
-
final_dropout=0.1,
|
216 |
-
layerdrop=0.1,
|
217 |
-
initializer_range=0.02,
|
218 |
-
layer_norm_eps=1e-5,
|
219 |
-
feat_extract_norm="group",
|
220 |
-
feat_extract_activation="gelu",
|
221 |
-
conv_dim=(512, 512, 512, 512, 512, 512, 512),
|
222 |
-
conv_stride=(5, 2, 2, 2, 2, 2, 2),
|
223 |
-
conv_kernel=(10, 3, 3, 3, 3, 2, 2),
|
224 |
-
conv_bias=False,
|
225 |
-
num_conv_pos_embeddings=128,
|
226 |
-
num_conv_pos_embedding_groups=16,
|
227 |
-
do_stable_layer_norm=False,
|
228 |
-
apply_spec_augment=True,
|
229 |
-
mask_time_prob=0.05,
|
230 |
-
mask_time_length=10,
|
231 |
-
mask_time_min_masks=2,
|
232 |
-
mask_feature_prob=0.0,
|
233 |
-
mask_feature_length=10,
|
234 |
-
mask_feature_min_masks=0,
|
235 |
-
num_codevectors_per_group=320,
|
236 |
-
num_codevector_groups=2,
|
237 |
-
contrastive_logits_temperature=0.1,
|
238 |
-
num_negatives=100,
|
239 |
-
codevector_dim=256,
|
240 |
-
proj_codevector_dim=256,
|
241 |
-
diversity_loss_weight=0.1,
|
242 |
-
ctc_loss_reduction="sum",
|
243 |
-
ctc_zero_infinity=False,
|
244 |
-
use_weighted_layer_sum=False,
|
245 |
-
classifier_proj_size=256,
|
246 |
-
tdnn_dim=(512, 512, 512, 512, 1500),
|
247 |
-
tdnn_kernel=(5, 3, 3, 1, 1),
|
248 |
-
tdnn_dilation=(1, 2, 3, 1, 1),
|
249 |
-
xvector_output_dim=512,
|
250 |
-
pad_token_id=0,
|
251 |
-
bos_token_id=1,
|
252 |
-
eos_token_id=2,
|
253 |
-
add_adapter=False,
|
254 |
-
adapter_kernel_size=3,
|
255 |
-
adapter_stride=2,
|
256 |
-
num_adapter_layers=3,
|
257 |
-
output_hidden_size=None,
|
258 |
-
use_scan=False,
|
259 |
-
fuse_matmuls=False,
|
260 |
-
**kwargs
|
261 |
-
):
|
262 |
-
super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
|
263 |
-
self.hidden_size = hidden_size
|
264 |
-
self.feat_extract_norm = feat_extract_norm
|
265 |
-
self.feat_extract_activation = feat_extract_activation
|
266 |
-
self.conv_dim = list(conv_dim)
|
267 |
-
self.conv_stride = list(conv_stride)
|
268 |
-
self.conv_kernel = list(conv_kernel)
|
269 |
-
self.conv_bias = conv_bias
|
270 |
-
self.num_conv_pos_embeddings = num_conv_pos_embeddings
|
271 |
-
self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
|
272 |
-
self.num_feat_extract_layers = len(self.conv_dim)
|
273 |
-
self.num_hidden_layers = num_hidden_layers
|
274 |
-
self.intermediate_size = intermediate_size
|
275 |
-
self.hidden_act = hidden_act
|
276 |
-
self.num_attention_heads = num_attention_heads
|
277 |
-
self.hidden_dropout = hidden_dropout
|
278 |
-
self.attention_dropout = attention_dropout
|
279 |
-
self.activation_dropout = activation_dropout
|
280 |
-
self.feat_proj_dropout = feat_proj_dropout
|
281 |
-
self.final_dropout = final_dropout
|
282 |
-
self.layerdrop = layerdrop
|
283 |
-
self.layer_norm_eps = layer_norm_eps
|
284 |
-
self.initializer_range = initializer_range
|
285 |
-
self.vocab_size = vocab_size
|
286 |
-
self.do_stable_layer_norm = do_stable_layer_norm
|
287 |
-
self.use_weighted_layer_sum = use_weighted_layer_sum
|
288 |
-
self.use_scan = use_scan
|
289 |
-
self.fuse_matmuls = fuse_matmuls
|
290 |
-
|
291 |
-
if (
|
292 |
-
(len(self.conv_stride) != self.num_feat_extract_layers)
|
293 |
-
or (len(self.conv_kernel) != self.num_feat_extract_layers)
|
294 |
-
or (len(self.conv_dim) != self.num_feat_extract_layers)
|
295 |
-
):
|
296 |
-
raise ValueError(
|
297 |
-
"Configuration for convolutional layers is incorrect. "
|
298 |
-
"It is required that `len(config.conv_dim)` == `len(config.conv_stride)` == `len(config.conv_kernel)`, "
|
299 |
-
f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride) "
|
300 |
-
f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`."
|
301 |
-
)
|
302 |
-
|
303 |
-
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
|
304 |
-
self.apply_spec_augment = apply_spec_augment
|
305 |
-
self.mask_time_prob = mask_time_prob
|
306 |
-
self.mask_time_length = mask_time_length
|
307 |
-
self.mask_time_min_masks = mask_time_min_masks
|
308 |
-
self.mask_feature_prob = mask_feature_prob
|
309 |
-
self.mask_feature_length = mask_feature_length
|
310 |
-
self.mask_feature_min_masks = mask_feature_min_masks
|
311 |
-
|
312 |
-
# parameters for pretraining with codevector quantized representations
|
313 |
-
self.num_codevectors_per_group = num_codevectors_per_group
|
314 |
-
self.num_codevector_groups = num_codevector_groups
|
315 |
-
self.contrastive_logits_temperature = contrastive_logits_temperature
|
316 |
-
self.feat_quantizer_dropout = feat_quantizer_dropout
|
317 |
-
self.num_negatives = num_negatives
|
318 |
-
self.codevector_dim = codevector_dim
|
319 |
-
self.proj_codevector_dim = proj_codevector_dim
|
320 |
-
self.diversity_loss_weight = diversity_loss_weight
|
321 |
-
|
322 |
-
# ctc loss
|
323 |
-
self.ctc_loss_reduction = ctc_loss_reduction
|
324 |
-
self.ctc_zero_infinity = ctc_zero_infinity
|
325 |
-
|
326 |
-
# adapter
|
327 |
-
self.add_adapter = add_adapter
|
328 |
-
self.adapter_kernel_size = adapter_kernel_size
|
329 |
-
self.adapter_stride = adapter_stride
|
330 |
-
self.num_adapter_layers = num_adapter_layers
|
331 |
-
self.output_hidden_size = output_hidden_size or hidden_size
|
332 |
-
|
333 |
-
# SequenceClassification-specific parameter. Feel free to ignore for other classes.
|
334 |
-
self.classifier_proj_size = classifier_proj_size
|
335 |
-
|
336 |
-
# XVector-specific parameters. Feel free to ignore for other classes.
|
337 |
-
self.tdnn_dim = list(tdnn_dim)
|
338 |
-
self.tdnn_kernel = list(tdnn_kernel)
|
339 |
-
self.tdnn_dilation = list(tdnn_dilation)
|
340 |
-
self.xvector_output_dim = xvector_output_dim
|
341 |
-
|
342 |
-
@property
|
343 |
-
def inputs_to_logits_ratio(self):
|
344 |
-
return functools.reduce(operator.mul, self.conv_stride, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/modeling_flax_bart.py
DELETED
@@ -1,816 +0,0 @@
|
|
1 |
-
# coding=utf-8
|
2 |
-
# Copyright 2021 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team. All rights reserved.
|
3 |
-
#
|
4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
# you may not use this file except in compliance with the License.
|
6 |
-
# You may obtain a copy of the License at
|
7 |
-
#
|
8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
#
|
10 |
-
# Unless required by applicable law or agreed to in writing, software
|
11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
# See the License for the specific language governing permissions and
|
14 |
-
# limitations under the License.
|
15 |
-
""" Flax Bart model."""
|
16 |
-
|
17 |
-
import math
|
18 |
-
import random
|
19 |
-
from functools import partial
|
20 |
-
from typing import Optional, Tuple
|
21 |
-
|
22 |
-
import numpy as np
|
23 |
-
|
24 |
-
import flax.linen as nn
|
25 |
-
import jax
|
26 |
-
import jax.numpy as jnp
|
27 |
-
from flax.core.frozen_dict import FrozenDict, unfreeze
|
28 |
-
from flax.linen import combine_masks, make_causal_mask
|
29 |
-
from flax.linen import partitioning as nn_partitioning
|
30 |
-
from flax.linen.attention import dot_product_attention_weights
|
31 |
-
from jax import lax
|
32 |
-
from jax.random import PRNGKey
|
33 |
-
|
34 |
-
from transformers.modeling_flax_outputs import (
|
35 |
-
FlaxBaseModelOutputWithPastAndCrossAttentions,
|
36 |
-
FlaxCausalLMOutputWithCrossAttentions,
|
37 |
-
)
|
38 |
-
from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
|
39 |
-
|
40 |
-
from models import BartConfig
|
41 |
-
|
42 |
-
|
43 |
-
scan_with_axes = nn_partitioning.scan_with_axes
|
44 |
-
remat = nn_partitioning.remat
|
45 |
-
|
46 |
-
|
47 |
-
def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
|
48 |
-
"""
|
49 |
-
Shift input ids one token to the right.
|
50 |
-
"""
|
51 |
-
shifted_input_ids = np.zeros_like(input_ids)
|
52 |
-
shifted_input_ids[:, 1:] = input_ids[:, :-1]
|
53 |
-
shifted_input_ids[:, 0] = decoder_start_token_id
|
54 |
-
|
55 |
-
shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
|
56 |
-
return shifted_input_ids
|
57 |
-
|
58 |
-
|
59 |
-
class FlaxBartAttention(nn.Module):
|
60 |
-
config: BartConfig
|
61 |
-
embed_dim: int
|
62 |
-
num_heads: int
|
63 |
-
dropout: float = 0.0
|
64 |
-
causal: bool = False
|
65 |
-
bias: bool = True
|
66 |
-
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
67 |
-
|
68 |
-
def setup(self) -> None:
|
69 |
-
self.head_dim = self.embed_dim // self.num_heads
|
70 |
-
if self.head_dim * self.num_heads != self.embed_dim:
|
71 |
-
raise ValueError(
|
72 |
-
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
73 |
-
f" and `num_heads`: {self.num_heads})."
|
74 |
-
)
|
75 |
-
|
76 |
-
dense = partial(
|
77 |
-
nn.Dense,
|
78 |
-
self.embed_dim,
|
79 |
-
use_bias=self.bias,
|
80 |
-
dtype=self.dtype,
|
81 |
-
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
82 |
-
)
|
83 |
-
|
84 |
-
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
|
85 |
-
|
86 |
-
self.fused_proj = nn.Dense(
|
87 |
-
self.embed_dim * 3,
|
88 |
-
use_bias=self.bias,
|
89 |
-
dtype=self.dtype,
|
90 |
-
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
91 |
-
)
|
92 |
-
|
93 |
-
self.fused_key_value = nn.Dense(
|
94 |
-
self.embed_dim * 2,
|
95 |
-
use_bias=self.bias,
|
96 |
-
dtype=self.dtype,
|
97 |
-
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
98 |
-
)
|
99 |
-
|
100 |
-
self.out_proj = dense()
|
101 |
-
|
102 |
-
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
103 |
-
|
104 |
-
if self.causal:
|
105 |
-
self.causal_mask = make_causal_mask(
|
106 |
-
jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
|
107 |
-
)
|
108 |
-
|
109 |
-
def _split_heads(self, hidden_states):
|
110 |
-
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
|
111 |
-
|
112 |
-
def _merge_heads(self, hidden_states):
|
113 |
-
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
|
114 |
-
|
115 |
-
@nn.compact
|
116 |
-
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
117 |
-
"""
|
118 |
-
This function takes projected key, value states from a single input token and concatenates the states to cached
|
119 |
-
states from previous steps. This function is slighly adapted from the official Flax repository:
|
120 |
-
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
121 |
-
"""
|
122 |
-
# detect if we're initializing by absence of existing cache data.
|
123 |
-
is_initialized = self.has_variable("cache", "cached_key")
|
124 |
-
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
|
125 |
-
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
|
126 |
-
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
|
127 |
-
|
128 |
-
if is_initialized:
|
129 |
-
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
|
130 |
-
# update key, value caches with our new 1d spatial slices
|
131 |
-
cur_index = cache_index.value
|
132 |
-
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
|
133 |
-
key = lax.dynamic_update_slice(cached_key.value, key, indices)
|
134 |
-
value = lax.dynamic_update_slice(cached_value.value, value, indices)
|
135 |
-
cached_key.value = key
|
136 |
-
cached_value.value = value
|
137 |
-
num_updated_cache_vectors = query.shape[1]
|
138 |
-
cache_index.value = cache_index.value + num_updated_cache_vectors
|
139 |
-
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
|
140 |
-
pad_mask = jnp.broadcast_to(
|
141 |
-
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
|
142 |
-
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
|
143 |
-
)
|
144 |
-
attention_mask = combine_masks(pad_mask, attention_mask)
|
145 |
-
return key, value, attention_mask
|
146 |
-
|
147 |
-
def __call__(
|
148 |
-
self,
|
149 |
-
hidden_states: jnp.ndarray,
|
150 |
-
key_value_states: Optional[jnp.ndarray] = None,
|
151 |
-
attention_mask: Optional[jnp.ndarray] = None,
|
152 |
-
init_cache: bool = False,
|
153 |
-
deterministic: bool = True,
|
154 |
-
) -> Tuple[jnp.ndarray]:
|
155 |
-
"""Input shape: Batch x Time x Channel"""
|
156 |
-
|
157 |
-
# if key_value_states are provided this layer is used as a cross-attention layer
|
158 |
-
# for the decoder
|
159 |
-
is_cross_attention = key_value_states is not None
|
160 |
-
batch_size = hidden_states.shape[0]
|
161 |
-
|
162 |
-
if self.config.fuse_matmuls:
|
163 |
-
# get key, value proj
|
164 |
-
if is_cross_attention:
|
165 |
-
# get query proj
|
166 |
-
query_states = self.q_proj(hidden_states)
|
167 |
-
# cross_attentions
|
168 |
-
attention_states = self.fused_key_value(key_value_states)
|
169 |
-
key_states, value_states = jnp.split(attention_states, 2, axis=-1)
|
170 |
-
else:
|
171 |
-
attention_states = self.fused_proj(hidden_states)
|
172 |
-
query_states, key_states, value_states = jnp.split(attention_states, 3, axis=-1)
|
173 |
-
|
174 |
-
else:
|
175 |
-
# get query proj
|
176 |
-
query_states = self.q_proj(hidden_states)
|
177 |
-
# get key, value proj
|
178 |
-
if is_cross_attention:
|
179 |
-
# cross_attentions
|
180 |
-
key_states = self.k_proj(key_value_states)
|
181 |
-
value_states = self.v_proj(key_value_states)
|
182 |
-
else:
|
183 |
-
# self_attention
|
184 |
-
key_states = self.k_proj(hidden_states)
|
185 |
-
value_states = self.v_proj(hidden_states)
|
186 |
-
|
187 |
-
query_states = self._split_heads(query_states)
|
188 |
-
key_states = self._split_heads(key_states)
|
189 |
-
value_states = self._split_heads(value_states)
|
190 |
-
|
191 |
-
# handle cache prepare causal attention mask
|
192 |
-
if self.causal:
|
193 |
-
query_length, key_length = query_states.shape[1], key_states.shape[1]
|
194 |
-
if self.has_variable("cache", "cached_key"):
|
195 |
-
mask_shift = self.variables["cache"]["cache_index"]
|
196 |
-
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
|
197 |
-
causal_mask = lax.dynamic_slice(
|
198 |
-
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
|
199 |
-
)
|
200 |
-
else:
|
201 |
-
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
|
202 |
-
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
|
203 |
-
|
204 |
-
# combine masks if needed
|
205 |
-
if attention_mask is not None and self.causal:
|
206 |
-
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
|
207 |
-
attention_mask = combine_masks(attention_mask, causal_mask)
|
208 |
-
elif self.causal:
|
209 |
-
attention_mask = causal_mask
|
210 |
-
elif attention_mask is not None:
|
211 |
-
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
212 |
-
|
213 |
-
# During fast autoregressive decoding, we feed one position at a time,
|
214 |
-
# and cache the keys and values step by step.
|
215 |
-
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
|
216 |
-
key_states, value_states, attention_mask = self._concatenate_to_cache(
|
217 |
-
key_states, value_states, query_states, attention_mask
|
218 |
-
)
|
219 |
-
|
220 |
-
# Convert the boolean attention mask to an attention bias.
|
221 |
-
if attention_mask is not None:
|
222 |
-
# attention mask in the form of attention bias
|
223 |
-
attention_bias = lax.select(
|
224 |
-
attention_mask > 0,
|
225 |
-
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
226 |
-
jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
|
227 |
-
)
|
228 |
-
else:
|
229 |
-
attention_bias = None
|
230 |
-
|
231 |
-
dropout_rng = None
|
232 |
-
if not deterministic and self.dropout > 0.0:
|
233 |
-
dropout_rng = self.make_rng("dropout")
|
234 |
-
|
235 |
-
attn_weights = dot_product_attention_weights(
|
236 |
-
query_states,
|
237 |
-
key_states,
|
238 |
-
bias=attention_bias,
|
239 |
-
dropout_rng=dropout_rng,
|
240 |
-
dropout_rate=self.dropout,
|
241 |
-
broadcast_dropout=True,
|
242 |
-
deterministic=deterministic,
|
243 |
-
dtype=self.dtype,
|
244 |
-
precision=None,
|
245 |
-
)
|
246 |
-
|
247 |
-
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
248 |
-
attn_output = self._merge_heads(attn_output)
|
249 |
-
attn_output = self.out_proj(attn_output)
|
250 |
-
|
251 |
-
return attn_output, attn_weights
|
252 |
-
|
253 |
-
|
254 |
-
class FlaxBartDecoderLayer(nn.Module):
|
255 |
-
config: BartConfig
|
256 |
-
dtype: jnp.dtype = jnp.float32
|
257 |
-
|
258 |
-
def setup(self) -> None:
|
259 |
-
self.embed_dim = self.config.d_model
|
260 |
-
self.self_attn = FlaxBartAttention(
|
261 |
-
config=self.config,
|
262 |
-
embed_dim=self.embed_dim,
|
263 |
-
num_heads=self.config.decoder_attention_heads,
|
264 |
-
dropout=self.config.attention_dropout,
|
265 |
-
causal=True,
|
266 |
-
dtype=self.dtype,
|
267 |
-
)
|
268 |
-
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
269 |
-
self.activation_fn = ACT2FN[self.config.activation_function]
|
270 |
-
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
|
271 |
-
|
272 |
-
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
|
273 |
-
self.encoder_attn = FlaxBartAttention(
|
274 |
-
config=self.config,
|
275 |
-
embed_dim=self.embed_dim,
|
276 |
-
num_heads=self.config.decoder_attention_heads,
|
277 |
-
dropout=self.config.attention_dropout,
|
278 |
-
dtype=self.dtype,
|
279 |
-
)
|
280 |
-
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
|
281 |
-
self.fc1 = nn.Dense(
|
282 |
-
self.config.encoder_ffn_dim,
|
283 |
-
dtype=self.dtype,
|
284 |
-
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
285 |
-
)
|
286 |
-
self.fc2 = nn.Dense(
|
287 |
-
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
|
288 |
-
)
|
289 |
-
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
|
290 |
-
|
291 |
-
def __call__(
|
292 |
-
self,
|
293 |
-
hidden_states: jnp.ndarray,
|
294 |
-
attention_mask: jnp.ndarray,
|
295 |
-
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
296 |
-
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
297 |
-
init_cache: bool = False,
|
298 |
-
output_attentions: bool = True,
|
299 |
-
deterministic: bool = True,
|
300 |
-
) -> Tuple[jnp.ndarray]:
|
301 |
-
|
302 |
-
if self.config.use_scan:
|
303 |
-
hidden_states = hidden_states[0]
|
304 |
-
|
305 |
-
residual = hidden_states
|
306 |
-
|
307 |
-
# Self Attention
|
308 |
-
hidden_states, self_attn_weights = self.self_attn(
|
309 |
-
hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache
|
310 |
-
)
|
311 |
-
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
312 |
-
hidden_states = residual + hidden_states
|
313 |
-
hidden_states = self.self_attn_layer_norm(hidden_states)
|
314 |
-
|
315 |
-
# Cross-Attention Block
|
316 |
-
cross_attn_weights = None
|
317 |
-
if encoder_hidden_states is not None:
|
318 |
-
residual = hidden_states
|
319 |
-
|
320 |
-
hidden_states, cross_attn_weights = self.encoder_attn(
|
321 |
-
hidden_states=hidden_states,
|
322 |
-
key_value_states=encoder_hidden_states,
|
323 |
-
attention_mask=encoder_attention_mask,
|
324 |
-
)
|
325 |
-
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
326 |
-
hidden_states = residual + hidden_states
|
327 |
-
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
328 |
-
|
329 |
-
# Fully Connected
|
330 |
-
residual = hidden_states
|
331 |
-
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
332 |
-
hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
|
333 |
-
hidden_states = self.fc2(hidden_states)
|
334 |
-
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
335 |
-
hidden_states = residual + hidden_states
|
336 |
-
hidden_states = self.final_layer_norm(hidden_states)
|
337 |
-
|
338 |
-
outputs = (hidden_states,)
|
339 |
-
|
340 |
-
if output_attentions:
|
341 |
-
outputs += (self_attn_weights, cross_attn_weights)
|
342 |
-
|
343 |
-
if self.config.use_scan:
|
344 |
-
outputs = (outputs, None)
|
345 |
-
|
346 |
-
return outputs
|
347 |
-
|
348 |
-
|
349 |
-
class FlaxBartDecoderLayerCollection(nn.Module):
|
350 |
-
config: BartConfig
|
351 |
-
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
352 |
-
|
353 |
-
@nn.compact
|
354 |
-
def __call__(
|
355 |
-
self,
|
356 |
-
hidden_states,
|
357 |
-
attention_mask,
|
358 |
-
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
359 |
-
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
360 |
-
deterministic: bool = True,
|
361 |
-
init_cache: bool = False,
|
362 |
-
output_attentions: bool = False,
|
363 |
-
output_hidden_states: bool = False,
|
364 |
-
return_dict: bool = True,
|
365 |
-
):
|
366 |
-
# decoder layers
|
367 |
-
all_hidden_states = () if output_hidden_states else None
|
368 |
-
all_self_attns = () if output_attentions else None
|
369 |
-
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
370 |
-
|
371 |
-
num_decoder_layers = self.config.decoder_layers
|
372 |
-
BlockDecoderLayer = (
|
373 |
-
remat(
|
374 |
-
FlaxBartDecoderLayer,
|
375 |
-
static_argnums=(4, 5, 6),
|
376 |
-
prevent_cse=not self.config.use_scan,
|
377 |
-
)
|
378 |
-
if self.config.gradient_checkpointing
|
379 |
-
else FlaxBartDecoderLayer
|
380 |
-
)
|
381 |
-
|
382 |
-
if self.config.use_scan:
|
383 |
-
# since all decoder layers are the same, we use nn.scan directly
|
384 |
-
assert not output_attentions, "cannot use `scan` with `output_attentions` set to `True`"
|
385 |
-
assert not output_hidden_states, "cannot use `scan` with `output_hidden_states` set to `True`"
|
386 |
-
hidden_states = (hidden_states,)
|
387 |
-
|
388 |
-
# TODO: add layerdrop in checkpointed scan (note: default value for layerdrop in config is zero)
|
389 |
-
hidden_states, _ = scan_with_axes(
|
390 |
-
BlockDecoderLayer,
|
391 |
-
variable_axes={"params": 0, "cache": 0},
|
392 |
-
split_rngs={"params": True, "dropout": True},
|
393 |
-
in_axes=(nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast),
|
394 |
-
length=num_decoder_layers,
|
395 |
-
)(self.config, dtype=self.dtype, name="FlaxBartDecoderLayers")(
|
396 |
-
hidden_states,
|
397 |
-
attention_mask,
|
398 |
-
encoder_hidden_states,
|
399 |
-
encoder_attention_mask,
|
400 |
-
init_cache,
|
401 |
-
output_attentions,
|
402 |
-
deterministic,
|
403 |
-
)
|
404 |
-
hidden_states = hidden_states[0]
|
405 |
-
|
406 |
-
else:
|
407 |
-
for layer in range(num_decoder_layers):
|
408 |
-
if output_hidden_states:
|
409 |
-
all_hidden_states += (hidden_states,)
|
410 |
-
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
411 |
-
dropout_probability = random.uniform(0, 1)
|
412 |
-
if not deterministic and (dropout_probability < self.config.decoder_layerdrop):
|
413 |
-
layer_outputs = (None, None, None)
|
414 |
-
else:
|
415 |
-
layer_outputs = BlockDecoderLayer(self.config, dtype=self.dtype, name=str(layer),)(
|
416 |
-
hidden_states,
|
417 |
-
attention_mask,
|
418 |
-
encoder_hidden_states,
|
419 |
-
encoder_attention_mask,
|
420 |
-
init_cache,
|
421 |
-
output_attentions,
|
422 |
-
deterministic,
|
423 |
-
)
|
424 |
-
|
425 |
-
hidden_states = layer_outputs[0]
|
426 |
-
if output_attentions:
|
427 |
-
all_self_attns += (layer_outputs[1],)
|
428 |
-
|
429 |
-
if encoder_hidden_states is not None:
|
430 |
-
all_cross_attentions += (layer_outputs[2],)
|
431 |
-
|
432 |
-
# add hidden states from the last decoder layer
|
433 |
-
if output_hidden_states:
|
434 |
-
all_hidden_states += (hidden_states,)
|
435 |
-
|
436 |
-
outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions]
|
437 |
-
|
438 |
-
if not return_dict:
|
439 |
-
return tuple(v for v in outputs if v is not None)
|
440 |
-
|
441 |
-
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
442 |
-
last_hidden_state=hidden_states,
|
443 |
-
hidden_states=all_hidden_states,
|
444 |
-
attentions=all_self_attns,
|
445 |
-
cross_attentions=all_cross_attentions,
|
446 |
-
)
|
447 |
-
|
448 |
-
|
449 |
-
class FlaxBartDecoder(nn.Module):
|
450 |
-
config: BartConfig
|
451 |
-
embed_tokens: nn.Embed
|
452 |
-
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
453 |
-
|
454 |
-
def setup(self):
|
455 |
-
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
456 |
-
|
457 |
-
embed_dim = self.config.d_model
|
458 |
-
self.padding_idx = self.config.pad_token_id
|
459 |
-
self.max_target_positions = self.config.max_position_embeddings
|
460 |
-
self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
|
461 |
-
|
462 |
-
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
|
463 |
-
# and adjust num_embeddings appropriately. Other models don't have this hack
|
464 |
-
self.offset = 2
|
465 |
-
self.embed_positions = nn.Embed(
|
466 |
-
self.config.max_position_embeddings + self.offset,
|
467 |
-
embed_dim,
|
468 |
-
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
469 |
-
)
|
470 |
-
|
471 |
-
self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
|
472 |
-
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
|
473 |
-
|
474 |
-
def __call__(
|
475 |
-
self,
|
476 |
-
input_ids,
|
477 |
-
attention_mask,
|
478 |
-
position_ids,
|
479 |
-
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
480 |
-
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
481 |
-
init_cache: bool = False,
|
482 |
-
output_attentions: bool = False,
|
483 |
-
output_hidden_states: bool = False,
|
484 |
-
return_dict: bool = True,
|
485 |
-
deterministic: bool = True,
|
486 |
-
):
|
487 |
-
input_shape = input_ids.shape
|
488 |
-
input_ids = input_ids.reshape(-1, input_shape[-1])
|
489 |
-
|
490 |
-
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
491 |
-
|
492 |
-
# embed positions
|
493 |
-
positions = self.embed_positions(position_ids + self.offset)
|
494 |
-
|
495 |
-
hidden_states = inputs_embeds + positions
|
496 |
-
hidden_states = self.layernorm_embedding(hidden_states)
|
497 |
-
|
498 |
-
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
499 |
-
|
500 |
-
outputs = self.layers(
|
501 |
-
hidden_states,
|
502 |
-
attention_mask,
|
503 |
-
encoder_hidden_states,
|
504 |
-
encoder_attention_mask,
|
505 |
-
deterministic=deterministic,
|
506 |
-
init_cache=init_cache,
|
507 |
-
output_attentions=output_attentions,
|
508 |
-
output_hidden_states=output_hidden_states,
|
509 |
-
return_dict=return_dict,
|
510 |
-
)
|
511 |
-
|
512 |
-
if not return_dict:
|
513 |
-
return outputs
|
514 |
-
|
515 |
-
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
516 |
-
last_hidden_state=outputs.last_hidden_state,
|
517 |
-
hidden_states=outputs.hidden_states,
|
518 |
-
attentions=outputs.attentions,
|
519 |
-
cross_attentions=outputs.cross_attentions,
|
520 |
-
)
|
521 |
-
|
522 |
-
|
523 |
-
class FlaxBartDecoderPreTrainedModel(FlaxPreTrainedModel):
|
524 |
-
config_class = BartConfig
|
525 |
-
base_model_prefix: str = "model"
|
526 |
-
module_class: nn.Module = None
|
527 |
-
|
528 |
-
def __init__(
|
529 |
-
self,
|
530 |
-
config: BartConfig,
|
531 |
-
input_shape: Tuple[int] = (1, 1),
|
532 |
-
seed: int = 0,
|
533 |
-
dtype: jnp.dtype = jnp.float32,
|
534 |
-
_do_init: bool = True,
|
535 |
-
**kwargs
|
536 |
-
):
|
537 |
-
config.is_decoder = True
|
538 |
-
config.is_encoder_decoder = False
|
539 |
-
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
540 |
-
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
541 |
-
|
542 |
-
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
543 |
-
# init input tensors
|
544 |
-
input_ids = jnp.zeros(input_shape, dtype="i4")
|
545 |
-
attention_mask = jnp.ones_like(input_ids)
|
546 |
-
|
547 |
-
batch_size, sequence_length = input_ids.shape
|
548 |
-
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
|
549 |
-
|
550 |
-
params_rng, dropout_rng = jax.random.split(rng)
|
551 |
-
rngs = {"params": params_rng, "dropout": dropout_rng}
|
552 |
-
encoder_hidden_states = jnp.zeros(input_shape + (self.config.d_model,))
|
553 |
-
encoder_attention_mask = attention_mask
|
554 |
-
module_init_outputs = self.module.init(
|
555 |
-
rngs,
|
556 |
-
input_ids,
|
557 |
-
attention_mask,
|
558 |
-
position_ids,
|
559 |
-
encoder_hidden_states,
|
560 |
-
encoder_attention_mask,
|
561 |
-
return_dict=False,
|
562 |
-
)
|
563 |
-
return module_init_outputs["params"]
|
564 |
-
|
565 |
-
def init_cache(self, batch_size, max_length):
|
566 |
-
r"""
|
567 |
-
Args:
|
568 |
-
batch_size (`int`):
|
569 |
-
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
|
570 |
-
max_length (`int`):
|
571 |
-
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
|
572 |
-
cache.
|
573 |
-
"""
|
574 |
-
# init input variables to retrieve cache
|
575 |
-
input_ids = jnp.ones((batch_size, max_length), dtype="i4")
|
576 |
-
attention_mask = jnp.ones_like(input_ids, dtype="i4")
|
577 |
-
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
578 |
-
|
579 |
-
init_variables = self.module.init(
|
580 |
-
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
|
581 |
-
)
|
582 |
-
return unfreeze(init_variables["cache"])
|
583 |
-
|
584 |
-
def __call__(
|
585 |
-
self,
|
586 |
-
input_ids: jnp.ndarray,
|
587 |
-
attention_mask: Optional[jnp.ndarray] = None,
|
588 |
-
position_ids: Optional[jnp.ndarray] = None,
|
589 |
-
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
590 |
-
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
591 |
-
output_attentions: Optional[bool] = None,
|
592 |
-
output_hidden_states: Optional[bool] = None,
|
593 |
-
return_dict: Optional[bool] = None,
|
594 |
-
train: bool = False,
|
595 |
-
params: dict = None,
|
596 |
-
past_key_values: dict = None,
|
597 |
-
dropout_rng: PRNGKey = None,
|
598 |
-
):
|
599 |
-
"""
|
600 |
-
Args:
|
601 |
-
input_ids (`jnp.ndarray` of shape `(target_batch_size, target_sequence_length)`):
|
602 |
-
Indices of decoder input sequence tokens in the vocabulary.
|
603 |
-
|
604 |
-
Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
605 |
-
[`PreTrainedTokenizer.__call__`] for details.
|
606 |
-
|
607 |
-
[What are decoder input IDs?](../glossary#decoder-input-ids)
|
608 |
-
|
609 |
-
For translation and summarization training, `decoder_input_ids` should be provided. If no
|
610 |
-
`decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
|
611 |
-
for denoising pre-training following the paper.
|
612 |
-
attention_mask (`jnp.ndarray` of shape `(target_batch_size, target_sequence_length)`, *optional*):
|
613 |
-
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
614 |
-
be used by default.
|
615 |
-
|
616 |
-
If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the
|
617 |
-
paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
|
618 |
-
position_ids (`numpy.ndarray` of shape `(target_batch_size, sequence_length)`, *optional*):
|
619 |
-
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
|
620 |
-
range `[0, config.max_position_embeddings - 1]`.
|
621 |
-
encoder_hidden_states (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
|
622 |
-
A sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
623 |
-
encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
624 |
-
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
625 |
-
|
626 |
-
- 1 for tokens that are **not masked**,
|
627 |
-
- 0 for tokens that are **masked**.
|
628 |
-
|
629 |
-
[What are attention masks?](../glossary#attention-mask)
|
630 |
-
output_attentions (`bool`, *optional*):
|
631 |
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
632 |
-
tensors for more detail.
|
633 |
-
output_hidden_states (`bool`, *optional*):
|
634 |
-
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
635 |
-
more detail.
|
636 |
-
return_dict (`bool`, *optional*):
|
637 |
-
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
638 |
-
past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
|
639 |
-
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
|
640 |
-
auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
|
641 |
-
"""
|
642 |
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
643 |
-
output_hidden_states = (
|
644 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
645 |
-
)
|
646 |
-
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
647 |
-
|
648 |
-
if encoder_hidden_states is not None and encoder_attention_mask is None:
|
649 |
-
batch_size, sequence_length = encoder_hidden_states.shape[:2]
|
650 |
-
encoder_attention_mask = jnp.ones((batch_size, sequence_length))
|
651 |
-
|
652 |
-
# prepare decoder inputs
|
653 |
-
if attention_mask is None:
|
654 |
-
attention_mask = jnp.ones_like(input_ids)
|
655 |
-
if position_ids is None:
|
656 |
-
batch_size, sequence_length = input_ids.shape
|
657 |
-
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
|
658 |
-
|
659 |
-
# Handle any PRNG if needed
|
660 |
-
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
|
661 |
-
|
662 |
-
inputs = {"params": params or self.params}
|
663 |
-
|
664 |
-
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
|
665 |
-
# down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
|
666 |
-
# changed by FlaxBartAttention module
|
667 |
-
if past_key_values:
|
668 |
-
inputs["cache"] = past_key_values
|
669 |
-
mutable = ["cache"]
|
670 |
-
else:
|
671 |
-
mutable = False
|
672 |
-
|
673 |
-
outputs = self.module.apply(
|
674 |
-
inputs,
|
675 |
-
input_ids=jnp.array(input_ids, dtype="i4"),
|
676 |
-
attention_mask=jnp.array(attention_mask, dtype="i4"),
|
677 |
-
position_ids=jnp.array(position_ids, dtype="i4"),
|
678 |
-
encoder_hidden_states=encoder_hidden_states,
|
679 |
-
encoder_attention_mask=encoder_attention_mask,
|
680 |
-
output_attentions=output_attentions,
|
681 |
-
output_hidden_states=output_hidden_states,
|
682 |
-
return_dict=return_dict,
|
683 |
-
deterministic=not train,
|
684 |
-
rngs=rngs,
|
685 |
-
mutable=mutable,
|
686 |
-
)
|
687 |
-
|
688 |
-
# add updated cache to model output
|
689 |
-
if past_key_values is not None and return_dict:
|
690 |
-
outputs, past_key_values = outputs
|
691 |
-
outputs["past_key_values"] = unfreeze(past_key_values["cache"])
|
692 |
-
return outputs
|
693 |
-
elif past_key_values is not None and not return_dict:
|
694 |
-
outputs, past_key_values = outputs
|
695 |
-
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
|
696 |
-
|
697 |
-
return outputs
|
698 |
-
|
699 |
-
|
700 |
-
class FlaxBartDecoderWrapper(nn.Module):
|
701 |
-
"""
|
702 |
-
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
|
703 |
-
used in combination with the [`EncoderDecoderModel`] framework.
|
704 |
-
"""
|
705 |
-
|
706 |
-
config: BartConfig
|
707 |
-
dtype: jnp.dtype = jnp.float32
|
708 |
-
|
709 |
-
def setup(self):
|
710 |
-
embed_dim = self.config.d_model
|
711 |
-
embed_tokens = nn.Embed(
|
712 |
-
self.config.vocab_size,
|
713 |
-
embed_dim,
|
714 |
-
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
715 |
-
)
|
716 |
-
self.decoder = FlaxBartDecoder(config=self.config, embed_tokens=embed_tokens, dtype=self.dtype)
|
717 |
-
|
718 |
-
def __call__(self, *args, **kwargs):
|
719 |
-
return self.decoder(*args, **kwargs)
|
720 |
-
|
721 |
-
|
722 |
-
class FlaxBartForCausalLMModule(nn.Module):
|
723 |
-
"""Bart Decoder Module with a language modeling head on top (linear layer with weights tied to the input embeddings)
|
724 |
-
e.g. for autoregressive tasks.
|
725 |
-
"""
|
726 |
-
|
727 |
-
config: BartConfig
|
728 |
-
dtype: jnp.dtype = jnp.float32
|
729 |
-
|
730 |
-
def setup(self):
|
731 |
-
self.model = FlaxBartDecoderWrapper(config=self.config, dtype=self.dtype)
|
732 |
-
self.lm_head = nn.Dense(
|
733 |
-
self.config.vocab_size,
|
734 |
-
use_bias=False,
|
735 |
-
dtype=self.dtype,
|
736 |
-
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
737 |
-
)
|
738 |
-
|
739 |
-
def __call__(
|
740 |
-
self,
|
741 |
-
input_ids,
|
742 |
-
attention_mask,
|
743 |
-
position_ids,
|
744 |
-
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
745 |
-
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
746 |
-
init_cache: bool = False,
|
747 |
-
output_attentions: bool = False,
|
748 |
-
output_hidden_states: bool = False,
|
749 |
-
return_dict: bool = True,
|
750 |
-
deterministic: bool = True,
|
751 |
-
):
|
752 |
-
|
753 |
-
outputs = self.model(
|
754 |
-
input_ids,
|
755 |
-
attention_mask,
|
756 |
-
position_ids,
|
757 |
-
encoder_hidden_states,
|
758 |
-
encoder_attention_mask,
|
759 |
-
deterministic=deterministic,
|
760 |
-
init_cache=init_cache,
|
761 |
-
output_attentions=output_attentions,
|
762 |
-
output_hidden_states=output_hidden_states,
|
763 |
-
return_dict=return_dict,
|
764 |
-
)
|
765 |
-
|
766 |
-
hidden_states = outputs[0]
|
767 |
-
|
768 |
-
if self.config.tie_word_embeddings:
|
769 |
-
shared_embedding = self.model.variables["params"]["decoder"]["embed_tokens"]["embedding"]
|
770 |
-
lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
|
771 |
-
else:
|
772 |
-
lm_logits = self.lm_head(hidden_states)
|
773 |
-
|
774 |
-
if not return_dict:
|
775 |
-
return (lm_logits,) + outputs[1:]
|
776 |
-
|
777 |
-
return FlaxCausalLMOutputWithCrossAttentions(
|
778 |
-
logits=lm_logits,
|
779 |
-
hidden_states=outputs.hidden_states,
|
780 |
-
attentions=outputs.attentions,
|
781 |
-
cross_attentions=outputs.cross_attentions,
|
782 |
-
)
|
783 |
-
|
784 |
-
|
785 |
-
class FlaxBartForCausalLM(FlaxBartDecoderPreTrainedModel):
|
786 |
-
"""Bart Decoder Model with a language modeling head on top (linear layer with weights tied to the input embeddings)
|
787 |
-
e.g. for autoregressive tasks.
|
788 |
-
"""
|
789 |
-
|
790 |
-
module_class = FlaxBartForCausalLMModule
|
791 |
-
|
792 |
-
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
793 |
-
# initializing the cache
|
794 |
-
batch_size, seq_length = input_ids.shape
|
795 |
-
|
796 |
-
past_key_values = self.init_cache(batch_size, max_length)
|
797 |
-
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
|
798 |
-
# But since the decoder uses a causal mask, those positions are masked anyway.
|
799 |
-
# Thus, we can create a single static attention_mask here, which is more efficient for compilation
|
800 |
-
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
|
801 |
-
if attention_mask is not None:
|
802 |
-
position_ids = attention_mask.cumsum(axis=-1) - 1
|
803 |
-
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
|
804 |
-
else:
|
805 |
-
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
|
806 |
-
|
807 |
-
return {
|
808 |
-
"past_key_values": past_key_values,
|
809 |
-
"attention_mask": extended_attention_mask,
|
810 |
-
"position_ids": position_ids,
|
811 |
-
}
|
812 |
-
|
813 |
-
def update_inputs_for_generation(self, model_outputs, model_kwargs):
|
814 |
-
model_kwargs["past_key_values"] = model_outputs.past_key_values
|
815 |
-
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
|
816 |
-
return model_kwargs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/modeling_flax_speech_encoder_decoder.py
DELETED
@@ -1,1245 +0,0 @@
|
|
1 |
-
# coding=utf-8
|
2 |
-
# Copyright 2022 The HuggingFace Inc. team.
|
3 |
-
#
|
4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
# you may not use this file except in compliance with the License.
|
6 |
-
# You may obtain a copy of the License at
|
7 |
-
#
|
8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
#
|
10 |
-
# Unless required by applicable law or agreed to in writing, software
|
11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
# See the License for the specific language governing permissions and
|
14 |
-
# limitations under the License.
|
15 |
-
""" Classes to support Flax Speech-Encoder-Decoder architectures"""
|
16 |
-
|
17 |
-
import os
|
18 |
-
from functools import partial
|
19 |
-
from typing import Optional, Tuple, Union, Dict
|
20 |
-
|
21 |
-
import flax
|
22 |
-
import flax.linen as nn
|
23 |
-
import jax
|
24 |
-
import jax.numpy as jnp
|
25 |
-
from flax.core.frozen_dict import FrozenDict, unfreeze
|
26 |
-
from jax import lax
|
27 |
-
from jax.random import PRNGKey
|
28 |
-
import numpy as np
|
29 |
-
|
30 |
-
from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput
|
31 |
-
from transformers.modeling_flax_utils import FlaxPreTrainedModel
|
32 |
-
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, ModelOutput
|
33 |
-
from transformers.generation_flax_utils import FlaxLogitsProcessorList
|
34 |
-
from models import (
|
35 |
-
FlaxWav2Vec2Model,
|
36 |
-
FlaxWav2Vec2Module,
|
37 |
-
FlaxBartForCausalLM,
|
38 |
-
FlaxBartForCausalLMModule,
|
39 |
-
BartConfig,
|
40 |
-
Wav2Vec2Config,
|
41 |
-
SpeechEncoderDecoderConfig,
|
42 |
-
)
|
43 |
-
|
44 |
-
logger = logging.get_logger(__name__)
|
45 |
-
|
46 |
-
_CONFIG_FOR_DOC = "SpeechEncoderDecoderConfig"
|
47 |
-
|
48 |
-
SPEECH_ENCODER_DECODER_START_DOCSTRING = r"""
|
49 |
-
This class can be used to initialize a speech-sequence-to-text-sequence model with any pretrained speech
|
50 |
-
autoencoding model as the encoder and any pretrained text autoregressive model as the decoder. The encoder is
|
51 |
-
loaded via [`~AutoModel.from_pretrained`] function and the decoder is loaded via
|
52 |
-
[`~AutoModelForCausalLM.from_pretrained`] function. Cross-attention layers are automatically added to the decoder
|
53 |
-
and should be fine-tuned on a downstream generative task, like summarization.
|
54 |
-
|
55 |
-
The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation
|
56 |
-
tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation
|
57 |
-
Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi
|
58 |
-
Zhou, Wei Li, Peter J. Liu.
|
59 |
-
|
60 |
-
Additionally, in [Large-Scale Self- and Semi-Supervised Learning for Speech
|
61 |
-
Translation](https://arxiv.org/abs/2104.06678) it is shown how leveraging large pretrained speech models for speech
|
62 |
-
translation yields a significant performance improvement.
|
63 |
-
|
64 |
-
After such an Speech-Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other
|
65 |
-
models (see the examples for more information).
|
66 |
-
|
67 |
-
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
68 |
-
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
69 |
-
etc.)
|
70 |
-
|
71 |
-
This model is also a Flax Linen
|
72 |
-
[flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
|
73 |
-
regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
|
74 |
-
|
75 |
-
Parameters:
|
76 |
-
config ([`SpeechEncoderDecoderConfig`]): Model configuration class with all the parameters of the model.
|
77 |
-
Initializing with a config file does not load the weights associated with the model, only the
|
78 |
-
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
|
79 |
-
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
|
80 |
-
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
|
81 |
-
`jax.numpy.bfloat16` (on TPUs).
|
82 |
-
|
83 |
-
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
84 |
-
specified all the computation will be performed with the given `dtype`.
|
85 |
-
|
86 |
-
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
87 |
-
parameters.**
|
88 |
-
|
89 |
-
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
|
90 |
-
[`~FlaxPreTrainedModel.to_bf16`].
|
91 |
-
"""
|
92 |
-
|
93 |
-
SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
|
94 |
-
Args:
|
95 |
-
inputs (`jnp.ndarray` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*):
|
96 |
-
Float values of input raw speech waveform or speech features. Values can be obtained by loading a *.flac*
|
97 |
-
or *.wav* audio file into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile
|
98 |
-
library (*pip install soundfile*). To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or
|
99 |
-
[`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type
|
100 |
-
*torch.FloatTensor*.
|
101 |
-
attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
102 |
-
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
103 |
-
|
104 |
-
- 1 for tokens that are **not masked**,
|
105 |
-
- 0 for tokens that are **masked**.
|
106 |
-
|
107 |
-
[What are attention masks?](../glossary#attention-mask)
|
108 |
-
decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
|
109 |
-
Indices of decoder input sequence tokens in the vocabulary.
|
110 |
-
|
111 |
-
Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
112 |
-
[`PreTrainedTokenizer.__call__`] for details.
|
113 |
-
|
114 |
-
[What are input IDs?](../glossary#input-ids)
|
115 |
-
|
116 |
-
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
117 |
-
`past_key_values`).
|
118 |
-
|
119 |
-
For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
|
120 |
-
created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
|
121 |
-
and prepending them with the `decoder_start_token_id`.
|
122 |
-
decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
|
123 |
-
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
124 |
-
be used by default.
|
125 |
-
decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
126 |
-
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
|
127 |
-
range `[0, config.decoder.max_position_embeddings - 1]`.
|
128 |
-
output_hidden_states (`bool`, *optional*):
|
129 |
-
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
130 |
-
more detail.
|
131 |
-
return_dict (`bool`, *optional*):
|
132 |
-
If set to `True`, the model will return a [`~utils.FlaxSeq2SeqLMOutput`] instead of a plain tuple.
|
133 |
-
"""
|
134 |
-
|
135 |
-
SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING = r"""
|
136 |
-
Args:
|
137 |
-
inputs (`jnp.ndarray` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*):
|
138 |
-
Float values of input raw speech waveform or speech features. Values can be obtained by loading a *.flac*
|
139 |
-
or *.wav* audio file into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile
|
140 |
-
library (*pip install soundfile*). To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or
|
141 |
-
[`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type
|
142 |
-
*torch.FloatTensor*.
|
143 |
-
attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
144 |
-
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
145 |
-
|
146 |
-
- 1 for tokens that are **not masked**,
|
147 |
-
- 0 for tokens that are **masked**.
|
148 |
-
|
149 |
-
[What are attention masks?](../glossary#attention-mask)
|
150 |
-
output_attentions (`bool`, *optional*):
|
151 |
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
152 |
-
tensors for more detail.
|
153 |
-
output_hidden_states (`bool`, *optional*):
|
154 |
-
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
155 |
-
more detail.
|
156 |
-
return_dict (`bool`, *optional*):
|
157 |
-
If set to `True`, the model will return a [`~utils.FlaxBaseModelOutput`] instead of a plain tuple.
|
158 |
-
"""
|
159 |
-
|
160 |
-
SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r"""
|
161 |
-
Args:
|
162 |
-
decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
|
163 |
-
Indices of decoder input sequence tokens in the vocabulary.
|
164 |
-
|
165 |
-
Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
166 |
-
[`PreTrainedTokenizer.__call__`] for details.
|
167 |
-
|
168 |
-
[What are decoder input IDs?](../glossary#decoder-input-ids)
|
169 |
-
|
170 |
-
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
171 |
-
`past_key_values`).
|
172 |
-
|
173 |
-
For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
|
174 |
-
created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
|
175 |
-
and prepending them with the `decoder_start_token_id`.
|
176 |
-
encoder_outputs (`tuple(tuple(jnp.ndarray)`):
|
177 |
-
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
|
178 |
-
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
|
179 |
-
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
180 |
-
encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
181 |
-
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
182 |
-
|
183 |
-
- 1 for tokens that are **not masked**,
|
184 |
-
- 0 for tokens that are **masked**.
|
185 |
-
|
186 |
-
[What are attention masks?](../glossary#attention-mask)
|
187 |
-
decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
|
188 |
-
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
189 |
-
be used by default.
|
190 |
-
decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
191 |
-
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
|
192 |
-
range `[0, config.decoder.max_position_embeddings - 1]`.
|
193 |
-
past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
|
194 |
-
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
|
195 |
-
auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
|
196 |
-
output_attentions (`bool`, *optional*):
|
197 |
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
198 |
-
tensors for more detail.
|
199 |
-
output_hidden_states (`bool`, *optional*):
|
200 |
-
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
201 |
-
more detail.
|
202 |
-
return_dict (`bool`, *optional*):
|
203 |
-
If set to `True`, the model will return a [`~utils.FlaxCausalLMOutputWithCrossAttentions`] instead of a
|
204 |
-
plain tuple.
|
205 |
-
"""
|
206 |
-
|
207 |
-
@flax.struct.dataclass
|
208 |
-
class FlaxBeamSearchOutput(ModelOutput):
|
209 |
-
"""
|
210 |
-
Flax Base class for outputs of decoder-only generation models using greedy search.
|
211 |
-
|
212 |
-
|
213 |
-
Args:
|
214 |
-
sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
|
215 |
-
The generated sequences.
|
216 |
-
scores (`jnp.ndarray` of shape `(batch_size,)`):
|
217 |
-
The scores (log probabilites) of the generated sequences.
|
218 |
-
"""
|
219 |
-
|
220 |
-
sequences: jnp.ndarray = None
|
221 |
-
scores: jnp.ndarray = None
|
222 |
-
|
223 |
-
|
224 |
-
@flax.struct.dataclass
|
225 |
-
class BeamSearchState:
|
226 |
-
cur_len: jnp.ndarray
|
227 |
-
running_sequences: jnp.ndarray
|
228 |
-
running_scores: jnp.ndarray
|
229 |
-
sequences: jnp.ndarray
|
230 |
-
scores: jnp.ndarray
|
231 |
-
is_sent_finished: jnp.ndarray
|
232 |
-
model_kwargs: Dict[str, jnp.ndarray]
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
class FlaxSpeechEncoderDecoderModule(nn.Module):
|
238 |
-
config: SpeechEncoderDecoderConfig
|
239 |
-
dtype: jnp.dtype = jnp.float32
|
240 |
-
|
241 |
-
def setup(self):
|
242 |
-
encoder_config = self.config.encoder
|
243 |
-
decoder_config = self.config.decoder
|
244 |
-
|
245 |
-
# TODO: configure FlaxAutoModel mappings (required when trialling different encoder-decoder combinations)
|
246 |
-
encoder_module = FlaxWav2Vec2Module
|
247 |
-
decoder_module = FlaxBartForCausalLMModule
|
248 |
-
|
249 |
-
self.encoder = encoder_module(encoder_config, dtype=self.dtype)
|
250 |
-
self.decoder = decoder_module(decoder_config, dtype=self.dtype)
|
251 |
-
|
252 |
-
# encoder outputs might need to be projected to different dimension for decoder
|
253 |
-
if (
|
254 |
-
self.encoder.config.hidden_size != self.decoder.config.hidden_size
|
255 |
-
and self.decoder.config.cross_attention_hidden_size is None
|
256 |
-
):
|
257 |
-
self.enc_to_dec_proj = nn.Dense(
|
258 |
-
self.decoder.config.hidden_size,
|
259 |
-
kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range),
|
260 |
-
dtype=self.dtype,
|
261 |
-
)
|
262 |
-
else:
|
263 |
-
self.enc_to_dec_proj = None
|
264 |
-
|
265 |
-
def _get_feat_extract_output_lengths(
|
266 |
-
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
|
267 |
-
):
|
268 |
-
"""
|
269 |
-
Computes the output length of the convolutional layers
|
270 |
-
"""
|
271 |
-
|
272 |
-
add_adapter = self.config.encoder.add_adapter if add_adapter is None else add_adapter
|
273 |
-
|
274 |
-
def _conv_out_length(input_length, kernel_size, stride):
|
275 |
-
# 1D convolutional layer output length formula taken
|
276 |
-
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
277 |
-
return (input_length - kernel_size) // stride + 1
|
278 |
-
|
279 |
-
for kernel_size, stride in zip(self.config.encoder.conv_kernel, self.config.encoder.conv_stride):
|
280 |
-
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
|
281 |
-
|
282 |
-
if add_adapter:
|
283 |
-
for _ in range(self.config.encoder.num_adapter_layers):
|
284 |
-
input_lengths = _conv_out_length(input_lengths, 1, self.config.encoder.adapter_stride)
|
285 |
-
|
286 |
-
return input_lengths
|
287 |
-
|
288 |
-
def _get_encoder_module(self):
|
289 |
-
return self.encoder
|
290 |
-
|
291 |
-
def _get_projection_module(self):
|
292 |
-
return self.enc_to_dec_proj
|
293 |
-
|
294 |
-
def _get_decoder_module(self):
|
295 |
-
return self.decoder
|
296 |
-
|
297 |
-
def __call__(
|
298 |
-
self,
|
299 |
-
inputs,
|
300 |
-
attention_mask,
|
301 |
-
decoder_input_ids,
|
302 |
-
decoder_attention_mask,
|
303 |
-
decoder_position_ids,
|
304 |
-
encoder_outputs=None,
|
305 |
-
extract_features=None,
|
306 |
-
output_attentions: bool = False,
|
307 |
-
output_hidden_states: bool = False,
|
308 |
-
output_features: bool = False,
|
309 |
-
return_dict: bool = True,
|
310 |
-
deterministic: bool = True,
|
311 |
-
freeze_feature_encoder: bool = False,
|
312 |
-
):
|
313 |
-
if encoder_outputs is None:
|
314 |
-
encoder_outputs = self.encoder(
|
315 |
-
inputs,
|
316 |
-
attention_mask=attention_mask,
|
317 |
-
extract_features=extract_features,
|
318 |
-
output_attentions=output_attentions,
|
319 |
-
output_hidden_states=output_hidden_states,
|
320 |
-
output_features=output_features,
|
321 |
-
return_dict=return_dict,
|
322 |
-
deterministic=deterministic,
|
323 |
-
freeze_feature_encoder=freeze_feature_encoder,
|
324 |
-
)
|
325 |
-
|
326 |
-
if output_features:
|
327 |
-
return encoder_outputs
|
328 |
-
|
329 |
-
encoder_hidden_states = encoder_outputs[0]
|
330 |
-
|
331 |
-
# optionally project encoder_hidden_states
|
332 |
-
if self.enc_to_dec_proj is not None:
|
333 |
-
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
|
334 |
-
|
335 |
-
# compute correct encoder attention mask
|
336 |
-
if attention_mask is not None:
|
337 |
-
encoder_attention_mask = self.encoder._get_feature_vector_attention_mask(
|
338 |
-
encoder_hidden_states.shape[1], attention_mask
|
339 |
-
)
|
340 |
-
else:
|
341 |
-
encoder_attention_mask = None
|
342 |
-
|
343 |
-
# flax script modeling_flax_wav2vec2.py
|
344 |
-
decoder_outputs = self.decoder(
|
345 |
-
input_ids=decoder_input_ids,
|
346 |
-
attention_mask=decoder_attention_mask,
|
347 |
-
position_ids=decoder_position_ids,
|
348 |
-
encoder_hidden_states=encoder_hidden_states,
|
349 |
-
encoder_attention_mask=encoder_attention_mask,
|
350 |
-
output_attentions=output_attentions,
|
351 |
-
output_hidden_states=output_hidden_states,
|
352 |
-
return_dict=return_dict,
|
353 |
-
deterministic=deterministic,
|
354 |
-
)
|
355 |
-
|
356 |
-
if not return_dict:
|
357 |
-
return decoder_outputs + encoder_outputs
|
358 |
-
|
359 |
-
return FlaxSeq2SeqLMOutput(
|
360 |
-
logits=decoder_outputs.logits,
|
361 |
-
decoder_hidden_states=decoder_outputs.hidden_states,
|
362 |
-
decoder_attentions=decoder_outputs.attentions,
|
363 |
-
cross_attentions=decoder_outputs.cross_attentions,
|
364 |
-
encoder_last_hidden_state=encoder_hidden_states,
|
365 |
-
encoder_hidden_states=encoder_outputs.hidden_states,
|
366 |
-
encoder_attentions=encoder_outputs.attentions,
|
367 |
-
)
|
368 |
-
|
369 |
-
|
370 |
-
@add_start_docstrings(SPEECH_ENCODER_DECODER_START_DOCSTRING)
|
371 |
-
class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
|
372 |
-
r"""
|
373 |
-
[`FlaxSpeechEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture
|
374 |
-
with the module (flax.nn.Module) of one of the base model classes of the library as encoder module and another one
|
375 |
-
as decoder module when created with the :meth*~transformers.FlaxAutoModel.from_pretrained* class method for the
|
376 |
-
encoder and :meth*~transformers.FlaxAutoModelForCausalLM.from_pretrained* class method for the decoder.
|
377 |
-
"""
|
378 |
-
|
379 |
-
config_class = SpeechEncoderDecoderConfig
|
380 |
-
base_model_prefix: str = "speech_encoder_decoder"
|
381 |
-
module_class = FlaxSpeechEncoderDecoderModule
|
382 |
-
|
383 |
-
def __init__(
|
384 |
-
self,
|
385 |
-
config: SpeechEncoderDecoderConfig,
|
386 |
-
input_shape: Optional[Tuple] = None,
|
387 |
-
seed: int = 0,
|
388 |
-
dtype: jnp.dtype = jnp.float32,
|
389 |
-
_do_init: bool = True,
|
390 |
-
**kwargs
|
391 |
-
):
|
392 |
-
|
393 |
-
if not _do_init:
|
394 |
-
raise ValueError(
|
395 |
-
"`FlaxSpeechEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`."
|
396 |
-
)
|
397 |
-
|
398 |
-
if config.decoder.cross_attention_hidden_size is not None:
|
399 |
-
# Raise ValueError or option to project enc to dec hidden_size (eg EncAdapterLayer)
|
400 |
-
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
|
401 |
-
raise ValueError(
|
402 |
-
"If `cross_attention_hidden_size` is specified in the decoder's configuration, "
|
403 |
-
"it has to be equal to the encoder's `hidden_size`. "
|
404 |
-
f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
|
405 |
-
f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
|
406 |
-
)
|
407 |
-
|
408 |
-
# make sure input & output embeddings are not tied
|
409 |
-
config.tie_word_embeddings = False
|
410 |
-
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
411 |
-
|
412 |
-
if input_shape is None:
|
413 |
-
# speech encoders almost always downsample the sequence length dimension
|
414 |
-
encoder_input_length = 1024
|
415 |
-
decoder_input_length = module._get_feat_extract_output_lengths(encoder_input_length)
|
416 |
-
input_shape = ((1, encoder_input_length), (1, decoder_input_length))
|
417 |
-
|
418 |
-
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
419 |
-
|
420 |
-
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
421 |
-
encoder_input_shape, decoder_input_shape = input_shape
|
422 |
-
|
423 |
-
# init input DeviceArrays
|
424 |
-
inputs = jnp.zeros(encoder_input_shape, dtype="f4")
|
425 |
-
attention_mask = jnp.ones_like(inputs, dtype="i4")
|
426 |
-
decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4")
|
427 |
-
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
|
428 |
-
|
429 |
-
batch_size, sequence_length = inputs.shape
|
430 |
-
|
431 |
-
decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape
|
432 |
-
if not decoder_batch_size == batch_size:
|
433 |
-
raise ValueError(
|
434 |
-
f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder and {decoder_batch_size} for decoder."
|
435 |
-
)
|
436 |
-
decoder_position_ids = jnp.broadcast_to(
|
437 |
-
jnp.arange(decoder_sequence_length)[None, :], (decoder_batch_size, decoder_sequence_length)
|
438 |
-
)
|
439 |
-
|
440 |
-
params_rng, dropout_rng = jax.random.split(rng)
|
441 |
-
rngs = {"params": params_rng, "dropout": dropout_rng}
|
442 |
-
|
443 |
-
return self.module.init(
|
444 |
-
rngs,
|
445 |
-
inputs,
|
446 |
-
attention_mask,
|
447 |
-
decoder_input_ids,
|
448 |
-
decoder_attention_mask,
|
449 |
-
decoder_position_ids,
|
450 |
-
)["params"]
|
451 |
-
|
452 |
-
def init_cache(self, batch_size, max_length, encoder_outputs):
|
453 |
-
r"""
|
454 |
-
Args:
|
455 |
-
batch_size (`int`):
|
456 |
-
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
|
457 |
-
max_length (`int`):
|
458 |
-
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
|
459 |
-
cache.
|
460 |
-
encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
|
461 |
-
`encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
|
462 |
-
`attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
|
463 |
-
is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
|
464 |
-
cross-attention of the decoder.
|
465 |
-
"""
|
466 |
-
# init input variables to retrieve cache
|
467 |
-
decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
|
468 |
-
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
|
469 |
-
decoder_position_ids = jnp.broadcast_to(
|
470 |
-
jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
|
471 |
-
)
|
472 |
-
|
473 |
-
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
|
474 |
-
decoder_module = module._get_decoder_module()
|
475 |
-
return decoder_module(
|
476 |
-
input_ids=decoder_input_ids,
|
477 |
-
attention_mask=decoder_attention_mask,
|
478 |
-
position_ids=decoder_position_ids,
|
479 |
-
**kwargs,
|
480 |
-
)
|
481 |
-
|
482 |
-
init_variables = self.module.init(
|
483 |
-
jax.random.PRNGKey(0),
|
484 |
-
decoder_input_ids=decoder_input_ids,
|
485 |
-
decoder_attention_mask=decoder_attention_mask,
|
486 |
-
decoder_position_ids=decoder_position_ids,
|
487 |
-
encoder_hidden_states=encoder_outputs[0],
|
488 |
-
init_cache=True,
|
489 |
-
method=_decoder_forward, # we only need to call the decoder to init the cache
|
490 |
-
)
|
491 |
-
return unfreeze(init_variables["cache"])
|
492 |
-
|
493 |
-
def _get_feat_extract_output_lengths(
|
494 |
-
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
|
495 |
-
):
|
496 |
-
return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter)
|
497 |
-
|
498 |
-
@add_start_docstrings(SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING)
|
499 |
-
@replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
500 |
-
def encode(
|
501 |
-
self,
|
502 |
-
inputs: jnp.ndarray,
|
503 |
-
attention_mask: Optional[jnp.ndarray] = None,
|
504 |
-
extract_features: Optional[jnp.ndarray] = None,
|
505 |
-
output_attentions: Optional[bool] = None,
|
506 |
-
output_hidden_states: Optional[bool] = None,
|
507 |
-
output_features: Optional[bool] = None,
|
508 |
-
return_dict: Optional[bool] = None,
|
509 |
-
train: bool = False,
|
510 |
-
freeze_feature_encoder: bool = False,
|
511 |
-
params: dict = None,
|
512 |
-
dropout_rng: PRNGKey = None,
|
513 |
-
):
|
514 |
-
r"""
|
515 |
-
Returns:
|
516 |
-
|
517 |
-
Example:
|
518 |
-
|
519 |
-
```python
|
520 |
-
>>> from transformers import FlaxSpeechEncoderDecoderModel
|
521 |
-
|
522 |
-
>>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized
|
523 |
-
>>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
|
524 |
-
... "facebook/wav2vec2-large-lv60", "facebook/bart-large"
|
525 |
-
... )
|
526 |
-
|
527 |
-
>>> inputs = jnp.ones((2, 5000), dtype=jnp.float32)
|
528 |
-
>>> encoder_outputs = model.encode(inputs)
|
529 |
-
```"""
|
530 |
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
531 |
-
output_hidden_states = (
|
532 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
533 |
-
)
|
534 |
-
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
535 |
-
|
536 |
-
if attention_mask is None:
|
537 |
-
attention_mask = jnp.ones_like(inputs, dtype="i4")
|
538 |
-
|
539 |
-
if extract_features is not None:
|
540 |
-
extract_features = jnp.array(extract_features, dtype="f4")
|
541 |
-
|
542 |
-
# Handle any PRNG if needed
|
543 |
-
rngs = {}
|
544 |
-
if dropout_rng is not None:
|
545 |
-
rngs["dropout"] = dropout_rng
|
546 |
-
|
547 |
-
def _encoder_forward(module, inputs, attention_mask, **kwargs):
|
548 |
-
encode_module = module._get_encoder_module()
|
549 |
-
return encode_module(inputs, attention_mask, **kwargs)
|
550 |
-
|
551 |
-
outputs = self.module.apply(
|
552 |
-
{"params": params or self.params},
|
553 |
-
inputs=jnp.array(inputs, dtype="f4"),
|
554 |
-
attention_mask=jnp.array(attention_mask, dtype="i4"),
|
555 |
-
extract_features=extract_features,
|
556 |
-
output_attentions=output_attentions,
|
557 |
-
output_hidden_states=output_hidden_states,
|
558 |
-
output_features=output_features,
|
559 |
-
return_dict=return_dict,
|
560 |
-
deterministic=not train,
|
561 |
-
freeze_feature_encoder=freeze_feature_encoder,
|
562 |
-
rngs=rngs,
|
563 |
-
method=_encoder_forward,
|
564 |
-
)
|
565 |
-
|
566 |
-
if return_dict and not output_features:
|
567 |
-
outputs = FlaxBaseModelOutput(
|
568 |
-
last_hidden_state=outputs.last_hidden_state,
|
569 |
-
hidden_states=outputs.hidden_states,
|
570 |
-
attentions=outputs.attentions,
|
571 |
-
)
|
572 |
-
|
573 |
-
return outputs
|
574 |
-
|
575 |
-
@add_start_docstrings(SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING)
|
576 |
-
@replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
577 |
-
def decode(
|
578 |
-
self,
|
579 |
-
decoder_input_ids,
|
580 |
-
encoder_outputs,
|
581 |
-
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
582 |
-
decoder_attention_mask: Optional[jnp.ndarray] = None,
|
583 |
-
decoder_position_ids: Optional[jnp.ndarray] = None,
|
584 |
-
past_key_values: dict = None,
|
585 |
-
output_attentions: Optional[bool] = None,
|
586 |
-
output_hidden_states: Optional[bool] = None,
|
587 |
-
return_dict: Optional[bool] = None,
|
588 |
-
train: bool = False,
|
589 |
-
params: dict = None,
|
590 |
-
dropout_rng: PRNGKey = None,
|
591 |
-
):
|
592 |
-
r"""
|
593 |
-
Returns:
|
594 |
-
|
595 |
-
Example:
|
596 |
-
|
597 |
-
```python
|
598 |
-
>>> from transformers import FlaxSpeechEncoderDecoderModel
|
599 |
-
>>> import jax.numpy as jnp
|
600 |
-
|
601 |
-
>>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized
|
602 |
-
>>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
|
603 |
-
... "facebook/wav2vec2-large-lv60", "facebook/bart-large"
|
604 |
-
... )
|
605 |
-
|
606 |
-
>>> inputs = jnp.ones((2, 5000), dtype=jnp.float32)
|
607 |
-
>>> encoder_outputs = model.encode(inputs)
|
608 |
-
|
609 |
-
>>> decoder_start_token_id = model.config.decoder.bos_token_id
|
610 |
-
>>> decoder_input_ids = jnp.ones((inputs.shape[0], 1), dtype="i4") * decoder_start_token_id
|
611 |
-
|
612 |
-
>>> outputs = model.decode(decoder_input_ids, encoder_outputs)
|
613 |
-
>>> logits = outputs.logits
|
614 |
-
```"""
|
615 |
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
616 |
-
output_hidden_states = (
|
617 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
618 |
-
)
|
619 |
-
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
620 |
-
|
621 |
-
encoder_hidden_states = encoder_outputs[0]
|
622 |
-
if encoder_attention_mask is None:
|
623 |
-
batch_size, sequence_length = encoder_hidden_states.shape[:2]
|
624 |
-
encoder_attention_mask = jnp.ones((batch_size, sequence_length))
|
625 |
-
|
626 |
-
batch_size, sequence_length = decoder_input_ids.shape
|
627 |
-
if decoder_attention_mask is None:
|
628 |
-
decoder_attention_mask = jnp.ones((batch_size, sequence_length))
|
629 |
-
|
630 |
-
if decoder_position_ids is None:
|
631 |
-
if past_key_values is not None:
|
632 |
-
raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
|
633 |
-
|
634 |
-
decoder_position_ids = jnp.broadcast_to(
|
635 |
-
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
|
636 |
-
)
|
637 |
-
|
638 |
-
# Handle any PRNG if needed
|
639 |
-
rngs = {}
|
640 |
-
if dropout_rng is not None:
|
641 |
-
rngs["dropout"] = dropout_rng
|
642 |
-
|
643 |
-
params = {"params": params or self.params}
|
644 |
-
|
645 |
-
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be
|
646 |
-
# passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
|
647 |
-
# it can be changed by FlaxBartAttention module
|
648 |
-
if past_key_values:
|
649 |
-
params["cache"] = past_key_values
|
650 |
-
mutable = ["cache"]
|
651 |
-
else:
|
652 |
-
mutable = False
|
653 |
-
|
654 |
-
def _decoder_forward(
|
655 |
-
module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, encoder_hidden_states, **kwargs
|
656 |
-
):
|
657 |
-
|
658 |
-
projection_module = module._get_projection_module()
|
659 |
-
decoder_module = module._get_decoder_module()
|
660 |
-
|
661 |
-
# optionally project encoder_hidden_states
|
662 |
-
if projection_module is not None:
|
663 |
-
encoder_hidden_states = projection_module(encoder_hidden_states)
|
664 |
-
|
665 |
-
return decoder_module(
|
666 |
-
decoder_input_ids,
|
667 |
-
decoder_attention_mask,
|
668 |
-
decoder_position_ids,
|
669 |
-
encoder_hidden_states,
|
670 |
-
**kwargs,
|
671 |
-
)
|
672 |
-
|
673 |
-
outputs = self.module.apply(
|
674 |
-
params,
|
675 |
-
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
|
676 |
-
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
|
677 |
-
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
|
678 |
-
encoder_hidden_states=encoder_hidden_states,
|
679 |
-
encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
|
680 |
-
output_attentions=output_attentions,
|
681 |
-
output_hidden_states=output_hidden_states,
|
682 |
-
return_dict=return_dict,
|
683 |
-
deterministic=not train,
|
684 |
-
rngs=rngs,
|
685 |
-
mutable=mutable,
|
686 |
-
method=_decoder_forward,
|
687 |
-
)
|
688 |
-
|
689 |
-
# add updated cache to model output
|
690 |
-
if past_key_values is not None and return_dict:
|
691 |
-
outputs, past = outputs
|
692 |
-
outputs["past_key_values"] = unfreeze(past["cache"])
|
693 |
-
return outputs
|
694 |
-
elif past_key_values is not None and not return_dict:
|
695 |
-
outputs, past = outputs
|
696 |
-
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
|
697 |
-
|
698 |
-
return outputs
|
699 |
-
|
700 |
-
@add_start_docstrings_to_model_forward(SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING)
|
701 |
-
@replace_return_docstrings(output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
702 |
-
def __call__(
|
703 |
-
self,
|
704 |
-
inputs: jnp.ndarray,
|
705 |
-
attention_mask: Optional[jnp.ndarray] = None,
|
706 |
-
extract_features: Optional[jnp.ndarray] = None,
|
707 |
-
decoder_input_ids: Optional[jnp.ndarray] = None,
|
708 |
-
decoder_attention_mask: Optional[jnp.ndarray] = None,
|
709 |
-
decoder_position_ids: Optional[jnp.ndarray] = None,
|
710 |
-
output_attentions: Optional[bool] = None,
|
711 |
-
output_hidden_states: Optional[bool] = None,
|
712 |
-
output_features: Optional[bool] = None,
|
713 |
-
return_dict: Optional[bool] = None,
|
714 |
-
train: bool = False,
|
715 |
-
freeze_feature_encoder: bool = False,
|
716 |
-
params: dict = None,
|
717 |
-
dropout_rng: PRNGKey = None,
|
718 |
-
):
|
719 |
-
r"""
|
720 |
-
Returns:
|
721 |
-
|
722 |
-
Examples:
|
723 |
-
|
724 |
-
```python
|
725 |
-
>>> from transformers import FlaxSpeechEncoderDecoderModel, BartTokenizer
|
726 |
-
|
727 |
-
>>> # load a fine-tuned wav2vec2-2-bart model
|
728 |
-
>>> model = FlaxSpeechEncoderDecoderModel.from_pretrained("patrickvonplaten/wav2vec2-2-bart-large")
|
729 |
-
>>> # load output tokenizer
|
730 |
-
>>> tokenizer_output = BartTokenizer.from_pretrained("facebook/bart-large")
|
731 |
-
|
732 |
-
>>> inputs = jnp.ones((2, 5000), dtype=jnp.float32)
|
733 |
-
|
734 |
-
>>> # use bart's special bos, pad and eos tokens
|
735 |
-
>>> model.config.decoder_start_token_id = model.decoder.config.bos_token_id
|
736 |
-
>>> model.config.pad_token_id = model.decoder.config.pad_token_id
|
737 |
-
>>> model.config.eos_token_id = model.decoder.config.eos_token_id
|
738 |
-
|
739 |
-
>>> outputs = model.generate(inputs)
|
740 |
-
# Assert something? More interesting input? dtype correct?
|
741 |
-
```
|
742 |
-
"""
|
743 |
-
|
744 |
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
745 |
-
output_hidden_states = (
|
746 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
747 |
-
)
|
748 |
-
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
749 |
-
|
750 |
-
# prepare encoder inputs
|
751 |
-
if attention_mask is None:
|
752 |
-
attention_mask = jnp.ones_like(inputs, dtype="i4")
|
753 |
-
|
754 |
-
if extract_features is not None:
|
755 |
-
inputs = None # we can omit passing the inputs to the model to save memory
|
756 |
-
extract_features = jnp.array(extract_features, dtype="f4")
|
757 |
-
else:
|
758 |
-
inputs = jnp.array(inputs, dtype="f4")
|
759 |
-
|
760 |
-
# prepare decoder inputs
|
761 |
-
if decoder_input_ids is None:
|
762 |
-
raise ValueError(
|
763 |
-
"`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must be specified as an input argument."
|
764 |
-
)
|
765 |
-
if decoder_attention_mask is None:
|
766 |
-
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
|
767 |
-
if decoder_position_ids is None:
|
768 |
-
batch_size, sequence_length = decoder_input_ids.shape
|
769 |
-
decoder_position_ids = jnp.broadcast_to(
|
770 |
-
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
|
771 |
-
)
|
772 |
-
|
773 |
-
# Handle any PRNG if needed
|
774 |
-
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
|
775 |
-
|
776 |
-
return self.module.apply(
|
777 |
-
{"params": params or self.params},
|
778 |
-
inputs=inputs,
|
779 |
-
attention_mask=jnp.array(attention_mask, dtype="i4"),
|
780 |
-
extract_features=extract_features,
|
781 |
-
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
|
782 |
-
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
|
783 |
-
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
|
784 |
-
output_attentions=output_attentions,
|
785 |
-
output_hidden_states=output_hidden_states,
|
786 |
-
output_features=output_features,
|
787 |
-
return_dict=return_dict,
|
788 |
-
deterministic=not train,
|
789 |
-
freeze_feature_encoder=freeze_feature_encoder,
|
790 |
-
rngs=rngs,
|
791 |
-
)
|
792 |
-
|
793 |
-
def prepare_inputs_for_generation(
|
794 |
-
self,
|
795 |
-
decoder_input_ids,
|
796 |
-
max_length,
|
797 |
-
attention_mask: Optional[jnp.DeviceArray] = None,
|
798 |
-
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
799 |
-
encoder_outputs=None,
|
800 |
-
**kwargs
|
801 |
-
):
|
802 |
-
# initializing the cache
|
803 |
-
batch_size, seq_length = decoder_input_ids.shape
|
804 |
-
|
805 |
-
past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
|
806 |
-
# Note that usually one would have to put 0's in the attention_mask for x > input.shape[-1] and x < cache_length.
|
807 |
-
# But since the decoder uses a causal mask, those positions are masked anyways.
|
808 |
-
# Thus we can create a single static attention_mask here, which is more efficient for compilation
|
809 |
-
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
|
810 |
-
if decoder_attention_mask is not None:
|
811 |
-
decoder_position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
|
812 |
-
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
|
813 |
-
else:
|
814 |
-
decoder_position_ids = jnp.broadcast_to(
|
815 |
-
jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
|
816 |
-
)
|
817 |
-
|
818 |
-
return {
|
819 |
-
"past_key_values": past_key_values,
|
820 |
-
"encoder_outputs": encoder_outputs,
|
821 |
-
"encoder_attention_mask": attention_mask,
|
822 |
-
"decoder_attention_mask": extended_attention_mask,
|
823 |
-
"decoder_position_ids": decoder_position_ids,
|
824 |
-
}
|
825 |
-
|
826 |
-
def update_inputs_for_generation(self, model_outputs, model_kwargs):
|
827 |
-
model_kwargs["past_key_values"] = model_outputs.past_key_values
|
828 |
-
model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
|
829 |
-
return model_kwargs
|
830 |
-
|
831 |
-
@classmethod
|
832 |
-
def from_encoder_decoder_pretrained(
|
833 |
-
cls,
|
834 |
-
encoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
|
835 |
-
decoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
|
836 |
-
*model_args,
|
837 |
-
**kwargs
|
838 |
-
) -> FlaxPreTrainedModel:
|
839 |
-
r"""
|
840 |
-
Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model
|
841 |
-
checkpoints.
|
842 |
-
|
843 |
-
Params:
|
844 |
-
encoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*):
|
845 |
-
Information necessary to initiate the encoder. Can be either:
|
846 |
-
|
847 |
-
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
848 |
-
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
|
849 |
-
user or organization name, like `dbmdz/bert-base-german-cased`.
|
850 |
-
- A path to a *directory* containing model weights saved using
|
851 |
-
[`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
|
852 |
-
|
853 |
-
decoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*, defaults to `None`):
|
854 |
-
Information necessary to initiate the decoder. Can be either:
|
855 |
-
|
856 |
-
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
857 |
-
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
|
858 |
-
user or organization name, like `dbmdz/bert-base-german-cased`.
|
859 |
-
- A path to a *directory* containing model weights saved using
|
860 |
-
[`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
|
861 |
-
|
862 |
-
model_args (remaining positional arguments, *optional*):
|
863 |
-
All remaning positional arguments will be passed to the underlying model's `__init__` method.
|
864 |
-
|
865 |
-
kwargs (remaining dictionary of keyword arguments, *optional*):
|
866 |
-
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
867 |
-
`output_attentions=True`).
|
868 |
-
|
869 |
-
- To update the encoder configuration, use the prefix *encoder_* for each configuration parameter.
|
870 |
-
- To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
|
871 |
-
- To update the parent model configuration, do not use a prefix for each configuration parameter.
|
872 |
-
|
873 |
-
Behaves differently depending on whether a `config` is provided or automatically loaded.
|
874 |
-
|
875 |
-
Example:
|
876 |
-
|
877 |
-
```python
|
878 |
-
>>> from transformers import FlaxSpeechEncoderDecoderModel
|
879 |
-
|
880 |
-
>>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized
|
881 |
-
>>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
|
882 |
-
... "facebook/wav2vec2-large-lv60", "facebook/bart-large"
|
883 |
-
... )
|
884 |
-
>>> # saving model after fine-tuning
|
885 |
-
>>> model.save_pretrained("./wav2vec2-2-bart-large")
|
886 |
-
>>> # load fine-tuned model
|
887 |
-
>>> model = FlaxSpeechEncoderDecoderModel.from_pretrained("./wav2vec2-2-bart-large")
|
888 |
-
```"""
|
889 |
-
|
890 |
-
kwargs_encoder = {
|
891 |
-
argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
|
892 |
-
}
|
893 |
-
|
894 |
-
kwargs_decoder = {
|
895 |
-
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
|
896 |
-
}
|
897 |
-
|
898 |
-
# remove encoder, decoder kwargs from kwargs
|
899 |
-
for key in kwargs_encoder.keys():
|
900 |
-
del kwargs["encoder_" + key]
|
901 |
-
for key in kwargs_decoder.keys():
|
902 |
-
del kwargs["decoder_" + key]
|
903 |
-
|
904 |
-
# Load and initialize the encoder and decoder
|
905 |
-
# The distinction between encoder and decoder at the model level is made
|
906 |
-
# by the value of the flag `is_decoder` that we need to set correctly.
|
907 |
-
encoder = kwargs_encoder.pop("model", None)
|
908 |
-
if encoder is None:
|
909 |
-
if encoder_pretrained_model_name_or_path is None:
|
910 |
-
raise ValueError(
|
911 |
-
"If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
|
912 |
-
"to be defined."
|
913 |
-
)
|
914 |
-
|
915 |
-
if "config" not in kwargs_encoder:
|
916 |
-
# TODO: AutoConfig .from_pretrained
|
917 |
-
encoder_config, kwargs_encoder = Wav2Vec2Config.from_pretrained(
|
918 |
-
encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True
|
919 |
-
)
|
920 |
-
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
|
921 |
-
logger.info(
|
922 |
-
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
|
923 |
-
"from a decoder model. Cross-attention and casual mask are disabled."
|
924 |
-
)
|
925 |
-
encoder_config.is_decoder = False
|
926 |
-
encoder_config.add_cross_attention = False
|
927 |
-
|
928 |
-
kwargs_encoder["config"] = encoder_config
|
929 |
-
|
930 |
-
# TODO: FlaxAutoModel .from_pretrained
|
931 |
-
encoder = FlaxWav2Vec2Model.from_pretrained(
|
932 |
-
encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder
|
933 |
-
)
|
934 |
-
|
935 |
-
decoder = kwargs_decoder.pop("model", None)
|
936 |
-
if decoder is None:
|
937 |
-
if decoder_pretrained_model_name_or_path is None:
|
938 |
-
raise ValueError(
|
939 |
-
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
|
940 |
-
"to be defined."
|
941 |
-
)
|
942 |
-
|
943 |
-
if "config" not in kwargs_decoder:
|
944 |
-
# TODO: AutoConfig .from_pretrained
|
945 |
-
decoder_config, kwargs_decoder = BartConfig.from_pretrained(
|
946 |
-
decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
|
947 |
-
)
|
948 |
-
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
|
949 |
-
logger.info(
|
950 |
-
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
|
951 |
-
f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
|
952 |
-
f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
|
953 |
-
"cross attention layers."
|
954 |
-
)
|
955 |
-
decoder_config.is_decoder = True
|
956 |
-
decoder_config.add_cross_attention = True
|
957 |
-
|
958 |
-
kwargs_decoder["config"] = decoder_config
|
959 |
-
|
960 |
-
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
|
961 |
-
logger.warning(
|
962 |
-
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
|
963 |
-
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
|
964 |
-
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
|
965 |
-
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
|
966 |
-
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
967 |
-
)
|
968 |
-
|
969 |
-
# TODO: FlaxAutoModelForCausalLM .from_pretrained
|
970 |
-
decoder = FlaxBartForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
971 |
-
|
972 |
-
# instantiate config with corresponding kwargs
|
973 |
-
dtype = kwargs.pop("dtype", jnp.float32)
|
974 |
-
config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
|
975 |
-
|
976 |
-
# make sure input & output word embeddings are not tied
|
977 |
-
config.tie_word_embeddings = False
|
978 |
-
|
979 |
-
# init model
|
980 |
-
model = cls(config, dtype=dtype)
|
981 |
-
model.params["encoder"] = encoder.params
|
982 |
-
model.params["decoder"] = decoder.params
|
983 |
-
|
984 |
-
return model
|
985 |
-
|
986 |
-
def _beam_search(
|
987 |
-
self,
|
988 |
-
input_ids: None,
|
989 |
-
max_length: Optional[int] = None,
|
990 |
-
pad_token_id: Optional[int] = None,
|
991 |
-
eos_token_id: Optional[int] = None,
|
992 |
-
length_penalty: Optional[float] = None,
|
993 |
-
early_stopping: Optional[bool] = None,
|
994 |
-
logits_processor: Optional[FlaxLogitsProcessorList] = None,
|
995 |
-
trace: bool = True,
|
996 |
-
params: Optional[Dict[str, jnp.ndarray]] = None,
|
997 |
-
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
|
998 |
-
):
|
999 |
-
"""
|
1000 |
-
This beam search function is heavily inspired by Flax's official example:
|
1001 |
-
https://github.com/google/flax/blob/master/examples/wmt/train.py#L254
|
1002 |
-
"""
|
1003 |
-
|
1004 |
-
def flatten_beam_dim(tensor):
|
1005 |
-
"""Flattens the first two dimensions of a non-scalar array."""
|
1006 |
-
# ignore scalars (e.g. cache index)
|
1007 |
-
if tensor.ndim == 0 or tensor.ndim == 1:
|
1008 |
-
return tensor
|
1009 |
-
elif tensor.ndim == 6:
|
1010 |
-
return tensor.reshape(tensor.shape[:1] + (tensor.shape[1] * tensor.shape[2],) + tensor.shape[3:])
|
1011 |
-
return tensor.reshape((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:])
|
1012 |
-
|
1013 |
-
def unflatten_beam_dim(tensor, batch_size, num_beams):
|
1014 |
-
"""Unflattens the first, flat batch*beam dimension of a non-scalar array."""
|
1015 |
-
# ignore scalars (e.g. cache index)
|
1016 |
-
if tensor.ndim == 0 or tensor.ndim == 1:
|
1017 |
-
return tensor
|
1018 |
-
if tensor.ndim == 5:
|
1019 |
-
return tensor.reshape(tensor.shape[:1] + (batch_size, num_beams) + tensor.shape[2:])
|
1020 |
-
return tensor.reshape((batch_size, num_beams) + tensor.shape[1:])
|
1021 |
-
|
1022 |
-
def gather_beams(nested, beam_indices, batch_size, new_num_beams):
|
1023 |
-
"""
|
1024 |
-
Gathers the beam slices indexed by beam_indices into new beam array.
|
1025 |
-
"""
|
1026 |
-
batch_indices = jnp.reshape(
|
1027 |
-
jnp.arange(batch_size * new_num_beams) // new_num_beams, (batch_size, new_num_beams)
|
1028 |
-
)
|
1029 |
-
|
1030 |
-
def gather_fn(tensor):
|
1031 |
-
# ignore scalars (e.g. cache index)
|
1032 |
-
if tensor.ndim == 0 or tensor.ndim == 1:
|
1033 |
-
return tensor
|
1034 |
-
if tensor.ndim == 6:
|
1035 |
-
return tensor[:, batch_indices, beam_indices]
|
1036 |
-
return tensor[batch_indices, beam_indices]
|
1037 |
-
|
1038 |
-
return jax.tree_map(gather_fn, nested)
|
1039 |
-
|
1040 |
-
# init values
|
1041 |
-
max_length = max_length if max_length is not None else self.config.max_length
|
1042 |
-
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
1043 |
-
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
1044 |
-
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
1045 |
-
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
|
1046 |
-
|
1047 |
-
batch_size, num_beams, cur_len = input_ids.shape
|
1048 |
-
|
1049 |
-
eos_token_id = jnp.array(eos_token_id)
|
1050 |
-
pad_token_id = jnp.array(pad_token_id)
|
1051 |
-
cur_len = jnp.array(cur_len)
|
1052 |
-
|
1053 |
-
# per batch,beam-item holding current token in loop.
|
1054 |
-
sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
|
1055 |
-
running_sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
|
1056 |
-
running_sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0, 0))
|
1057 |
-
|
1058 |
-
# per batch,beam-item state bit indicating if sentence has finished.
|
1059 |
-
is_sent_finished = jnp.zeros((batch_size, num_beams), dtype=jnp.bool_)
|
1060 |
-
|
1061 |
-
# per batch,beam-item score, logprobs
|
1062 |
-
running_scores = jnp.tile(jnp.array([0.0] + [np.array(-1.0e7)] * (num_beams - 1)), [batch_size, 1])
|
1063 |
-
scores = jnp.ones((batch_size, num_beams)) * np.array(-1.0e7)
|
1064 |
-
|
1065 |
-
# For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
|
1066 |
-
# and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
|
1067 |
-
model = self.decode if self.config.is_encoder_decoder else self
|
1068 |
-
|
1069 |
-
# flatten beam dim
|
1070 |
-
if "encoder_outputs" in model_kwargs:
|
1071 |
-
model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim(
|
1072 |
-
model_kwargs["encoder_outputs"]["last_hidden_state"]
|
1073 |
-
)
|
1074 |
-
if "attention_mask" in model_kwargs:
|
1075 |
-
model_kwargs["attention_mask"] = flatten_beam_dim(model_kwargs["attention_mask"])
|
1076 |
-
|
1077 |
-
# initialize model specific kwargs
|
1078 |
-
model_kwargs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), max_length, **model_kwargs)
|
1079 |
-
|
1080 |
-
# initialize state
|
1081 |
-
state = BeamSearchState(
|
1082 |
-
cur_len=cur_len,
|
1083 |
-
running_sequences=running_sequences,
|
1084 |
-
running_scores=running_scores,
|
1085 |
-
sequences=sequences,
|
1086 |
-
scores=scores,
|
1087 |
-
is_sent_finished=is_sent_finished,
|
1088 |
-
model_kwargs=model_kwargs,
|
1089 |
-
)
|
1090 |
-
|
1091 |
-
def beam_search_cond_fn(state):
|
1092 |
-
"""beam search state termination condition fn."""
|
1093 |
-
|
1094 |
-
# 1. is less than max length?
|
1095 |
-
not_max_length_yet = state.cur_len < max_length
|
1096 |
-
|
1097 |
-
# 2. can the new beams still improve?
|
1098 |
-
best_running_score = state.running_scores[:, -1:] / (max_length**length_penalty)
|
1099 |
-
worst_finished_score = jnp.where(
|
1100 |
-
state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7)
|
1101 |
-
)
|
1102 |
-
improvement_still_possible = jnp.all(worst_finished_score < best_running_score)
|
1103 |
-
|
1104 |
-
# 3. is there still a beam that has not finished?
|
1105 |
-
still_open_beam = ~(jnp.all(state.is_sent_finished) & early_stopping)
|
1106 |
-
|
1107 |
-
return not_max_length_yet & still_open_beam & improvement_still_possible
|
1108 |
-
|
1109 |
-
def beam_search_body_fn(state, input_ids_length=1):
|
1110 |
-
"""beam search state update fn."""
|
1111 |
-
# 1. Forward current tokens
|
1112 |
-
# Collect the current position slice along length to feed the fast
|
1113 |
-
# autoregressive decoder model. Flatten the beam dimension into batch
|
1114 |
-
# dimension for feeding into the model.
|
1115 |
-
# unflatten beam dimension
|
1116 |
-
# Unflatten beam dimension in attention cache arrays
|
1117 |
-
input_token = flatten_beam_dim(
|
1118 |
-
lax.dynamic_slice(
|
1119 |
-
state.running_sequences,
|
1120 |
-
(0, 0, state.cur_len - input_ids_length),
|
1121 |
-
(batch_size, num_beams, input_ids_length),
|
1122 |
-
)
|
1123 |
-
)
|
1124 |
-
model_outputs = model(input_token, params=params, **state.model_kwargs)
|
1125 |
-
|
1126 |
-
logits = unflatten_beam_dim(model_outputs.logits[:, -1], batch_size, num_beams)
|
1127 |
-
cache = jax.tree_map(
|
1128 |
-
lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams), model_outputs.past_key_values
|
1129 |
-
)
|
1130 |
-
|
1131 |
-
# adapt logits for FlaxMarianMTModel
|
1132 |
-
logits = self._adapt_logits_for_beam_search(logits)
|
1133 |
-
|
1134 |
-
# 2. Compute log probs
|
1135 |
-
# get log probabilities from logits,
|
1136 |
-
# process logits with processors (*e.g.* min_length, ...), and
|
1137 |
-
# add new logprobs to existing running logprobs scores.
|
1138 |
-
log_probs = jax.nn.log_softmax(logits)
|
1139 |
-
log_probs = logits_processor(
|
1140 |
-
flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), state.cur_len
|
1141 |
-
)
|
1142 |
-
log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
|
1143 |
-
log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2)
|
1144 |
-
vocab_size = log_probs.shape[2]
|
1145 |
-
log_probs = log_probs.reshape((batch_size, num_beams * vocab_size))
|
1146 |
-
|
1147 |
-
# 3. Retrieve top-K
|
1148 |
-
# Each item in batch has num_beams * vocab_size candidate sequences.
|
1149 |
-
# For each item, get the top 2*k candidates with the highest log-
|
1150 |
-
# probabilities. We gather the top 2*K beams here so that even if the best
|
1151 |
-
# K sequences reach EOS simultaneously, we have another K sequences
|
1152 |
-
# remaining to continue the live beam search.
|
1153 |
-
# Gather the top 2*K scores from _all_ beams.
|
1154 |
-
# Gather 2*k top beams.
|
1155 |
-
# Recover the beam index by floor division.
|
1156 |
-
# Recover token id by modulo division and expand Id array for broadcasting.
|
1157 |
-
# Update sequences for the 2*K top-k new sequences.
|
1158 |
-
beams_to_keep = 2 * num_beams
|
1159 |
-
topk_log_probs, topk_indices = lax.top_k(log_probs, k=beams_to_keep)
|
1160 |
-
topk_beam_indices = topk_indices // vocab_size
|
1161 |
-
topk_running_sequences = gather_beams(
|
1162 |
-
state.running_sequences, topk_beam_indices, batch_size, beams_to_keep
|
1163 |
-
)
|
1164 |
-
topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
|
1165 |
-
topk_sequences = lax.dynamic_update_slice(topk_running_sequences, topk_ids, (0, 0, state.cur_len))
|
1166 |
-
|
1167 |
-
# 4. Check which sequences have ended
|
1168 |
-
# Update current sequences:
|
1169 |
-
# Did any of these sequences reach an end marker?
|
1170 |
-
# To prevent these just finished sequences from being added to the current sequences
|
1171 |
-
# set of active beam search sequences, set their log probs to a very large
|
1172 |
-
# negative value.
|
1173 |
-
did_topk_just_finished = topk_sequences[:, :, state.cur_len] == eos_token_id
|
1174 |
-
running_topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(-1.0e7)
|
1175 |
-
# 5. Get running sequences scores for next
|
1176 |
-
# Determine the top k beam indices (from top 2*k beams) from log probs
|
1177 |
-
# and gather top k beams (from top 2*k beams).
|
1178 |
-
next_topk_indices = jnp.flip(lax.top_k(running_topk_log_probs, k=num_beams)[1], axis=1)
|
1179 |
-
next_running_sequences, next_running_scores = gather_beams(
|
1180 |
-
[topk_sequences, running_topk_log_probs], next_topk_indices, batch_size, num_beams
|
1181 |
-
)
|
1182 |
-
|
1183 |
-
# 6. Process topk logits
|
1184 |
-
# Further process log probs:
|
1185 |
-
# - add length penalty
|
1186 |
-
# - make sure no scores can be added anymore if beam is full
|
1187 |
-
# - make sure still running sequences cannot be chosen as finalized beam
|
1188 |
-
topk_log_probs = topk_log_probs / (state.cur_len**length_penalty)
|
1189 |
-
beams_in_batch_are_full = (
|
1190 |
-
jnp.broadcast_to(state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape)
|
1191 |
-
& early_stopping
|
1192 |
-
)
|
1193 |
-
add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
|
1194 |
-
topk_log_probs += add_penalty * np.array(-1.0e7)
|
1195 |
-
|
1196 |
-
# 7. Get scores, sequences, is sentence finished for next.
|
1197 |
-
# Combine sequences, scores, and flags along the beam dimension and compare
|
1198 |
-
# new finished sequence scores to existing finished scores and select the
|
1199 |
-
# best from the new set of beams
|
1200 |
-
merged_sequences = jnp.concatenate([state.sequences, topk_sequences], axis=1)
|
1201 |
-
merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1)
|
1202 |
-
merged_is_sent_finished = jnp.concatenate([state.is_sent_finished, did_topk_just_finished], axis=1)
|
1203 |
-
topk_merged_indices = jnp.flip(lax.top_k(merged_scores, k=num_beams)[1], axis=1)
|
1204 |
-
next_sequences, next_scores, next_is_sent_finished = gather_beams(
|
1205 |
-
[merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams
|
1206 |
-
)
|
1207 |
-
|
1208 |
-
# 8. Update model kwargs.
|
1209 |
-
# Determine the top k beam indices from the original set of all beams.
|
1210 |
-
# With these, gather the top k beam-associated caches.
|
1211 |
-
next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams)
|
1212 |
-
next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams)
|
1213 |
-
model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache)
|
1214 |
-
next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
|
1215 |
-
|
1216 |
-
return BeamSearchState(
|
1217 |
-
cur_len=state.cur_len + 1,
|
1218 |
-
running_scores=next_running_scores,
|
1219 |
-
running_sequences=next_running_sequences,
|
1220 |
-
scores=next_scores,
|
1221 |
-
sequences=next_sequences,
|
1222 |
-
is_sent_finished=next_is_sent_finished,
|
1223 |
-
model_kwargs=next_model_kwargs,
|
1224 |
-
)
|
1225 |
-
|
1226 |
-
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
|
1227 |
-
if input_ids.shape[-1] > 1:
|
1228 |
-
state = partial(beam_search_body_fn, input_ids_length=input_ids.shape[-1])(state)
|
1229 |
-
|
1230 |
-
if not trace:
|
1231 |
-
state = self._run_loop_in_debug(beam_search_cond_fn, beam_search_body_fn, state)
|
1232 |
-
else:
|
1233 |
-
state = lax.while_loop(beam_search_cond_fn, beam_search_body_fn, state)
|
1234 |
-
|
1235 |
-
# Account for the edge-case where there are no finished sequences for a
|
1236 |
-
# particular batch item. If so, return running sequences for that batch item.
|
1237 |
-
none_finished = jnp.any(state.is_sent_finished, axis=1)
|
1238 |
-
sequences = jnp.where(none_finished[:, None, None], state.sequences, state.running_sequences)
|
1239 |
-
scores = jnp.where(none_finished[:, None], state.scores, state.running_scores)
|
1240 |
-
|
1241 |
-
# return all beams for each batch and the best score
|
1242 |
-
sequences = sequences[:, :]
|
1243 |
-
scores = scores[:, -1]
|
1244 |
-
|
1245 |
-
return FlaxBeamSearchOutput(sequences=sequences, scores=scores)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/modeling_flax_wav2vec2.py
DELETED
@@ -1,975 +0,0 @@
|
|
1 |
-
# coding=utf-8
|
2 |
-
# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
|
3 |
-
#
|
4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
# you may not use this file except in compliance with the License.
|
6 |
-
# You may obtain a copy of the License at
|
7 |
-
#
|
8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
#
|
10 |
-
# Unless required by applicable law or agreed to in writing, software
|
11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
# See the License for the specific language governing permissions and
|
14 |
-
# limitations under the License.
|
15 |
-
""" Flax Wav2Vec2 model."""
|
16 |
-
|
17 |
-
from functools import partial
|
18 |
-
from typing import Optional, Tuple, Union
|
19 |
-
|
20 |
-
import flax
|
21 |
-
import flax.linen as nn
|
22 |
-
import jax
|
23 |
-
import jax.numpy as jnp
|
24 |
-
from flax.core.frozen_dict import FrozenDict
|
25 |
-
from flax.linen import partitioning as nn_partitioning
|
26 |
-
from flax.linen.attention import dot_product_attention_weights
|
27 |
-
from jax import lax
|
28 |
-
|
29 |
-
from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
|
30 |
-
from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
|
31 |
-
from transformers.utils import ModelOutput
|
32 |
-
|
33 |
-
from models import Wav2Vec2Config
|
34 |
-
|
35 |
-
scan_with_axes = nn_partitioning.scan_with_axes
|
36 |
-
remat = nn_partitioning.remat
|
37 |
-
|
38 |
-
|
39 |
-
@flax.struct.dataclass
|
40 |
-
class FlaxWav2Vec2BaseModelOutput(ModelOutput):
|
41 |
-
"""
|
42 |
-
Output type of [`FlaxWav2Vec2BaseModelOutput`], with potential hidden states and attentions.
|
43 |
-
|
44 |
-
Args:
|
45 |
-
last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
|
46 |
-
Sequence of hidden-states at the output of the last layer of the model.
|
47 |
-
extract_features (`jnp.ndarray` of shape `(batch_size, sequence_length, last_conv_dim)`):
|
48 |
-
Sequence of extracted feature vectors of the last convolutional layer of the model with `last_conv_dim`
|
49 |
-
being the dimension of the last convolutional layer.
|
50 |
-
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
51 |
-
Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
|
52 |
-
`(batch_size, sequence_length, hidden_size)`.
|
53 |
-
|
54 |
-
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
55 |
-
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
56 |
-
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
57 |
-
sequence_length)`.
|
58 |
-
|
59 |
-
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
60 |
-
heads.
|
61 |
-
"""
|
62 |
-
|
63 |
-
last_hidden_state: jnp.ndarray = None
|
64 |
-
extract_features: jnp.ndarray = None
|
65 |
-
hidden_states: Optional[Tuple[jnp.ndarray]] = None
|
66 |
-
attentions: Optional[Tuple[jnp.ndarray]] = None
|
67 |
-
|
68 |
-
|
69 |
-
WAV_2_VEC_2_START_DOCSTRING = r"""
|
70 |
-
Wav2Vec2 was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
|
71 |
-
Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
|
72 |
-
Auli.
|
73 |
-
|
74 |
-
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
75 |
-
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
76 |
-
etc.)
|
77 |
-
|
78 |
-
This model is also a Flax Linen
|
79 |
-
[flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
|
80 |
-
regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
|
81 |
-
|
82 |
-
Finally, this model supports inherent JAX features such as:
|
83 |
-
|
84 |
-
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
85 |
-
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
86 |
-
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
87 |
-
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
|
88 |
-
|
89 |
-
Parameters:
|
90 |
-
config ([`Wav2Vec2Config`]): Model configuration class with all the parameters of the model.
|
91 |
-
Initializing with a config file does not load the weights associated with the model, only the
|
92 |
-
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
|
93 |
-
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
|
94 |
-
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
|
95 |
-
`jax.numpy.bfloat16` (on TPUs).
|
96 |
-
|
97 |
-
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
98 |
-
specified all the computation will be performed with the given `dtype`.
|
99 |
-
|
100 |
-
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
101 |
-
parameters.**
|
102 |
-
|
103 |
-
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
|
104 |
-
[`~FlaxPreTrainedModel.to_bf16`].
|
105 |
-
"""
|
106 |
-
|
107 |
-
|
108 |
-
WAV_2_VEC_2_INPUTS_DOCSTRING = r"""
|
109 |
-
Args:
|
110 |
-
input_values (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
|
111 |
-
Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file
|
112 |
-
into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile library (*pip install
|
113 |
-
soundfile*). To prepare the array into *input_values*, the [`Wav2Vec2Processor`] should be used for padding
|
114 |
-
and conversion into a tensor of type *jnp.ndarray*. See [`Wav2Vec2Processor.__call__`] for details.
|
115 |
-
attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
116 |
-
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
|
117 |
-
1]`:
|
118 |
-
|
119 |
-
- 1 for tokens that are **not masked**,
|
120 |
-
- 0 for tokens that are **masked**.
|
121 |
-
|
122 |
-
[What are attention masks?](../glossary#attention-mask) .. warning:: `attention_mask` should only be passed
|
123 |
-
if the corresponding processor has `config.return_attention_mask == True`. For all models whose processor
|
124 |
-
has `config.return_attention_mask == False`, such as
|
125 |
-
[wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base-960h), `attention_mask` should **not** be
|
126 |
-
passed to avoid degraded performance when doing batched inference. For such models `input_values` should
|
127 |
-
simply be padded with 0 and passed without `attention_mask`. Be aware that these models also yield slightly
|
128 |
-
different results depending on whether `input_values` is padded or not.
|
129 |
-
mask_time_indices (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
130 |
-
Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
|
131 |
-
masked extracted features in *config.proj_codevector_dim* space.
|
132 |
-
output_attentions (`bool`, *optional*):
|
133 |
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
134 |
-
tensors for more detail.
|
135 |
-
output_hidden_states (`bool`, *optional*):
|
136 |
-
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
137 |
-
more detail.
|
138 |
-
return_dict (`bool`, *optional*):
|
139 |
-
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
140 |
-
"""
|
141 |
-
|
142 |
-
|
143 |
-
class FlaxWav2Vec2LayerNormConvLayer(nn.Module):
|
144 |
-
config: Wav2Vec2Config
|
145 |
-
layer_id: int = 0
|
146 |
-
dtype: jnp.dtype = jnp.float32
|
147 |
-
|
148 |
-
def setup(self):
|
149 |
-
self.in_conv_dim = self.config.conv_dim[self.layer_id] if self.layer_id > 0 else 1
|
150 |
-
self.out_conv_dim = self.config.conv_dim[self.layer_id]
|
151 |
-
|
152 |
-
self.conv = nn.Conv(
|
153 |
-
features=self.config.conv_dim[self.layer_id],
|
154 |
-
kernel_size=(self.config.conv_kernel[self.layer_id],),
|
155 |
-
strides=(self.config.conv_stride[self.layer_id],),
|
156 |
-
use_bias=self.config.conv_bias,
|
157 |
-
kernel_init=jax.nn.initializers.he_normal(),
|
158 |
-
padding="VALID",
|
159 |
-
dtype=self.dtype,
|
160 |
-
)
|
161 |
-
self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
162 |
-
self.activation = ACT2FN[self.config.feat_extract_activation]
|
163 |
-
|
164 |
-
def __call__(self, hidden_states):
|
165 |
-
hidden_states = self.conv(hidden_states)
|
166 |
-
hidden_states = self.layer_norm(hidden_states)
|
167 |
-
hidden_states = self.activation(hidden_states)
|
168 |
-
return hidden_states
|
169 |
-
|
170 |
-
|
171 |
-
class FlaxConvWithWeightNorm(nn.Module):
|
172 |
-
config: Wav2Vec2Config
|
173 |
-
dtype: jnp.dtype = jnp.float32
|
174 |
-
|
175 |
-
def setup(self):
|
176 |
-
self.conv = nn.Conv(
|
177 |
-
features=self.config.hidden_size,
|
178 |
-
kernel_size=(self.config.num_conv_pos_embeddings,),
|
179 |
-
kernel_init=jax.nn.initializers.he_normal(),
|
180 |
-
padding="VALID",
|
181 |
-
feature_group_count=self.config.num_conv_pos_embedding_groups,
|
182 |
-
dtype=self.dtype,
|
183 |
-
)
|
184 |
-
weight_shape = (
|
185 |
-
self.conv.features,
|
186 |
-
self.conv.features // self.conv.feature_group_count,
|
187 |
-
self.conv.kernel_size[0],
|
188 |
-
)
|
189 |
-
self.weight_v = self.param("weight_v", jax.nn.initializers.he_normal(), weight_shape)
|
190 |
-
self.weight_g = self.param("weight_g", lambda _: jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :])
|
191 |
-
self.bias = self.param("bias", jax.nn.initializers.zeros, (self.conv.features,))
|
192 |
-
self.prev_padding = self.conv.kernel_size[0] // 2
|
193 |
-
|
194 |
-
def _get_normed_weights(self):
|
195 |
-
weight_v_norm = jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :]
|
196 |
-
normed_weight_v = jnp.divide(self.weight_v, weight_v_norm)
|
197 |
-
normed_kernel = jnp.multiply(normed_weight_v, self.weight_g)
|
198 |
-
return normed_kernel
|
199 |
-
|
200 |
-
def __call__(self, hidden_states):
|
201 |
-
kernel = self._get_normed_weights()
|
202 |
-
hidden_states = jnp.pad(hidden_states, ((0, 0), (self.prev_padding, self.prev_padding), (0, 0)))
|
203 |
-
hidden_states = self.conv.apply({"params": {"kernel": kernel.T, "bias": self.bias}}, hidden_states)
|
204 |
-
return hidden_states
|
205 |
-
|
206 |
-
|
207 |
-
class FlaxWav2Vec2PositionalConvEmbedding(nn.Module):
|
208 |
-
config: Wav2Vec2Config
|
209 |
-
dtype: jnp.dtype = jnp.float32
|
210 |
-
|
211 |
-
def setup(self):
|
212 |
-
self.conv = FlaxConvWithWeightNorm(self.config, dtype=self.dtype)
|
213 |
-
self.activation = ACT2FN[self.config.feat_extract_activation]
|
214 |
-
self.num_pad_remove = 1 if self.config.num_conv_pos_embeddings % 2 == 0 else 0
|
215 |
-
|
216 |
-
def __call__(self, hidden_states):
|
217 |
-
hidden_states = hidden_states.transpose((0, 1, 2))
|
218 |
-
|
219 |
-
hidden_states = self.conv(hidden_states)
|
220 |
-
|
221 |
-
if self.num_pad_remove > 0:
|
222 |
-
hidden_states = hidden_states[:, : -self.num_pad_remove, :]
|
223 |
-
hidden_states = self.activation(hidden_states)
|
224 |
-
|
225 |
-
hidden_states = hidden_states.transpose((0, 1, 2))
|
226 |
-
return hidden_states
|
227 |
-
|
228 |
-
|
229 |
-
class FlaxConvLayersCollection(nn.Module):
|
230 |
-
config: Wav2Vec2Config
|
231 |
-
dtype: jnp.dtype = jnp.float32
|
232 |
-
|
233 |
-
def setup(self):
|
234 |
-
if self.config.feat_extract_norm == "layer":
|
235 |
-
# note that we can't use scan on the conv layers as they differ on a layer-by-layer basis
|
236 |
-
BlockLayer = remat(FlaxWav2Vec2LayerNormConvLayer) if self.config.gradient_checkpointing else FlaxWav2Vec2LayerNormConvLayer
|
237 |
-
self.layers = [
|
238 |
-
BlockLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype)
|
239 |
-
for i in range(self.config.num_feat_extract_layers)
|
240 |
-
]
|
241 |
-
elif self.config.feat_extract_norm == "group":
|
242 |
-
raise NotImplementedError("At the moment only ``config.feat_extact_norm == 'layer'`` is supported")
|
243 |
-
else:
|
244 |
-
raise ValueError(
|
245 |
-
f"`config.feat_extract_norm` is {self.config.feat_extract_norm}, but has to be one of ['group', 'layer']"
|
246 |
-
)
|
247 |
-
|
248 |
-
def __call__(self, hidden_states):
|
249 |
-
for i, conv_layer in enumerate(self.layers):
|
250 |
-
hidden_states = conv_layer(hidden_states)
|
251 |
-
return hidden_states
|
252 |
-
|
253 |
-
|
254 |
-
class FlaxWav2Vec2FeatureEncoder(nn.Module):
|
255 |
-
"""Construct the features from raw audio waveform"""
|
256 |
-
|
257 |
-
config: Wav2Vec2Config
|
258 |
-
dtype: jnp.dtype = jnp.float32
|
259 |
-
|
260 |
-
def setup(self):
|
261 |
-
self.conv_layers = FlaxConvLayersCollection(self.config, dtype=self.dtype)
|
262 |
-
|
263 |
-
def __call__(self, input_values, freeze_feature_encoder=False):
|
264 |
-
hidden_states = input_values[:, :, None]
|
265 |
-
hidden_states = self.conv_layers(hidden_states)
|
266 |
-
if freeze_feature_encoder:
|
267 |
-
hidden_states = jax.lax.stop_gradient(hidden_states)
|
268 |
-
return hidden_states
|
269 |
-
|
270 |
-
|
271 |
-
class FlaxWav2Vec2FeatureProjection(nn.Module):
|
272 |
-
config: Wav2Vec2Config
|
273 |
-
dtype: jnp.dtype = jnp.float32
|
274 |
-
|
275 |
-
def setup(self):
|
276 |
-
self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
277 |
-
self.projection = nn.Dense(
|
278 |
-
self.config.hidden_size,
|
279 |
-
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
280 |
-
dtype=self.dtype,
|
281 |
-
)
|
282 |
-
self.dropout = nn.Dropout(rate=self.config.feat_proj_dropout)
|
283 |
-
|
284 |
-
def __call__(self, hidden_states, deterministic=True):
|
285 |
-
norm_hidden_states = self.layer_norm(hidden_states)
|
286 |
-
hidden_states = self.projection(norm_hidden_states)
|
287 |
-
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
288 |
-
return hidden_states, norm_hidden_states
|
289 |
-
|
290 |
-
|
291 |
-
class FlaxWav2Vec2Attention(nn.Module):
|
292 |
-
config: Wav2Vec2Config
|
293 |
-
embed_dim: int
|
294 |
-
num_heads: int
|
295 |
-
dropout: float = 0.0
|
296 |
-
bias: bool = True
|
297 |
-
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
298 |
-
|
299 |
-
def setup(self) -> None:
|
300 |
-
self.head_dim = self.embed_dim // self.num_heads
|
301 |
-
if self.head_dim * self.num_heads != self.embed_dim:
|
302 |
-
raise ValueError(
|
303 |
-
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
|
304 |
-
)
|
305 |
-
|
306 |
-
dense = partial(
|
307 |
-
nn.Dense,
|
308 |
-
self.embed_dim,
|
309 |
-
use_bias=self.bias,
|
310 |
-
dtype=self.dtype,
|
311 |
-
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
312 |
-
)
|
313 |
-
|
314 |
-
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
|
315 |
-
|
316 |
-
self.fused_proj = nn.Dense(
|
317 |
-
self.embed_dim * 3,
|
318 |
-
use_bias=self.bias,
|
319 |
-
dtype=self.dtype,
|
320 |
-
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
321 |
-
)
|
322 |
-
|
323 |
-
self.out_proj = dense()
|
324 |
-
|
325 |
-
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
326 |
-
|
327 |
-
def _split_heads(self, hidden_states):
|
328 |
-
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
|
329 |
-
|
330 |
-
def _merge_heads(self, hidden_states):
|
331 |
-
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
|
332 |
-
|
333 |
-
def __call__(
|
334 |
-
self,
|
335 |
-
hidden_states: jnp.ndarray,
|
336 |
-
key_value_states: Optional[jnp.ndarray] = None,
|
337 |
-
attention_mask: Optional[jnp.ndarray] = None,
|
338 |
-
deterministic: bool = True,
|
339 |
-
) -> Tuple[jnp.ndarray]:
|
340 |
-
"""Input shape: Batch x Time x Channel"""
|
341 |
-
|
342 |
-
if self.config.fuse_matmuls:
|
343 |
-
attention_states = self.fused_proj(hidden_states)
|
344 |
-
query_states, key_states, value_states = jnp.split(attention_states, 3, axis=-1)
|
345 |
-
|
346 |
-
else:
|
347 |
-
# get query proj
|
348 |
-
query_states = self.q_proj(hidden_states)
|
349 |
-
|
350 |
-
key_states = self.k_proj(hidden_states)
|
351 |
-
value_states = self.v_proj(hidden_states)
|
352 |
-
|
353 |
-
query_states = self._split_heads(query_states)
|
354 |
-
key_states = self._split_heads(key_states)
|
355 |
-
value_states = self._split_heads(value_states)
|
356 |
-
|
357 |
-
if attention_mask is not None:
|
358 |
-
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
359 |
-
|
360 |
-
# Convert the boolean attention mask to an attention bias.
|
361 |
-
if attention_mask is not None:
|
362 |
-
# attention mask in the form of attention bias
|
363 |
-
attention_bias = lax.select(
|
364 |
-
attention_mask > 0,
|
365 |
-
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
366 |
-
jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
|
367 |
-
)
|
368 |
-
else:
|
369 |
-
attention_bias = None
|
370 |
-
|
371 |
-
dropout_rng = None
|
372 |
-
if not deterministic and self.dropout > 0.0:
|
373 |
-
dropout_rng = self.make_rng("dropout")
|
374 |
-
|
375 |
-
attn_weights = dot_product_attention_weights(
|
376 |
-
query_states,
|
377 |
-
key_states,
|
378 |
-
bias=attention_bias,
|
379 |
-
dropout_rng=dropout_rng,
|
380 |
-
dropout_rate=self.dropout,
|
381 |
-
broadcast_dropout=True,
|
382 |
-
deterministic=deterministic,
|
383 |
-
dtype=self.dtype,
|
384 |
-
precision=None,
|
385 |
-
)
|
386 |
-
|
387 |
-
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
388 |
-
attn_output = self._merge_heads(attn_output)
|
389 |
-
attn_output = self.out_proj(attn_output)
|
390 |
-
|
391 |
-
return attn_output, attn_weights
|
392 |
-
|
393 |
-
|
394 |
-
class FlaxWav2Vec2FeedForward(nn.Module):
|
395 |
-
config: Wav2Vec2Config
|
396 |
-
dtype: jnp.dtype = jnp.float32
|
397 |
-
|
398 |
-
def setup(self):
|
399 |
-
self.intermediate_dropout = nn.Dropout(rate=self.config.activation_dropout)
|
400 |
-
|
401 |
-
self.intermediate_dense = nn.Dense(
|
402 |
-
self.config.intermediate_size,
|
403 |
-
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
404 |
-
dtype=self.dtype,
|
405 |
-
)
|
406 |
-
if isinstance(self.config.hidden_act, str):
|
407 |
-
self.intermediate_act_fn = ACT2FN[self.config.hidden_act]
|
408 |
-
else:
|
409 |
-
self.intermediate_act_fn = self.config.hidden_act
|
410 |
-
|
411 |
-
self.output_dense = nn.Dense(
|
412 |
-
self.config.hidden_size,
|
413 |
-
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
414 |
-
dtype=self.dtype,
|
415 |
-
)
|
416 |
-
self.output_dropout = nn.Dropout(rate=self.config.hidden_dropout)
|
417 |
-
|
418 |
-
def __call__(self, hidden_states, deterministic=True):
|
419 |
-
hidden_states = self.intermediate_dense(hidden_states)
|
420 |
-
hidden_states = self.intermediate_act_fn(hidden_states)
|
421 |
-
hidden_states = self.intermediate_dropout(hidden_states, deterministic=deterministic)
|
422 |
-
|
423 |
-
hidden_states = self.output_dense(hidden_states)
|
424 |
-
hidden_states = self.output_dropout(hidden_states, deterministic=deterministic)
|
425 |
-
return hidden_states
|
426 |
-
|
427 |
-
|
428 |
-
class FlaxWav2Vec2EncoderLayerStableLayerNorm(nn.Module):
|
429 |
-
config: Wav2Vec2Config
|
430 |
-
dtype: jnp.dtype = jnp.float32
|
431 |
-
|
432 |
-
def setup(self):
|
433 |
-
self.attention = FlaxWav2Vec2Attention(
|
434 |
-
config=self.config,
|
435 |
-
embed_dim=self.config.hidden_size,
|
436 |
-
num_heads=self.config.num_attention_heads,
|
437 |
-
dropout=self.config.attention_dropout,
|
438 |
-
dtype=self.dtype,
|
439 |
-
)
|
440 |
-
self.dropout = nn.Dropout(rate=self.config.hidden_dropout)
|
441 |
-
self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
442 |
-
self.feed_forward = FlaxWav2Vec2FeedForward(self.config, dtype=self.dtype)
|
443 |
-
self.final_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
444 |
-
|
445 |
-
def __call__(self, hidden_states, attention_mask=None, deterministic=True, output_attentions=False):
|
446 |
-
if self.config.use_scan:
|
447 |
-
hidden_states = hidden_states[0]
|
448 |
-
attn_residual = hidden_states
|
449 |
-
hidden_states = self.layer_norm(hidden_states)
|
450 |
-
hidden_states, attn_weights = self.attention(
|
451 |
-
hidden_states, attention_mask=attention_mask, deterministic=deterministic
|
452 |
-
)
|
453 |
-
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
454 |
-
hidden_states = attn_residual + hidden_states
|
455 |
-
hidden_states = hidden_states + self.feed_forward(
|
456 |
-
self.final_layer_norm(hidden_states), deterministic=deterministic
|
457 |
-
)
|
458 |
-
|
459 |
-
outputs = (hidden_states,)
|
460 |
-
|
461 |
-
if output_attentions:
|
462 |
-
outputs += (attn_weights,)
|
463 |
-
|
464 |
-
if self.config.use_scan:
|
465 |
-
outputs = (outputs, None)
|
466 |
-
|
467 |
-
return outputs
|
468 |
-
|
469 |
-
|
470 |
-
class FlaxWav2Vec2EncoderLayerStableLayerNormCollection(nn.Module):
|
471 |
-
config: Wav2Vec2Config
|
472 |
-
dtype: jnp.dtype = jnp.float32
|
473 |
-
|
474 |
-
@nn.compact
|
475 |
-
def __call__(
|
476 |
-
self,
|
477 |
-
hidden_states,
|
478 |
-
attention_mask=None,
|
479 |
-
deterministic: bool = True,
|
480 |
-
output_attentions: bool = False,
|
481 |
-
output_hidden_states: bool = False,
|
482 |
-
return_dict: bool = True,
|
483 |
-
):
|
484 |
-
all_attentions = () if output_attentions else None
|
485 |
-
all_hidden_states = () if output_hidden_states else None
|
486 |
-
|
487 |
-
num_layers = self.config.num_hidden_layers
|
488 |
-
BlockEncoderLayer = (
|
489 |
-
remat(
|
490 |
-
FlaxWav2Vec2EncoderLayerStableLayerNorm,
|
491 |
-
static_argnums=(2, 3),
|
492 |
-
prevent_cse=not self.config.use_scan,
|
493 |
-
)
|
494 |
-
if self.config.gradient_checkpointing
|
495 |
-
else FlaxWav2Vec2EncoderLayerStableLayerNorm
|
496 |
-
)
|
497 |
-
|
498 |
-
if self.config.use_scan:
|
499 |
-
# since all decoder layers are the same, we use nn.scan directly
|
500 |
-
assert not output_attentions, "cannot use `scan` with `output_attentions` set to `True`"
|
501 |
-
assert not output_hidden_states, "cannot use `scan` with `output_hidden_states` set to `True`"
|
502 |
-
hidden_states = (hidden_states,)
|
503 |
-
|
504 |
-
hidden_states, _ = scan_with_axes(
|
505 |
-
BlockEncoderLayer,
|
506 |
-
variable_axes={"params": 0, "cache": 0},
|
507 |
-
split_rngs={"params": True, "dropout": True},
|
508 |
-
in_axes=(nn.broadcast, nn.broadcast, nn.broadcast),
|
509 |
-
length=num_layers,
|
510 |
-
)(self.config, dtype=self.dtype, name="FlaxWav2Vec2EncoderLayers",)(
|
511 |
-
hidden_states, attention_mask, deterministic, output_attentions
|
512 |
-
)
|
513 |
-
hidden_states = hidden_states[0]
|
514 |
-
|
515 |
-
else:
|
516 |
-
for layer in range(num_layers):
|
517 |
-
if output_hidden_states:
|
518 |
-
all_hidden_states += (hidden_states,)
|
519 |
-
|
520 |
-
layer_outputs = BlockEncoderLayer(
|
521 |
-
self.config,
|
522 |
-
dtype=self.dtype,
|
523 |
-
name=str(layer),
|
524 |
-
)(hidden_states, attention_mask, deterministic, output_attentions)
|
525 |
-
|
526 |
-
hidden_states = layer_outputs[0]
|
527 |
-
|
528 |
-
if output_attentions:
|
529 |
-
all_attentions += (layer_outputs[1],)
|
530 |
-
|
531 |
-
if output_hidden_states:
|
532 |
-
all_hidden_states += (hidden_states,)
|
533 |
-
|
534 |
-
outputs = (hidden_states, all_hidden_states, all_attentions)
|
535 |
-
|
536 |
-
if not return_dict:
|
537 |
-
return tuple(v for v in outputs if v is not None)
|
538 |
-
|
539 |
-
return FlaxBaseModelOutput(
|
540 |
-
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
541 |
-
)
|
542 |
-
|
543 |
-
|
544 |
-
class FlaxWav2Vec2StableLayerNormEncoder(nn.Module):
|
545 |
-
config: Wav2Vec2Config
|
546 |
-
dtype: jnp.dtype = jnp.float32
|
547 |
-
|
548 |
-
def setup(self):
|
549 |
-
self.pos_conv_embed = FlaxWav2Vec2PositionalConvEmbedding(self.config, dtype=self.dtype)
|
550 |
-
self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
551 |
-
self.dropout = nn.Dropout(rate=self.config.hidden_dropout)
|
552 |
-
self.layers = FlaxWav2Vec2EncoderLayerStableLayerNormCollection(self.config, dtype=self.dtype)
|
553 |
-
|
554 |
-
def __call__(
|
555 |
-
self,
|
556 |
-
hidden_states,
|
557 |
-
attention_mask=None,
|
558 |
-
deterministic=True,
|
559 |
-
output_attentions=False,
|
560 |
-
output_hidden_states=False,
|
561 |
-
return_dict=True,
|
562 |
-
):
|
563 |
-
|
564 |
-
if attention_mask is not None:
|
565 |
-
# make sure padded tokens are not attended to
|
566 |
-
hidden_states = jnp.where(
|
567 |
-
jnp.broadcast_to(attention_mask[:, :, None], hidden_states.shape), hidden_states, 0
|
568 |
-
)
|
569 |
-
|
570 |
-
position_embeddings = self.pos_conv_embed(hidden_states)
|
571 |
-
|
572 |
-
hidden_states = hidden_states + position_embeddings
|
573 |
-
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
574 |
-
|
575 |
-
outputs = self.layers(
|
576 |
-
hidden_states,
|
577 |
-
attention_mask,
|
578 |
-
output_attentions=output_attentions,
|
579 |
-
output_hidden_states=output_hidden_states,
|
580 |
-
return_dict=return_dict,
|
581 |
-
)
|
582 |
-
|
583 |
-
last_hidden_state = self.layer_norm(outputs[0])
|
584 |
-
|
585 |
-
# update the last element in `hidden_states` after applying `layernorm` above
|
586 |
-
hidden_states = None
|
587 |
-
if output_hidden_states:
|
588 |
-
hidden_states = outputs[1]
|
589 |
-
hidden_states = hidden_states[:-1] + (last_hidden_state,)
|
590 |
-
|
591 |
-
if not return_dict:
|
592 |
-
outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
|
593 |
-
return tuple(v for v in outputs if v is not None)
|
594 |
-
|
595 |
-
return FlaxBaseModelOutput(
|
596 |
-
last_hidden_state=last_hidden_state, hidden_states=hidden_states, attentions=outputs.attentions
|
597 |
-
)
|
598 |
-
|
599 |
-
|
600 |
-
class FlaxWav2Vec2Adapter(nn.Module):
|
601 |
-
config: Wav2Vec2Config
|
602 |
-
dtype: jnp.dtype = jnp.float32
|
603 |
-
|
604 |
-
def setup(self):
|
605 |
-
# hidden_states require down-projection if feature dims don't match
|
606 |
-
if self.config.output_hidden_size != self.config.hidden_size:
|
607 |
-
self.proj = nn.Dense(
|
608 |
-
self.config.output_hidden_size,
|
609 |
-
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
610 |
-
dtype=self.dtype,
|
611 |
-
)
|
612 |
-
self.proj_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
613 |
-
else:
|
614 |
-
self.proj = self.proj_layer_norm = None
|
615 |
-
|
616 |
-
self.layers = FlaxWav2Vec2AdapterLayersCollection(self.config, dtype=self.dtype)
|
617 |
-
|
618 |
-
def __call__(self, hidden_states, deterministic=True):
|
619 |
-
# down-project hidden_states if required
|
620 |
-
if self.proj is not None and self.proj_layer_norm is not None:
|
621 |
-
hidden_states = self.proj(hidden_states)
|
622 |
-
hidden_states = self.proj_layer_norm(hidden_states)
|
623 |
-
|
624 |
-
hidden_states = self.layers(hidden_states)
|
625 |
-
|
626 |
-
return hidden_states
|
627 |
-
|
628 |
-
|
629 |
-
class FlaxWav2Vec2AdapterLayer(nn.Module):
|
630 |
-
config: Wav2Vec2Config
|
631 |
-
dtype: jnp.dtype = jnp.float32
|
632 |
-
|
633 |
-
def setup(self):
|
634 |
-
self.conv = nn.Conv(
|
635 |
-
features=2 * self.config.output_hidden_size,
|
636 |
-
kernel_size=(self.config.adapter_kernel_size,),
|
637 |
-
strides=(self.config.adapter_stride,),
|
638 |
-
padding=((1, 1),),
|
639 |
-
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
640 |
-
dtype=self.dtype,
|
641 |
-
)
|
642 |
-
|
643 |
-
def __call__(self, hidden_states):
|
644 |
-
hidden_states = self.conv(hidden_states)
|
645 |
-
hidden_states = nn.glu(hidden_states, axis=2)
|
646 |
-
|
647 |
-
return hidden_states
|
648 |
-
|
649 |
-
|
650 |
-
class FlaxWav2Vec2AdapterLayersCollection(nn.Module):
|
651 |
-
config: Wav2Vec2Config
|
652 |
-
dtype: jnp.dtype = jnp.float32
|
653 |
-
|
654 |
-
def setup(self):
|
655 |
-
BlockAdapterLayer = remat(FlaxWav2Vec2AdapterLayer) if self.config.gradient_checkpointing else FlaxWav2Vec2AdapterLayer
|
656 |
-
self.layers = [
|
657 |
-
BlockAdapterLayer(self.config, name=str(i), dtype=self.dtype)
|
658 |
-
for i in range(self.config.num_adapter_layers)
|
659 |
-
]
|
660 |
-
|
661 |
-
def __call__(self, hidden_states):
|
662 |
-
for conv_layer in self.layers:
|
663 |
-
hidden_states = conv_layer(hidden_states)
|
664 |
-
|
665 |
-
return hidden_states
|
666 |
-
|
667 |
-
|
668 |
-
class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
|
669 |
-
"""
|
670 |
-
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
671 |
-
models.
|
672 |
-
"""
|
673 |
-
|
674 |
-
config_class = Wav2Vec2Config
|
675 |
-
base_model_prefix: str = "wav2vec2"
|
676 |
-
main_input_name = "input_values"
|
677 |
-
module_class: nn.Module = None
|
678 |
-
|
679 |
-
def __init__(
|
680 |
-
self,
|
681 |
-
config: Wav2Vec2Config,
|
682 |
-
input_shape: Tuple = (1, 1024),
|
683 |
-
seed: int = 0,
|
684 |
-
dtype: jnp.dtype = jnp.float32,
|
685 |
-
_do_init: bool = True,
|
686 |
-
**kwargs,
|
687 |
-
):
|
688 |
-
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
689 |
-
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
690 |
-
|
691 |
-
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
692 |
-
# init input tensors
|
693 |
-
input_values = jnp.zeros(input_shape, dtype="i4")
|
694 |
-
attention_mask = jnp.ones_like(input_values)
|
695 |
-
params_rng, dropout_rng = jax.random.split(rng, 2)
|
696 |
-
rngs = {"params": params_rng, "dropout": dropout_rng}
|
697 |
-
|
698 |
-
return self.module.init(rngs, input_values, attention_mask, return_dict=False)["params"]
|
699 |
-
|
700 |
-
def __call__(
|
701 |
-
self,
|
702 |
-
input_values,
|
703 |
-
attention_mask=None,
|
704 |
-
mask_time_indices=None,
|
705 |
-
extract_features=None,
|
706 |
-
params: dict = None,
|
707 |
-
dropout_rng: jax.random.PRNGKey = None,
|
708 |
-
train: bool = False,
|
709 |
-
output_attentions: Optional[bool] = None,
|
710 |
-
output_hidden_states: Optional[bool] = None,
|
711 |
-
output_features: Optional[bool] = None,
|
712 |
-
freeze_feature_encoder: bool = False,
|
713 |
-
return_dict: Optional[bool] = None,
|
714 |
-
):
|
715 |
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
716 |
-
output_hidden_states = (
|
717 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
718 |
-
)
|
719 |
-
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
720 |
-
|
721 |
-
if attention_mask is None:
|
722 |
-
batch_size, sequence_length = input_values.shape
|
723 |
-
attention_mask = jnp.ones((batch_size, sequence_length))
|
724 |
-
|
725 |
-
if extract_features is not None:
|
726 |
-
extract_features = jnp.array(extract_features, dtype="f4")
|
727 |
-
|
728 |
-
# Handle any PRNG if needed
|
729 |
-
rngs = {}
|
730 |
-
if dropout_rng is not None:
|
731 |
-
rngs["dropout"] = dropout_rng
|
732 |
-
|
733 |
-
inputs = {"params": params or self.params}
|
734 |
-
|
735 |
-
return self.module.apply(
|
736 |
-
inputs,
|
737 |
-
jnp.array(input_values, dtype="f4"),
|
738 |
-
jnp.array(attention_mask, dtype="i4"),
|
739 |
-
mask_time_indices,
|
740 |
-
extract_features,
|
741 |
-
not train,
|
742 |
-
output_attentions,
|
743 |
-
output_hidden_states,
|
744 |
-
output_features,
|
745 |
-
freeze_feature_encoder,
|
746 |
-
return_dict,
|
747 |
-
rngs=rngs,
|
748 |
-
)
|
749 |
-
|
750 |
-
def _get_feat_extract_output_lengths(
|
751 |
-
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
|
752 |
-
):
|
753 |
-
return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter)
|
754 |
-
|
755 |
-
def _get_feature_vector_attention_mask(
|
756 |
-
self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
|
757 |
-
):
|
758 |
-
return self.module._get_feature_vector_attention_mask(feature_vector_length, attention_mask, add_adapter=add_adapter)
|
759 |
-
|
760 |
-
|
761 |
-
class FlaxWav2Vec2Module(nn.Module):
|
762 |
-
config: Wav2Vec2Config
|
763 |
-
dtype: jnp.dtype = jnp.float32
|
764 |
-
|
765 |
-
def setup(self):
|
766 |
-
self.feature_extractor = FlaxWav2Vec2FeatureEncoder(self.config, dtype=self.dtype)
|
767 |
-
self.feature_projection = FlaxWav2Vec2FeatureProjection(self.config, dtype=self.dtype)
|
768 |
-
self.masked_spec_embed = self.param(
|
769 |
-
"masked_spec_embed", jax.nn.initializers.uniform(), (self.config.hidden_size,)
|
770 |
-
)
|
771 |
-
|
772 |
-
if self.config.do_stable_layer_norm:
|
773 |
-
self.encoder = FlaxWav2Vec2StableLayerNormEncoder(self.config, dtype=self.dtype)
|
774 |
-
else:
|
775 |
-
raise NotImplementedError("``config.do_stable_layer_norm is False`` is currently not supported.")
|
776 |
-
|
777 |
-
self.adapter = FlaxWav2Vec2Adapter(self.config, dtype=self.dtype) if self.config.add_adapter else None
|
778 |
-
|
779 |
-
def __call__(
|
780 |
-
self,
|
781 |
-
input_values,
|
782 |
-
attention_mask=None,
|
783 |
-
mask_time_indices=None,
|
784 |
-
extract_features=None,
|
785 |
-
deterministic=True,
|
786 |
-
output_attentions=None,
|
787 |
-
output_hidden_states=None,
|
788 |
-
output_features=False,
|
789 |
-
freeze_feature_encoder=False,
|
790 |
-
return_dict=None,
|
791 |
-
):
|
792 |
-
|
793 |
-
# forward pass through the feature extractor if features not specified
|
794 |
-
if extract_features is None:
|
795 |
-
extract_features = self.feature_extractor(input_values, freeze_feature_encoder=freeze_feature_encoder)
|
796 |
-
|
797 |
-
if output_features:
|
798 |
-
return extract_features
|
799 |
-
|
800 |
-
# make sure that no loss is computed on padded inputs
|
801 |
-
if attention_mask is not None:
|
802 |
-
# compute reduced attention_mask corresponding to feature vectors
|
803 |
-
attention_mask = self._get_feature_vector_attention_mask(
|
804 |
-
extract_features.shape[1], attention_mask, add_adapter=False
|
805 |
-
)
|
806 |
-
|
807 |
-
hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic)
|
808 |
-
if mask_time_indices is not None: # apply SpecAugment along time axis with given indices
|
809 |
-
hidden_states = jnp.where(
|
810 |
-
jnp.broadcast_to(mask_time_indices[:, :, None], hidden_states.shape),
|
811 |
-
jnp.broadcast_to(self.masked_spec_embed[None, None, :], hidden_states.shape),
|
812 |
-
hidden_states,
|
813 |
-
)
|
814 |
-
|
815 |
-
encoder_outputs = self.encoder(
|
816 |
-
hidden_states,
|
817 |
-
attention_mask=attention_mask,
|
818 |
-
deterministic=deterministic,
|
819 |
-
output_attentions=output_attentions,
|
820 |
-
output_hidden_states=output_hidden_states,
|
821 |
-
return_dict=return_dict,
|
822 |
-
)
|
823 |
-
|
824 |
-
hidden_states = encoder_outputs[0]
|
825 |
-
|
826 |
-
if self.adapter is not None:
|
827 |
-
hidden_states = self.adapter(hidden_states)
|
828 |
-
|
829 |
-
if not return_dict:
|
830 |
-
return (hidden_states, extract_features) + encoder_outputs[1:]
|
831 |
-
|
832 |
-
return FlaxWav2Vec2BaseModelOutput(
|
833 |
-
last_hidden_state=hidden_states,
|
834 |
-
extract_features=extract_features,
|
835 |
-
hidden_states=encoder_outputs.hidden_states,
|
836 |
-
attentions=encoder_outputs.attentions,
|
837 |
-
)
|
838 |
-
|
839 |
-
def _get_feat_extract_output_lengths(
|
840 |
-
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
|
841 |
-
):
|
842 |
-
"""
|
843 |
-
Computes the output length of the convolutional layers
|
844 |
-
"""
|
845 |
-
|
846 |
-
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
|
847 |
-
|
848 |
-
def _conv_out_length(input_length, kernel_size, stride):
|
849 |
-
# 1D convolutional layer output length formula taken
|
850 |
-
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
851 |
-
return (input_length - kernel_size) // stride + 1
|
852 |
-
|
853 |
-
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
|
854 |
-
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
|
855 |
-
|
856 |
-
if add_adapter:
|
857 |
-
for _ in range(self.config.num_adapter_layers):
|
858 |
-
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
|
859 |
-
|
860 |
-
return input_lengths
|
861 |
-
|
862 |
-
def _get_feature_vector_attention_mask(
|
863 |
-
self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
|
864 |
-
):
|
865 |
-
# Effectively attention_mask.sum(-1), but not inplace to be able to run
|
866 |
-
# on inference mode.
|
867 |
-
non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1]
|
868 |
-
|
869 |
-
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
|
870 |
-
|
871 |
-
batch_size = attention_mask.shape[0]
|
872 |
-
|
873 |
-
attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
|
874 |
-
# these two operations makes sure that all values
|
875 |
-
# before the output lengths indices are attended to
|
876 |
-
attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1)
|
877 |
-
attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
|
878 |
-
return attention_mask
|
879 |
-
|
880 |
-
|
881 |
-
class FlaxWav2Vec2Model(FlaxWav2Vec2PreTrainedModel):
|
882 |
-
module_class = FlaxWav2Vec2Module
|
883 |
-
|
884 |
-
|
885 |
-
class FlaxWav2Vec2ForCTCModule(nn.Module):
|
886 |
-
config: Wav2Vec2Config
|
887 |
-
dtype: jnp.dtype = jnp.float32
|
888 |
-
|
889 |
-
def setup(self):
|
890 |
-
self.wav2vec2 = FlaxWav2Vec2Module(self.config, dtype=self.dtype)
|
891 |
-
self.dropout = nn.Dropout(rate=self.config.final_dropout)
|
892 |
-
self.lm_head = nn.Dense(
|
893 |
-
self.config.vocab_size,
|
894 |
-
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
895 |
-
dtype=self.dtype,
|
896 |
-
)
|
897 |
-
|
898 |
-
def __call__(
|
899 |
-
self,
|
900 |
-
input_values,
|
901 |
-
attention_mask=None,
|
902 |
-
mask_time_indices=None,
|
903 |
-
extract_features=None,
|
904 |
-
deterministic=True,
|
905 |
-
output_attentions=None,
|
906 |
-
output_hidden_states=None,
|
907 |
-
output_features=False,
|
908 |
-
freeze_feature_encoder=False,
|
909 |
-
return_dict=None,
|
910 |
-
):
|
911 |
-
outputs = self.wav2vec2(
|
912 |
-
input_values,
|
913 |
-
attention_mask=attention_mask,
|
914 |
-
mask_time_indices=mask_time_indices,
|
915 |
-
deterministic=deterministic,
|
916 |
-
output_attentions=output_attentions,
|
917 |
-
output_hidden_states=output_hidden_states,
|
918 |
-
freeze_feature_encoder=freeze_feature_encoder,
|
919 |
-
return_dict=return_dict,
|
920 |
-
)
|
921 |
-
|
922 |
-
hidden_states = outputs[0]
|
923 |
-
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
924 |
-
|
925 |
-
logits = self.lm_head(hidden_states)
|
926 |
-
|
927 |
-
if not return_dict:
|
928 |
-
return (logits,) + outputs[2:]
|
929 |
-
|
930 |
-
return FlaxCausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
|
931 |
-
|
932 |
-
def _get_feat_extract_output_lengths(
|
933 |
-
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
|
934 |
-
):
|
935 |
-
"""
|
936 |
-
Computes the output length of the convolutional layers
|
937 |
-
"""
|
938 |
-
|
939 |
-
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
|
940 |
-
|
941 |
-
def _conv_out_length(input_length, kernel_size, stride):
|
942 |
-
# 1D convolutional layer output length formula taken
|
943 |
-
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
944 |
-
return (input_length - kernel_size) // stride + 1
|
945 |
-
|
946 |
-
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
|
947 |
-
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
|
948 |
-
|
949 |
-
if add_adapter:
|
950 |
-
for _ in range(self.config.num_adapter_layers):
|
951 |
-
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
|
952 |
-
|
953 |
-
return input_lengths
|
954 |
-
|
955 |
-
def _get_feature_vector_attention_mask(
|
956 |
-
self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
|
957 |
-
):
|
958 |
-
# Effectively attention_mask.sum(-1), but not inplace to be able to run
|
959 |
-
# on inference mode.
|
960 |
-
non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1]
|
961 |
-
|
962 |
-
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
|
963 |
-
|
964 |
-
batch_size = attention_mask.shape[0]
|
965 |
-
|
966 |
-
attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
|
967 |
-
# these two operations makes sure that all values
|
968 |
-
# before the output lengths indices are attended to
|
969 |
-
attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1)
|
970 |
-
attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
|
971 |
-
return attention_mask
|
972 |
-
|
973 |
-
|
974 |
-
class FlaxWav2Vec2ForCTC(FlaxWav2Vec2PreTrainedModel):
|
975 |
-
module_class = FlaxWav2Vec2ForCTCModule
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|