ydshieh commited on
Commit
4bf5c74
1 Parent(s): bfef308

fix attention_mask missing

Browse files
vit_gpt2/modeling_flax_vit_gpt2_lm.py CHANGED
@@ -386,6 +386,7 @@ class FlaxViTGPT2LMPreTrainedModel(FlaxPreTrainedModel):
386
  return self.module.apply(
387
  {"params": params or self.params},
388
  pixel_values=jnp.array(pixel_values, dtype=jnp.float32),
 
389
  decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
390
  decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
391
  decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
 
386
  return self.module.apply(
387
  {"params": params or self.params},
388
  pixel_values=jnp.array(pixel_values, dtype=jnp.float32),
389
+ attention_mask=attention_mask,
390
  decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
391
  decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
392
  decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),