boris commited on
Commit
0081723
1 Parent(s): a5ed112

feat(train): start pjit support

Browse files
Files changed (1) hide show
  1. tools/train/train.py +144 -106
tools/train/train.py CHANGED
@@ -30,21 +30,30 @@ from typing import Callable, Optional
30
  import datasets
31
  import jax
32
  import jax.numpy as jnp
 
33
  import optax
34
  import transformers
35
- import wandb
36
  from datasets import Dataset
37
- from distributed_shampoo import GraftingType, distributed_shampoo
38
  from flax import jax_utils, traverse_util
 
39
  from flax.jax_utils import unreplicate
40
  from flax.serialization import from_bytes, to_bytes
41
  from flax.training import train_state
42
  from flax.training.common_utils import get_metrics, onehot, shard_prng_key
 
 
43
  from tqdm import tqdm
44
  from transformers import AutoTokenizer, HfArgumentParser
45
 
 
46
  from dalle_mini.data import Dataset
47
- from dalle_mini.model import DalleBart, DalleBartConfig, DalleBartTokenizer
 
 
 
 
 
48
 
49
  logger = logging.getLogger(__name__)
50
 
@@ -223,7 +232,6 @@ class TrainingArguments:
223
  "help": 'The optimizer to use. Can be "distributed_shampoo" (default), "adam" or "adafactor"'
224
  },
225
  )
226
- weight_decay: float = field(default=None, metadata={"help": "Weight decay."})
227
  beta1: float = field(
228
  default=0.9,
229
  metadata={"help": "Beta1 for Adam & Distributed Shampoo."},
@@ -332,6 +340,13 @@ class TrainingArguments:
332
  metadata={"help": "Verify that TPU is not in use."},
333
  )
334
 
 
 
 
 
 
 
 
335
  def __post_init__(self):
336
  assert self.optim in [
337
  "distributed_shampoo",
@@ -340,9 +355,6 @@ class TrainingArguments:
340
  ], f"Selected optimizer not supported: {self.optim}"
341
  if self.per_device_eval_batch_size is None:
342
  self.per_device_eval_batch_size = self.per_device_train_batch_size
343
- if self.weight_decay is None:
344
- if self.optim in ["distributed_shampoo", "adam"]:
345
- self.weight_decay = 0.0
346
  if (
347
  os.path.exists(self.output_dir)
348
  and os.listdir(self.output_dir)
@@ -353,6 +365,10 @@ class TrainingArguments:
353
  f"Output directory ({self.output_dir}) already exists and is not empty."
354
  "Use --overwrite_output_dir to overcome."
355
  )
 
 
 
 
356
 
357
 
358
  class TrainState(train_state.TrainState):
@@ -361,11 +377,6 @@ class TrainState(train_state.TrainState):
361
  train_time: float = 0.0 # total time the model trained
362
  train_samples: int = 0 # number of samples seen
363
 
364
- def replicate(self):
365
- return jax_utils.replicate(self).replace(
366
- dropout_rng=shard_prng_key(self.dropout_rng)
367
- )
368
-
369
  def restore_state(self, artifact_dir):
370
  # restore optimizer state
371
  with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
@@ -487,8 +498,6 @@ def main():
487
  dtype=getattr(jnp, model_args.dtype),
488
  abstract_init=True,
489
  )
490
- # avoid OOM on TPU: see https://github.com/google/flax/issues/1658
491
- print(model.params)
492
 
493
  # load tokenizer
494
  tokenizer = DalleBartTokenizer.from_pretrained(
@@ -512,8 +521,6 @@ def main():
512
  dtype=getattr(jnp, model_args.dtype),
513
  abstract_init=True,
514
  )
515
- # avoid OOM on TPU: see https://github.com/google/flax/issues/1658
516
- print(model.params)
517
  else:
