valhalla commited on
Commit
6197b2f
1 Parent(s): 8f484d9

add standalone modeling file

Browse files
dalle_mini/configuration_bart.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ BART model configuration """
16
+ import warnings
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+ BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
25
+ "facebook/bart-large": "https://huggingface.co/facebook/bart-large/resolve/main/config.json",
26
+ # See all BART models at https://huggingface.co/models?filter=bart
27
+ }
28
+
29
+
30
+ class BartConfig(PretrainedConfig):
31
+ r"""
32
+ This is the configuration class to store the configuration of a :class:`~transformers.BartModel`. It is used to
33
+ instantiate a BART model according to the specified arguments, defining the model architecture. Instantiating a
34
+ configuration with the defaults will yield a similar configuration to that of the BART `facebook/bart-large
35
+ <https://huggingface.co/facebook/bart-large>`__ architecture.
36
+
37
+ Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
38
+ outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
39
+
40
+
41
+ Args:
42
+ vocab_size (:obj:`int`, `optional`, defaults to 50265):
43
+ Vocabulary size of the BART model. Defines the number of different tokens that can be represented by the
44
+ :obj:`inputs_ids` passed when calling :class:`~transformers.BartModel` or
45
+ :class:`~transformers.TFBartModel`.
46
+ d_model (:obj:`int`, `optional`, defaults to 1024):
47
+ Dimensionality of the layers and the pooler layer.
48
+ encoder_layers (:obj:`int`, `optional`, defaults to 12):
49
+ Number of encoder layers.
50
+ decoder_layers (:obj:`int`, `optional`, defaults to 12):
51
+ Number of decoder layers.
52
+ encoder_attention_heads (:obj:`int`, `optional`, defaults to 16):
53
+ Number of attention heads for each attention layer in the Transformer encoder.
54
+ decoder_attention_heads (:obj:`int`, `optional`, defaults to 16):
55
+ Number of attention heads for each attention layer in the Transformer decoder.
56
+ decoder_ffn_dim (:obj:`int`, `optional`, defaults to 4096):
57
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
58
+ encoder_ffn_dim (:obj:`int`, `optional`, defaults to 4096):
59
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
60
+ activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
61
+ The non-linear activation function (function or string) in the encoder and pooler. If string,
62
+ :obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
63
+ dropout (:obj:`float`, `optional`, defaults to 0.1):
64
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
65
+ attention_dropout (:obj:`float`, `optional`, defaults to 0.0):
66
+ The dropout ratio for the attention probabilities.
67
+ activation_dropout (:obj:`float`, `optional`, defaults to 0.0):
68
+ The dropout ratio for activations inside the fully connected layer.
69
+ classifier_dropout (:obj:`float`, `optional`, defaults to 0.0):
70
+ The dropout ratio for classifier.
71
+ max_position_embeddings (:obj:`int`, `optional`, defaults to 1024):
72
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
73
+ just in case (e.g., 512 or 1024 or 2048).
74
+ init_std (:obj:`float`, `optional`, defaults to 0.02):
75
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
76
+ encoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
77
+ The LayerDrop probability for the encoder. See the `LayerDrop paper <see
78
+ https://arxiv.org/abs/1909.11556>`__ for more details.
79
+ decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
80
+ The LayerDrop probability for the decoder. See the `LayerDrop paper <see
81
+ https://arxiv.org/abs/1909.11556>`__ for more details.
82
+ gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
83
+ If True, use gradient checkpointing to save memory at the expense of slower backward pass.
84
+ scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`):
85
+ Scale embeddings by diving by sqrt(d_model).
86
+ use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
87
+ Whether or not the model should return the last key/values attentions (not used by all models).
88
+ num_labels: (:obj:`int`, `optional`, defaults to 3):
89
+ The number of labels to use in :class:`~transformers.BartForSequenceClassification`.
90
+ forced_eos_token_id (:obj:`int`, `optional`, defaults to 2):
91
+ The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to
92
+ :obj:`eos_token_id`.
93
+
94
+ Example::
95
+
96
+ >>> from transformers import BartModel, BartConfig
97
+
98
+ >>> # Initializing a BART facebook/bart-large style configuration
99
+ >>> configuration = BartConfig()
100
+
101
+ >>> # Initializing a model from the facebook/bart-large style configuration
102
+ >>> model = BartModel(configuration)
103
+
104
+ >>> # Accessing the model configuration
105
+ >>> configuration = model.config
106
+ """
107
+ model_type = "bart"
108
+ keys_to_ignore_at_inference = ["past_key_values"]
109
+ attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
110
+
111
+ def __init__(
112
+ self,
113
+ vocab_size=50265,
114
+ decoder_vocab_size=16384 + 1, # encoded image token space + 1 for bos
115
+ max_position_embeddings=1024,
116
+ decoder_max_position_embeddings=256 + 1, # number of encoded tokens + 1 for bos,
117
+ encoder_layers=12,
118
+ encoder_ffn_dim=4096,
119
+ encoder_attention_heads=16,
120
+ decoder_layers=12,
121
+ decoder_ffn_dim=4096,
122
+ decoder_attention_heads=16,
123
+ encoder_layerdrop=0.0,
124
+ decoder_layerdrop=0.0,
125
+ activation_function="gelu",
126
+ d_model=1024,
127
+ dropout=0.1,
128
+ attention_dropout=0.0,
129
+ activation_dropout=0.0,
130
+ init_std=0.02,
131
+ classifier_dropout=0.0,
132
+ scale_embedding=False,
133
+ gradient_checkpointing=False,
134
+ use_cache=True,
135
+ num_labels=3,
136
+ pad_token_id=1,
137
+ bos_token_id=0,
138
+ eos_token_id=2,
139
+ is_encoder_decoder=True,
140
+ decoder_start_token_id=16384,
141
+ forced_eos_token_id=2,
142
+ **kwargs,
143
+ ):
144
+ self.vocab_size = vocab_size
145
+ self.decoder_vocab_size = decoder_vocab_size
146
+ self.max_position_embeddings = max_position_embeddings
147
+ self.decoder_max_position_embeddings = decoder_max_position_embeddings
148
+ self.d_model = d_model
149
+ self.encoder_ffn_dim = encoder_ffn_dim
150
+ self.encoder_layers = encoder_layers
151
+ self.encoder_attention_heads = encoder_attention_heads
152
+ self.decoder_ffn_dim = decoder_ffn_dim
153
+ self.decoder_layers = decoder_layers
154
+ self.decoder_attention_heads = decoder_attention_heads
155
+ self.dropout = dropout
156
+ self.attention_dropout = attention_dropout
157
+ self.activation_dropout = activation_dropout
158
+ self.activation_function = activation_function
159
+ self.init_std = init_std
160
+ self.encoder_layerdrop = encoder_layerdrop
161
+ self.decoder_layerdrop = decoder_layerdrop
162
+ self.classifier_dropout = classifier_dropout
163
+ self.use_cache = use_cache
164
+ self.num_hidden_layers = encoder_layers
165
+ self.gradient_checkpointing = gradient_checkpointing
166
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
167
+
168
+ super().__init__(
169
+ num_labels=num_labels,
170
+ pad_token_id=pad_token_id,
171
+ bos_token_id=bos_token_id,
172
+ eos_token_id=eos_token_id,
173
+ is_encoder_decoder=is_encoder_decoder,
174
+ decoder_start_token_id=decoder_start_token_id,
175
+ forced_eos_token_id=forced_eos_token_id,
176
+ **kwargs,
177
+ )
178
+
179
+ # ensure backward compatibility for BART CNN models
180
+ if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
181
+ self.forced_bos_token_id = self.bos_token_id
182
+ warnings.warn(
183
+ f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions."
184
+ "The config can simply be saved and uploaded again to be fixed."
185
+ )
dalle_mini/modeling_bart_flax.py ADDED
@@ -0,0 +1,1023 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Flax Bart model. """
16
+
17
+ import math
18
+ from functools import partial
19
+ from typing import Callable, Optional, Tuple
20
+
21
+ import numpy as np
22
+
23
+ import flax.linen as nn
24
+ import jax
25
+ import jax.numpy as jnp
26
+ from flax.core.frozen_dict import FrozenDict, unfreeze
27
+ from flax.linen import combine_masks, make_causal_mask
28
+ from flax.linen.attention import dot_product_attention_weights
29
+ from jax import lax
30
+ from jax.random import PRNGKey
31
+
32
+
33
+ from transformers.modeling_flax_outputs import (
34
+ FlaxBaseModelOutput,
35
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
36
+ FlaxCausalLMOutputWithCrossAttentions,
37
+ FlaxSeq2SeqLMOutput,
38
+ FlaxSeq2SeqModelOutput,
39
+ )
40
+ from transformers.modeling_flax_utils import (
41
+ ACT2FN,
42
+ FlaxPreTrainedModel,
43
+ )
44
+ from transformers.utils import logging
45
+
46
+
47
+ from configuration_bart import BartConfig
48
+
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+
53
+ def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
54
+ """
55
+ Shift input ids one token to the right.
56
+ """
57
+ shifted_input_ids = np.zeros_like(input_ids)
58
+ shifted_input_ids[:, 1:] = input_ids[:, :-1]
59
+ shifted_input_ids[:, 0] = decoder_start_token_id
60
+
61
+ shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
62
+ return shifted_input_ids
63
+
64
+
65
+ class FlaxBartAttention(nn.Module):
66
+ config: BartConfig
67
+ embed_dim: int
68
+ num_heads: int
69
+ dropout: float = 0.0
70
+ causal: bool = False
71
+ bias: bool = True
72
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
73
+
74
+ def setup(self) -> None:
75
+ self.head_dim = self.embed_dim // self.num_heads
76
+ assert (
77
+ self.head_dim * self.num_heads == self.embed_dim
78
+ ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
79
+
80
+ dense = partial(
81
+ nn.Dense,
82
+ self.embed_dim,
83
+ use_bias=self.bias,
84
+ dtype=self.dtype,
85
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
86
+ )
87
+
88
+ self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
89
+ self.out_proj = dense()
90
+
91
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
92
+
93
+ if self.causal:
94
+ self.causal_mask = make_causal_mask(
95
+ jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
96
+ )
97
+
98
+ def _split_heads(self, hidden_states):
99
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
100
+
101
+ def _merge_heads(self, hidden_states):
102
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
103
+
104
+ @nn.compact
105
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
106
+ """
107
+ This function takes projected key, value states from a single input token and concatenates the states to cached
108
+ states from previous steps. This function is slighly adapted from the official Flax repository:
109
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
110
+ """
111
+ # detect if we're initializing by absence of existing cache data.
112
+ is_initialized = self.has_variable("cache", "cached_key")
113
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
114
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
115
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
116
+
117
+ if is_initialized:
118
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
119
+ # update key, value caches with our new 1d spatial slices
120
+ cur_index = cache_index.value
121
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
122
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
123
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
124
+ cached_key.value = key
125
+ cached_value.value = value
126
+ num_updated_cache_vectors = query.shape[1]
127
+ cache_index.value = cache_index.value + num_updated_cache_vectors
128
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
129
+ pad_mask = jnp.broadcast_to(
130
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
131
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
132
+ )
133
+ attention_mask = combine_masks(pad_mask, attention_mask)
134
+ return key, value, attention_mask
135
+
136
+ def __call__(
137
+ self,
138
+ hidden_states: jnp.ndarray,
139
+ attention_mask: jnp.ndarray,
140
+ key_value_states: Optional[jnp.ndarray] = None,
141
+ init_cache: bool = False,
142
+ deterministic: bool = True,
143
+ ) -> Tuple[jnp.ndarray]:
144
+ """Input shape: Batch x Time x Channel"""
145
+
146
+ # if key_value_states are provided this layer is used as a cross-attention layer
147
+ # for the decoder
148
+ is_cross_attention = key_value_states is not None
149
+ batch_size = hidden_states.shape[0]
150
+
151
+ # get query proj
152
+ query_states = self.q_proj(hidden_states)
153
+ # get key, value proj
154
+ if is_cross_attention:
155
+ # cross_attentions
156
+ key_states = self.k_proj(key_value_states)
157
+ value_states = self.v_proj(key_value_states)
158
+ else:
159
+ # self_attention
160
+ key_states = self.k_proj(hidden_states)
161
+ value_states = self.v_proj(hidden_states)
162
+
163
+ query_states = self._split_heads(query_states)
164
+ key_states = self._split_heads(key_states)
165
+ value_states = self._split_heads(value_states)
166
+
167
+ # handle cache prepare causal attention mask
168
+ if self.causal:
169
+ query_length, key_length = query_states.shape[1], key_states.shape[1]
170
+ if self.has_variable("cache", "cached_key"):
171
+ mask_shift = self.variables["cache"]["cache_index"]
172
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
173
+ causal_mask = lax.dynamic_slice(
174
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
175
+ )
176
+ else:
177
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
178
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
179
+
180
+ # combine masks if needed
181
+ if self.causal:
182
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
183
+ attention_mask = combine_masks(attention_mask, causal_mask)
184
+ else:
185
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
186
+
187
+ # During fast autoregressive decoding, we feed one position at a time,
188
+ # and cache the keys and values step by step.
189
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
190
+ key_states, value_states, attention_mask = self._concatenate_to_cache(
191
+ key_states, value_states, query_states, attention_mask
192
+ )
193
+
194
+ # Convert the boolean attention mask to an attention bias.
195
+ # attention mask in the form of attention bias
196
+ attention_bias = lax.select(
197
+ attention_mask > 0,
198
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
199
+ jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
200
+ )
201
+
202
+ dropout_rng = None
203
+ if not deterministic and self.dropout > 0.0:
204
+ dropout_rng = self.make_rng("dropout")
205
+
206
+ attn_weights = dot_product_attention_weights(
207
+ query_states,
208
+ key_states,
209
+ bias=attention_bias,
210
+ dropout_rng=dropout_rng,
211
+ dropout_rate=self.dropout,
212
+ broadcast_dropout=True,
213
+ deterministic=deterministic,
214
+ dtype=self.dtype,
215
+ precision=None,
216
+ )
217
+
218
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
219
+ attn_output = self._merge_heads(attn_output)
220
+ attn_output = self.out_proj(attn_output)
221
+
222
+ return attn_output
223
+
224
+
225
+ class FlaxBartEncoderLayer(nn.Module):
226
+ config: BartConfig
227
+ dtype: jnp.dtype = jnp.float32
228
+
229
+ def setup(self) -> None:
230
+ self.embed_dim = self.config.d_model
231
+ self.self_attn = FlaxBartAttention(
232
+ config=self.config,
233
+ embed_dim=self.embed_dim,
234
+ num_heads=self.config.encoder_attention_heads,
235
+ dropout=self.config.attention_dropout,
236
+ dtype=self.dtype,
237
+ )
238
+ self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
239
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
240
+ self.activation_fn = ACT2FN[self.config.activation_function]
241
+ self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
242
+ self.fc1 = nn.Dense(
243
+ self.config.encoder_ffn_dim,
244
+ dtype=self.dtype,
245
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
246
+ )
247
+ self.fc2 = nn.Dense(
248
+ self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
249
+ )
250
+ self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
251
+
252
+ def __call__(
253
+ self,
254
+ hidden_states: jnp.ndarray,
255
+ attention_mask: jnp.ndarray,
256
+ deterministic: bool = True,
257
+ ) -> Tuple[jnp.ndarray]:
258
+ residual = hidden_states
259
+ hidden_states = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask)
260
+
261
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
262
+ hidden_states = residual + hidden_states
263
+ hidden_states = self.self_attn_layer_norm(hidden_states)
264
+
265
+ residual = hidden_states
266
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
267
+ hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
268
+ hidden_states = self.fc2(hidden_states)
269
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
270
+ hidden_states = residual + hidden_states
271
+ hidden_states = self.final_layer_norm(hidden_states)
272
+
273
+ return hidden_states
274
+
275
+
276
+ class FlaxBartEncoderLayerCollection(nn.Module):
277
+ config: BartConfig
278
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
279
+
280
+ def setup(self):
281
+ self.layers = [
282
+ FlaxBartEncoderLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.encoder_layers)
283
+ ]
284
+
285
+ def __call__(
286
+ self,
287
+ hidden_states,
288
+ attention_mask,
289
+ deterministic: bool = True,
290
+ ):
291
+
292
+ for encoder_layer in self.layers:
293
+ hidden_states = encoder_layer(
294
+ hidden_states,
295
+ attention_mask,
296
+ deterministic,
297
+ )
298
+
299
+ return FlaxBaseModelOutput(last_hidden_state=hidden_states)
300
+
301
+
302
+ class FlaxBartDecoderLayer(nn.Module):
303
+ config: BartConfig
304
+ dtype: jnp.dtype = jnp.float32
305
+
306
+ def setup(self) -> None:
307
+ self.embed_dim = self.config.d_model
308
+ self.self_attn = FlaxBartAttention(
309
+ config=self.config,
310
+ embed_dim=self.embed_dim,
311
+ num_heads=self.config.decoder_attention_heads,
312
+ dropout=self.config.attention_dropout,
313
+ causal=True,
314
+ dtype=self.dtype,
315
+ )
316
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
317
+ self.activation_fn = ACT2FN[self.config.activation_function]
318
+ self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
319
+
320
+ self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
321
+ self.encoder_attn = FlaxBartAttention(
322
+ config=self.config,
323
+ embed_dim=self.embed_dim,
324
+ num_heads=self.config.decoder_attention_heads,
325
+ dropout=self.config.attention_dropout,
326
+ dtype=self.dtype,
327
+ )
328
+ self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
329
+ self.fc1 = nn.Dense(
330
+ self.config.encoder_ffn_dim,
331
+ dtype=self.dtype,
332
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
333
+ )
334
+ self.fc2 = nn.Dense(
335
+ self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
336
+ )
337
+ self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
338
+
339
+ def __call__(
340
+ self,
341
+ hidden_states: jnp.ndarray,
342
+ attention_mask: jnp.ndarray,
343
+ encoder_hidden_states: jnp.ndarray,
344
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
345
+ init_cache: bool = False,
346
+ deterministic: bool = True,
347
+ ) -> Tuple[jnp.ndarray]:
348
+ residual = hidden_states
349
+
350
+ # Self Attention
351
+ hidden_states = self.self_attn(
352
+ hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache
353
+ )
354
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
355
+ hidden_states = residual + hidden_states
356
+ hidden_states = self.self_attn_layer_norm(hidden_states)
357
+
358
+ # Cross-Attention Block
359
+ residual = hidden_states
360
+
361
+ hidden_states = self.encoder_attn(
362
+ hidden_states=hidden_states,
363
+ key_value_states=encoder_hidden_states,
364
+ attention_mask=encoder_attention_mask,
365
+ )
366
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
367
+ hidden_states = residual + hidden_states
368
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
369
+
370
+ # Fully Connected
371
+ residual = hidden_states
372
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
373
+ hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
374
+ hidden_states = self.fc2(hidden_states)
375
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
376
+ hidden_states = residual + hidden_states
377
+ hidden_states = self.final_layer_norm(hidden_states)
378
+
379
+ return hidden_states
380
+
381
+
382
+ class FlaxBartDecoderLayerCollection(nn.Module):
383
+ config: BartConfig
384
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
385
+
386
+ def setup(self):
387
+ self.layers = [
388
+ FlaxBartDecoderLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.decoder_layers)
389
+ ]
390
+
391
+ def __call__(
392
+ self,
393
+ hidden_states,
394
+ attention_mask,
395
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
396
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
397
+ deterministic: bool = True,
398
+ init_cache: bool = False,
399
+ ):
400
+ # decoder layers
401
+ for decoder_layer in self.layers:
402
+ hidden_states = decoder_layer(
403
+ hidden_states,
404
+ attention_mask=attention_mask,
405
+ encoder_hidden_states=encoder_hidden_states,
406
+ encoder_attention_mask=encoder_attention_mask,
407
+ init_cache=init_cache,
408
+ deterministic=deterministic,
409
+ )
410
+
411
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(last_hidden_state=hidden_states)
412
+
413
+
414
+ class FlaxBartEncoder(nn.Module):
415
+ config: BartConfig
416
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
417
+ embed_tokens: Optional[nn.Embed] = None
418
+
419
+ def setup(self):
420
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
421
+
422
+ embed_dim = self.config.d_model
423
+ self.padding_idx = self.config.pad_token_id
424
+ self.max_source_positions = self.config.max_position_embeddings
425
+ self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0
426
+
427
+ if self.embed_tokens is None:
428
+ self.embed_tokens = nn.Embed(
429
+ self.config.vocab_size,
430
+ embed_dim,
431
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
432
+ )
433
+
434
+ # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
435
+ # and adjust num_embeddings appropriately. Other models don't have this hack
436
+ self.offset = 2
437
+ self.embed_positions = nn.Embed(
438
+ self.config.max_position_embeddings + self.offset,
439
+ embed_dim,
440
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
441
+ )
442
+ self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
443
+ self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype)
444
+
445
+ def __call__(
446
+ self,
447
+ input_ids,
448
+ attention_mask,
449
+ position_ids,
450
+ deterministic: bool = True,
451
+ ):
452
+ input_shape = input_ids.shape
453
+ input_ids = input_ids.reshape(-1, input_shape[-1])
454
+
455
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
456
+
457
+ embed_pos = self.embed_positions(position_ids + self.offset)
458
+
459
+ hidden_states = inputs_embeds + embed_pos
460
+ hidden_states = self.layernorm_embedding(hidden_states)
461
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
462
+
463
+ outputs = self.layers(hidden_states, attention_mask, deterministic=deterministic)
464
+
465
+ return FlaxBaseModelOutput(
466
+ last_hidden_state=outputs.last_hidden_state,
467
+ hidden_states=outputs.hidden_states,
468
+ attentions=outputs.attentions,
469
+ )
470
+
471
+
472
+ class FlaxBartDecoder(nn.Module):
473
+ config: BartConfig
474
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
475
+ embed_tokens: Optional[nn.Embed] = None
476
+
477
+ def setup(self):
478
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
479
+
480
+ embed_dim = self.config.d_model
481
+ self.padding_idx = self.config.pad_token_id
482
+ self.max_target_positions = self.config.max_position_embeddings
483
+ self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
484
+
485
+ if self.embed_tokens is None:
486
+ self.embed_tokens = nn.Embed(
487
+ self.config.vocab_size,
488
+ embed_dim,
489
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
490
+ )
491
+
492
+ # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
493
+ # and adjust num_embeddings appropriately. Other models don't have this hack
494
+ self.offset = 2
495
+ self.embed_positions = nn.Embed(
496
+ self.config.max_position_embeddings + self.offset,
497
+ embed_dim,
498
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
499
+ )
500
+
501
+ self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
502
+ self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype)
503
+
504
+ def __call__(
505
+ self,
506
+ input_ids,
507
+ attention_mask,
508
+ position_ids,
509
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
510
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
511
+ init_cache: bool = False,
512
+ deterministic: bool = True,
513
+ ):
514
+ input_shape = input_ids.shape
515
+ input_ids = input_ids.reshape(-1, input_shape[-1])
516
+
517
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
518
+
519
+ # embed positions
520
+ positions = self.embed_positions(position_ids + self.offset)
521
+
522
+ hidden_states = inputs_embeds + positions
523
+ hidden_states = self.layernorm_embedding(hidden_states)
524
+
525
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
526
+
527
+ outputs = self.layers(
528
+ hidden_states,
529
+ attention_mask,
530
+ encoder_hidden_states,
531
+ encoder_attention_mask,
532
+ deterministic=deterministic,
533
+ init_cache=init_cache,
534
+ )
535
+
536
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
537
+ last_hidden_state=outputs.last_hidden_state,
538
+ hidden_states=outputs.hidden_states,
539
+ attentions=outputs.attentions,
540
+ cross_attentions=outputs.cross_attentions,
541
+ )
542
+
543
+
544
+ class FlaxBartModule(nn.Module):
545
+ config: BartConfig
546
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
547
+
548
+ def setup(self):
549
+ self.shared = nn.Embed(
550
+ self.config.vocab_size,
551
+ self.config.d_model,
552
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
553
+ )
554
+ # a separate embedding is used for the decoder
555
+ self.decoder_embed = nn.Embed(
556
+ self.config.decoder_vocab_size,
557
+ self.config.d_model,
558
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
559
+ )
560
+
561
+ self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
562
+ self.decoder = FlaxBartDecoder(self.config, dtype=self.dtype, embed_tokens=self.decoder_embed)
563
+
564
+ def _get_encoder_module(self):
565
+ return self.encoder
566
+
567
+ def _get_decoder_module(self):
568
+ return self.decoder
569
+
570
+ def __call__(
571
+ self,
572
+ input_ids,
573
+ attention_mask,
574
+ decoder_input_ids,
575
+ decoder_attention_mask,
576
+ position_ids,
577
+ decoder_position_ids,
578
+ output_attentions: bool = False,
579
+ output_hidden_states: bool = False,
580
+ return_dict: bool = True,
581
+ deterministic: bool = True,
582
+ ):
583
+ encoder_outputs = self.encoder(
584
+ input_ids=input_ids,
585
+ attention_mask=attention_mask,
586
+ position_ids=position_ids,
587
+ output_attentions=output_attentions,
588
+ output_hidden_states=output_hidden_states,
589
+ return_dict=return_dict,
590
+ deterministic=deterministic,
591
+ )
592
+
593
+ decoder_outputs = self.decoder(
594
+ input_ids=decoder_input_ids,
595
+ attention_mask=decoder_attention_mask,
596
+ position_ids=decoder_position_ids,
597
+ encoder_hidden_states=encoder_outputs[0],
598
+ encoder_attention_mask=attention_mask,
599
+ output_attentions=output_attentions,
600
+ output_hidden_states=output_hidden_states,
601
+ return_dict=return_dict,
602
+ deterministic=deterministic,
603
+ )
604
+
605
+ if not return_dict:
606
+ return decoder_outputs + encoder_outputs
607
+
608
+ return FlaxSeq2SeqModelOutput(
609
+ last_hidden_state=decoder_outputs.last_hidden_state,
610
+ decoder_hidden_states=decoder_outputs.hidden_states,
611
+ decoder_attentions=decoder_outputs.attentions,
612
+ cross_attentions=decoder_outputs.cross_attentions,
613
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
614
+ encoder_hidden_states=encoder_outputs.hidden_states,
615
+ encoder_attentions=encoder_outputs.attentions,
616
+ )
617
+
618
+
619
+ class FlaxBartPreTrainedModel(FlaxPreTrainedModel):
620
+ config_class = BartConfig
621
+ base_model_prefix: str = "model"
622
+ module_class: nn.Module = None
623
+
624
+ def __init__(
625
+ self,
626
+ config: BartConfig,
627
+ input_shape: Tuple[int] = (1, 1),
628
+ seed: int = 0,
629
+ dtype: jnp.dtype = jnp.float32,
630
+ **kwargs,
631
+ ):
632
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
633
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
634
+
635
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
636
+ # init input tensors
637
+ input_ids = jnp.zeros(input_shape, dtype="i4")
638
+ # make sure initialization pass will work for FlaxBartForSequenceClassificationModule
639
+ input_ids = jax.ops.index_update(input_ids, (..., -1), self.config.eos_token_id)
640
+ attention_mask = jnp.ones_like(input_ids)
641
+ decoder_input_ids = input_ids
642
+ decoder_attention_mask = jnp.ones_like(input_ids)
643
+
644
+ batch_size, sequence_length = input_ids.shape
645
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
646
+ decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
647
+
648
+ params_rng, dropout_rng = jax.random.split(rng)
649
+ rngs = {"params": params_rng, "dropout": dropout_rng}
650
+
651
+ return self.module.init(
652
+ rngs,
653
+ input_ids,
654
+ attention_mask,
655
+ decoder_input_ids,
656
+ decoder_attention_mask,
657
+ position_ids,
658
+ decoder_position_ids,
659
+ )["params"]
660
+
661
+ def init_cache(self, batch_size, max_length, encoder_outputs):
662
+ r"""
663
+ Args:
664
+ batch_size (:obj:`int`):
665
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
666
+ max_length (:obj:`int`):
667
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
668
+ cache.
669
+ encoder_outputs (:obj:`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
670
+ ``encoder_outputs`` consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`,
671
+ `optional`: :obj:`attentions`). :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length,
672
+ hidden_size)`, `optional`) is a sequence of hidden-states at the output of the last layer of the
673
+ encoder. Used in the cross-attention of the decoder.
674
+ """
675
+ # init input variables to retrieve cache
676
+ decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
677
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
678
+ decoder_position_ids = jnp.broadcast_to(
679
+ jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
680
+ )
681
+
682
+ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
683
+ decoder_module = module._get_decoder_module()
684
+ return decoder_module(
685
+ decoder_input_ids,
686
+ decoder_attention_mask,
687
+ decoder_position_ids,
688
+ **kwargs,
689
+ )
690
+
691
+ init_variables = self.module.init(
692
+ jax.random.PRNGKey(0),
693
+ decoder_input_ids=decoder_input_ids,
694
+ decoder_attention_mask=decoder_attention_mask,
695
+ decoder_position_ids=decoder_position_ids,
696
+ encoder_hidden_states=encoder_outputs[0],
697
+ init_cache=True,
698
+ method=_decoder_forward, # we only need to call the decoder to init the cache
699
+ )
700
+ return unfreeze(init_variables["cache"])
701
+
702
+ def encode(
703
+ self,
704
+ input_ids: jnp.ndarray,
705
+ attention_mask: Optional[jnp.ndarray] = None,
706
+ position_ids: Optional[jnp.ndarray] = None,
707
+ train: bool = False,
708
+ params: dict = None,
709
+ dropout_rng: PRNGKey = None,
710
+ ):
711
+ r"""
712
+ Returns:
713
+
714
+ Example::
715
+
716
+ >>> from transformers import BartTokenizer, FlaxBartForConditionalGeneration
717
+
718
+ >>> model = FlaxBartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
719
+ >>> tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
720
+
721
+ >>> text = "My friends are cool but they eat too many carbs."
722
+ >>> inputs = tokenizer(text, max_length=1024, return_tensors='jax')
723
+ >>> encoder_outputs = model.encode(**inputs)
724
+ """
725
+ if attention_mask is None:
726
+ attention_mask = jnp.ones_like(input_ids)
727
+ if position_ids is None:
728
+ batch_size, sequence_length = input_ids.shape
729
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
730
+
731
+ # Handle any PRNG if needed
732
+ rngs = {}
733
+ if dropout_rng is not None:
734
+ rngs["dropout"] = dropout_rng
735
+
736
+ def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs):
737
+ encode_module = module._get_encoder_module()
738
+ return encode_module(input_ids, attention_mask, position_ids, **kwargs)
739
+
740
+ return self.module.apply(
741
+ {"params": params or self.params},
742
+ input_ids=jnp.array(input_ids, dtype="i4"),
743
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
744
+ position_ids=jnp.array(position_ids, dtype="i4"),
745
+ deterministic=not train,
746
+ rngs=rngs,
747
+ method=_encoder_forward,
748
+ )
749
+
750
+ def __call__(
751
+ self,
752
+ input_ids: jnp.ndarray,
753
+ attention_mask: Optional[jnp.ndarray] = None,
754
+ decoder_input_ids: Optional[jnp.ndarray] = None,
755
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
756
+ position_ids: Optional[jnp.ndarray] = None,
757
+ decoder_position_ids: Optional[jnp.ndarray] = None,
758
+ output_attentions: Optional[bool] = None,
759
+ output_hidden_states: Optional[bool] = None,
760
+ return_dict: Optional[bool] = None,
761
+ train: bool = False,
762
+ params: dict = None,
763
+ dropout_rng: PRNGKey = None,
764
+ ):
765
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
766
+ output_hidden_states = (
767
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
768
+ )
769
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
770
+
771
+ # prepare encoder inputs
772
+ if attention_mask is None:
773
+ attention_mask = jnp.ones_like(input_ids)
774
+ if position_ids is None:
775
+ batch_size, sequence_length = input_ids.shape
776
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
777
+
778
+ # prepare decoder inputs
779
+ if decoder_input_ids is None:
780
+ decoder_input_ids = shift_tokens_right(
781
+ input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id
782
+ )
783
+ if decoder_attention_mask is None:
784
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
785
+ if decoder_position_ids is None:
786
+ batch_size, sequence_length = decoder_input_ids.shape
787
+ decoder_position_ids = jnp.broadcast_to(
788
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
789
+ )
790
+
791
+ # Handle any PRNG if needed
792
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
793
+
794
+ return self.module.apply(
795
+ {"params": params or self.params},
796
+ input_ids=jnp.array(input_ids, dtype="i4"),
797
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
798
+ position_ids=jnp.array(position_ids, dtype="i4"),
799
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
800
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
801
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
802
+ deterministic=not train,
803
+ rngs=rngs,
804
+ )
805
+
806
+
807
+ class FlaxBartForConditionalGenerationModule(nn.Module):
808
+ config: BartConfig
809
+ dtype: jnp.dtype = jnp.float32
810
+ bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros
811
+
812
+ def setup(self):
813
+ self.model = FlaxBartModule(config=self.config, dtype=self.dtype)
814
+ self.lm_head = nn.Dense(
815
+ self.config.decoder_vocab_size,
816
+ use_bias=False,
817
+ dtype=self.dtype,
818
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
819
+ )
820
+ self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.config.decoder_vocab_size))
821
+
822
+ def _get_encoder_module(self):
823
+ return self.model.encoder
824
+
825
+ def _get_decoder_module(self):
826
+ return self.model.decoder
827
+
828
+ def __call__(
829
+ self,
830
+ input_ids,
831
+ attention_mask,
832
+ decoder_input_ids,
833
+ decoder_attention_mask,
834
+ position_ids,
835
+ decoder_position_ids,
836
+ deterministic: bool = True,
837
+ ):
838
+ outputs = self.model(
839
+ input_ids=input_ids,
840
+ attention_mask=attention_mask,
841
+ decoder_input_ids=decoder_input_ids,
842
+ decoder_attention_mask=decoder_attention_mask,
843
+ position_ids=position_ids,
844
+ decoder_position_ids=decoder_position_ids,
845
+ deterministic=deterministic,
846
+ )
847
+
848
+ hidden_states = outputs[0]
849
+
850
+ if self.config.tie_word_embeddings:
851
+ shared_embedding = self.model.variables["params"]["shared"]["embedding"]
852
+ lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
853
+ else:
854
+ lm_logits = self.lm_head(hidden_states)
855
+
856
+ lm_logits += self.final_logits_bias
857
+
858
+ return FlaxSeq2SeqLMOutput(
859
+ logits=lm_logits,
860
+ decoder_hidden_states=outputs.decoder_hidden_states,
861
+ decoder_attentions=outputs.decoder_attentions,
862
+ cross_attentions=outputs.cross_attentions,
863
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
864
+ encoder_hidden_states=outputs.encoder_hidden_states,
865
+ encoder_attentions=outputs.encoder_attentions,
866
+ )
867
+
868
+
869
+ class FlaxBartForConditionalGeneration(FlaxBartPreTrainedModel):
870
+ module_class = FlaxBartForConditionalGenerationModule
871
+ dtype: jnp.dtype = jnp.float32
872
+
873
+ def decode(
874
+ self,
875
+ decoder_input_ids,
876
+ encoder_outputs,
877
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
878
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
879
+ decoder_position_ids: Optional[jnp.ndarray] = None,
880
+ past_key_values: dict = None,
881
+ train: bool = False,
882
+ params: dict = None,
883
+ dropout_rng: PRNGKey = None,
884
+ ):
885
+ r"""
886
+ Returns:
887
+
888
+ Example::
889
+
890
+ >>> from transformers import BartTokenizer, FlaxBartForConditionalGeneration
891
+
892
+ >>> model = FlaxBartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
893
+ >>> tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
894
+
895
+ >>> text = "My friends are cool but they eat too many carbs."
896
+ >>> inputs = tokenizer(text, max_length=1024, return_tensors='jax')
897
+ >>> encoder_outputs = model.encode(**inputs)
898
+
899
+ >>> decoder_start_token_id = model.config.decoder_start_token_id
900
+ >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
901
+
902
+ >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
903
+ >>> logits = outputs.logits
904
+ """
905
+ encoder_hidden_states = encoder_outputs[0]
906
+ if encoder_attention_mask is None:
907
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
908
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
909
+
910
+ batch_size, sequence_length = decoder_input_ids.shape
911
+ if decoder_attention_mask is None:
912
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length))
913
+
914
+ if decoder_position_ids is None:
915
+ if past_key_values is not None:
916
+ raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
917
+
918
+ decoder_position_ids = jnp.broadcast_to(
919
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
920
+ )
921
+
922
+ # Handle any PRNG if needed
923
+ rngs = {}
924
+ if dropout_rng is not None:
925
+ rngs["dropout"] = dropout_rng
926
+
927
+ inputs = {"params": params or self.params}
928
+
929
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
930
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
931
+ # it can be changed by FlaxBartAttention module
932
+ if past_key_values:
933
+ inputs["cache"] = past_key_values
934
+ mutable = ["cache"]
935
+ else:
936
+ mutable = False
937
+
938
+ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
939
+ decoder_module = module._get_decoder_module()
940
+ outputs = decoder_module(
941
+ decoder_input_ids,
942
+ decoder_attention_mask,
943
+ decoder_position_ids,
944
+ **kwargs,
945
+ )
946
+ hidden_states = outputs[0]
947
+
948
+ if self.config.tie_word_embeddings:
949
+ shared_embedding = module.model.variables["params"]["shared"]["embedding"]
950
+ lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
951
+ else:
952
+ lm_logits = module.lm_head(hidden_states)
953
+
954
+ lm_logits += module.final_logits_bias
955
+ return lm_logits, outputs
956
+
957
+ outputs = self.module.apply(
958
+ inputs,
959
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
960
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
961
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
962
+ encoder_hidden_states=encoder_hidden_states,
963
+ encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
964
+ deterministic=not train,
965
+ rngs=rngs,
966
+ mutable=mutable,
967
+ method=_decoder_forward,
968
+ )
969
+
970
+ if past_key_values is None:
971
+ lm_logits, decoder_outputs = outputs
972
+ else:
973
+ (lm_logits, decoder_outputs), past = outputs
974
+
975
+ outputs = FlaxCausalLMOutputWithCrossAttentions(
976
+ logits=lm_logits,
977
+ hidden_states=decoder_outputs.hidden_states,
978
+ attentions=decoder_outputs.attentions,
979
+ cross_attentions=decoder_outputs.cross_attentions,
980
+ )
981
+
982
+ # add updated cache to model output
983
+ if past_key_values is not None:
984
+ outputs["past_key_values"] = unfreeze(past["cache"])
985
+ return outputs
986
+
987
+ return outputs
988
+
989
+ def prepare_inputs_for_generation(
990
+ self,
991
+ decoder_input_ids,
992
+ max_length,
993
+ attention_mask: Optional[jnp.DeviceArray] = None,
994
+ decoder_attention_mask: Optional[jnp.DeviceArray] = None,
995
+ encoder_outputs=None,
996
+ **kwargs,
997
+ ):
998
+ # initializing the cache
999
+ batch_size, seq_length = decoder_input_ids.shape
1000
+
1001
+ past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
1002
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
1003
+ # But since the decoder uses a causal mask, those positions are masked anyways.
1004
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
1005
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
1006
+ if decoder_attention_mask is not None:
1007
+ position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
1008
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
1009
+ else:
1010
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
1011
+
1012
+ return {
1013
+ "past_key_values": past_key_values,
1014
+ "encoder_outputs": encoder_outputs,
1015
+ "encoder_attention_mask": attention_mask,
1016
+ "decoder_attention_mask": extended_attention_mask,
1017
+ "decoder_position_ids": position_ids,
1018
+ }
1019
+
1020
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
1021
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
1022
+ model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
1023
+ return model_kwargs