ydshieh
commited on
Commit
•
9a48a1a
1
Parent(s):
7262cfd
Fix project_encoder
Browse files
vit_gpt2/modeling_flax_gpt2.py
CHANGED
@@ -312,7 +312,8 @@ class FlaxGPT2Block(nn.Module):
|
|
312 |
# [IMPORTANT] Cross attention requires ``causal=False``! This is a bug I made previously.
|
313 |
self.cross_attn = FlaxGPT2Attention(config=self.config, dtype=self.dtype, causal=False, self_attn=False)
|
314 |
|
315 |
-
|
|
|
316 |
self.encoder_projection_ln = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
317 |
self.encoder_projection_mlp = FlaxGPT2MLP(self.config, self.config.hidden_size, dtype=self.dtype)
|
318 |
|
@@ -352,7 +353,8 @@ class FlaxGPT2Block(nn.Module):
|
|
352 |
cross_attn_weights = None
|
353 |
if encoder_hidden_states is not None:
|
354 |
|
355 |
-
|
|
|
356 |
residual = encoder_hidden_states
|
357 |
encoder_hidden_states = self.encoder_projection_ln(encoder_hidden_states)
|
358 |
feed_forward_hidden_states = self.encoder_projection_mlp(
|
|
|
312 |
# [IMPORTANT] Cross attention requires ``causal=False``! This is a bug I made previously.
|
313 |
self.cross_attn = FlaxGPT2Attention(config=self.config, dtype=self.dtype, causal=False, self_attn=False)
|
314 |
|
315 |
+
project_encoder = getattr(self.config, "project_encoder", None)
|
316 |
+
if project_encoder:
|
317 |
self.encoder_projection_ln = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
318 |
self.encoder_projection_mlp = FlaxGPT2MLP(self.config, self.config.hidden_size, dtype=self.dtype)
|
319 |
|
|
|
353 |
cross_attn_weights = None
|
354 |
if encoder_hidden_states is not None:
|
355 |
|
356 |
+
project_encoder = getattr(self.config, "project_encoder", None)
|
357 |
+
if project_encoder:
|
358 |
residual = encoder_hidden_states
|
359 |
encoder_hidden_states = self.encoder_projection_ln(encoder_hidden_states)
|
360 |
feed_forward_hidden_states = self.encoder_projection_mlp(
|