518
  model = DalleBart(
519
  config,
@@ -544,7 +551,7 @@ def main():
544
 
545
  # Initialize our training
546
  rng = jax.random.PRNGKey(training_args.seed_model)
547
- rng, dropout_rng = jax.random.split(rng)
548
 
549
  # Store some constant
550
  num_epochs = training_args.num_train_epochs
@@ -601,32 +608,9 @@ def main():
601
 
602
  learning_rate_fn = create_learning_rate_fn()
603
 
604
- # We use Optax's "masking" functionality to not apply weight decay
605
- # to bias and LayerNorm scale parameters. decay_mask_fn returns a
606
- # mask boolean with the same structure as the parameters.
607
- # The mask is True for parameters that should be decayed.
608
- # Note that this mask is specifically adapted for FlaxBart.
609
- def decay_mask_fn(params):
610
- flat_params = traverse_util.flatten_dict(params)
611
- layer_norm_params = [
612
- (name, "scale")
613
- for name in [
614
- "self_attn_layer_norm",
615
- "layernorm_embedding",
616
- "final_layer_norm",
617
- ]
618
- ]
619
- flat_mask = {
620
- path: (path[-1] != "bias" and path[-2:] not in layer_norm_params)
621
- for path in flat_params
622
- }
623
- return traverse_util.unflatten_dict(flat_mask)
624
-
625
  # create adam optimizer
626
  if training_args.optim == "distributed_shampoo":
627
  # parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
628
- # Notes:
629
- # - mask for weight decay is not implemented
630
  optimizer = distributed_shampoo(
631
  learning_rate_fn,
632
  block_size=training_args.block_size,
@@ -634,7 +618,6 @@ def main():
634
  beta2=training_args.beta2,
635
  diagonal_epsilon=1e-10,
636
  matrix_epsilon=1e-8,
637
- weight_decay=training_args.weight_decay,
638
  start_preconditioning_step=training_args.warmup_steps,
639
  preconditioning_compute_steps=training_args.preconditioning_compute_steps,
640
  statistics_compute_steps=1,
@@ -657,26 +640,76 @@ def main():
657
  b1=training_args.beta1,
658
  b2=training_args.beta2,
659
  eps=training_args.adam_epsilon,
660
- weight_decay=training_args.weight_decay,
661
- mask=decay_mask_fn,
662
  )
663
  elif training_args.optim == "adafactor":
664
  # We use the default parameters here to initialize adafactor,
665
  # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
666
  optimizer = optax.adafactor(
667
  learning_rate=learning_rate_fn,
668
- weight_decay_rate=training_args.weight_decay,
669
- weight_decay_mask=decay_mask_fn,
670
  clipping_threshold=training_args.max_grad_norm,
671
  )
672
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
673
  # Setup train state
674
- state = TrainState.create(
675
  apply_fn=model.__call__,
676
- params=model.params,
 
677
  tx=optimizer,
678
  dropout_rng=dropout_rng,
 
679
  )
 
 
 
 
 
 
 
 
 
 
 
 
680
  if training_args.resume_from_checkpoint is not None:
681
  # restore optimizer state and other parameters
682
  # we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
@@ -756,8 +789,17 @@ def main():
756
  return metrics
757
 
758
  # Create parallel version of the train and eval step
759
- p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
760
- p_eval_step = jax.pmap(eval_step, "batch")
 
 
 
 
 
 
 
 
 
761
 
762
  logger.info("***** Running training *****")
763
  logger.info(f" Num examples = {len_train_dataset}")
@@ -792,9 +834,6 @@ def main():
792
  }
793
  )
794
 
795
- # replicate state on each device
796
- state = state.replicate()
797
-
798
  def run_evaluation():
799
  # ======================== Evaluating ==============================
800
  eval_metrics = []
@@ -823,9 +862,7 @@ def main():
823
  eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
824
 
825
  # log metrics
826
- metrics_logger.log(
827
- eval_metrics, step=unreplicate(state.step), prefix="eval"
828
- )
829
 
830
  # Print metrics and update progress bar
831
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
@@ -836,7 +873,7 @@ def main():
836
 
837
  def run_save_model(state, eval_metrics=None):
838
  if jax.process_index() == 0:
839
- params = jax.device_get(unreplicate(state.params))
840
  # save model locally
