boris commited on
Commit
6c1f112
1 Parent(s): 678a62f

fix: labels array

Browse files
Files changed (1) hide show
  1. seq2seq/run_seq2seq_flax.py +1 -1
seq2seq/run_seq2seq_flax.py CHANGED
@@ -479,7 +479,7 @@ def main():
479
  # set up targets
480
  # Note: labels correspond to our target indices
481
  # decoder input ids are the same but shifted to the right with bos at the beginning (and without last token)
482
- labels = [[eval(indices) for indices in examples['encoding']]]
483
  labels = np.asarray(labels)
484
 
485
  # We need the labels, in addition to the decoder_input_ids, for the compute_loss function
 
479
  # set up targets
480
  # Note: labels correspond to our target indices
481
  # decoder input ids are the same but shifted to the right with bos at the beginning (and without last token)
482
+ labels = [eval(indices) for indices in examples['encoding']]
483
  labels = np.asarray(labels)
484
 
485
  # We need the labels, in addition to the decoder_input_ids, for the compute_loss function