boris commited on
Commit
945d86c
1 Parent(s): 6c1f112

fix: should be converted to array

Browse files
Files changed (1) hide show
  1. seq2seq/run_seq2seq_flax.py +1 -1
seq2seq/run_seq2seq_flax.py CHANGED
@@ -490,7 +490,7 @@ def main():
490
  jnp.array(labels), config.pad_token_id, config.decoder_start_token_id
491
  )
492
 
493
- model_inputs["decoder_input_ids"] = decoder_input_ids
494
 
495
  return model_inputs
496
 
 
490
  jnp.array(labels), config.pad_token_id, config.decoder_start_token_id
491
  )
492
 
493
+ model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
494
 
495
  return model_inputs
496