pere commited on
Commit
a8c4f2a
·
verified ·
1 Parent(s): 2c614c4

Update run_mlm_flax.py

Browse files
Files changed (1) hide show
  1. run_mlm_flax.py +13 -1
run_mlm_flax.py CHANGED
@@ -687,7 +687,18 @@ def main():
687
  per_device_eval_batch_size = training_args.per_device_eval_batch_size
688
  eval_batch_size = per_device_eval_batch_size * local_device_count
689
 
 
 
 
 
 
 
 
 
 
 
690
  num_train_steps = (len(tokenized_datasets["train"]) // (train_batch_size * jax.process_count())) * num_epochs
 
691
 
692
  # Create learning rate schedule
693
  warmup_fn = optax.linear_schedule(
@@ -817,7 +828,8 @@ def main():
817
 
818
  train_samples_idx = np.arange(num_train_samples)
819
  train_samples_idx = np.random.permutation(train_samples_idx)
820
- # Split the training indices across processes train_samples_idx = np.array_split(train_samples_idx, jax.process_count())[jax.process_index()]
 
821
  train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size, drop_last=True)
822
 
823
  # Gather the indexes for creating the batch and do a training step
 
687
  per_device_eval_batch_size = training_args.per_device_eval_batch_size
688
  eval_batch_size = per_device_eval_batch_size * local_device_count
689
 
690
+ # Calculate Global Batch Sizes
691
+ global_train_batch_size = train_batch_size * jax.process_count()
692
+ global_eval_batch_size = eval_batch_size * jax.process_count()
693
+
694
+ # Log Batch Sizes
695
+ logger.info(f"Per-process train batch size: {train_batch_size}")
696
+ logger.info(f"Global train batch size: {global_train_batch_size}")
697
+ logger.info(f"Per-process eval batch size: {per_device_eval_batch_size}")
698
+ logger.info(f"Global eval batch size: {global_eval_batch_size}")
699
+
700
  num_train_steps = (len(tokenized_datasets["train"]) // (train_batch_size * jax.process_count())) * num_epochs
701
+ logger.info(f"Number of training steps: {num_train_steps}")
702
 
703
  # Create learning rate schedule
704
  warmup_fn = optax.linear_schedule(
 
828
 
829
  train_samples_idx = np.arange(num_train_samples)
830
  train_samples_idx = np.random.permutation(train_samples_idx)
831
+ # Split the training indices across processes
832
+ train_samples_idx = np.array_split(train_samples_idx, jax.process_count())[jax.process_index()]
833
  train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size, drop_last=True)
834
 
835
  # Gather the indexes for creating the batch and do a training step