ydshieh
commited on
Commit
•
0ac6b6e
1
Parent(s):
89ca80d
Make a remark about a serious bug I made previously
Browse files
vit_gpt2/modeling_flax_gpt2.py
CHANGED
@@ -309,6 +309,7 @@ class FlaxGPT2Block(nn.Module):
|
|
309 |
|
310 |
if not self.only_self_attn:
|
311 |
self.encoder_ln = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
|
|
312 |
self.encoder_attn = FlaxGPT2Attention(config=self.config, dtype=self.dtype, causal=False, self_attn=False)
|
313 |
|
314 |
self.ln_2 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
|
|
309 |
|
310 |
if not self.only_self_attn:
|
311 |
self.encoder_ln = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
312 |
+
# [IMPORTANT] Cross attention requires ``causal=False``! This is a bug I made previously.
|
313 |
self.encoder_attn = FlaxGPT2Attention(config=self.config, dtype=self.dtype, causal=False, self_attn=False)
|
314 |
|
315 |
self.ln_2 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|