ydshieh commited on
Commit
0ac6b6e
1 Parent(s): 89ca80d

Make a remark about a serious bug I made previously

Browse files
Files changed (1) hide show
  1. vit_gpt2/modeling_flax_gpt2.py +1 -0
vit_gpt2/modeling_flax_gpt2.py CHANGED
@@ -309,6 +309,7 @@ class FlaxGPT2Block(nn.Module):
309
 
310
  if not self.only_self_attn:
311
  self.encoder_ln = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
 
312
  self.encoder_attn = FlaxGPT2Attention(config=self.config, dtype=self.dtype, causal=False, self_attn=False)
313
 
314
  self.ln_2 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
 
309
 
310
  if not self.only_self_attn:
311
  self.encoder_ln = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
312
+ # [IMPORTANT] Cross attention requires ``causal=False``! This is a bug I made previously.
313
  self.encoder_attn = FlaxGPT2Attention(config=self.config, dtype=self.dtype, causal=False, self_attn=False)
314
 
315
  self.ln_2 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)