841
  model.save_pretrained(
842
  training_args.output_dir,
@@ -847,11 +884,11 @@ def main():
847
  tokenizer.save_pretrained(training_args.output_dir)
848
 
849
  # save state
850
- opt_state = unreplicate(state.opt_state)
851
  with (Path(training_args.output_dir) / "opt_state.msgpack").open("wb") as f:
852
  f.write(to_bytes(opt_state))
853
  state_dict = {
854
- k: jax.device_get(unreplicate(getattr(state, k))).item()
855
  for k in ["step", "epoch", "train_time", "train_samples"]
856
  }
857
  with (Path(training_args.output_dir) / "training_state.json").open(
@@ -912,63 +949,64 @@ def main():
912
  last_time = time.perf_counter()
913
  train_metrics = None
914
 
915
- for epoch in epochs:
916
- state.replace(epoch=jax_utils.replicate(epoch))
917
- # ======================== Training ================================
918
- metrics_logger.log({"train/epoch": epoch}, step=unreplicate(state.step))
919
-
920
- # Generate an epoch by shuffling sampling indices from the train dataset
921
- train_loader = dataset.dataloader(
922
- "train",
923
- training_args.per_device_train_batch_size,
924
- training_args.gradient_accumulation_steps,
925
- epoch,
926
- )
927
- # train
928
- for batch in tqdm(
929
- train_loader,
930
- desc="Training...",
931
- position=1,
932
- leave=False,
933
- total=steps_per_epoch,
934
- ):
 
935
 
936
- # calculate delta time (we have a lag of one step but it's ok)
937
- new_time = time.perf_counter()
938
- delta_time = new_time - last_time
939
- last_time = new_time
940
 
941
- # train step
942
- state, train_metrics = p_train_step(
943
- state, batch, jax_utils.replicate(delta_time)
944
- )
945
- step = unreplicate(state.step)
946
 
947
- if step % training_args.logging_steps == 0 and jax.process_index() == 0:
948
- all_metrics = metrics_logger.get_all_train_metrics(train_metrics, state)
949
- metrics_logger.log(all_metrics, step=step, prefix="train")
 
 
950
 
951
- eval_metrics = None
952
- if training_args.eval_steps and step % training_args.eval_steps == 0:
953
- eval_metrics = run_evaluation()
954
 
955
- if step % training_args.save_steps == 0:
956
- run_save_model(state, eval_metrics)
957
 
958
- # log final train metrics
959
- if train_metrics is not None:
960
- all_metrics = metrics_logger.get_all_train_metrics(train_metrics, state)
961
- metrics_logger.log(all_metrics, step=step, prefix="train")
962
 
963
- epochs.write(
964
- f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"
965
- )
966
 
967
- # Final evaluation
968
- eval_metrics = run_evaluation()
969
 
970
- # save checkpoint after each epoch
971
- run_save_model(state, eval_metrics)
972
 
973
 
974
  if __name__ == "__main__":
 
30
  import datasets
31
  import jax
32
  import jax.numpy as jnp
33
+ import numpy as np
34
  import optax
35
  import transformers
 
36
  from datasets import Dataset
37
+ from distributed_shampoo import GraftingType, distributed_shampoo, pad_matrix
38
  from flax import jax_utils, traverse_util
39
+ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
40
  from flax.jax_utils import unreplicate
41
  from flax.serialization import from_bytes, to_bytes
42
  from flax.training import train_state
43
  from flax.training.common_utils import get_metrics, onehot, shard_prng_key
44
+ from jax.experimental import PartitionSpec, maps
45
+ from jax.experimental.pjit import pjit
46
  from tqdm import tqdm
47
  from transformers import AutoTokenizer, HfArgumentParser
48
 
49
+ import wandb
50
  from dalle_mini.data import Dataset
51
+ from dalle_mini.model import (
52
+ DalleBart,
53
+ DalleBartConfig,
54
+ DalleBartTokenizer,
55
+ set_partitions,
56
+ )
57
 
58
  logger = logging.getLogger(__name__)
59
 
 
232
  "help": 'The optimizer to use. Can be "distributed_shampoo" (default), "adam" or "adafactor"'
233
  },
234
  )
 
235
  beta1: float = field(
236
  default=0.9,
237
  metadata={"help": "Beta1 for Adam & Distributed Shampoo."},
 
340
  metadata={"help": "Verify that TPU is not in use."},
341
  )
342
 
343
+ mp_devices: Optional[int] = field(
344
+ default=1,
345
+ metadata={
346
+ "help": "Number of devices required for model parallelism. The other dimension of available devices is used for data parallelism."
347
+ },
348
+ )
349
+
350
  def __post_init__(self):
351
  assert self.optim in [
352
  "distributed_shampoo",
 
355
  ], f"Selected optimizer not supported: {self.optim}"
356
  if self.per_device_eval_batch_size is None:
357
  self.per_device_eval_batch_size = self.per_device_train_batch_size
 
 
 
358
  if (
359
  os.path.exists(self.output_dir)
360
  and os.listdir(self.output_dir)
 
365
  f"Output directory ({self.output_dir}) already exists and is not empty."
366
  "Use --overwrite_output_dir to overcome."
367
  )
