ydshieh commited on
Commit
89ca80d
1 Parent(s): 03d8c80

fix pixel_values being casted to int

Browse files
vit_gpt2/modeling_flax_vit_gpt2_lm.py CHANGED
@@ -241,7 +241,7 @@ class FlaxViTGPT2LMPreTrainedModel(FlaxPreTrainedModel):
241
 
242
  return self.module.apply(
243
  {"params": params or self.params},
244
- pixel_values=jnp.array(pixel_values, dtype="i4"),
245
  output_attentions=output_attentions,
246
  output_hidden_states=output_hidden_states,
247
  return_dict=return_dict,
 
241
 
242
  return self.module.apply(
243
  {"params": params or self.params},
244
+ pixel_values=jnp.array(pixel_values, dtype=jnp.float32),
245
  output_attentions=output_attentions,
246
  output_hidden_states=output_hidden_states,
247
  return_dict=return_dict,