boris commited on
Commit
769d20a
1 Parent(s): 00d4661

feat: allow relative position (#156)

Browse files
src/dalle_mini/model/configuration.py CHANGED
@@ -64,12 +64,14 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
64
  use_head_scale=False, # used in NormFormer
65
  use_cosine_attention=False, # used in Swin v2
66
  tau_init=0.05, # used only in cosine attention (Swin v2)
 
 
67
  use_deepnet_scaling=False, # used in Deepnet
68
  use_glu=False, # "GLU Variants Improve Transformer"
69
  use_alibi=False, # Not implemented yet - from "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation"
70
  sinkhorn_iters=1, # used in SinkFormers
71
- use_final_ln_encoder=False, # final layer normalization in encoder
72
- use_final_ln_decoder=False, # final layer normalization in decoder
73
  # parameters that should not be necessary but could affect results
74
  force_ln_scale=False, # force scale in layernorm even when followed by dense layers
75
  **kwargs,
@@ -98,6 +100,8 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
98
  self.ln_positions = ln_positions
99
  self.use_cosine_attention = use_cosine_attention
100
  self.tau_init = tau_init
 
 
101
  self.use_deepnet_scaling = use_deepnet_scaling
102
  self.use_glu = use_glu
103
  self.use_alibi = use_alibi
 
64
  use_head_scale=False, # used in NormFormer
65
  use_cosine_attention=False, # used in Swin v2
66
  tau_init=0.05, # used only in cosine attention (Swin v2)
67
+ use_absolute_position_embeddings=True, # default
68
+ use_swin_position_embeddings=False, # used in Swin v1/v2
69
  use_deepnet_scaling=False, # used in Deepnet
70
  use_glu=False, # "GLU Variants Improve Transformer"
71
  use_alibi=False, # Not implemented yet - from "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation"
72
  sinkhorn_iters=1, # used in SinkFormers
73
+ use_final_ln_encoder=True, # final layer normalization in encoder
74
+ use_final_ln_decoder=True, # final layer normalization in decoder
75
  # parameters that should not be necessary but could affect results
76
  force_ln_scale=False, # force scale in layernorm even when followed by dense layers
77
  **kwargs,
 
100
  self.ln_positions = ln_positions
101
  self.use_cosine_attention = use_cosine_attention
102
  self.tau_init = tau_init
103
+ self.use_absolute_position_embeddings = use_absolute_position_embeddings
104
+ self.use_swin_position_embeddings = use_swin_position_embeddings
105
  self.use_deepnet_scaling = use_deepnet_scaling
106
  self.use_glu = use_glu
107
  self.use_alibi = use_alibi
src/dalle_mini/model/modeling.py CHANGED
@@ -25,6 +25,7 @@ import flax.linen as nn
25
  import jax
26
  import jax.numpy as jnp
27
  import msgpack.exceptions
 
28
  from flax.core.frozen_dict import unfreeze
29
  from flax.linen import combine_masks, make_causal_mask
30
  from flax.linen import partitioning as nn_partitioning
@@ -52,8 +53,6 @@ from transformers.modeling_flax_outputs import (
52
  from transformers.modeling_flax_utils import ACT2FN
53
  from transformers.models.bart.modeling_flax_bart import (
54
  FlaxBartAttention,
55
- FlaxBartDecoder,
56
- FlaxBartEncoder,
57
  FlaxBartForConditionalGeneration,
58
  FlaxBartForConditionalGenerationModule,
59
  FlaxBartModule,
@@ -180,6 +179,7 @@ def dot_product_attention_weights(
180
  key: Any,
181
  bias: Optional[Any] = None,
182
  mask: Optional[Any] = None,
 
183
  broadcast_dropout: bool = True,
184
  dropout_rng: Optional[PRNGKey] = None,
185
  dropout_rate: float = 0.0,
@@ -210,6 +210,10 @@ def dot_product_attention_weights(
210
  if bias is not None:
211
  attn_weights = attn_weights + bias
212
 
 
 
 
 
213
  # normalize the attention weights
214
  if causal or sinkhorn_iters == 1:
215
  # sinkhorn does not work for causal (leaks info of future tokens into past)
@@ -251,6 +255,8 @@ class FlaxBartAttention(FlaxBartAttention):
251
  """
252
 
253
  is_encoder: bool = False
 
 
254
 
255
  def setup(self) -> None:
256
  self.head_dim = self.embed_dim // self.num_heads
@@ -305,6 +311,15 @@ class FlaxBartAttention(FlaxBartAttention):
305
  (1, self.num_heads, 1, 1),
306
  )
307
 
 
 
 
 
 
 
 
 
 
308
  if self.causal:
309
  # used only in decoder
310
  self.causal_mask = make_causal_mask(
@@ -400,11 +415,21 @@ class FlaxBartAttention(FlaxBartAttention):
400
  key_states = key_states / (
401
  jnp.linalg.norm(key_states, axis=-1, keepdims=True) + 1e-8
402
  )
 
 
 
 
 
 
 
 
 
403
  attn_weights = dot_product_attention_weights(
404
  query_states,
405
  key_states,
406
  bias=attention_bias,
407
  mask=attention_mask,
 
408
  dropout_rng=dropout_rng,
409
  dropout_rate=self.dropout,
410
  broadcast_dropout=True,
@@ -593,6 +618,8 @@ class FlaxBartEncoderLayer(nn.Module):
593
  bias=self.config.use_bias,
594
  dtype=self.dtype,
595
  is_encoder=True,
 
 
596
  )(hidden_states=hidden_states, attention_mask=attention_mask)
597
 
598
  if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
@@ -699,6 +726,8 @@ class FlaxBartDecoderLayer(nn.Module):
699
  bias=self.config.use_bias,
700
  dtype=self.dtype,
701
  is_encoder=False,
 
 
702
  )(
703
  hidden_states=hidden_states,
704
  attention_mask=attention_mask,
@@ -737,6 +766,8 @@ class FlaxBartDecoderLayer(nn.Module):
737
  bias=self.config.use_bias,
738
  dtype=self.dtype,
739
  is_encoder=False,
 
 
740
  )(
741
  hidden_states=hidden_states,
742
  key_value_states=encoder_hidden_states,
@@ -953,7 +984,10 @@ class FlaxBartDecoderLayerCollection(nn.Module):
953
  )
954
 
955
 
956
- class FlaxBartEncoder(FlaxBartEncoder):
 
 
 
957
  """
958
  Edits:
959
  - offset set to 0 (no padding token)
@@ -972,18 +1006,62 @@ class FlaxBartEncoder(FlaxBartEncoder):
972
  # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
973
  # and adjust num_embeddings appropriately. Other models don't have this hack
974
  self.offset = 0
975
- self.embed_positions = nn.Embed(
976
- self.config.max_text_length + self.offset,
977
- embed_dim,
978
- embedding_init=jax.nn.initializers.normal(self.config.init_std),
979
- )
 
980
  self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
981
  self.layernorm_embedding = norm(
982
  self.config.ln_type, dtype=self.dtype, epsilon=1e-05
983
  )
984
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
985
 
986
- class FlaxBartDecoder(FlaxBartDecoder):
 
 
 
987
  """
988
  Edits:
989
  - offset set to 0 (no padding token)
@@ -1004,17 +1082,65 @@ class FlaxBartDecoder(FlaxBartDecoder):
1004
  # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
1005
  # and adjust num_embeddings appropriately. Other models don't have this hack
1006
  self.offset = 0
1007
- self.embed_positions = nn.Embed(
1008
- self.config.image_length + self.offset, # image length for BOS
1009
- embed_dim,
1010
- embedding_init=jax.nn.initializers.normal(self.config.init_std),
1011
- )
 
1012
 
1013
  self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
1014
  self.layernorm_embedding = norm(
1015
  self.config.ln_type, dtype=self.dtype, epsilon=1e-05
1016
  )
1017
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1018
 
1019
  class FlaxBartModule(FlaxBartModule):
1020
  """
 
25
  import jax
26
  import jax.numpy as jnp
27
  import msgpack.exceptions
28
+ from einops import rearrange
29
  from flax.core.frozen_dict import unfreeze
30
  from flax.linen import combine_masks, make_causal_mask
31
  from flax.linen import partitioning as nn_partitioning
 
53
  from transformers.modeling_flax_utils import ACT2FN
54
  from transformers.models.bart.modeling_flax_bart import (
55
  FlaxBartAttention,
 
 
56
  FlaxBartForConditionalGeneration,
57
  FlaxBartForConditionalGenerationModule,
58
  FlaxBartModule,
 
179
  key: Any,
180
  bias: Optional[Any] = None,
181
  mask: Optional[Any] = None,
182
+ embed_pos: Optional[Any] = None,
183
  broadcast_dropout: bool = True,
184
  dropout_rng: Optional[PRNGKey] = None,
185
  dropout_rate: float = 0.0,
 
210
  if bias is not None:
211
  attn_weights = attn_weights + bias
212
 
213
+ # add relative position
214
+ if embed_pos is not None:
215
+ attn_weights = attn_weights + embed_pos
216
+
217
  # normalize the attention weights
218
  if causal or sinkhorn_iters == 1:
219
  # sinkhorn does not work for causal (leaks info of future tokens into past)
 
255
  """
256
 
257
  is_encoder: bool = False
258
+ q_length: int = None
259
+ k_length: int = None
260
 
261
  def setup(self) -> None:
262
  self.head_dim = self.embed_dim // self.num_heads
 
311
  (1, self.num_heads, 1, 1),
312
  )
313
 
314
+ if self.config.use_swin_position_embeddings:
315
+ self.rel_bias = nn.Embed(
316
+ self.q_length,
317
+ self.k_length * self.num_heads,
318
+ embedding_init=deepnet_init()
319
+ if self.config.use_deepnet_scaling
320
+ else jax.nn.initializers.normal(self.config.init_std),
321
+ )
322
+
323
  if self.causal:
324
  # used only in decoder
325
  self.causal_mask = make_causal_mask(
 
415
  key_states = key_states / (
416
  jnp.linalg.norm(key_states, axis=-1, keepdims=True) + 1e-8
417
  )
418
+
419
+ # relative position embeddings
420
+ if self.config.use_swin_position_embeddings:
421
+ position_ids = jnp.arange(self.q_length)
422
+ embed_pos = self.rel_bias(position_ids)
423
+ embed_pos = rearrange(embed_pos, "q (k h) -> 1 h q k", h=self.num_heads)
424
+ else:
425
+ embed_pos = None
426
+
427
  attn_weights = dot_product_attention_weights(
428
  query_states,
429
  key_states,
430
  bias=attention_bias,
431
  mask=attention_mask,
432
+ embed_pos=embed_pos,
433
  dropout_rng=dropout_rng,
434
  dropout_rate=self.dropout,
435
  broadcast_dropout=True,
 
618
  bias=self.config.use_bias,
619
  dtype=self.dtype,
620
  is_encoder=True,
621
+ q_length=self.config.max_text_length,
622
+ k_length=self.config.max_text_length,
623
  )(hidden_states=hidden_states, attention_mask=attention_mask)
624
 
625
  if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
 
726
  bias=self.config.use_bias,
727
  dtype=self.dtype,
728
  is_encoder=False,
729
+ q_length=self.config.image_length,
730
+ k_length=self.config.image_length,
731
  )(
732
  hidden_states=hidden_states,
733
  attention_mask=attention_mask,
 
766
  bias=self.config.use_bias,
767
  dtype=self.dtype,
768
  is_encoder=False,
769
+ q_length=self.config.image_length,
770
+ k_length=self.config.max_text_length,
771
  )(
772
  hidden_states=hidden_states,
773
  key_value_states=encoder_hidden_states,
 
984
  )
985
 
986
 
987
+ class FlaxBartEncoder(nn.Module):
988
+ config: DalleBartConfig
989
+ embed_tokens: nn.Embed
990
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
991
  """
992
  Edits:
993
  - offset set to 0 (no padding token)
 
1006
  # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
1007
  # and adjust num_embeddings appropriately. Other models don't have this hack
1008
  self.offset = 0
1009
+ if self.config.use_absolute_position_embeddings:
1010
+ self.embed_positions = nn.Embed(
1011
+ self.config.max_text_length + self.offset, # image length for BOS
1012
+ embed_dim,
1013
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
1014
+ )
1015
  self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
1016
  self.layernorm_embedding = norm(
1017
  self.config.ln_type, dtype=self.dtype, epsilon=1e-05
1018
  )
1019
 
1020
+ def __call__(
1021
+ self,
1022
+ input_ids,
1023
+ attention_mask,
1024
+ position_ids,
1025
+ output_attentions: bool = False,
1026
+ output_hidden_states: bool = False,
1027
+ return_dict: bool = True,
1028
+ deterministic: bool = True,
1029
+ ):
1030
+ input_shape = input_ids.shape
1031
+ input_ids = input_ids.reshape(-1, input_shape[-1])
1032
+
1033
+ hidden_states = self.embed_tokens(input_ids) * self.embed_scale
1034
+
1035
+ if self.config.use_absolute_position_embeddings:
1036
+ embed_pos = self.embed_positions(position_ids + self.offset)
1037
+ hidden_states = hidden_states + embed_pos
1038
+
1039
+ hidden_states = self.layernorm_embedding(hidden_states)
1040
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
1041
+
1042
+ outputs = self.layers(
1043
+ hidden_states,
1044
+ attention_mask,
1045
+ deterministic=deterministic,
1046
+ output_attentions=output_attentions,
1047
+ output_hidden_states=output_hidden_states,
1048
+ return_dict=return_dict,
1049
+ )
1050
+
1051
+ if not return_dict:
1052
+ return outputs
1053
+
1054
+ return FlaxBaseModelOutput(
1055
+ last_hidden_state=outputs.last_hidden_state,
1056
+ hidden_states=outputs.hidden_states,
1057
+ attentions=outputs.attentions,
1058
+ )
1059
+
1060
 
1061
+ class FlaxBartDecoder(nn.Module):
1062
+ config: DalleBartConfig
1063
+ embed_tokens: nn.Embed
1064
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
1065
  """
1066
  Edits:
1067
  - offset set to 0 (no padding token)
 
1082
  # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
1083
  # and adjust num_embeddings appropriately. Other models don't have this hack
1084
  self.offset = 0
1085
+ if self.config.use_absolute_position_embeddings:
1086
+ self.embed_positions = nn.Embed(
1087
+ self.config.image_length + self.offset, # image length for BOS
1088
+ embed_dim,
1089
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
1090
+ )
1091
 
1092
  self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
1093
  self.layernorm_embedding = norm(
1094
  self.config.ln_type, dtype=self.dtype, epsilon=1e-05
1095
  )
1096
 
1097
+ def __call__(
1098
+ self,
1099
+ input_ids,
1100
+ attention_mask,
1101
+ position_ids,
1102
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
1103
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
1104
+ init_cache: bool = False,
1105
+ output_attentions: bool = False,
1106
+ output_hidden_states: bool = False,
1107
+ return_dict: bool = True,
1108
+ deterministic: bool = True,
1109
+ ):
1110
+ input_shape = input_ids.shape
1111
+ input_ids = input_ids.reshape(-1, input_shape[-1])
1112
+
1113
+ hidden_states = self.embed_tokens(input_ids) * self.embed_scale
1114
+
1115
+ if self.config.use_absolute_position_embeddings:
1116
+ embed_pos = self.embed_positions(position_ids + self.offset)
1117
+ hidden_states = hidden_states + embed_pos
1118
+
1119
+ hidden_states = self.layernorm_embedding(hidden_states)
1120
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
1121
+
1122
+ outputs = self.layers(
1123
+ hidden_states,
1124
+ attention_mask,
1125
+ encoder_hidden_states,
1126
+ encoder_attention_mask,
1127
+ deterministic=deterministic,
1128
+ init_cache=init_cache,
1129
+ output_attentions=output_attentions,
1130
+ output_hidden_states=output_hidden_states,
1131
+ return_dict=return_dict,
1132
+ )
1133
+
1134
+ if not return_dict:
1135
+ return outputs
1136
+
1137
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
1138
+ last_hidden_state=outputs.last_hidden_state,
1139
+ hidden_states=outputs.hidden_states,
1140
+ attentions=outputs.attentions,
1141
+ cross_attentions=outputs.cross_attentions,
1142
+ )
1143
+
1144
 
1145
  class FlaxBartModule(FlaxBartModule):
1146
  """
src/dalle_mini/model/partitions.py CHANGED
@@ -38,6 +38,7 @@ def _get_partition_rules():
38
  # embeddings
39
  (("embed_positions", "embedding"), P("mp", None)),
40
  (("embed_tokens", "embedding"), P("mp", None)),
 
41
  # attention
42
  (("(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")),
43
  (("out_proj", "kernel"), P("mp", None)),
 
38
  # embeddings
39
  (("embed_positions", "embedding"), P("mp", None)),
40
  (("embed_tokens", "embedding"), P("mp", None)),
41
+ (("rel_bias", "embedding"), P(None, "mp")),
42
  # attention
43
  (("(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")),
44
  (("out_proj", "kernel"), P("mp", None)),