boris commited on
Commit
cc34d07
1 Parent(s): f5239e1

feat(train): distributed_shampoo with pjit

Browse files
src/dalle_mini/model/modeling.py CHANGED
@@ -312,7 +312,7 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
312
  seed: int = 0,
313
  dtype: jnp.dtype = jnp.float32,
314
  abstract_init: bool = False,
315
- load_on_cpu: bool = True,
316
  **kwargs,
317
  ):
318
  module = self.module_class(config=config, dtype=dtype, **kwargs)
 
312
  seed: int = 0,
313
  dtype: jnp.dtype = jnp.float32,
314
  abstract_init: bool = False,
315
+ load_on_cpu: bool = False,
316
  **kwargs,
317
  ):
318
  module = self.module_class(config=config, dtype=dtype, **kwargs)
tools/train/train.py CHANGED
@@ -36,7 +36,7 @@ import transformers
36
  import wandb
37
  from datasets import Dataset
38
  from distributed_shampoo import GraftingType, distributed_shampoo
39
- from flax.core.frozen_dict import freeze
40
  from flax.serialization import from_bytes, to_bytes
41
  from flax.training import train_state
42
  from flax.training.common_utils import onehot, stack_forest
@@ -478,6 +478,7 @@ def main():
478
  artifact_dir,
479
  dtype=getattr(jnp, model_args.dtype),
480
  abstract_init=True,
 
481
  )
482
 
483
  # load tokenizer
@@ -501,12 +502,14 @@ def main():
501
  seed=training_args.seed_model,
502
  dtype=getattr(jnp, model_args.dtype),
503
  abstract_init=True,
 
504
  )
505
  else:
506
  model = DalleBart(
507
  config,
508
  seed=training_args.seed_model,
509
  dtype=getattr(jnp, model_args.dtype),
 
510
  )
511
 
512
  # Load tokenizer
@@ -606,7 +609,10 @@ def main():
606
  graft_type=GraftingType.RMSPROP_NORMALIZED,
607
  nesterov=False,
608
  exponent_override=0,
609
- batch_axis_name="batch",
 
 
 
610
  inverse_failure_threshold=0.1,
611
  moving_average_for_momentum=True,
612
  skip_preconditioning_dim_size_gt=training_args.skip_preconditioning_dim_size_gt,
@@ -630,31 +636,48 @@ def main():
630
  clipping_threshold=training_args.max_grad_norm,
631
  )
632
 
633
- # get opt_state shape without actual init
634
- opt_state_shape = jax.eval_shape(lambda x: optimizer.init(x), model.params)
635
-
636
  # get PartitionSpec for model params
637
  param_spec = set_partitions(model.params)
638
 
639
- # create PartitionSpec for opt_state
640
- def opt_state_spec_per_leaf(x):
641
- if training_args.optim in ["adam", "adafactor"]:
642
- if isinstance(x, dict):
643
- # variables with same structure as params
644
- return param_spec
645
- else:
646
- # other variables such as count
647
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
648
  else:
649
- # TODO: create spec for Distributed Shampoo
650
  raise NotImplementedError
 
651
 
652
- opt_state_spec = jax.tree_map(
653
- opt_state_spec_per_leaf,
654
- opt_state_shape,
655
- # return None spec for empty elements
656
- is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)),
657
- )
658
 
659
  # create a mesh
660
  mesh_shape = (training_args.dp_devices, training_args.mp_devices)
@@ -674,51 +697,62 @@ def main():
674
  tx=optimizer,
675
  )
676
 
677
- opt_state, attr_state = None, None
678
- if training_args.resume_from_checkpoint is not None:
679
- # restore opt_state
680
- with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
681
- opt_state = from_bytes(opt_state_shape, f.read())
682
- # need to freeze dict for pjit
683
- opt_state = jax.tree_map(
684
- lambda x: freeze(x) if isinstance(x, dict) else x,
685
- opt_state,
686
- is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)),
687
- )
688
- # restore other attributes
689
- with (Path(artifact_dir) / "training_state.json").open("r") as f:
690
- attr_state = json.load(f)
691
-
692
  # create training state
693
- def init_state(params, opt_state):
694
  if training_args.resume_from_checkpoint is None:
695
- state = TrainState.create(
696
- apply_fn=model.__call__,
697
- tx=optimizer,
698
- params=freeze(params),
699
- dropout_rng=dropout_rng,
700
- )
 
 
 
 
 
 
 
 
 
 
701
  else:
702
- state = TrainState(
703
- apply_fn=model.__call__,
704
- tx=optimizer,
705
- params=freeze(params),
706
- opt_state=opt_state,
707
- dropout_rng=dropout_rng,
708
- **attr_state,
709
- )
710
- return state
711
 
712
- with maps.mesh(mesh.devices, mesh.axis_names):
713
- state = pjit(
714
- init_state,
715
- in_axis_resources=(param_spec, opt_state_spec),
716
- out_axis_resources=state_spec,
717
- donate_argnums=(0, 1),
718
- )(freeze(model.params), opt_state)
719
-
720
- # free memory from large parameters
721
- del model._params, opt_state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
722
 
723
  # label smoothed cross entropy
724
  def loss_fn(logits, labels):
 
36
  import wandb
37
  from datasets import Dataset
38
  from distributed_shampoo import GraftingType, distributed_shampoo
