boris commited on
Commit
f254058
1 Parent(s): b7c7458

feat(train): improve pjit speed

Browse files
Files changed (2) hide show
  1. src/dalle_mini/data.py +7 -28
  2. tools/train/train.py +78 -42
src/dalle_mini/data.py CHANGED
@@ -152,14 +152,7 @@ class Dataset:
152
  ),
153
  )
154
 
155
- def dataloader(
156
- self, split, per_device_batch_size, gradient_accumulation_steps=None, epoch=None
157
- ):
158
- num_devices = jax.local_device_count()
159
- total_batch_size = per_device_batch_size * num_devices
160
- if gradient_accumulation_steps is not None:
161
- total_batch_size *= gradient_accumulation_steps
162
-
163
  def _dataloader_datasets_non_streaming(
164
  dataset: Dataset,
165
  rng: jax.random.PRNGKey = None,
@@ -168,7 +161,7 @@ class Dataset:
168
  Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
169
  Shuffle batches if rng is set.
170
  """
171
- steps_per_epoch = len(dataset) // total_batch_size
172
 
173
  if rng is not None:
174
  batch_idx = jax.random.permutation(rng, len(dataset))
@@ -176,20 +169,13 @@ class Dataset:
176
  batch_idx = jnp.arange(len(dataset))
177
 
178
  batch_idx = batch_idx[
179
- : steps_per_epoch * total_batch_size
180
  ] # Skip incomplete batch.
181
- batch_idx = batch_idx.reshape((steps_per_epoch, total_batch_size))
182
 
183
  for idx in batch_idx:
184
  batch = dataset[idx]
185
  batch = {k: jnp.array(v) for k, v in batch.items()}
186
- if gradient_accumulation_steps is not None:
187
- batch = jax.tree_map(
188
- lambda x: x.reshape(
189
- (gradient_accumulation_steps, -1) + x.shape[1:]
190
- ),
191
- batch,
192
- )
193
  yield batch
194
 
195
  def _dataloader_datasets_streaming(
@@ -205,22 +191,15 @@ class Dataset:
205
  # For validation data we put the entire set on each host as we could lose
206
  # too many samples on pods
207
  if epoch is not None:
208
- # reshuffle training data at each epoch (not applicable with validation set)
 
209
  dataset.set_epoch(epoch)
210
  epoch += 1
211
  for item in dataset:
212
  for k, v in item.items():
213
  batch[k].append(v)
214
- if len(batch[keys[0]]) == total_batch_size:
215
  batch = {k: jnp.array(v) for k, v in batch.items()}
216
- if gradient_accumulation_steps is not None:
217
- # training mode
218
- batch = jax.tree_map(
219
- lambda x: x.reshape(
220
- (gradient_accumulation_steps, -1) + x.shape[1:]
221
- ),
222
- batch,
223
- )
224
  yield batch
225
  batch = {k: [] for k in keys}
226
  first_loop = False
 
152
  ),
153
  )
154
 
155
+ def dataloader(self, split, batch_size, epoch=None):
 
 
 
 
 
 
 
156
  def _dataloader_datasets_non_streaming(
157
  dataset: Dataset,
158
  rng: jax.random.PRNGKey = None,
 
161
  Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
162
  Shuffle batches if rng is set.
163
  """
164
+ steps_per_epoch = len(dataset) // batch_size
165
 
166
  if rng is not None:
167
  batch_idx = jax.random.permutation(rng, len(dataset))
 
169
  batch_idx = jnp.arange(len(dataset))
170
 
171
  batch_idx = batch_idx[
172
+ : steps_per_epoch * batch_size
173
  ] # Skip incomplete batch.
174
+ batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
175
 
176
  for idx in batch_idx:
177
  batch = dataset[idx]
178
  batch = {k: jnp.array(v) for k, v in batch.items()}
 
 
 
 
 
 
 
179
  yield batch
180
 
181
  def _dataloader_datasets_streaming(
 
191
  # For validation data we put the entire set on each host as we could lose
192
  # too many samples on pods
193
  if epoch is not None:
194
+ assert split == "train"
195
+ # reshuffle training data at each epoch
196
  dataset.set_epoch(epoch)
197
  epoch += 1
198
  for item in dataset:
199
  for k, v in item.items():
200
  batch[k].append(v)
201
+ if len(batch[keys[0]]) == batch_size:
202
  batch = {k: jnp.array(v) for k, v in batch.items()}
 
 
 
 
 
 
 
 
203
  yield batch
204
  batch = {k: [] for k in keys}
205
  first_loop = False
tools/train/train.py CHANGED
@@ -36,12 +36,12 @@ import transformers
36
  import wandb
37
  from datasets import Dataset
38
  from distributed_shampoo import GraftingType, distributed_shampoo
39
- from flax.core.frozen_dict import FrozenDict, freeze
40
  from flax.serialization import from_bytes, to_bytes
41
  from flax.training import train_state
42
  from flax.training.common_utils import onehot, stack_forest
43
  from jax.experimental import PartitionSpec, maps
44
- from jax.experimental.pjit import pjit
45
  from tqdm import tqdm
46
  from transformers import HfArgumentParser
47
 
@@ -551,12 +551,12 @@ def main():
551
  num_epochs = training_args.num_train_epochs
552
  # batch size
553
  minibatch_size = (
554
- training_args.per_device_train_batch_size * jax.local_device_count()
555
  )
556
  batch_size_per_node = minibatch_size * training_args.gradient_accumulation_steps
557
  batch_size_per_step = batch_size_per_node * jax.process_count()
558
  eval_batch_size = (
559
- training_args.per_device_eval_batch_size * jax.local_device_count()
560
  )
561
  len_train_dataset, len_eval_dataset = dataset.length
562
  steps_per_epoch = (
@@ -762,6 +762,10 @@ def main():
762
  # free memory
763
  del model._params
764
 
 
 
 
 
765
  # label smoothed cross entropy
766
  def loss_fn(logits, labels):
767
  loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
@@ -771,16 +775,18 @@ def main():
771
  # Define gradient update step fn
772
  def train_step(state, batch, delta_time):
773
  # check correct batch shape during compilation
774
- assert batch["labels"].shape[0:2] == (
 
775
  training_args.gradient_accumulation_steps,
776
- minibatch_size,
777
- ), f"Expected label batch of shape gradient_acculumation x minibatch_size x items and got {batch['labels'].shape}"
778
  # create a new rng
779
  dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
780
  # use a different rng per node
781
  dropout_rng = jax.random.fold_in(dropout_rng, jax.process_index())
782
 
783
  def compute_loss(params, minibatch):
 
784
  labels = minibatch.pop("labels")
785
  logits = state.apply_fn(
786
  **minibatch, params=params, dropout_rng=dropout_rng, train=True
@@ -789,32 +795,52 @@ def main():
789
 
790
  grad_fn = jax.value_and_grad(compute_loss)
791
 
792
- if training_args.gradient_accumulation_steps == 1:
793
- minibatch = jax.tree_map(lambda x: x[0], batch)
794
- loss, grads = grad_fn(state.params, minibatch)
795
- else:
796
 
797
- def _cumul_loss_grads(i, cumul_loss_grads):
798
- minibatch = jax.tree_map(lambda x: x[i], batch)
799
- return jax.tree_map(
800
- lambda x, y: x + y,
801
- cumul_loss_grads,
802
- grad_fn(state.params, minibatch),
803
  )
 
 
804
 
805
- init_loss_grads = (
806
- 0.0,
807
- jax.tree_map(jnp.zeros_like, state.params),
808
- )
809
- loss, grads = jax.tree_map(
810
- lambda x: x / training_args.gradient_accumulation_steps,
811
- jax.lax.fori_loop(
812
- 0,
813
- training_args.gradient_accumulation_steps,
814
- _cumul_loss_grads,
815
- init_loss_grads,
816
- ),
817
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
818
 
819
  state = state.apply_gradients(
820
  grads=grads,
@@ -832,6 +858,7 @@ def main():
832
 
833
  # Define eval fn
834
  def eval_step(params, batch):
 
835
  labels = batch.pop("labels")
836
  logits = model(**batch, params=params, train=False)[0]
837
  loss = loss_fn(logits, labels)
@@ -843,13 +870,13 @@ def main():
843
  # Create parallel version of the train and eval step
844
  p_train_step = pjit(
845
  train_step,
846
- in_axis_resources=(state_spec, PartitionSpec(None, "batch"), None),
847
  out_axis_resources=(state_spec, None),
848
  donate_argnums=(0,),
849
  )
850
  p_eval_step = pjit(
851
  eval_step,
852
- in_axis_resources=(param_spec, PartitionSpec("batch")),
853
  out_axis_resources=None,
854
  )
855
 
@@ -890,9 +917,7 @@ def main():
890
  # ======================== Evaluating ==============================
891
  eval_metrics = []
892
  if training_args.do_eval:
893
- eval_loader = dataset.dataloader(
894
- "eval", training_args.per_device_eval_batch_size
895
- )
896
  eval_steps = (
897
  len_eval_dataset // eval_batch_size
898
  if len_eval_dataset is not None
@@ -905,8 +930,8 @@ def main():
905
  leave=False,
906
  total=eval_steps,
907
  ):
908
- # Model forward
909
- metrics = p_eval_step(state.params, batch)
910
  eval_metrics.append(metrics)
911
 
912
  # normalize eval metrics
@@ -1010,8 +1035,7 @@ def main():
1010
  # Generate an epoch by shuffling sampling indices from the train dataset
1011
  train_loader = dataset.dataloader(
1012
  "train",
1013
- training_args.per_device_train_batch_size,
1014
- training_args.gradient_accumulation_steps,
1015
  epoch,
1016
  )
1017
  # train
@@ -1022,15 +1046,27 @@ def main():
1022
  leave=False,
1023
  total=steps_per_epoch,
1024
  ):
1025
-
1026
  # calculate delta time (we have a lag of one step but it's ok)
1027
  new_time = time.perf_counter()
1028
  delta_time = new_time - last_time
1029
  last_time = new_time
1030
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1031
  # train step
1032
- state, train_metrics = p_train_step(state, batch, delta_time)
1033
- step = state.step
1034
 
1035
  if step % training_args.logging_steps == 0 and jax.process_index() == 0:
1036
  all_metrics = metrics_logger.get_all_train_metrics(
 
36
  import wandb
37
  from datasets import Dataset
38
  from distributed_shampoo import GraftingType, distributed_shampoo
39
+ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
40
  from flax.serialization import from_bytes, to_bytes
41
  from flax.training import train_state
42
  from flax.training.common_utils import onehot, stack_forest
43
  from jax.experimental import PartitionSpec, maps
44
+ from jax.experimental.pjit import pjit, with_sharding_constraint
45
  from tqdm import tqdm
46
  from transformers import HfArgumentParser
47
 
 
551
  num_epochs = training_args.num_train_epochs
552
  # batch size
553
  minibatch_size = (
554
+ training_args.per_device_train_batch_size * training_args.dp_devices
555
  )
556
  batch_size_per_node = minibatch_size * training_args.gradient_accumulation_steps
557
  batch_size_per_step = batch_size_per_node * jax.process_count()
558
  eval_batch_size = (
559
+ training_args.per_device_eval_batch_size * training_args.dp_devices
560
  )
561
  len_train_dataset, len_eval_dataset = dataset.length
562
  steps_per_epoch = (
 
762
  # free memory
763
  del model._params
764
 
765
+ # define batch specs
766
+ keys = ["attention_mask", "decoder_input_ids", "input_ids", "labels"]
767
+ batch_spec = freeze({k: PartitionSpec("batch") for k in keys})
768
+
769
  # label smoothed cross entropy
770
  def loss_fn(logits, labels):
771
  loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
 
775
  # Define gradient update step fn
776
  def train_step(state, batch, delta_time):
777
  # check correct batch shape during compilation
778
+ assert batch["labels"].shape[0:3] == (
779
+ training_args.dp_devices,
780
  training_args.gradient_accumulation_steps,
781
+ training_args.per_device_train_batch_size,
782
+ ), f"Expected label batch of shape dp_devices x gradient_acculumation x batch_per_device and got {batch['labels'].shape}"
783
  # create a new rng
784
  dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
785
  # use a different rng per node
786
  dropout_rng = jax.random.fold_in(dropout_rng, jax.process_index())
787
 
788
  def compute_loss(params, minibatch):
789
+ minibatch = unfreeze(minibatch)
790
  labels = minibatch.pop("labels")
791
  logits = state.apply_fn(
792
  **minibatch, params=params, dropout_rng=dropout_rng, train=True
 
795
 
796
  grad_fn = jax.value_and_grad(compute_loss)
797
 
798
+ def loss_grad_per_device(device_batch):
799
+ # device_batch has format (gradient_accumulation_steps, batch_size, ...)
 
 
800
 
801
+ if training_args.gradient_accumulation_steps == 1:
802
+ minibatch = jax.tree_map(
803
+ lambda x: x[0],
804
+ device_batch,
 
 
805
  )
806
+ loss, grads = grad_fn(state.params, minibatch)
807
+ else:
808
 
809
+ def _cumul_loss_grads(i, cumul_loss_grads):
810
+ minibatch = jax.tree_map(
811
+ lambda x: x[i],
812
+ device_batch,
813
+ )
814
+ return jax.tree_map(
815
+ lambda x, y: x + y,
816
+ cumul_loss_grads,
817
+ grad_fn(state.params, minibatch),
818
+ )
819
+
820
+ init_loss_grads = (
821
+ 0.0,
822
+ jax.tree_map(jnp.zeros_like, state.params),
823
+ )
824
+ loss, grads = jax.tree_map(
825
+ lambda x: x / training_args.gradient_accumulation_steps,
826
+ jax.lax.fori_loop(
827
+ 0,
828
+ training_args.gradient_accumulation_steps,
829
+ _cumul_loss_grads,
830
+ init_loss_grads,
831
+ ),
832
+ )
833
+ return loss, grads
834
+
835
+ # calculate loss, grads per dp device
836
+ # batch has shape (dp_devices, gradient_accumulation_steps, batch_per_dp_device, ...)
837
+ loss, grads = jax.vmap(loss_grad_per_device, in_axes=0, out_axes=(0, 0))(batch)
838
+ # enforce sharding constraints to avoid OOM
839
+ loss = with_sharding_constraint(loss, PartitionSpec("batch"))
840
+ grads = with_sharding_constraint(grads, PartitionSpec("batch"))
841
+ # calculate the mean over all devices
842
+ loss = jnp.mean(loss)
843
+ grads = jax.tree_map(lambda x: jnp.mean(x, axis=0), grads)
844
 
845
  state = state.apply_gradients(
846
  grads=grads,
 
858
 
859
  # Define eval fn
860
  def eval_step(params, batch):
861
+ batch = unfreeze(batch)
862
  labels = batch.pop("labels")
863
  logits = model(**batch, params=params, train=False)[0]
864
  loss = loss_fn(logits, labels)
 
870
  # Create parallel version of the train and eval step
871
  p_train_step = pjit(
872
  train_step,
873
+ in_axis_resources=(state_spec, batch_spec, None),
874
  out_axis_resources=(state_spec, None),
875
  donate_argnums=(0,),
876
  )
877
  p_eval_step = pjit(
878
  eval_step,
879
+ in_axis_resources=(param_spec, batch_spec),
880
  out_axis_resources=None,
881
  )
882
 
 
917
  # ======================== Evaluating ==============================
918
  eval_metrics = []
919
  if training_args.do_eval:
920
+ eval_loader = dataset.dataloader("eval", eval_batch_size)
 
 
921
  eval_steps = (
922
  len_eval_dataset // eval_batch_size
923
  if len_eval_dataset is not None
 
930
  leave=False,
931
  total=eval_steps,
932
  ):
933
+ # TODO: make this more efficient once training loop is fast
934
+ metrics = p_eval_step(state.params, freeze(batch))
935
  eval_metrics.append(metrics)
936
 
937
  # normalize eval metrics
 
1035
  # Generate an epoch by shuffling sampling indices from the train dataset
1036
  train_loader = dataset.dataloader(
1037
  "train",
1038
+ batch_size_per_node,
 
1039
  epoch,
1040
  )
1041
  # train
 
1046
  leave=False,
1047
  total=steps_per_epoch,
1048
  ):
 
1049
  # calculate delta time (we have a lag of one step but it's ok)
1050
  new_time = time.perf_counter()
1051
  delta_time = new_time - last_time
1052
  last_time = new_time
1053
 
1054
+ # reshape data into (dp_devices, gradient_accumulation_steps, batch_per_dp_device, ...)
1055
+ batch = jax.tree_map(
1056
+ lambda x: x.reshape(
1057
+ (
1058
+ training_args.dp_devices,
1059
+ training_args.gradient_accumulation_steps,
1060
+ -1,
1061
+ )
1062
+ + x.shape[1:]
1063
+ ),
1064
+ batch,
1065
+ )
1066
+
1067
  # train step
1068
+ state, train_metrics = p_train_step(state, freeze(batch), delta_time)
1069
+ step = int(state.step)
1070
 
1071
  if step % training_args.logging_steps == 0 and jax.process_index() == 0:
1072
  all_metrics = metrics_logger.get_all_train_metrics(