dat commited on
Commit
afb3179
1 Parent(s): bc11ccf

with grad_accum

Browse files
Files changed (1) hide show
  1. run_mlm_flax.py +34 -20
run_mlm_flax.py CHANGED
@@ -288,7 +288,8 @@ def rotate_checkpoints(ckpt_dir: str, save_total_limit: int):
288
 
289
 
290
 
291
-
 
292
 
293
 
294
 
@@ -396,10 +397,10 @@ if __name__ == "__main__":
396
  return train, val
397
  train, val = train_val_files()
398
  datasets = load_dataset('json', data_files={'train': train, 'validation': val})
399
- #datasets["train"] = datasets["train"].select(range(int(0.8*len(datasets["train"]))))
400
- #datasets["validation"] = datasets["validation"].select(range(int(0.8*len(datasets["validation"]))))
401
- datasets["train"] = datasets["train"].select(range(10000))
402
- datasets["validation"] = datasets["validation"].select(range(10000))
403
 
404
 
405
 
@@ -566,7 +567,7 @@ if __name__ == "__main__":
566
 
567
  # Store some constant
568
  num_epochs = int(training_args.num_train_epochs)
569
- train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() #* training_args.gradient_accumulation_steps
570
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
571
 
572
  num_train_steps = len(train_dataset) // train_batch_size * num_epochs
@@ -613,14 +614,14 @@ if __name__ == "__main__":
613
  mask=decay_mask_fn,
614
  )
615
 
616
- if training_args.gradient_accumulation_steps > 1:
617
- optimizer = optax.MultiSteps(optimizer, training_args.gradient_accumulation_steps)
618
- grad_accum_steps = training_args.gradient_accumulation_steps
619
 
620
  # Setup train state
621
 
622
 
623
- state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
624
 
625
  if training_args.resume_from_checkpoint:
626
  state = restore_checkpoint(training_args.resume_from_checkpoint, state)
@@ -645,17 +646,30 @@ if __name__ == "__main__":
645
  # take average
646
  loss = loss.sum() / label_mask.sum()
647
 
648
- return loss
649
 
650
  grad_fn = jax.value_and_grad(loss_fn)
651
- loss, grad = grad_fn(state.params)
652
- grad = jax.lax.pmean(grad, "batch")
653
- new_state = state.apply_gradients(grads=grad)
654
-
 
 
 
 
 
 
 
 
 
 
 
 
655
  metrics = jax.lax.pmean(
656
- {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step // grad_accum_steps)}, axis_name="batch" #
657
  )
658
 
 
659
  return new_state, metrics, new_dropout_rng
660
 
661
  # Create parallel version of the train step
@@ -699,10 +713,10 @@ if __name__ == "__main__":
699
  # Generate an epoch by shuffling sampling indices from the train dataset
700
  num_train_samples = len(train_dataset)
701
  train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
702
- train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size // grad_accum_steps) #
703
 
704
  # Gather the indexes for creating the batch and do a training step
705
- for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1,initial=resume_step // grad_accum_steps)): #
706
  samples = [train_dataset[int(idx)] for idx in batch_idx]
707
  model_inputs = data_collator(samples, pad_to_multiple_of=16)
708
 
@@ -716,7 +730,7 @@ if __name__ == "__main__":
716
  if cur_step < resume_step:
717
  continue
718
 
719
- if (cur_step % training_args.logging_steps * grad_accum_steps) == 0 and cur_step > 0: #
720
  # Save metrics
721
  train_metric = jax_utils.unreplicate(train_metric)
722
  train_time += time.time() - train_start
@@ -733,7 +747,7 @@ if __name__ == "__main__":
733
 
734
  train_metrics = []
735
 
736
- if cur_step % (training_args.eval_steps * grad_accum_steps) == 0 and cur_step > 0: #
737
  # ======================== Evaluating ==============================
738
  num_eval_samples = len(eval_dataset)
739
  eval_samples_idx = jnp.arange(num_eval_samples)
288
 
289
 
290
 
291
+ class TrainState(train_state.TrainState):
292
+ grad_accum: jnp.ndarray
293
 
294
 
295
 
397
  return train, val
398
  train, val = train_val_files()
399
  datasets = load_dataset('json', data_files={'train': train, 'validation': val})
400
+ datasets["train"] = datasets["train"].select(range(int(0.8*len(datasets["train"]))))
401
+ datasets["validation"] = datasets["validation"].select(range(int(0.8*len(datasets["validation"]))))
402
+ #datasets["train"] = datasets["train"].select(range(10000))
403
+ #datasets["validation"] = datasets["validation"].select(range(10000))
404
 
405
 
406
 
567
 
568
  # Store some constant
569
  num_epochs = int(training_args.num_train_epochs)
570
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
571
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
572
 
573
  num_train_steps = len(train_dataset) // train_batch_size * num_epochs
614
  mask=decay_mask_fn,
615
  )
