boris commited on
Commit
c9e9575
1 Parent(s): cbeacb9

feat: gradient accumulation

Browse files
Files changed (1) hide show
  1. seq2seq/run_seq2seq_flax.py +37 -10
seq2seq/run_seq2seq_flax.py CHANGED
@@ -239,6 +239,8 @@ class DataTrainingArguments:
239
 
240
  class TrainState(train_state.TrainState):
241
  dropout_rng: jnp.ndarray
 
 
242
 
243
  def replicate(self):
244
  return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
@@ -590,14 +592,16 @@ def main():
590
  # Store some constant
591
  num_epochs = int(training_args.num_train_epochs)
592
  train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
 
593
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
594
  steps_per_epoch = len(train_dataset) // train_batch_size
595
- total_train_steps = steps_per_epoch * num_epochs
 
596
 
597
  # Create learning rate schedule
598
  linear_decay_lr_schedule_fn = create_learning_rate_fn(
599
  len(train_dataset),
600
- train_batch_size,
601
  training_args.num_train_epochs,
602
  training_args.warmup_steps,
603
  training_args.learning_rate,
@@ -636,7 +640,14 @@ def main():
636
  )
637
 
638
  # Setup train state
639
- state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
 
 
 
 
 
 
 
640
 
641
  # label smoothed cross entropy
642
  def loss_fn(logits, labels):
@@ -655,15 +666,28 @@ def main():
655
  return loss
656
 
657
  grad_fn = jax.value_and_grad(compute_loss)
658
- loss, grad = grad_fn(state.params)
659
- grad = jax.lax.pmean(grad, "batch")
 
 
 
 
 
 
 
 
660
 
661
- new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
 
 
 
 
 
662
 
663
- metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
664
  metrics = jax.lax.pmean(metrics, axis_name="batch")
665
 
666
- return new_state, metrics
667
 
668
  # Define eval fn
669
  def eval_step(params, batch):
@@ -702,8 +726,11 @@ def main():
702
  logger.info(f" Num examples = {len(train_dataset)}")
703
  logger.info(f" Num Epochs = {num_epochs}")
704
  logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
705
- logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
706
- logger.info(f" Total optimization steps = {total_train_steps}")
 
 
 
707
 
708
  train_time = 0
709
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
 
239
 
240
  class TrainState(train_state.TrainState):
241
  dropout_rng: jnp.ndarray
242
+ grad_accum: jnp.ndarray
243
+ optimizer_step: int
244
 
245
  def replicate(self):
246
  return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
 
592
  # Store some constant
593
  num_epochs = int(training_args.num_train_epochs)
594
  train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
595
+ total_batch_size = int(train_batch_size) * training_args.gradient_accumulation_steps
596
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
597
  steps_per_epoch = len(train_dataset) // train_batch_size
598
+ total_steps = steps_per_epoch * num_epochs
599
+ total_optimization_steps = (len(train_dataset) // total_batch_size) * num_epochs
600
 
601
  # Create learning rate schedule
602
  linear_decay_lr_schedule_fn = create_learning_rate_fn(
603
  len(train_dataset),
604
+ total_batch_size,
605
  training_args.num_train_epochs,
606
  training_args.warmup_steps,
607
  training_args.learning_rate,
 
640
  )
641
 
642
  # Setup train state
643
+ state = TrainState.create(
644
+ apply_fn=model.__call__,
645
+ params=model.params,
646
+ tx=adamw,
647
+ dropout_rng=dropout_rng,
648
+ grad_accum=jax.tree_map(jnp.zeros_like, model.params),
649
+ optimizer_step=0,
650
+ )
651
 
652
  # label smoothed cross entropy
653
  def loss_fn(logits, labels):
 
666
  return loss
667
 
668
  grad_fn = jax.value_and_grad(compute_loss)
669
+ loss, grads = grad_fn(state.params)
670
+ grad_accum = jax.tree_multimap(lambda x, y: x + y, grads, state.grad_accum)
671
+
672
+ def update_fn():
673
+ grads = jax.tree_map(lambda x: x / training_args.gradient_accumulation_steps, grad_accum)
674
+ grads = jax.lax.pmean(grads, "batch")
675
+ new_state = state.apply_gradients(
676
+ grads=grads, grad_accum=jax.tree_map(jnp.zeros_like, grads), optimizer_step=state.optimizer_step
677
+ )
678
+ return new_state
679
 
680
+ new_state = jax.lax.cond(
681
+ state.step % training_args.gradient_accumulation_steps == 0,
682
+ lambda _: update_fn(),
683
+ lambda _: state.replace(grad_accum=grad_accum, step=state.step + 1),
684
+ None,
685
+ )
686
 
687
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.optimizer_step)}
688
  metrics = jax.lax.pmean(metrics, axis_name="batch")
689
 
690
+ return new_state.replace(dropout_rng=new_dropout_rng), metrics
691
 
692
  # Define eval fn
693
  def eval_step(params, batch):
 
726
  logger.info(f" Num examples = {len(train_dataset)}")
727
  logger.info(f" Num Epochs = {num_epochs}")
728
  logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
729
+ logger.info(
730
+ f" Total train batch size (w. parallel & distributed) = {train_batch_size * training_args.gradient_accumulation_steps}"
731
+ )
732
+ logger.info(f" Total global steps = {total_steps}")
733
+ logger.info(f" Total optimization steps = {total_optimization_steps}")
734
 
735
  train_time = 0
736
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)