versae commited on
Commit
986ff4e
1 Parent(s): 1c5d797

Adding pad_to_multiple_of=16

Browse files
Files changed (1) hide show
  1. run_mlm_flax_stream.py +4 -4
run_mlm_flax_stream.py CHANGED
@@ -218,9 +218,9 @@ class FlaxDataCollatorForLanguageModeling:
218
  "You should pass `mlm=False` to train on causal language modeling instead."
219
  )
220
 
221
- def __call__(self, examples: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]:
222
  # Handle dict or lists with proper padding and conversion to tensor.
223
- batch = self.tokenizer.pad(examples, return_tensors=TensorType.NUMPY)
224
 
225
  # If special token mask has been preprocessed, pop it from the dict.
226
  special_tokens_mask = batch.pop("special_tokens_mask", None)
@@ -653,7 +653,7 @@ if __name__ == "__main__":
653
  samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)
654
 
655
  # process input samples
656
- model_inputs = data_collator(samples)
657
 
658
  # Model forward
659
  model_inputs = shard(model_inputs.data)
@@ -678,7 +678,7 @@ if __name__ == "__main__":
678
  for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=1)):
679
  # process input samples
680
  batch_eval_samples = {k: [v[idx] for idx in batch_idx] for k, v in eval_samples.items()}
681
- model_inputs = data_collator(batch_eval_samples)
682
 
683
  # Model forward
684
  model_inputs = shard(model_inputs.data)
218
  "You should pass `mlm=False` to train on causal language modeling instead."
219
  )
220
 
221
+ def __call__(self, examples: List[Dict[str, np.ndarray]], pad_to_multiple_of: int) -> Dict[str, np.ndarray]:
222
  # Handle dict or lists with proper padding and conversion to tensor.
223
+ batch = self.tokenizer.pad(examples, pad_to_multiple_of=pad_to_multiple_of, return_tensors=TensorType.NUMPY)
224
 
225
  # If special token mask has been preprocessed, pop it from the dict.
226
  special_tokens_mask = batch.pop("special_tokens_mask", None)
653
  samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)
654
 
655
  # process input samples
656
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
657
 
658
  # Model forward
659
  model_inputs = shard(model_inputs.data)
678
  for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=1)):
679
  # process input samples
680
  batch_eval_samples = {k: [v[idx] for idx in batch_idx] for k, v in eval_samples.items()}
681
+ model_inputs = data_collator(batch_eval_samples, pad_to_multiple_of=16)
682
 
683
  # Model forward
684
  model_inputs = shard(model_inputs.data)