boris commited on
Commit
6f1f2d9
1 Parent(s): a96f4dc
dalle_mini/model/__init__.py CHANGED
@@ -1,2 +1,2 @@
1
  from .configuration import DalleBartConfig
2
- from .modeling import DalleBartForConditionalGeneration
 
1
  from .configuration import DalleBartConfig
2
+ from .modeling import DalleBartForConditionalGeneration
dalle_mini/model/configuration.py CHANGED
@@ -18,7 +18,6 @@ import warnings
18
  from transformers.configuration_utils import PretrainedConfig
19
  from transformers.utils import logging
20
 
21
-
22
  logger = logging.get_logger(__name__)
23
 
24
 
@@ -88,7 +87,10 @@ class DalleBartConfig(PretrainedConfig):
88
  """
89
  model_type = "dallebart"
90
  keys_to_ignore_at_inference = ["past_key_values"]
91
- attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
 
 
 
92
 
93
  def __init__(
94
  self,
@@ -118,7 +120,7 @@ class DalleBartConfig(PretrainedConfig):
118
  num_labels=3,
119
  is_encoder_decoder=True,
120
  forced_eos_token_id=None,
121
- tie_word_embeddings=False, # don't tie for scaling reasons and due to different modalities and sizes
122
  **kwargs,
123
  ):
124
  self.normalize_text = normalize_text
@@ -144,18 +146,27 @@ class DalleBartConfig(PretrainedConfig):
144
  self.use_cache = use_cache
145
  self.num_hidden_layers = encoder_layers
146
  self.gradient_checkpointing = gradient_checkpointing
147
- self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
 
 
148
  self.decoder_start_token_id = image_vocab_size # BOS appended to vocab
149
  self.min_length = image_length + 1
150
  self.max_length = image_length + 1
151
 
152
  # remove keys we are about to set to prevent errors
153
- for k in ['bos_token_id', 'eos_token_id', 'pad_token_id', 'decoder_start_token_id', 'forced_eos_token_id']:
 
 
 
 
 
 
154
  kwargs.pop(k, None)
155
 
156
  super().__init__(
157
  num_labels=num_labels,
158
- pad_token_id=image_vocab_size + 1, # needed to avoid errors during generation (converted to jnp.array)
 
159
  bos_token_id=image_vocab_size + 1, # set to unreachable values
160
  eos_token_id=image_vocab_size + 1,
161
  is_encoder_decoder=is_encoder_decoder,
@@ -166,7 +177,9 @@ class DalleBartConfig(PretrainedConfig):
166
  )
167
 
168
  # ensure backward compatibility for BART CNN models
169
- if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
 
 
170
  self.forced_bos_token_id = self.bos_token_id
171
  warnings.warn(
172
  f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions."
 
18
  from transformers.configuration_utils import PretrainedConfig
19
  from transformers.utils import logging
20
 
 
21
  logger = logging.get_logger(__name__)
22
 
23
 
 
87
  """
88
  model_type = "dallebart"
89
  keys_to_ignore_at_inference = ["past_key_values"]
90
+ attribute_map = {
91
+ "num_attention_heads": "encoder_attention_heads",
92
+ "hidden_size": "d_model",
93
+ }
94
 
95
  def __init__(
96
  self,
 
120
  num_labels=3,
121
  is_encoder_decoder=True,
122
  forced_eos_token_id=None,
123
+ tie_word_embeddings=False, # don't tie for scaling reasons and due to different modalities and sizes
124
  **kwargs,
125
  ):
126
  self.normalize_text = normalize_text
 
146
  self.use_cache = use_cache
147
  self.num_hidden_layers = encoder_layers
148
  self.gradient_checkpointing = gradient_checkpointing
149
+ self.scale_embedding = (
150
+ scale_embedding # scale factor will be sqrt(d_model) if True
151
+ )
152
  self.decoder_start_token_id = image_vocab_size # BOS appended to vocab
153
  self.min_length = image_length + 1
154
  self.max_length = image_length + 1
155
 
156
  # remove keys we are about to set to prevent errors