368
+ assert (
369
+ jax.device_count() % self.mp_devices == 0
370
+ ), f"Number of available devices ({jax.device_count()} must be divisible by number of devices used for model parallelism ({self.mp_devices})."
371
+ self.dp_devices = jax.device_count() // self.mp_devices
372
 
373
 
374
  class TrainState(train_state.TrainState):
 
377
  train_time: float = 0.0 # total time the model trained
378
  train_samples: int = 0 # number of samples seen
379
 
 
 
 
 
 
380
  def restore_state(self, artifact_dir):
381
  # restore optimizer state
382
  with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
 
498
  dtype=getattr(jnp, model_args.dtype),
499
  abstract_init=True,
500
  )
 
 
501
 
502
  # load tokenizer
503
  tokenizer = DalleBartTokenizer.from_pretrained(
 
521
  dtype=getattr(jnp, model_args.dtype),
522
  abstract_init=True,
523
  )
 
 
524
  else:
525
  model = DalleBart(
526
  config,
 
551
 
552
  # Initialize our training
553
  rng = jax.random.PRNGKey(training_args.seed_model)
554
+ rng, *dropout_rng = jax.random.split(rng, num=training_args.dp_devices + 1)
555
 
556
  # Store some constant
557
  num_epochs = training_args.num_train_epochs
 
608
 
609
  learning_rate_fn = create_learning_rate_fn()
610
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
611
  # create adam optimizer
612
  if training_args.optim == "distributed_shampoo":
613
  # parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
 
 
614
  optimizer = distributed_shampoo(
615
  learning_rate_fn,
616
  block_size=training_args.block_size,
 
618
  beta2=training_args.beta2,
619
  diagonal_epsilon=1e-10,
620
  matrix_epsilon=1e-8,
 
621
  start_preconditioning_step=training_args.warmup_steps,
622
  preconditioning_compute_steps=training_args.preconditioning_compute_steps,
623
  statistics_compute_steps=1,
 
640
  b1=training_args.beta1,
641
  b2=training_args.beta2,
642
  eps=training_args.adam_epsilon,
 
 
643
  )
644
  elif training_args.optim == "adafactor":
645
  # We use the default parameters here to initialize adafactor,
646
  # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
647
  optimizer = optax.adafactor(
648
  learning_rate=learning_rate_fn,
 
 
649
  clipping_threshold=training_args.max_grad_norm,
650
  )
651
 
652
+ # get opt_state shape without actual init
653
+ param_shape = jax.tree_map(lambda x: x.shape, model.params)
654
+ opt_state_shape = jax.eval_shape(lambda x: optimizer.init(x), param_shape)
655
+
656
+ # get PartitionSpec for model params
657
+ param_spec = set_partitions(model.params)
658
+
659
+ # create PartitionSpec for opt_state
660
+ def opt_state_spec_per_leaf(x):
661
+ if training_args.optim in ["adam", "adafactor"]:
662
+ if isinstance(x, dict):
663
+ # variables with same structure as params
664
+ return param_spec
665
+ else:
666
+ # other variables such as count
667
+ return None
668
+ else:
669
+ # TODO: create spec for Distributed Shampoo
670
+ raise NotImplementedError
671
+
672
+ opt_state_spec = jax.tree_map(
673
+ opt_state_spec_per_leaf,
674
+ opt_state_shape,
675
+ # return None spec for empty elements
676
+ is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)),
677
+ )
678
+
679
+ # create a mesh
680
+ mesh_shape = (training_args.dp_devices, training_args.mp_devices)
681
+ devices = np.asarray(jax.devices()).reshape(*mesh_shape)
682
+ mesh = maps.Mesh(devices, ("batch", "mp"))
683
+
684
+ # move params & init opt_state over specified devices
685
+ with maps.mesh(mesh.devices, mesh.axis_names):
686
+ params, opt_state = pjit(
687
+ lambda x: (x, optimizer.init(x)),
688
+ in_axis_resources=None,
689
+ out_axis_resources=(param_spec, opt_state_spec),
690
+ )(freeze(model.params))
691
+
692
  # Setup train state
693
+ state = TrainState(
694
  apply_fn=model.__call__,
695
+ params=params,
696
+ opt_state=opt_state,
697
  tx=optimizer,
698
  dropout_rng=dropout_rng,
699
+ step=0,
700
  )
