Pedro Cuenca commited on
Commit
835ea55
1 Parent(s): 945d86c

Shift tokens in numpy because the built in shift function stalls.

Browse files

Possible cause is the conversion to jax arrays and then back to numpy,
we might be moving data to/from the TPU.

Files changed (1) hide show
  1. seq2seq/run_seq2seq_flax.py +11 -11
seq2seq/run_seq2seq_flax.py CHANGED
@@ -462,16 +462,19 @@ def main():
462
  # Temporarily set max_target_length for training.
463
  max_target_length = data_args.max_target_length
464
 
465
- # In Flax, for seq2seq models we need to pass `decoder_input_ids`
466
- # as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
467
- # for that dynamically import the `shift_tokens_right` function from the model file
468
- model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
469
- shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
 
 
 
470
 
471
- # Setting padding="max_length" as we need fixed length inputs for jitted functions
472
  def preprocess_function(examples):
473
  inputs = examples[text_column]
474
  inputs = [prefix + inp for inp in inputs]
 
475
  model_inputs = tokenizer(
476
  inputs, max_length=data_args.max_source_length, padding="max_length", truncation=True, return_tensors="np"
477
  )
@@ -486,11 +489,8 @@ def main():
486
  model_inputs["labels"] = labels
487
 
488
  # In our case, this prepends the bos token and removes the last one
489
- decoder_input_ids = shift_tokens_right_fn(
490
- jnp.array(labels), config.pad_token_id, config.decoder_start_token_id
491
- )
492
-
493
- model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
494
 
495
  return model_inputs
496
 
 
462
  # Temporarily set max_target_length for training.
463
  max_target_length = data_args.max_target_length
464
 
465
+ def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int):
466
+ """
467
+ Shift input ids one token to the right.
468
+ """
469
+ shifted_input_ids = np.zeros(input_ids.shape)
470
+ shifted_input_ids[:, 1:] = input_ids[:, :-1]
471
+ shifted_input_ids[:, 0] = decoder_start_token_id
472
+ return shifted_input_ids
473
 
 
474
  def preprocess_function(examples):
475
  inputs = examples[text_column]
476
  inputs = [prefix + inp for inp in inputs]
477
+ # Setting padding="max_length" as we need fixed length inputs for jitted functions
478
  model_inputs = tokenizer(
479
  inputs, max_length=data_args.max_source_length, padding="max_length", truncation=True, return_tensors="np"
480
  )
 
489
  model_inputs["labels"] = labels
490
 
491
  # In our case, this prepends the bos token and removes the last one
492
+ decoder_input_ids = shift_tokens_right(labels, config.decoder_start_token_id)
493
+ model_inputs["decoder_input_ids"] = decoder_input_ids
 
 
 
494
 
495
  return model_inputs
496