157
+ for k in [
158
+ "bos_token_id",
159
+ "eos_token_id",
160
+ "pad_token_id",
161
+ "decoder_start_token_id",
162
+ "forced_eos_token_id",
163
+ ]:
164
  kwargs.pop(k, None)
165
 
166
  super().__init__(
167
  num_labels=num_labels,
168
+ pad_token_id=image_vocab_size
169
+ + 1, # needed to avoid errors during generation (converted to jnp.array)
170
  bos_token_id=image_vocab_size + 1, # set to unreachable values
171
  eos_token_id=image_vocab_size + 1,
172
  is_encoder_decoder=is_encoder_decoder,
 
177
  )
178
 
179
  # ensure backward compatibility for BART CNN models
180
+ if self.forced_bos_token_id is None and kwargs.get(
181
+ "force_bos_token_to_be_generated", False
182
+ ):
183
  self.forced_bos_token_id = self.bos_token_id
184
  warnings.warn(
185
  f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions."
dalle_mini/model/modeling.py CHANGED
@@ -18,19 +18,16 @@ import math
18
  from functools import partial
19
  from typing import Callable, Optional, Tuple
20
 
21
- import numpy as np
22
-
23
  import flax.linen as nn
24
  import jax
25
  import jax.numpy as jnp
 
26
  from flax.core.frozen_dict import FrozenDict, unfreeze
27
- from flax.traverse_util import flatten_dict
28
  from flax.linen import combine_masks, make_causal_mask
29
  from flax.linen.attention import dot_product_attention_weights
 
30
  from jax import lax
31
  from jax.random import PRNGKey
32
-
33
-
34
  from transformers.modeling_flax_outputs import (
35
  FlaxBaseModelOutput,
36
  FlaxBaseModelOutputWithPastAndCrossAttentions,
@@ -38,20 +35,17 @@ from transformers.modeling_flax_outputs import (
38
  FlaxSeq2SeqLMOutput,
39
  FlaxSeq2SeqModelOutput,
40
  )
41
- from transformers.modeling_flax_utils import (
42
- ACT2FN,
43
- FlaxPreTrainedModel,
44
- )
45
  from transformers.utils import logging
46
 
47
-
48
  from .configuration import DalleBartConfig
49
 
50
-
51
  logger = logging.get_logger(__name__)
52
 
53
 
54
- def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
 
 
55
  """
56
  Shift input ids one token to the right.
57
  """
@@ -59,7 +53,9 @@ def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_tok
59
  shifted_input_ids[:, 1:] = input_ids[:, :-1]
60
  shifted_input_ids[:, 0] = decoder_start_token_id
61
 
62
- shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
 
 
63
  return shifted_input_ids
64
 
65
 
@@ -97,7 +93,9 @@ class FlaxBartAttention(nn.Module):
97
  )
98
 
99
  def _split_heads(self, hidden_states):
100
- return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
 
 
101
 
102
  def _merge_heads(self, hidden_states):
103
  return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
@@ -111,9 +109,15 @@ class FlaxBartAttention(nn.Module):
111
  """
112
  # detect if we're initializing by absence of existing cache data.
113
  is_initialized = self.has_variable("cache", "cached_key")
114
- cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
115
- cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
116
- cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
 
 
 
 
 
 
117
 
118
  if is_initialized:
119
  *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
@@ -172,15 +176,21 @@ class FlaxBartAttention(nn.Module):
172
  mask_shift = self.variables["cache"]["cache_index"]
173
  max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
174
  causal_mask = lax.dynamic_slice(
175
- self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
 
 
176
  )
177
  else:
178
  causal_mask = self.causal_mask[:, :, :query_length, :key_length]
179
- causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
 
 
180
 
181
  # combine masks if needed
182
  if self.causal:
183
- attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
 
 
184
  attention_mask = combine_masks(attention_mask, causal_mask)
185
  else:
186
  attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
@@ -261,7 +271,9 @@ class FlaxBartEncoderLayer(nn.Module):
261
  deterministic: bool = True,
262
  ) -> Tuple[jnp.ndarray]:
263
  residual = hidden_states
264
- hidden_states = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask)
 
 
265
 
266
  hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
267
  hidden_states = residual + hidden_states
@@ -269,7 +281,9 @@ class FlaxBartEncoderLayer(nn.Module):
269
 
270
  residual = hidden_states
271
  hidden_states = self.activation_fn(self.fc1(hidden_states))
272
- hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
 
 
273
  hidden_states = self.fc2(hidden_states)
274
  hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
275
  hidden_states = residual + hidden_states
@@ -283,9 +297,14 @@ class FlaxBartEncoderLayerCollection(nn.Module):
283
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
284
 
285
  def setup(self):
286
- layer_module = nn.remat(FlaxBartEncoderLayer) if self.config.gradient_checkpointing else FlaxBartEncoderLayer
 
 
 
 
287
  self.layers = [
288
- layer_module(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.encoder_layers)
 
289
  ]
290
 
291
  def __call__(
@@ -359,7 +378,9 @@ class FlaxBartDecoderLayer(nn.Module):
359
 
360
  # Self Attention
361
  hidden_states = self.self_attn(
362
- hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache
 
 
363
  )
364
  hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
365
  hidden_states = residual + hidden_states
@@ -380,7 +401,9 @@ class FlaxBartDecoderLayer(nn.Module):
380
  # Fully Connected
381
  residual = hidden_states
382
  hidden_states = self.activation_fn(self.fc1(hidden_states))
383
- hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
 
 
384
  hidden_states = self.fc2(hidden_states)
385
  hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
386
  hidden_states = residual + hidden_states
@@ -394,9 +417,14 @@ class FlaxBartDecoderLayerCollection(nn.Module):
394
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
395
 
396
  def setup(self):
397
- layer_module = nn.remat(FlaxBartDecoderLayer) if self.config.gradient_checkpointing else FlaxBartDecoderLayer
 
 
 
 
398
  self.layers = [
399
- layer_module(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.decoder_layers)
 
400
  ]
401
 
402
  def __call__(
@@ -419,7 +447,9 @@ class FlaxBartDecoderLayerCollection(nn.Module):
419
  deterministic=deterministic,
420
  )
421
 
422
- return FlaxBaseModelOutputWithPastAndCrossAttentions(last_hidden_state=hidden_states)
 
 
423
 
424
 
425
  class DalleBartEncoder(nn.Module):
@@ -470,7 +500,9 @@ class DalleBartEncoder(nn.Module):
470
  hidden_states = self.layernorm_embedding(hidden_states)
471
  hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
472
 
473
- outputs = self.layers(hidden_states, attention_mask, deterministic=deterministic)
 
 
474
 
475
  return FlaxBaseModelOutput(
476
  last_hidden_state=outputs.last_hidden_state,
@@ -488,7 +520,9 @@ class DalleBartDecoder(nn.Module):
488
 
489
  embed_dim = self.config.d_model
490
  self.padding_idx = self.config.pad_token_id
491
- self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
 
 
492
 
493
  self.embed_tokens = nn.Embed(
494
  self.config.image_vocab_size + 1, # image vocab size + 1 for BOS
@@ -619,11 +653,15 @@ class DalleBartPreTrainedModel(FlaxPreTrainedModel):
619
  **kwargs,
620
  ):
621
  module = self.module_class(config=config, dtype=dtype)
622
- super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, **kwargs)
 
 
623
 
624
  @property
625
  def num_params(self):
626
- num_params = jax.tree_map(lambda param: param.size, flatten_dict(unfreeze(self.params))).values()
 
 
627
  return sum(list(num_params))
628
 
629
  def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
@@ -636,8 +674,12 @@ class DalleBartPreTrainedModel(FlaxPreTrainedModel):
636
  decoder_attention_mask = jnp.ones_like(input_ids)
637
 
638
  batch_size, sequence_length = input_ids.shape
639
- position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
640
- decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
 
 
 
 
641
 
642
  params_rng, dropout_rng = jax.random.split(rng)
643
  rngs = {"params": params_rng, "dropout": dropout_rng}
@@ -670,10 +712,17 @@ class DalleBartPreTrainedModel(FlaxPreTrainedModel):
670
  decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
671
  decoder_attention_mask = jnp.ones_like(decoder_input_ids)
672
  decoder_position_ids = jnp.broadcast_to(
673
- jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
 
674
  )
675
 
676
- def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
 
 
 
 
 
 
677
  decoder_module = module._get_decoder_module()
678
  return decoder_module(
679
  decoder_input_ids,
@@ -720,7 +769,9 @@ class DalleBartPreTrainedModel(FlaxPreTrainedModel):
720
  attention_mask = jnp.ones_like(input_ids)
721
  if position_ids is None:
722
  batch_size, sequence_length = input_ids.shape
723
- position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
 
 
724
 
725
  # Handle any PRNG if needed
726
  rngs = {}
@@ -754,19 +805,25 @@ class DalleBartPreTrainedModel(FlaxPreTrainedModel):
754
  params: dict = None,
755
  dropout_rng: PRNGKey = None,
756
  ):
757
- return_dict = return_dict if return_dict is not None else self.config.return_dict
 
 
758
 
759
  # prepare encoder inputs
760
  if attention_mask is None:
761
  attention_mask = jnp.ones_like(input_ids)
762
  if position_ids is None:
763
  batch_size, sequence_length = input_ids.shape
764
- position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
 
 
765
 
766
  # prepare decoder inputs
767
  if decoder_input_ids is None:
768
  decoder_input_ids = shift_tokens_right(
769
- input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id
 
 
770
  )
771
  if decoder_attention_mask is None:
772
  decoder_attention_mask = jnp.ones_like(decoder_input_ids)
@@ -839,7 +896,9 @@ class DalleBartForConditionalGenerationModule(nn.Module):
839
 
840
  if self.config.tie_word_embeddings:
841
  shared_embedding = self.model.variables["params"]["shared"]["embedding"]
842
- lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
 
 
843
  else:
844
  lm_logits = self.lm_head(hidden_states)
845
 
@@ -901,7 +960,9 @@ class DalleBartForConditionalGeneration(DalleBartPreTrainedModel):
901
 
902
  if decoder_position_ids is None:
903
  if past_key_values is not None:
904
- raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
 
 
905
 
906
  decoder_position_ids = jnp.broadcast_to(
907
  jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
@@ -923,7 +984,13 @@ class DalleBartForConditionalGeneration(DalleBartPreTrainedModel):
923
  else:
924
  mutable = False
925
 
926
- def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
 
 
 
 
 
 
927
  decoder_module = module._get_decoder_module()
928
  outputs = decoder_module(
929
  decoder_input_ids,
@@ -934,8 +1001,12 @@ class DalleBartForConditionalGeneration(DalleBartPreTrainedModel):
934
  hidden_states = outputs[0]
935
 
936
  if self.config.tie_word_embeddings:
937
- shared_embedding = module.model.variables["params"]["shared"]["embedding"]
938
- lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
 
 
 
 
939
  else:
940
  lm_logits = module.lm_head(hidden_states)
941
 
@@ -993,9 +1064,13 @@ class DalleBartForConditionalGeneration(DalleBartPreTrainedModel):
993
  extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
994
  if decoder_attention_mask is not None:
995
  position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
996
- extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
 
 
997
  else:
998
- position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
 
 
999
 
1000
  return {
1001
  "past_key_values": past_key_values,
@@ -1007,5 +1082,7 @@ class DalleBartForConditionalGeneration(DalleBartPreTrainedModel):
1007
 
1008
  def update_inputs_for_generation(self, model_outputs, model_kwargs):
1009
  model_kwargs["past_key_values"] = model_outputs.past_key_values
1010
- model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
 
 
1011
  return model_kwargs
 
18
  from functools import partial
19
  from typing import Callable, Optional, Tuple
20
 
 
 
21
  import flax.linen as nn
22
  import jax
23
  import jax.numpy as jnp
24
+ import numpy as np
25
  from flax.core.frozen_dict import FrozenDict, unfreeze
 
26
  from flax.linen import combine_masks, make_causal_mask
27
  from flax.linen.attention import dot_product_attention_weights
28
+ from flax.traverse_util import flatten_dict
29
  from jax import lax
30
  from jax.random import PRNGKey
 
 
31
  from transformers.modeling_flax_outputs import (
32
  FlaxBaseModelOutput,
33
  FlaxBaseModelOutputWithPastAndCrossAttentions,
 
35
  FlaxSeq2SeqLMOutput,
36
  FlaxSeq2SeqModelOutput,
37
  )
38
+ from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
 
 
 
39
  from transformers.utils import logging
40
 
 
41
  from .configuration import DalleBartConfig
42
 
 
43
  logger = logging.get_logger(__name__)
44
 
45
 
46
+ def shift_tokens_right(
47
+ input_ids: np.array, pad_token_id: int, decoder_start_token_id: int
48
+ ) -> np.ndarray:
49
  """
50
  Shift input ids one token to the right.
51
  """
 
53
  shifted_input_ids[:, 1:] = input_ids[:, :-1]
54
  shifted_input_ids[:, 0] = decoder_start_token_id
55
 
56
+ shifted_input_ids = np.where(
57
+ shifted_input_ids == -100, pad_token_id, shifted_input_ids
58
+ )
59
  return shifted_input_ids
60
 
61
 
 
93
  )
94
 
95
  def _split_heads(self, hidden_states):
96
+ return hidden_states.reshape(
97
+ hidden_states.shape[:2] + (self.num_heads, self.head_dim)
98
+ )
99
 
100
  def _merge_heads(self, hidden_states):
101
  return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
 
109
  """
110
  # detect if we're initializing by absence of existing cache data.
111
  is_initialized = self.has_variable("cache", "cached_key")
112
+ cached_key = self.variable(
113
+ "cache", "cached_key", jnp.zeros, key.shape, key.dtype
114
+ )
115
+ cached_value = self.variable(
116
+ "cache", "cached_value", jnp.zeros, value.shape, value.dtype
117
+ )
118
+ cache_index = self.variable(
119
+ "cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)
120
+ )
121
 
122
  if is_initialized:
123
  *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
 
176
  mask_shift = self.variables["cache"]["cache_index"]
177
  max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
178
  causal_mask = lax.dynamic_slice(
179
+ self.causal_mask,
180
+ (0, 0, mask_shift, 0),
181
+ (1, 1, query_length, max_decoder_length),
182
  )
183
  else:
184
  causal_mask = self.causal_mask[:, :, :query_length, :key_length]
185
+ causal_mask = jnp.broadcast_to(
186
+ causal_mask, (batch_size,) + causal_mask.shape[1:]
187
+ )
188
 
189
  # combine masks if needed
190
  if self.causal:
191
+ attention_mask = jnp.broadcast_to(
192
+ jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape
193
+ )
194
  attention_mask = combine_masks(attention_mask, causal_mask)
195
  else:
196
  attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
 
271
  deterministic: bool = True,
272
  ) -> Tuple[jnp.ndarray]:
