boris commited on
Commit
7b5868f
1 Parent(s): 00710bc

feat(pjit): follow t5x style

Browse files
Files changed (1) hide show
  1. tools/train/train.py +65 -58
tools/train/train.py CHANGED
@@ -765,6 +765,7 @@ def main():
765
  # define batch specs
766
  keys = ["attention_mask", "decoder_input_ids", "input_ids", "labels"]
767
  batch_spec = freeze({k: PartitionSpec("batch") for k in keys})
 
768
 
769
  # label smoothed cross entropy
770
  def loss_fn(logits, labels):
@@ -774,18 +775,22 @@ def main():
774
 
775
  # Define gradient update step fn
776
  def train_step(state, batch, delta_time):
 
777
  # check correct batch shape during compilation
778
- assert batch["labels"].shape[0:3] == (
779
- training_args.dp_devices,
780
  training_args.gradient_accumulation_steps,
781
- training_args.per_device_train_batch_size,
782
  ), f"Expected label batch of shape dp_devices x gradient_acculumation x batch_per_device and got {batch['labels'].shape}"
783
- # create a new rng
784
- dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
785
- # use a different rng per node
786
- dropout_rng = jax.random.fold_in(dropout_rng, jax.process_index())
787
 
788
- def compute_loss(params, minibatch):
 
 
 
 
 
 
 
 
789
  minibatch = unfreeze(minibatch)
790
  labels = minibatch.pop("labels")
