ydshieh
commited on
Commit
•
4bf5c74
1
Parent(s):
bfef308
fix attention_mask missing
Browse files
vit_gpt2/modeling_flax_vit_gpt2_lm.py
CHANGED
@@ -386,6 +386,7 @@ class FlaxViTGPT2LMPreTrainedModel(FlaxPreTrainedModel):
|
|
386 |
return self.module.apply(
|
387 |
{"params": params or self.params},
|
388 |
pixel_values=jnp.array(pixel_values, dtype=jnp.float32),
|
|
|
389 |
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
|
390 |
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
|
391 |
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
|
|
|
386 |
return self.module.apply(
|
387 |
{"params": params or self.params},
|
388 |
pixel_values=jnp.array(pixel_values, dtype=jnp.float32),
|
389 |
+
attention_mask=attention_mask,
|
390 |
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
|
391 |
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
|
392 |
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
|