273
  residual = hidden_states
274
+ hidden_states = self.self_attn(
275
+ hidden_states=hidden_states, attention_mask=attention_mask
276
+ )
277
 
278
  hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
279
  hidden_states = residual + hidden_states
 
281
 
282
  residual = hidden_states
283
  hidden_states = self.activation_fn(self.fc1(hidden_states))
284
+ hidden_states = self.activation_dropout_layer(
285
+ hidden_states, deterministic=deterministic
286
+ )
287
  hidden_states = self.fc2(hidden_states)
288
  hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
289
  hidden_states = residual + hidden_states
 
297
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
298
 
299
  def setup(self):
300
+ layer_module = (
301
+ nn.remat(FlaxBartEncoderLayer)
302
+ if self.config.gradient_checkpointing
303
+ else FlaxBartEncoderLayer
304
+ )
305
  self.layers = [
306
+ layer_module(self.config, name=str(i), dtype=self.dtype)
307
+ for i in range(self.config.encoder_layers)
308
  ]
309
 
310
  def __call__(
 
378
 
379
  # Self Attention
380
  hidden_states = self.self_attn(
381
+ hidden_states=hidden_states,
382
+ attention_mask=attention_mask,
383
+ init_cache=init_cache,
384
  )
385
  hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
386
  hidden_states = residual + hidden_states
 
401
  # Fully Connected
402
  residual = hidden_states
403
  hidden_states = self.activation_fn(self.fc1(hidden_states))
404
+ hidden_states = self.activation_dropout_layer(
405
+ hidden_states, deterministic=deterministic
406
+ )
407
  hidden_states = self.fc2(hidden_states)
408
  hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
409
  hidden_states = residual + hidden_states
 
417
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
418
 
419
  def setup(self):
420
+ layer_module = (
421
+ nn.remat(FlaxBartDecoderLayer)
422
+ if self.config.gradient_checkpointing
423
+ else FlaxBartDecoderLayer
424
+ )
425
  self.layers = [
426
+ layer_module(self.config, name=str(i), dtype=self.dtype)
427
+ for i in range(self.config.decoder_layers)
428
  ]
429
 
430
  def __call__(
 
447
  deterministic=deterministic,
448
  )
449
 
450
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
451
+ last_hidden_state=hidden_states
452
+ )
453
 