791
  logits = state.apply_fn(
@@ -795,58 +800,61 @@ def main():
795
 
796
  grad_fn = jax.value_and_grad(compute_loss)
797
 
798
- def loss_grad_per_device(device_batch):
799
- # device_batch has format (gradient_accumulation_steps, batch_size, ...)
 
 
 
 
 
 
800
 
801
- if training_args.gradient_accumulation_steps == 1:
802
- minibatch = jax.tree_map(
803
- lambda x: x[0],
804
- device_batch,
805
- )
806
- loss, grads = grad_fn(state.params, minibatch)
807
- else:
808
 
809
- def _cumul_loss_grads(i, cumul_loss_grads):
810
- minibatch = jax.tree_map(
811
- lambda x: x[i],
812
- device_batch,
813
- )
814
- return jax.tree_map(
815
- lambda x, y: x + y,
816
- cumul_loss_grads,
817
- grad_fn(state.params, minibatch),
818
- )
819
 
820
- init_loss_grads = (
821
- 0.0,
822
- jax.tree_map(jnp.zeros_like, state.params),
823
- )
824
- loss, grads = jax.tree_map(
825
- lambda x: x / training_args.gradient_accumulation_steps,
826
- jax.lax.fori_loop(
827
- 0,
828
- training_args.gradient_accumulation_steps,
829
- _cumul_loss_grads,
830
- init_loss_grads,
831
- ),
832
- )
833
- return loss, grads
834
-
835
- # calculate loss, grads per dp device
836
- # batch has shape (dp_devices, gradient_accumulation_steps, batch_per_dp_device, ...)
837
- loss, grads = jax.vmap(loss_grad_per_device, in_axes=0, out_axes=(0, 0))(batch)
838
- # enforce sharding constraints to avoid OOM
839
- loss = with_sharding_constraint(loss, PartitionSpec("batch"))
840
- grads = jax.tree_map(
841
- lambda x: with_sharding_constraint(x, PartitionSpec("batch")), grads
842
- )
843
- # calculate the mean over all devices
844
- loss = jnp.mean(loss)
845
- grads = jax.tree_map(lambda x: jnp.mean(x, axis=0), grads)
 
 
 
 
 
 
 
846
 
 
 
847
  state = state.apply_gradients(
848
  grads=grads,
849
- dropout_rng=new_dropout_rng,
850
  train_time=state.train_time + delta_time,
851
  train_samples=state.train_samples + batch_size_per_step,
852
  )
@@ -872,7 +880,7 @@ def main():
872
  # Create parallel version of the train and eval step
873
  p_train_step = pjit(
874
  train_step,
875
- in_axis_resources=(state_spec, batch_spec, None),
876
  out_axis_resources=(state_spec, None),
877
  donate_argnums=(0,),
878
  )
@@ -1053,13 +1061,12 @@ def main():
1053
  delta_time = new_time - last_time
1054
  last_time = new_time
1055
 
1056
- # reshape data into (dp_devices, gradient_accumulation_steps, batch_per_dp_device, ...)
1057
  batch = jax.tree_map(
1058
  lambda x: x.reshape(
1059
  (
1060
- training_args.dp_devices,
1061
  training_args.gradient_accumulation_steps,
1062
- -1,
1063
  )
1064
  + x.shape[1:]
1065
  ),
 
765
  # define batch specs
766
  keys = ["attention_mask", "decoder_input_ids", "input_ids", "labels"]
767
  batch_spec = freeze({k: PartitionSpec("batch") for k in keys})
768
+ grad_batch_spec = freeze({k: PartitionSpec(None, "batch") for k in keys})
769
 
770
  # label smoothed cross entropy
771
  def loss_fn(logits, labels):
 
775
 
776
  # Define gradient update step fn
777
  def train_step(state, batch, delta_time):
778
+ # batch is (gradient_accumulation_steps, minibatch_size, ...)
779
  # check correct batch shape during compilation
780
+ assert batch["labels"].shape[0:2] == (
 
781
  training_args.gradient_accumulation_steps,
782
+ minibatch_size,
783
  ), f"Expected label batch of shape dp_devices x gradient_acculumation x batch_per_device and got {batch['labels'].shape}"
 
 
 
 
784
 
785
+ # get a minibatch (one gradient accumulation slice)
786
+ def get_minibatch(batch, grad_idx):
787
+ return jax.tree_map(
788
+ lambda x: jax.lax.dynamic_index_in_dim(x, grad_idx, keepdims=False),
789
+ batch,
790
+ )
791
+
792
+ def compute_loss(params, minibatch, dropout_rng):
793
+ # minibatch has dim (batch_size, ...)
794
  minibatch = unfreeze(minibatch)
795
  labels = minibatch.pop("labels")
796
  logits = state.apply_fn(
 
800
 
801
  grad_fn = jax.value_and_grad(compute_loss)
802
 
803
+ def loss_and_grad(grad_idx, dropout_rng):
804
+ minibatch = get_minibatch(batch, grad_idx)
805
+ # ensure batch is sharded over devices
806
+ minibatch = jax.tree_map(
807
+ lambda x: with_sharding_constraint(x, PartitionSpec("batch")), minibatch
808
+ )
809
+ # return loss and grads
810
+ return grad_fn(state.params, minibatch, dropout_rng)
811
 
812
+ # create a new rng
813
+ dropout_rng, _ = jax.random.split(state.dropout_rng)
814
+ # use a different rng per node
815
+ dropout_rng = jax.random.fold_in(dropout_rng, jax.process_index())
 
 
 
816
 
817
+ if training_args.gradient_accumulation_steps == 1:
 
 
 
 
 
 
 
 
 
818
 
819
+ def batch_step(dropout_rng):
820
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
821
+ loss_grad = loss_and_grad(0, dropout_rng)
822
+ return loss_grad, new_dropout_rng
823
+
824
+ loss_grad, dropout_rng = batch_step(dropout_rng)
825
+ else:
826
+ # create initial state for per_minibatch_step loop
827
+ init_cumul_loss_grad = (
828
+ 0.0,
829
+ jax.tree_map(jnp.zeros_like, state.params),
830
+ )
831
+ init_minibatch_step = (init_cumul_loss_grad, dropout_rng)
832
+
833
+ # accumulate gradients
834
+ def cumul_minibatch_step(grad_idx, cumul_loss_grad_dropout):
835
+ cumul_loss_grad, dropout_rng = cumul_loss_grad_dropout
836
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
837
+ loss_grad = loss_and_grad(grad_idx, dropout_rng)
838
+ cumul_loss_grad = jax.tree_map(jnp.add, cumul_loss_grad, loss_grad)
839
+ return cumul_loss_grad, new_dropout_rng
840
+
841
+ # loop over gradients
842
+ loss_grad, dropout_rng = jax.lax.fori_loop(
843
+ 0,
844
+ training_args.gradient_accumulation_steps,
845
+ cumul_minibatch_step,
846
+ init_minibatch_step,
847
+ )
848
+ # sum -> mean
849
+ loss_grad = jax.tree_map(
850
+ lambda x: x / training_args.gradient_accumulation_steps, loss_grad
851
+ )
852
 
853
+ # update state
854
+ loss, grads = loss_grad
855
  state = state.apply_gradients(
856
  grads=grads,
857
+ dropout_rng=dropout_rng,
858
  train_time=state.train_time + delta_time,
859
  train_samples=state.train_samples + batch_size_per_step,
860
  )
 
880
  # Create parallel version of the train and eval step
881
  p_train_step = pjit(
882
  train_step,
883
+ in_axis_resources=(state_spec, grad_batch_spec, None),
884
  out_axis_resources=(state_spec, None),
885
  donate_argnums=(0,),
886
  )
 
1061
  delta_time = new_time - last_time
1062
  last_time = new_time
1063
 
1064
+ # reshape data into (gradient_accumulation_steps, minibatch_size, ...)
1065
  batch = jax.tree_map(
1066
  lambda x: x.reshape(
1067
  (
 
1068
  training_args.gradient_accumulation_steps,
1069
+ minibatch_size,
1070
  )
1071
  + x.shape[1:]
1072
  ),