boris commited on
Commit
8654dc9
1 Parent(s): 972bc8d

fix: causal_mask based on image tokens

Browse files
Files changed (1) hide show
  1. dalle_mini/model/modeling.py +3 -2
dalle_mini/model/modeling.py CHANGED
@@ -52,7 +52,7 @@ logger = logging.get_logger(__name__)
52
  class FlaxBartAttention(FlaxBartAttention):
53
  """
54
  Edits:
55
- - causal mask considers embed_dim instead of max_position_embeddings
56
  """
57
 
58
  def setup(self) -> None:
@@ -77,8 +77,9 @@ class FlaxBartAttention(FlaxBartAttention):
77
  self.dropout_layer = nn.Dropout(rate=self.dropout)
78
 
79
  if self.causal:
 
80
  self.causal_mask = make_causal_mask(
81
- jnp.ones((1, self.embed_dim), dtype="bool"), dtype="bool"
82
  )
83
 
84
 
 
52
  class FlaxBartAttention(FlaxBartAttention):
53
  """
54
  Edits:
55
+ - causal mask is used only in decoder and considers image_length + 1 (for BOS)
56
  """
57
 
58
  def setup(self) -> None:
 
77
  self.dropout_layer = nn.Dropout(rate=self.dropout)
78
 
79
  if self.causal:
80
+ # used only in decoder
81
  self.causal_mask = make_causal_mask(
82
+ jnp.ones((1, self.config.image_length + 1), dtype="bool"), dtype="bool"
83
  )
84
 
85