454
 
455
  class DalleBartEncoder(nn.Module):
 
500
  hidden_states = self.layernorm_embedding(hidden_states)
501
  hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
502
 
503
+ outputs = self.layers(
504
+ hidden_states, attention_mask, deterministic=deterministic
505
+ )
506
 
507
  return FlaxBaseModelOutput(
508
  last_hidden_state=outputs.last_hidden_state,
 
520
 
521
  embed_dim = self.config.d_model
522
  self.padding_idx = self.config.pad_token_id
523
+ self.embed_scale = (
524
+ math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
525
+ )
526
 
527
  self.embed_tokens = nn.Embed(
528
  self.config.image_vocab_size + 1, # image vocab size + 1 for BOS
 
653
  **kwargs,
654
  ):
655
  module = self.module_class(config=config, dtype=dtype)
656
+ super().__init__(
657
+ config, module, input_shape=input_shape, seed=seed, dtype=dtype, **kwargs
658
+ )
659
 
660
  @property
661
  def num_params(self):
662
+ num_params = jax.tree_map(
663
+ lambda param: param.size, flatten_dict(unfreeze(self.params))
664
+ ).values()
665
  return sum(list(num_params))
666
 
667
  def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
 
674
  decoder_attention_mask = jnp.ones_like(input_ids)
675
 
676
  batch_size, sequence_length = input_ids.shape
677
+ position_ids = jnp.broadcast_to(
678
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
679
+ )
680
+ decoder_position_ids = jnp.broadcast_to(
681
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
682
+ )
683
 
684
  params_rng, dropout_rng = jax.random.split(rng)
685
  rngs = {"params": params_rng, "dropout": dropout_rng}
 
712
  decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
713
  decoder_attention_mask = jnp.ones_like(decoder_input_ids)
714
  decoder_position_ids = jnp.broadcast_to(
715
+ jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]),
716
+ decoder_input_ids.shape,
717
  )