39
+ from flax.core.frozen_dict import freeze, unfreeze
40
  from flax.serialization import from_bytes, to_bytes
41
  from flax.training import train_state
42
  from flax.training.common_utils import onehot, stack_forest
 
478
  artifact_dir,
479
  dtype=getattr(jnp, model_args.dtype),
480
  abstract_init=True,
481
+ load_on_cpu=True,
482
  )
483
 
484
  # load tokenizer
 
502
  seed=training_args.seed_model,
503
  dtype=getattr(jnp, model_args.dtype),
504
  abstract_init=True,
505
+ load_on_cpu=True,
506
  )
507
  else:
508
  model = DalleBart(
509
  config,
510
  seed=training_args.seed_model,
511
  dtype=getattr(jnp, model_args.dtype),
512
+ load_on_cpu=True,
513
  )
514
 
515
  # Load tokenizer
 
609
  graft_type=GraftingType.RMSPROP_NORMALIZED,
610
  nesterov=False,
611
  exponent_override=0,
612
+ statistics_partition_spec=PartitionSpec(None, "batch", None),
613
+ preconditioner_partition_spec=PartitionSpec("batch", None, None),
614
+ num_devices_for_pjit=training_args.dp_devices,
615
+ shard_optimizer_states=True,
616
  inverse_failure_threshold=0.1,
617
  moving_average_for_momentum=True,
618
  skip_preconditioning_dim_size_gt=training_args.skip_preconditioning_dim_size_gt,
 
636
  clipping_threshold=training_args.max_grad_norm,
637
  )
638
 
 
 
 
639
  # get PartitionSpec for model params
640
  param_spec = set_partitions(model.params)
641
 
642
+ # get PartitionSpec for optimizer state
643
+ def get_opt_state_spec_and_shape(param_spec):
644
+ if training_args.optim == "adam":
645
+ # get opt_state shape without actual init
646
+ opt_state_shape = jax.eval_shape(optimizer.init, model.params)
647
+
648
+ def _opt_state_spec_per_leaf(x):
649
+ if isinstance(x, dict):
650
+ # variables with same structure as params
651
+ return param_spec
652
+ else:
653
+ # other variables such as count
654
+ return None
655
+
656
+ opt_state_spec = jax.tree_map(
657
+ _opt_state_spec_per_leaf,
658
+ opt_state_shape,
659
+ # return None spec for empty elements
660
+ is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)),
661
+ )
662
+
663
+ elif training_args.optim == "adafactor":
664
+ # factorized state must be replicated (rank different than params)
665
+ opt_state_spec = None
666
+
667
+ elif training_args.optim == "distributed_shampoo":
668
+ # memory efficient in distributed_shampoo, fake init
669
+ _opt_state = optimizer.init(model.params)
670
+ opt_state_spec = _opt_state.pspec_fn(
671
+ params=model.params,
672
+ params_partition_spec=unfreeze(param_spec),
673
+ partition_spec_for_statistics=PartitionSpec(None, "batch", None),
674
+ )
675
+ opt_state_shape = _opt_state.shape_and_dtype_fn(model.params)
676
  else:
 
677
  raise NotImplementedError
678
+ return opt_state_spec, opt_state_shape
679
 
680
+ opt_state_spec, opt_state_shape = get_opt_state_spec_and_shape(param_spec)
 
 
 
 
 
681
 
682
  # create a mesh
683
  mesh_shape = (training_args.dp_devices, training_args.mp_devices)
 
697
  tx=optimizer,
698
  )
699
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
700
  # create training state
701
+ with maps.mesh(mesh.devices, mesh.axis_names):
702
  if training_args.resume_from_checkpoint is None:
703
+
704
+ def init_state(params):
705
+ return TrainState.create(
706
+ apply_fn=model.__call__,
707
+ tx=optimizer,
708
+ params=params,
709
+ dropout_rng=dropout_rng,
710
+ )
711
+
712
+ state = pjit(
713
+ init_state,
714
+ in_axis_resources=(param_spec,),
715
+ out_axis_resources=state_spec,
716
+ donate_argnums=(0,),
717
+ )(freeze(model.params))
718
+
719
  else:
720
+ # restore opt_state
721
+ with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
722
+ opt_state = from_bytes(opt_state_shape, f.read())
723
+ # need to freeze dict for pjit
724
+ opt_state = jax.tree_map(
725
+ lambda x: freeze(x) if isinstance(x, dict) else x,
726
+ opt_state,
727
+ is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)),
728
+ )
729
 
730
+ # restore other attributes
731
+ with (Path(artifact_dir) / "training_state.json").open("r") as f:
732
+ attr_state = json.load(f)
733
+
734
+ def restore_state(params, opt_state):
735
+ return TrainState(
736
+ apply_fn=model.__call__,
737
+ tx=optimizer,
738
+ params=params,
739
+ opt_state=opt_state,
740
+ dropout_rng=dropout_rng,
741
+ **attr_state,
742
+ )
743
+
744
+ state = pjit(
745
+ restore_state,
746
+ in_axis_resources=(param_spec, opt_state_spec),
747
+ out_axis_resources=state_spec,
748
+ donate_argnums=(0, 1),
749
+ )(freeze(model.params), opt_state)
750
+
751
+ # remove opt_state from CPU
752
+ del opt_state
753
+
754
+ # free memory
755
+ del model._params
756
 
757
  # label smoothed cross entropy
758
  def loss_fn(logits, labels):