boris commited on
Commit
0952927
1 Parent(s): 1bb3269

feat(train) - handle multiple nodes (#130)

Browse files
Files changed (2) hide show
  1. src/dalle_mini/data.py +1 -1
  2. tools/train/train.py +69 -54
src/dalle_mini/data.py CHANGED
@@ -94,7 +94,7 @@ class Dataset:
94
  if self.streaming:
95
  # we need to shuffle early in streaming mode
96
  if hasattr(self, "train_dataset"):
97
- self.train_dataset = self.train_dataset.shuffle(1000, self.seed_dataset)
98
  else:
99
  # prepare rng for later shuffling
100
  if self.seed_dataset is None:
 
94
  if self.streaming:
95
  # we need to shuffle early in streaming mode
96
  if hasattr(self, "train_dataset"):
97
+ self.train_dataset = self.train_dataset.shuffle(5000, self.seed_dataset)
98
  else:
99
  # prepare rng for later shuffling
100
  if self.seed_dataset is None:
tools/train/train.py CHANGED
@@ -1,6 +1,6 @@
1
  #!/usr/bin/env python
2
  # coding=utf-8
3
- # Copyright 2021 The HuggingFace Team All rights reserved.
4
  #
5
  # Licensed under the Apache License, Version 2.0 (the "License");
6
  # you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
14
  # See the License for the specific language governing permissions and
15
  # limitations under the License.
16
  """
17
- Fine-tuning the library models for seq2seq, text to image.
18
  Script adapted from run_summarization_flax.py
19
  """
20
 
@@ -527,23 +527,29 @@ def main():
527
  dataset.preprocess(tokenizer=tokenizer, config=model.config)
528
 
529
  # Initialize our training
530
- rng = jax.random.PRNGKey(training_args.seed_model)
531
- rng, dropout_rng = jax.random.split(rng)
532
 
533
  # Store some constant
534
  num_epochs = training_args.num_train_epochs
535
  # batch size
536
- minibatch_size = (
537
- training_args.per_device_train_batch_size * training_args.dp_devices
 
 
 
 
 
538
  )
539
- batch_size_per_node = minibatch_size * training_args.gradient_accumulation_steps
540
  batch_size_per_step = batch_size_per_node * jax.process_count()
541
- eval_batch_size = (
542
- training_args.per_device_eval_batch_size * training_args.dp_devices
 
 
543
  )
 
544
  len_train_dataset, len_eval_dataset = dataset.length
545
  steps_per_epoch = (
546
- len_train_dataset // batch_size_per_node
547
  if len_train_dataset is not None
548
  else None
549
  )
@@ -763,13 +769,21 @@ def main():
763
 
764
  # Define gradient update step fn
765
  def train_step(state, batch, delta_time):
766
- # batch is (gradient_accumulation_steps, minibatch_size, ...)
767
- # check correct batch shape during compilation
768
- assert batch["labels"].shape[0:3] == (
769
- training_args.gradient_accumulation_steps,
770
- training_args.dp_devices,
771
- training_args.per_device_train_batch_size,
772
- ), f"Expected label batch of shape dp_devices x gradient_acculumation x batch_per_device and got {batch['labels'].shape}"
 
 
 
 
 
 
 
 
773
 
774
  # get a minibatch (one gradient accumulation slice)
775
  def get_minibatch(batch, grad_idx):
@@ -791,54 +805,45 @@ def main():
791
  def loss_and_grad(grad_idx, dropout_rng):
792
  # minibatch at grad_idx, shape (dp_devices, per_device_train_batch_size, ...)
793
  minibatch = get_minibatch(batch, grad_idx)
794
- # ensure batch is sharded over devices
 
 
795
  minibatch = jax.tree_map(
796
- lambda x: with_sharding_constraint(x, PartitionSpec("batch")), minibatch
 
797
  )
798
- # calculate loss and grads independently per dp_device
799
  loss_grads = jax.vmap(grad_fn, in_axes=(None, 0, None), out_axes=(0, 0))(
800
  state.params, minibatch, dropout_rng
801
  )
802
- # ensure they are sharded over devices
803
  loss_grads = jax.tree_map(
804
  lambda x: with_sharding_constraint(x, PartitionSpec("batch")),
805
  loss_grads,
806
  )
807
-
808
  # average across all devices
809
  loss_grads = jax.tree_map(lambda x: jnp.mean(x, axis=0), loss_grads)
810
-
811
  # return loss and grads
812
- return loss_grads
813
-
814
- # create a new rng
815
- dropout_rng, _ = jax.random.split(state.dropout_rng)
816
- # use a different rng per node
817
- dropout_rng = jax.random.fold_in(dropout_rng, jax.process_index())
818
 
819
  if training_args.gradient_accumulation_steps == 1:
820
-
821
- def batch_step(dropout_rng):
822
- dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
823
- loss_grad = loss_and_grad(0, dropout_rng)
824
- return loss_grad, new_dropout_rng
825
-
826
- loss_grad, dropout_rng = batch_step(dropout_rng)
827
  else:
828
- # create initial state for per_minibatch_step loop
829
- init_cumul_loss_grad = (
830
- 0.0,
831
- jax.tree_map(jnp.zeros_like, state.params),
 
 
 
832
  )
833
- init_minibatch_step = (init_cumul_loss_grad, dropout_rng)
834
 
835
  # accumulate gradients
836
  def cumul_minibatch_step(grad_idx, cumul_loss_grad_dropout):
837
  cumul_loss_grad, dropout_rng = cumul_loss_grad_dropout
838
- dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
839
- loss_grad = loss_and_grad(grad_idx, dropout_rng)
840
  cumul_loss_grad = jax.tree_map(jnp.add, cumul_loss_grad, loss_grad)
841
- return cumul_loss_grad, new_dropout_rng
842
 
843
  # loop over gradients
844
  loss_grad, dropout_rng = jax.lax.fori_loop(
@@ -870,6 +875,20 @@ def main():
870
 
871
  # Define eval fn
872
  def eval_step(state, batch):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
873
  def compute_eval_loss(batch):
874
  batch, labels = batch.pop("labels")
875
  logits = state.apply_fn(**batch, params=state.params, train=False)[0]
@@ -936,9 +955,9 @@ def main():
936
  def run_evaluation():
937
  # ======================== Evaluating ==============================
938
  if training_args.do_eval:
939
- eval_loader = dataset.dataloader("eval", eval_batch_size)
940
  eval_steps = (
941
- len_eval_dataset // eval_batch_size
942
  if len_eval_dataset is not None
943
  else None
944
  )
@@ -950,17 +969,14 @@ def main():
950
  leave=False,
951
  total=eval_steps,
952
  ):
953
- # reshape data into (dp_devices, batch_per_dp, ...)
954
  batch = jax.tree_map(
955
  lambda x: x.reshape(
956
- (
957
- training_args.dp_devices,
958
- training_args.per_device_eval_batch_size,
959
- )
960
- + x.shape[1:]
961
  ),
962
  batch,
963
  )
 
964
  # freeze batch to pass safely to jax transforms
965
  batch = freeze(batch)
966
  # accumulate losses async
@@ -1081,8 +1097,7 @@ def main():
1081
  lambda x: x.reshape(
1082
  (
1083
  training_args.gradient_accumulation_steps,
1084
- training_args.dp_devices,
1085
- training_args.per_device_train_batch_size,
1086
  )
1087
  + x.shape[1:]
1088
  ),
 
1
  #!/usr/bin/env python
2
  # coding=utf-8
3
+ # Copyright 2021-2022 The HuggingFace & DALL·E Mini Team All rights reserved.
4
  #
5
  # Licensed under the Apache License, Version 2.0 (the "License");
6
  # you may not use this file except in compliance with the License.
 
14
  # See the License for the specific language governing permissions and
15
  # limitations under the License.
16
  """
17
+ Training DALL·E Mini.
18
  Script adapted from run_summarization_flax.py
19
  """
20
 
 
527
  dataset.preprocess(tokenizer=tokenizer, config=model.config)
528
 
529
  # Initialize our training
530
+ dropout_rng = jax.random.PRNGKey(training_args.seed_model)
 
531
 
532
  # Store some constant
533
  num_epochs = training_args.num_train_epochs
534
  # batch size
535
+ batch_size_per_node_per_grad_step = (
536
+ training_args.per_device_train_batch_size
537
+ * jax.local_device_count()
538
+ // training_args.mp_devices
539
+ )
540
+ batch_size_per_node = (
541
+ batch_size_per_node_per_grad_step * training_args.gradient_accumulation_steps
542
  )
 
543
  batch_size_per_step = batch_size_per_node * jax.process_count()
544
+ eval_batch_size_per_node = (
545
+ training_args.per_device_eval_batch_size
546
+ * jax.local_device_count()
547
+ // training_args.mp_devices
548
  )
549
+ eval_batch_size_per_step = eval_batch_size_per_node * jax.process_count()
550
  len_train_dataset, len_eval_dataset = dataset.length
551
  steps_per_epoch = (
552
+ len_train_dataset // batch_size_per_step
553
  if len_train_dataset is not None
554
  else None
555
  )
 
769
 
770
  # Define gradient update step fn
771
  def train_step(state, batch, delta_time):
772
+ # we reshape to (gradient_accumulation_steps, dp_devices, ...)
773
+ # allows feeding partial batch size per node for full model parallel
774
+ batch = jax.tree_map(
775
+ lambda x: x.reshape(
776
+ (
777
+ training_args.gradient_accumulation_steps,
778
+ training_args.dp_devices,
779
+ training_args.per_device_train_batch_size,
780
+ )
781
+ + x.shape[2:]
782
+ ),
783
+ batch,
784
+ )
785
+ # ensure data is sharded correctly per dp device
786
+ batch = with_sharding_constraint(batch, grad_batch_spec)
787
 
788
  # get a minibatch (one gradient accumulation slice)
789
  def get_minibatch(batch, grad_idx):
 
805
  def loss_and_grad(grad_idx, dropout_rng):
806
  # minibatch at grad_idx, shape (dp_devices, per_device_train_batch_size, ...)
807
  minibatch = get_minibatch(batch, grad_idx)
808
+ # calculate loss and grads independently per dp_device
809
+ dropout_rng, _ = jax.random.split(dropout_rng)
810
+ # ensure inputs are sharded per device
811
  minibatch = jax.tree_map(
812
+ lambda x: with_sharding_constraint(x, PartitionSpec("batch")),
813
+ minibatch,
814
  )
815
+ # only 1 single rng per grad step, let us handle larger batch size
816
  loss_grads = jax.vmap(grad_fn, in_axes=(None, 0, None), out_axes=(0, 0))(
817
  state.params, minibatch, dropout_rng
818
  )
819
+ # ensure outputs are sharded per device
820
  loss_grads = jax.tree_map(
821
  lambda x: with_sharding_constraint(x, PartitionSpec("batch")),
822
  loss_grads,
823
  )
 
824
  # average across all devices
825
  loss_grads = jax.tree_map(lambda x: jnp.mean(x, axis=0), loss_grads)
 
826
  # return loss and grads
827
+ return loss_grads, dropout_rng
 
 
 
 
 
828
 
829
  if training_args.gradient_accumulation_steps == 1:
830
+ loss_grad, dropout_rng = loss_and_grad(0, state.dropout_rng)
 
 
 
 
 
 
831
  else:
832
+ # create initial state for cumul_minibatch_step loop
833
+ init_minibatch_step = (
834
+ (
835
+ 0.0,
836
+ jax.tree_map(jnp.zeros_like, state.params),
837
+ ),
838
+ state.dropout_rng,
839
  )
 
840
 
841
  # accumulate gradients
842
  def cumul_minibatch_step(grad_idx, cumul_loss_grad_dropout):
843
  cumul_loss_grad, dropout_rng = cumul_loss_grad_dropout
844
+ loss_grad, dropout_rng = loss_and_grad(grad_idx, dropout_rng)
 
845
  cumul_loss_grad = jax.tree_map(jnp.add, cumul_loss_grad, loss_grad)
846
+ return cumul_loss_grad, dropout_rng
847
 
848
  # loop over gradients
849
  loss_grad, dropout_rng = jax.lax.fori_loop(
 
875
 
876
  # Define eval fn
877
  def eval_step(state, batch):
878
+ # we reshape to (dp_devices, ...)
879
+ batch = jax.tree_map(
880
+ lambda x: x.reshape(
881
+ (
882
+ training_args.dp_devices,
883
+ training_args.per_device_eval_batch_size,
884
+ )
885
+ + x.shape[1:]
886
+ ),
887
+ batch,
888
+ )
889
+ # ensure data is sharded correctly per dp device
890
+ batch = with_sharding_constraint(batch, batch_spec)
891
+
892
  def compute_eval_loss(batch):
893
  batch, labels = batch.pop("labels")
894
  logits = state.apply_fn(**batch, params=state.params, train=False)[0]
 
955
  def run_evaluation():
956
  # ======================== Evaluating ==============================
957
  if training_args.do_eval:
958
+ eval_loader = dataset.dataloader("eval", eval_batch_size_per_step)
959
  eval_steps = (
960
+ len_eval_dataset // eval_batch_size_per_step
961
  if len_eval_dataset is not None
962
  else None
963
  )
 
969
  leave=False,
970
  total=eval_steps,
971
  ):
972
+ # need to keep only eval_batch_size_per_node items relevant to the node
973
  batch = jax.tree_map(
974
  lambda x: x.reshape(
975
+ (jax.process_count(), eval_batch_size_per_node) + x.shape[1:]
 
 
 
 
976
  ),
977
  batch,
978
  )
979
+ batch = jax.tree_map(lambda x: x[jax.process_index()], batch)
980
  # freeze batch to pass safely to jax transforms
981
  batch = freeze(batch)
982
  # accumulate losses async
 
1097
  lambda x: x.reshape(
1098
  (
1099
  training_args.gradient_accumulation_steps,
1100
+ batch_size_per_node_per_grad_step,
 
1101
  )
1102
  + x.shape[1:]
1103
  ),