616
 
617
+ #if training_args.gradient_accumulation_steps > 1:
618
+ # optimizer = optax.MultiSteps(optimizer, training_args.gradient_accumulation_steps)
619
+ #grad_accum_steps = training_args.gradient_accumulation_steps
620
 
621
  # Setup train state
622
 
623
 
624
+ state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer,grad_accum=jax.tree_map(jnp.zeros_like, model.params))
625
 
626
  if training_args.resume_from_checkpoint:
627
  state = restore_checkpoint(training_args.resume_from_checkpoint, state)
646
  # take average
647
  loss = loss.sum() / label_mask.sum()
648
 
649
+ return loss / training_args.gradient_accumulation_steps
650
 
651
  grad_fn = jax.value_and_grad(loss_fn)
652
+ loss, grads = grad_fn(state.params)
653
+ grad_accum = jax.tree_multimap(lambda x, y: x + y, grads, state.grad_accum)
654
+
655
+ def update_fn():
656
+ grads = jax.tree_map(lambda x: x / training_args.gradient_accumulation_steps, grad_accum)
657
+ grads = jax.lax.pmean(grad_accum, "batch")
658
+ new_state = state.apply_gradients(grads=grads,grad_accum=jax.tree_map(jnp.zeros_like, grads))
659
+ return new_state
660
+
661
+ new_state = jax.lax.cond(
662
+ state.step % training_args.gradient_accumulation_steps == 0,
663
+ lambda _: update_fn(),
664
+ lambda _: state.replace(grad_accum=grad_accum, step=state.step + 1),
665
+ None,
666
+ )
667
+
668
  metrics = jax.lax.pmean(
669
+ {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch" #
670
  )
671
 
672
+ #return new_state.replace(new_dropout_rng=new_dropout_rng), metrics
673
  return new_state, metrics, new_dropout_rng
674
 
675
  # Create parallel version of the train step
713
  # Generate an epoch by shuffling sampling indices from the train dataset
714
  num_train_samples = len(train_dataset)
715
  train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
716
+ train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) #// grad_accum_steps
717
 
718
  # Gather the indexes for creating the batch and do a training step
719
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1,initial=resume_step)): #grad_accum
720
  samples = [train_dataset[int(idx)] for idx in batch_idx]
721
  model_inputs = data_collator(samples, pad_to_multiple_of=16)
722
 
730
  if cur_step < resume_step:
731
  continue
732
 
733
+ if (cur_step % training_args.logging_steps) == 0 and cur_step > 0: # * grad_accum_steps
734
  # Save metrics
735
  train_metric = jax_utils.unreplicate(train_metric)
736
  train_time += time.time() - train_start
747
 
748
  train_metrics = []
749
 
750
+ if cur_step % (training_args.eval_steps) == 0 and cur_step > 0: #* grad_accum_steps
751
  # ======================== Evaluating ==============================
752
  num_eval_samples = len(eval_dataset)
753
  eval_samples_idx = jnp.arange(num_eval_samples)