718
 
719
+ def _decoder_forward(
720
+ module,
721
+ decoder_input_ids,
722
+ decoder_attention_mask,
723
+ decoder_position_ids,
724
+ **kwargs,
725
+ ):
726
  decoder_module = module._get_decoder_module()
727
  return decoder_module(
728
  decoder_input_ids,
 
769
  attention_mask = jnp.ones_like(input_ids)
770
  if position_ids is None:
771
  batch_size, sequence_length = input_ids.shape
772
+ position_ids = jnp.broadcast_to(
773
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
774
+ )
775
 
776
  # Handle any PRNG if needed
777
  rngs = {}
 
805
  params: dict = None,
806
  dropout_rng: PRNGKey = None,
807
  ):
808
+ return_dict = (
809
+ return_dict if return_dict is not None else self.config.return_dict
810
+ )
811
 
812
  # prepare encoder inputs
813
  if attention_mask is None:
814
  attention_mask = jnp.ones_like(input_ids)
815
  if position_ids is None:
816
  batch_size, sequence_length = input_ids.shape
817
+ position_ids = jnp.broadcast_to(
818
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
819
+ )
820
 
821
  # prepare decoder inputs
822
  if decoder_input_ids is None:
823
  decoder_input_ids = shift_tokens_right(
824
+ input_ids,
825
+ self.config.pad_token_id,
826
+ decoder_start_token_id=self.config.decoder_start_token_id,
827
  )