701
+
702
+ # create PartitionSpec for state
703
+ state_spec = {
704
+ "params": param_spec,
705
+ "opt_state": opt_state_spec,
706
+ "dropout_rng": PartitionSpec("batch", None),
707
+ "epoch": None,
708
+ "step": None,
709
+ "train_samples": None,
710
+ "train_time": None,
711
+ }
712
+
713
  if training_args.resume_from_checkpoint is not None:
714
  # restore optimizer state and other parameters
715
  # we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
 
789
  return metrics
790
 
791
  # Create parallel version of the train and eval step
792
+ p_train_step = pjit(
793
+ train_step,
794
+ in_axis_resources=(state_spec, None, None),
795
+ out_axis_resources=(state_spec, None),
796
+ donate_argnums=(0,),
797
+ )
798
+ p_eval_step = pjit(
799
+ eval_step,
800
+ in_axis_resources=(param_spec, PartitionSpec("batch", None)),
801
+ out_axis_resources=None,
802
+ )
803
 
804
  logger.info("***** Running training *****")
805
  logger.info(f" Num examples = {len_train_dataset}")
 
834
  }
835
  )
836
 
 
 
 
837
  def run_evaluation():
838
  # ======================== Evaluating ==============================
839
  eval_metrics = []
 
862
  eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
863
 
864
  # log metrics
865
+ metrics_logger.log(eval_metrics, step=state.step, prefix="eval")
 
 
866
 
867
  # Print metrics and update progress bar
868
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
 
873
 
874
  def run_save_model(state, eval_metrics=None):
875
  if jax.process_index() == 0:
876
+ params = jax.device_get(state.params)
877
  # save model locally
878
  model.save_pretrained(
879
  training_args.output_dir,
 
884
  tokenizer.save_pretrained(training_args.output_dir)
885
 
886
  # save state
887
+ opt_state = jax.device_get(state.opt_state)
888
  with (Path(training_args.output_dir) / "opt_state.msgpack").open("wb") as f:
889
  f.write(to_bytes(opt_state))
890
  state_dict = {
891
+ k: jax.device_get(getattr(state, k)).item()
892
  for k in ["step", "epoch", "train_time", "train_samples"]
893
  }
894
  with (Path(training_args.output_dir) / "training_state.json").open(
 
949
  last_time = time.perf_counter()
950
  train_metrics = None
951
 
952
+ with maps.mesh(mesh.devices, mesh.axis_names):
953
+ for epoch in epochs:
954
+ state.replace(epoch=epoch)
955
+ # ======================== Training ================================
956
+ metrics_logger.log({"train/epoch": epoch}, step=state.step)
957
+
958
+ # Generate an epoch by shuffling sampling indices from the train dataset
959
+ train_loader = dataset.dataloader(
960
+ "train",
961
+ training_args.per_device_train_batch_size,
962
+ training_args.gradient_accumulation_steps,
963
+ epoch,
964
+ )
965
+ # train
966
+ for batch in tqdm(
967
+ train_loader,
968
+ desc="Training...",
969
+ position=1,
970
+ leave=False,
971
+ total=steps_per_epoch,
972
+ ):
973
 
974
+ # calculate delta time (we have a lag of one step but it's ok)
975
+ new_time = time.perf_counter()
976
+ delta_time = new_time - last_time
977
+ last_time = new_time
978
 
979
+ # train step
980
+ state, train_metrics = p_train_step(state, batch, delta_time)
981
+ step = state.step
 
 
982
 
983
+ if step % training_args.logging_steps == 0 and jax.process_index() == 0:
984
+ all_metrics = metrics_logger.get_all_train_metrics(
985
+ train_metrics, state
986
+ )
987
+ metrics_logger.log(all_metrics, step=step, prefix="train")
988
 
989
+ eval_metrics = None
990
+ if training_args.eval_steps and step % training_args.eval_steps == 0:
991
+ eval_metrics = run_evaluation()
992
 
993
+ if step % training_args.save_steps == 0:
994
+ run_save_model(state, eval_metrics)
995
 
996
+ # log final train metrics
997
+ if train_metrics is not None:
998
+ all_metrics = metrics_logger.get_all_train_metrics(train_metrics, state)
999
+ metrics_logger.log(all_metrics, step=step, prefix="train")
1000
 
1001
+ epochs.write(
1002
+ f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"
1003
+ )
1004
 
1005
+ # Final evaluation
1006
+ eval_metrics = run_evaluation()
1007
 
1008
+ # save checkpoint after each epoch
1009
+ run_save_model(state, eval_metrics)
1010
 
1011
 
1012
  if __name__ == "__main__":