Pedro Cuenca commited on
Commit
df3c7bd
1 Parent(s): a841a4c

Preprocessing: return "labels", "decoder_input_ids" and

Browse files

"decoder_attention_mask".

All fields are required later on to compute the loss.

Note that labels and decoder_input_ids are the same in our case. I'm not
sure that's correct, but shifting right the decoder_inputs would lose
the last token.

Files changed (1) hide show
  1. seq2seq/run_seq2seq_flax.py +16 -5
seq2seq/run_seq2seq_flax.py CHANGED
@@ -458,19 +458,30 @@ def main():
458
  )
459
 
460
  # set up targets
461
- model_inputs["labels"] = [eval(indices) for indices in examples['encoding']]
 
 
 
 
 
 
 
 
462
 
463
  # TODO: if data processing prevents correct compilation, we will:
464
  # - have data saved in JSONL (to avoid `eval` which is needed here to convert string "[2]" to list[int])
465
  # - use below `shift_tokens_right_fn`
466
- decoder_input_ids = shift_tokens_right_fn(
467
- jnp.array(labels["input_ids"]), config.pad_token_id, config.decoder_start_token_id
468
- )
 
469
 
470
- model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
471
 
472
  # We need decoder_attention_mask so we can ignore pad tokens from loss
473
  # TODO: I don't believe we need "decoder_attention_mask" in this case because all labels have same length
 
 
474
  #model_inputs["decoder_attention_mask"] = labels["attention_mask"]
475
 
476
  return model_inputs
 
458
  )
459
 
460
  # set up targets
461
+ # Note: we prepend the bos token instead of doing `shift_tokens_right` because the latter
462
+ # removes the last token, and we know we don't need padding. In our case, labels
463
+ # has a length of exactly 1 + 256, while shifting would produce 256 tokens.
464
+ labels = [[config.decoder_start_token_id] + eval(indices) for indices in examples['encoding']]
465
+ labels = np.asarray(labels)
466
+
467
+ # We need the labels, in addition to the decoder_input_ids, for the compute_loss function
468
+ # In our case, they are the same as decoder_input_ids. Is that correct?
469
+ model_inputs["labels"] = labels
470
 
471
  # TODO: if data processing prevents correct compilation, we will:
472
  # - have data saved in JSONL (to avoid `eval` which is needed here to convert string "[2]" to list[int])
473
  # - use below `shift_tokens_right_fn`
474
+ # In our case, this prepends the bos token and removes the last one
475
+ # decoder_input_ids = shift_tokens_right_fn(
476
+ # jnp.array(labels), config.pad_token_id, config.decoder_start_token_id
477
+ # )
478
 
479
+ model_inputs["decoder_input_ids"] = labels
480
 
481
  # We need decoder_attention_mask so we can ignore pad tokens from loss
482
  # TODO: I don't believe we need "decoder_attention_mask" in this case because all labels have same length
483
+ # However, we need to provide a mask or modify the compute_loss function, which relies on having one
484
+ model_inputs["decoder_attention_mask"] = np.ones(labels.shape)
485
  #model_inputs["decoder_attention_mask"] = labels["attention_mask"]
486
 
487
  return model_inputs