ydshieh commited on
Commit
58a121d
1 Parent(s): 3ed2a5d

change to add_cross_attention

Browse files
Files changed (1) hide show
  1. vit_gpt2/modeling_flax_gpt2.py +2 -2
vit_gpt2/modeling_flax_gpt2.py CHANGED
@@ -299,7 +299,7 @@ class FlaxGPT2Block(nn.Module):
299
 
300
  def setup(self):
301
 
302
- self.only_self_attn = not self.config.getattr('with_cross_attention', False)
303
 
304
  hidden_size = self.config.hidden_size
305
  inner_dim = self.config.n_inner if self.config.n_inner is not None else 4 * hidden_size
@@ -412,7 +412,7 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
412
  params_rng, dropout_rng = jax.random.split(rng)
413
  rngs = {"params": params_rng, "dropout": dropout_rng}
414
 
415
- if self.config.getattr('with_cross_attention', False):
416
  encoder_hidden_states = jnp.zeros(input_shape + (self.config.n_embd,))
417
  encoder_attention_mask = attention_mask
418
  module_init_outputs = self.module.init(
 
299
 
300
  def setup(self):
301
 
302
+ self.only_self_attn = not self.config.add_cross_attention
303
 
304
  hidden_size = self.config.hidden_size
305
  inner_dim = self.config.n_inner if self.config.n_inner is not None else 4 * hidden_size
 
412
  params_rng, dropout_rng = jax.random.split(rng)
413
  rngs = {"params": params_rng, "dropout": dropout_rng}
414
 
415
+ if self.config.add_cross_attention:
416
  encoder_hidden_states = jnp.zeros(input_shape + (self.config.n_embd,))
417
  encoder_attention_mask = attention_mask
418
  module_init_outputs = self.module.init(