ydshieh commited on
Commit
03d8c80
1 Parent(s): 9aceda3

fix decoder_position_ids in decode()

Browse files
vit_gpt2/modeling_flax_vit_gpt2_lm.py CHANGED
@@ -289,7 +289,7 @@ class FlaxViTGPT2LMPreTrainedModel(FlaxPreTrainedModel):
289
  "Make sure to provide `position_ids` when passing `past_key_values`."
290
  )
291
 
292
- position_ids = jnp.broadcast_to(
293
  jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
294
  )
295
 
 
289
  "Make sure to provide `position_ids` when passing `past_key_values`."
290
  )
291
 
292
+ decoder_position_ids = jnp.broadcast_to(
293
  jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
294
  )
295