ydshieh commited on
Commit
9a48a1a
1 Parent(s): 7262cfd

Fix project_encoder

Browse files
Files changed (1) hide show
  1. vit_gpt2/modeling_flax_gpt2.py +4 -2
vit_gpt2/modeling_flax_gpt2.py CHANGED
@@ -312,7 +312,8 @@ class FlaxGPT2Block(nn.Module):
312
  # [IMPORTANT] Cross attention requires ``causal=False``! This is a bug I made previously.
313
  self.cross_attn = FlaxGPT2Attention(config=self.config, dtype=self.dtype, causal=False, self_attn=False)
314
 
315
- if self.config.project_encoder:
 
316
  self.encoder_projection_ln = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
317
  self.encoder_projection_mlp = FlaxGPT2MLP(self.config, self.config.hidden_size, dtype=self.dtype)
318
 
@@ -352,7 +353,8 @@ class FlaxGPT2Block(nn.Module):
352
  cross_attn_weights = None
353
  if encoder_hidden_states is not None:
354
 
355
- if self.config.project_encoder:
 
356
  residual = encoder_hidden_states
357
  encoder_hidden_states = self.encoder_projection_ln(encoder_hidden_states)
358
  feed_forward_hidden_states = self.encoder_projection_mlp(
 
312
  # [IMPORTANT] Cross attention requires ``causal=False``! This is a bug I made previously.
313
  self.cross_attn = FlaxGPT2Attention(config=self.config, dtype=self.dtype, causal=False, self_attn=False)
314
 
315
+ project_encoder = getattr(self.config, "project_encoder", None)
316
+ if project_encoder:
317
  self.encoder_projection_ln = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
318
  self.encoder_projection_mlp = FlaxGPT2MLP(self.config, self.config.hidden_size, dtype=self.dtype)
319
 
 
353
  cross_attn_weights = None
354
  if encoder_hidden_states is not None:
355
 
356
+ project_encoder = getattr(self.config, "project_encoder", None)
357
+ if project_encoder:
358
  residual = encoder_hidden_states
359
  encoder_hidden_states = self.encoder_projection_ln(encoder_hidden_states)
360
  feed_forward_hidden_states = self.encoder_projection_mlp(