828
  if decoder_attention_mask is None:
829
  decoder_attention_mask = jnp.ones_like(decoder_input_ids)
 
896
 
897
  if self.config.tie_word_embeddings:
898
  shared_embedding = self.model.variables["params"]["shared"]["embedding"]
899
+ lm_logits = self.lm_head.apply(
900
+ {"params": {"kernel": shared_embedding.T}}, hidden_states
901
+ )
902
  else:
903
  lm_logits = self.lm_head(hidden_states)
904
 
 
960
 
961
  if decoder_position_ids is None:
962
  if past_key_values is not None:
963
+ raise ValueError(
964
+ "Make sure to provide `decoder_position_ids` when passing `past_key_values`."
965
+ )
966
 
967
  decoder_position_ids = jnp.broadcast_to(
968
  jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
 
984
  else:
985
  mutable = False
986
 
987
+ def _decoder_forward(
988
+ module,
989
+ decoder_input_ids,
990
+ decoder_attention_mask,
991
+ decoder_position_ids,
992
+ **kwargs,
993
+ ):
994
  decoder_module = module._get_decoder_module()
995
  outputs = decoder_module(
996
  decoder_input_ids,
 
1001
  hidden_states = outputs[0]
1002
 
1003
  if self.config.tie_word_embeddings:
1004
+ shared_embedding = module.model.variables["params"]["shared"][
1005
+ "embedding"
1006
+ ]
1007
+ lm_logits = module.lm_head.apply(
1008
+ {"params": {"kernel": shared_embedding.T}}, hidden_states
1009
+ )
1010
  else:
1011
  lm_logits = module.lm_head(hidden_states)
1012
 
 
1064
  extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
1065
  if decoder_attention_mask is not None:
1066
  position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
1067
+ extended_attention_mask = lax.dynamic_update_slice(
1068
+ extended_attention_mask, decoder_attention_mask, (0, 0)
1069
+ )
1070
  else:
1071
+ position_ids = jnp.broadcast_to(
1072
+ jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
1073
+ )
1074
 
1075
  return {
1076
  "past_key_values": past_key_values,
 
1082
 
1083
  def update_inputs_for_generation(self, model_outputs, model_kwargs):
1084
  model_kwargs["past_key_values"] = model_outputs.past_key_values
1085
+ model_kwargs["decoder_position_ids"] = (
1086
+ model_kwargs["decoder_position_ids"][:, -1:] + 1
1087
+ )
1088
  return model_kwargs
dalle_mini/model/partitions.py CHANGED
@@ -4,7 +4,6 @@ from flax.core.frozen_dict import freeze
4
  from flax.traverse_util import flatten_dict, unflatten_dict
5
  from jax.experimental import PartitionSpec as P
6
 
7
-
8
  # utils adapted from https://github.com/google-research/google-research/blob/master/flax_models/t5x/partitions.py
9
  # Sentinels
10
  _unmatched = object()
 
4
  from flax.traverse_util import flatten_dict, unflatten_dict
5
  from jax.experimental import PartitionSpec as P
6
 
 
7
  # utils adapted from https://github.com/google-research/google-research/blob/master/flax_models/t5x/partitions.py
8
  # Sentinels
9
  _unmatched = object()