ydshieh
commited on
Commit
•
89ca80d
1
Parent(s):
03d8c80
fix pixel_values being casted to int
Browse files
vit_gpt2/modeling_flax_vit_gpt2_lm.py
CHANGED
@@ -241,7 +241,7 @@ class FlaxViTGPT2LMPreTrainedModel(FlaxPreTrainedModel):
|
|
241 |
|
242 |
return self.module.apply(
|
243 |
{"params": params or self.params},
|
244 |
-
pixel_values=jnp.array(pixel_values, dtype=
|
245 |
output_attentions=output_attentions,
|
246 |
output_hidden_states=output_hidden_states,
|
247 |
return_dict=return_dict,
|
|
|
241 |
|
242 |
return self.module.apply(
|
243 |
{"params": params or self.params},
|
244 |
+
pixel_values=jnp.array(pixel_values, dtype=jnp.float32),
|
245 |
output_attentions=output_attentions,
|
246 |
output_hidden_states=output_hidden_states,
|
247 |
return_dict=return_dict,
|