boris commited on
Commit
2b7f5f1
1 Parent(s): 7b5868f

feat(train): overhead from 70% to 1% 🥳

Browse files
Files changed (1) hide show
  1. tools/train/train.py +21 -5
tools/train/train.py CHANGED
@@ -777,9 +777,10 @@ def main():
777
  def train_step(state, batch, delta_time):
778
  # batch is (gradient_accumulation_steps, minibatch_size, ...)
779
  # check correct batch shape during compilation
780
- assert batch["labels"].shape[0:2] == (
781
  training_args.gradient_accumulation_steps,
782
- minibatch_size,
 
783
  ), f"Expected label batch of shape dp_devices x gradient_acculumation x batch_per_device and got {batch['labels'].shape}"
784
 
785
  # get a minibatch (one gradient accumulation slice)
@@ -801,13 +802,27 @@ def main():
801
  grad_fn = jax.value_and_grad(compute_loss)
802
 
803
  def loss_and_grad(grad_idx, dropout_rng):
 
804
  minibatch = get_minibatch(batch, grad_idx)
805
  # ensure batch is sharded over devices
806
  minibatch = jax.tree_map(
807
  lambda x: with_sharding_constraint(x, PartitionSpec("batch")), minibatch
808
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
809
  # return loss and grads
810
- return grad_fn(state.params, minibatch, dropout_rng)
811
 
812
  # create a new rng
813
  dropout_rng, _ = jax.random.split(state.dropout_rng)
@@ -1061,12 +1076,13 @@ def main():
1061
  delta_time = new_time - last_time
1062
  last_time = new_time
1063
 
1064
- # reshape data into (gradient_accumulation_steps, minibatch_size, ...)
1065
  batch = jax.tree_map(
1066
  lambda x: x.reshape(
1067
  (
1068
  training_args.gradient_accumulation_steps,
1069
- minibatch_size,
 
1070
  )
1071
  + x.shape[1:]
1072
  ),
 
777
  def train_step(state, batch, delta_time):
778
  # batch is (gradient_accumulation_steps, minibatch_size, ...)
779
  # check correct batch shape during compilation
780
+ assert batch["labels"].shape[0:3] == (
781
  training_args.gradient_accumulation_steps,
782
+ training_args.dp_devices,
783
+ training_args.per_device_train_batch_size,
784
  ), f"Expected label batch of shape dp_devices x gradient_acculumation x batch_per_device and got {batch['labels'].shape}"
785
 
786
  # get a minibatch (one gradient accumulation slice)
 
802
  grad_fn = jax.value_and_grad(compute_loss)
803
 
804
  def loss_and_grad(grad_idx, dropout_rng):
805
+ # minibatch at grad_idx, shape (dp_devices, per_device_train_batch_size, ...)
806
  minibatch = get_minibatch(batch, grad_idx)
807
  # ensure batch is sharded over devices
808
  minibatch = jax.tree_map(
809
  lambda x: with_sharding_constraint(x, PartitionSpec("batch")), minibatch
810
  )
811
+ # calculate loss and grads independently per dp_device
812
+ loss_grads = jax.vmap(grad_fn, in_axes=(None, 0, None), out_axes=(0, 0))(
813
+ state.params, minibatch, dropout_rng
814
+ )
815
+ # ensure they are sharded over devices
816
+ loss_grads = jax.tree_map(
817
+ lambda x: with_sharding_constraint(x, PartitionSpec("batch")),
818
+ loss_grads,
819
+ )
820
+
821
+ # average across all devices
822
+ loss_grads = jax.tree_map(lambda x: jnp.mean(x, axis=0), loss_grads)
823
+
824
  # return loss and grads
825
+ return loss_grads
826
 
827
  # create a new rng
828
  dropout_rng, _ = jax.random.split(state.dropout_rng)
 
1076
  delta_time = new_time - last_time
1077
  last_time = new_time
1078
 
1079
+ # reshape data into (gradient_accumulation_steps, dp_devices, batch_per_dp, ...)
1080
  batch = jax.tree_map(
1081
  lambda x: x.reshape(
1082
  (
1083
  training_args.gradient_accumulation_steps,
1084
+ training_args.dp_devices,
1085
+ training_args.per_device_train_batch_size,
1086
  )
1087
  + x.shape[1:]
1088
  ),