boris commited on
Commit
1bfc1b5
1 Parent(s): 12f323d

feat(train): restore opt_state efficiently

Browse files
Files changed (1) hide show
  1. tools/train/train.py +40 -35
tools/train/train.py CHANGED
@@ -42,7 +42,7 @@ from flax.training.common_utils import onehot, stack_forest
42
  from jax.experimental import PartitionSpec, maps
43
  from jax.experimental.pjit import pjit
44
  from tqdm import tqdm
45
- from transformers import AutoTokenizer, HfArgumentParser
46
 
47
  import wandb
48
  from dalle_mini.data import Dataset
@@ -375,23 +375,6 @@ class TrainState(train_state.TrainState):
375
  train_time: float = 0.0 # total time the model trained
376
  train_samples: int = 0 # number of samples seen
377
 
378
- def restore_state(self, artifact_dir):
379
- # restore optimizer state
380
- with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
381
- new_opt_state = from_bytes(self.opt_state, f.read())
382
-
383
- # restore other parameters
384
- with (Path(artifact_dir) / "training_state.json").open("r") as f:
385
- training_state = json.load(f)
386
-
387
- # replace state
388
- return self.replace(
389
- opt_state=new_opt_state,
390
- step=training_state["step"],
391
- train_time=training_state["train_time"],
392
- train_samples=training_state["train_samples"],
393
- )
394
-
395
 
396
  class MetricsLogger:
397
  def __init__(self, state):
@@ -528,7 +511,7 @@ def main():
528
 
529
  # Load tokenizer
530
  if model_args.tokenizer_name is not None:
531
- tokenizer = AutoTokenizer.from_pretrained(
532
  model_args.tokenizer_name, use_fast=True
533
  )
534
  else:
@@ -648,8 +631,7 @@ def main():
648
  )
649
 
650
  # get opt_state shape without actual init
651
- param_shape = jax.tree_map(lambda x: x.shape, model.params)
652
- opt_state_shape = jax.eval_shape(lambda x: optimizer.init(x), param_shape)
653
 
654
  # get PartitionSpec for model params
655
  param_spec = set_partitions(model.params)
@@ -692,28 +674,51 @@ def main():
692
  tx=optimizer,
693
  )
694
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
695
  # create training state
696
- def init_state(params):
697
- state = TrainState.create(
698
- apply_fn=model.__call__,
699
- tx=optimizer,
700
- params=freeze(params),
701
- dropout_rng=dropout_rng,
702
- )
 
 
 
 
 
 
 
 
 
 
703
  return state
704
 
705
  with maps.mesh(mesh.devices, mesh.axis_names):
706
  state = pjit(
707
  init_state,
708
- in_axis_resources=None,
709
  out_axis_resources=state_spec,
710
- donate_argnums=(0,),
711
- )(freeze(model.params))
712
 
713
- if training_args.resume_from_checkpoint is not None:
714
- # restore optimizer state and other parameters
715
- # we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
716
- state = state.restore_state(artifact_dir)
717
 
718
  # label smoothed cross entropy
719
  def loss_fn(logits, labels):
 
42
  from jax.experimental import PartitionSpec, maps
43
  from jax.experimental.pjit import pjit
44
  from tqdm import tqdm
45
+ from transformers import HfArgumentParser
46
 
47
  import wandb
48
  from dalle_mini.data import Dataset
 
375
  train_time: float = 0.0 # total time the model trained
376
  train_samples: int = 0 # number of samples seen
377
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
 
379
  class MetricsLogger:
380
  def __init__(self, state):
 
511
 
512
  # Load tokenizer
513
  if model_args.tokenizer_name is not None:
514
+ tokenizer = DalleBartTokenizer.from_pretrained(
515
  model_args.tokenizer_name, use_fast=True
516
  )
517
  else:
 
631
  )
632
 
633
  # get opt_state shape without actual init
634
+ opt_state_shape = jax.eval_shape(lambda x: optimizer.init(x), model.params)
 
635
 
636
  # get PartitionSpec for model params
637
  param_spec = set_partitions(model.params)
 
674
  tx=optimizer,
675
  )
676
 
677
+ opt_state, attr_state = None, None
678
+ if training_args.resume_from_checkpoint is not None:
679
+ # restore opt_state
680
+ with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
681
+ opt_state = from_bytes(opt_state_shape, f.read())
682
+ # need to freeze dict for pjit
683
+ opt_state = jax.tree_map(
684
+ lambda x: freeze(x) if isinstance(x, dict) else x,
685
+ opt_state,
686
+ is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)),
687
+ )
688
+ # restore other attributes
689
+ with (Path(artifact_dir) / "training_state.json").open("r") as f:
690
+ attr_state = json.load(f)
691
+
692
  # create training state
693
+ def init_state(params, opt_state):
694
+ if training_args.resume_from_checkpoint is None:
695
+ state = TrainState.create(
696
+ apply_fn=model.__call__,
697
+ tx=optimizer,
698
+ params=freeze(params),
699
+ dropout_rng=dropout_rng,
700
+ )
701
+ else:
702
+ state = TrainState(
703
+ apply_fn=model.__call__,
704
+ tx=optimizer,
705
+ params=freeze(params),
706
+ opt_state=opt_state,
707
+ dropout_rng=dropout_rng,
708
+ **attr_state,
709
+ )
710
  return state
711
 
712
  with maps.mesh(mesh.devices, mesh.axis_names):
713
  state = pjit(
714
  init_state,
715
+ in_axis_resources=(param_spec, opt_state_spec),
716
  out_axis_resources=state_spec,
717
+ donate_argnums=(0, 1),
718
+ )(freeze(model.params), opt_state)
719
 
720
+ # free memory from large parameters
721
+ del model._params, opt_state
 
 
722
 
723
  # label smoothed cross entropy
724
  def loss_fn(logits, labels):