boris commited on
Commit
f5239e1
2 Parent(s): a5ed112 7a176b9

feat(train): use pjit (#125)

Browse files
src/dalle_mini/data.py CHANGED
@@ -6,7 +6,6 @@ import jax.numpy as jnp
6
  import numpy as np
7
  from braceexpand import braceexpand
8
  from datasets import Dataset, load_dataset
9
- from flax.training.common_utils import shard
10
 
11
  from .text import TextNormalizer
12
 
@@ -191,7 +190,6 @@ class Dataset:
191
  lambda x: x.reshape((-1, per_device_batch_size) + x.shape[1:]),
192
  batch,
193
  )
194
- batch = shard(batch)
195
  yield batch
196
 
197
  def _dataloader_datasets_streaming(
@@ -232,7 +230,6 @@ class Dataset:
232
  ),
233
  batch,
234
  )
235
- batch = shard(batch)
236
  yield batch
237
  batch = {k: [] for k in keys}
238
  first_loop = False
 
6
  import numpy as np
7
  from braceexpand import braceexpand
8
  from datasets import Dataset, load_dataset
 
9
 
10
  from .text import TextNormalizer
11
 
 
190
  lambda x: x.reshape((-1, per_device_batch_size) + x.shape[1:]),
191
  batch,
192
  )
 
193
  yield batch
194
 
195
  def _dataloader_datasets_streaming(
 
230
  ),
231
  batch,
232
  )
 
233
  yield batch
234
  batch = {k: [] for k in keys}
235
  first_loop = False
src/dalle_mini/model/__init__.py CHANGED
@@ -1,3 +1,4 @@
1
  from .configuration import DalleBartConfig
2
  from .modeling import DalleBart
 
3
  from .tokenizer import DalleBartTokenizer
 
1
  from .configuration import DalleBartConfig
2
  from .modeling import DalleBart
3
+ from .partitions import set_partitions
4
  from .tokenizer import DalleBartTokenizer
src/dalle_mini/model/modeling.py CHANGED
@@ -300,6 +300,7 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
300
  - added num_params property
301
  - config_class replaced to DalleBartConfig
302
  - __init__ accepts abstract_init which does uses parameter shape to initialize the model
 
303
  """
304
 
305
  config_class = DalleBartConfig
@@ -311,6 +312,7 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
311
  seed: int = 0,
312
  dtype: jnp.dtype = jnp.float32,
313
  abstract_init: bool = False,
 
314
  **kwargs,
315
  ):
316
  module = self.module_class(config=config, dtype=dtype, **kwargs)
@@ -330,15 +332,21 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
330
  self.key = PRNGKey(seed)
331
  self.dtype = dtype
332
 
 
 
 
 
 
 
333
  # randomly initialized parameters
334
  if abstract_init:
335
  # init the model weights only abstractly, eval_shape will return a pytree
336
  # with the structure as weights but without any actual values, this will just contain
337
  # the shape information. Weights need to be loaded later.
338
- init_fn = partial(self.init_weights, input_shape=input_shape)
339
  random_params = jax.eval_shape(init_fn, self.key)
340
  else:
341
- random_params = self.init_weights(self.key, input_shape)
342
 
343
  # save required_params as set
344
  self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
 
300
  - added num_params property
301
  - config_class replaced to DalleBartConfig
302
  - __init__ accepts abstract_init which does uses parameter shape to initialize the model
303
+ - init weights on CPU
304
  """
305
 
306
  config_class = DalleBartConfig
 
312
  seed: int = 0,
313
  dtype: jnp.dtype = jnp.float32,
314
  abstract_init: bool = False,
315
+ load_on_cpu: bool = True,
316
  **kwargs,
317
  ):
318
  module = self.module_class(config=config, dtype=dtype, **kwargs)
 
332
  self.key = PRNGKey(seed)
333
  self.dtype = dtype
334
 
335
+ # init weights on CPU
336
+ if load_on_cpu:
337
+ init_fn = jax.jit(self.init_weights, static_argnums=(1,), backend="cpu")
338
+ else:
339
+ init_fn = self.init_weights
340
+
341
  # randomly initialized parameters
342
  if abstract_init:
343
  # init the model weights only abstractly, eval_shape will return a pytree
344
  # with the structure as weights but without any actual values, this will just contain
345
  # the shape information. Weights need to be loaded later.
346
+ init_fn = partial(init_fn, input_shape=input_shape)
347
  random_params = jax.eval_shape(init_fn, self.key)
348
  else:
349
+ random_params = init_fn(self.key, input_shape)
350
 
351
  # save required_params as set
352
  self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
tools/train/train.py CHANGED
@@ -30,21 +30,28 @@ from typing import Callable, Optional
30
  import datasets
31
  import jax
32
  import jax.numpy as jnp
 
33
  import optax
34
  import transformers
35
  import wandb
36
  from datasets import Dataset
37
  from distributed_shampoo import GraftingType, distributed_shampoo
38
- from flax import jax_utils, traverse_util
39
- from flax.jax_utils import unreplicate
40
  from flax.serialization import from_bytes, to_bytes
41
  from flax.training import train_state
42
- from flax.training.common_utils import get_metrics, onehot, shard_prng_key
 
 
43
  from tqdm import tqdm
44
- from transformers import AutoTokenizer, HfArgumentParser
45
 
46
  from dalle_mini.data import Dataset
47
- from dalle_mini.model import DalleBart, DalleBartConfig, DalleBartTokenizer
 
 
 
 
 
48
 
49
  logger = logging.getLogger(__name__)
50
 
@@ -223,7 +230,6 @@ class TrainingArguments:
223
  "help": 'The optimizer to use. Can be "distributed_shampoo" (default), "adam" or "adafactor"'
224
  },
225
  )
226
- weight_decay: float = field(default=None, metadata={"help": "Weight decay."})
227
  beta1: float = field(
228
  default=0.9,
229
  metadata={"help": "Beta1 for Adam & Distributed Shampoo."},
@@ -332,6 +338,13 @@ class TrainingArguments:
332
  metadata={"help": "Verify that TPU is not in use."},
333
  )
334
 
 
 
 
 
 
 
 
335
  def __post_init__(self):
336
  assert self.optim in [
337
  "distributed_shampoo",
@@ -340,9 +353,6 @@ class TrainingArguments:
340
  ], f"Selected optimizer not supported: {self.optim}"
341
  if self.per_device_eval_batch_size is None:
342
  self.per_device_eval_batch_size = self.per_device_train_batch_size
343
- if self.weight_decay is None:
344
- if self.optim in ["distributed_shampoo", "adam"]:
345
- self.weight_decay = 0.0
346
  if (
347
  os.path.exists(self.output_dir)
348
  and os.listdir(self.output_dir)
@@ -353,6 +363,10 @@ class TrainingArguments:
353
  f"Output directory ({self.output_dir}) already exists and is not empty."
354
  "Use --overwrite_output_dir to overcome."
355
  )
 
 
 
 
356
 
357
 
358
  class TrainState(train_state.TrainState):
@@ -361,28 +375,6 @@ class TrainState(train_state.TrainState):
361
  train_time: float = 0.0 # total time the model trained
362
  train_samples: int = 0 # number of samples seen
363
 
364
- def replicate(self):
365
- return jax_utils.replicate(self).replace(
366
- dropout_rng=shard_prng_key(self.dropout_rng)
367
- )
368
-
369
- def restore_state(self, artifact_dir):
370
- # restore optimizer state
371
- with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
372
- new_opt_state = from_bytes(self.opt_state, f.read())
373
-
374
- # restore other parameters
375
- with (Path(artifact_dir) / "training_state.json").open("r") as f:
376
- training_state = json.load(f)
377
-
378
- # replace state
379
- return self.replace(
380
- opt_state=new_opt_state,
381
- step=training_state["step"],
382
- train_time=training_state["train_time"],
383
- train_samples=training_state["train_samples"],
384
- )
385
-
386
 
387
  class MetricsLogger:
388
  def __init__(self, state):
@@ -391,14 +383,14 @@ class MetricsLogger:
391
 
392
  def get_all_train_metrics(self, train_metrics, state):
393
  """Make a dict of training metrics to be logged"""
394
- metrics = unreplicate(train_metrics)
395
  # get state parameters
396
  state_dict = {
397
- k.split("_")[-1]: unreplicate(getattr(state, k))
398
  for k in ["epoch", "train_time", "train_samples"]
399
  }
400
  # timing metrics
401
- new_step = int(unreplicate(state.step))
402
  new_time = time.perf_counter()
403
  if new_step > self.step:
404
  time_per_step = (new_time - self.time) / (new_step - self.step)
@@ -487,8 +479,6 @@ def main():
487
  dtype=getattr(jnp, model_args.dtype),
488
  abstract_init=True,
489
  )
490
- # avoid OOM on TPU: see https://github.com/google/flax/issues/1658
491
- print(model.params)
492
 
493
  # load tokenizer
494
  tokenizer = DalleBartTokenizer.from_pretrained(
@@ -512,8 +502,6 @@ def main():
512
  dtype=getattr(jnp, model_args.dtype),
513
  abstract_init=True,
514
  )
515
- # avoid OOM on TPU: see https://github.com/google/flax/issues/1658
516
- print(model.params)
517
  else:
518
  model = DalleBart(
519
  config,
@@ -523,7 +511,7 @@ def main():
523
 
524
  # Load tokenizer
525
  if model_args.tokenizer_name is not None:
526
- tokenizer = AutoTokenizer.from_pretrained(
527
  model_args.tokenizer_name, use_fast=True
528
  )
529
  else:
@@ -601,32 +589,9 @@ def main():
601
 
602
  learning_rate_fn = create_learning_rate_fn()
603
 
604
- # We use Optax's "masking" functionality to not apply weight decay
605
- # to bias and LayerNorm scale parameters. decay_mask_fn returns a
606
- # mask boolean with the same structure as the parameters.
607
- # The mask is True for parameters that should be decayed.
608
- # Note that this mask is specifically adapted for FlaxBart.
609
- def decay_mask_fn(params):
610
- flat_params = traverse_util.flatten_dict(params)
611
- layer_norm_params = [
612
- (name, "scale")
613
- for name in [
614
- "self_attn_layer_norm",
615
- "layernorm_embedding",
616
- "final_layer_norm",
617
- ]
618
- ]
619
- flat_mask = {
620
- path: (path[-1] != "bias" and path[-2:] not in layer_norm_params)
621
- for path in flat_params
622
- }
623
- return traverse_util.unflatten_dict(flat_mask)
624
-
625
  # create adam optimizer
626
  if training_args.optim == "distributed_shampoo":
627
  # parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
628
- # Notes:
629
- # - mask for weight decay is not implemented
630
  optimizer = distributed_shampoo(
631
  learning_rate_fn,
632
  block_size=training_args.block_size,
@@ -634,7 +599,6 @@ def main():
634
  beta2=training_args.beta2,
635
  diagonal_epsilon=1e-10,
636
  matrix_epsilon=1e-8,
637
- weight_decay=training_args.weight_decay,
638
  start_preconditioning_step=training_args.warmup_steps,
639
  preconditioning_compute_steps=training_args.preconditioning_compute_steps,
640
  statistics_compute_steps=1,
@@ -657,30 +621,104 @@ def main():
657
  b1=training_args.beta1,
658
  b2=training_args.beta2,
659
  eps=training_args.adam_epsilon,
660
- weight_decay=training_args.weight_decay,
661
- mask=decay_mask_fn,
662
  )
663
  elif training_args.optim == "adafactor":
664
  # We use the default parameters here to initialize adafactor,
665
  # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
666
  optimizer = optax.adafactor(
667
  learning_rate=learning_rate_fn,
668
- weight_decay_rate=training_args.weight_decay,
669
- weight_decay_mask=decay_mask_fn,
670
  clipping_threshold=training_args.max_grad_norm,
671
  )
672
 
673
- # Setup train state
674
- state = TrainState.create(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
675
  apply_fn=model.__call__,
676
- params=model.params,
677
  tx=optimizer,
678
- dropout_rng=dropout_rng,
679
  )
 
 
680
  if training_args.resume_from_checkpoint is not None:
681
- # restore optimizer state and other parameters
682
- # we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
683
- state = state.restore_state(artifact_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
684
 
685
  # label smoothed cross entropy
686
  def loss_fn(logits, labels):
@@ -691,6 +729,8 @@ def main():
691
  # Define gradient update step fn
692
  def train_step(state, batch, delta_time):
693
  dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
 
 
694
 
695
  def compute_loss(params, minibatch):
696
  labels = minibatch.pop("labels")
@@ -728,7 +768,6 @@ def main():
728
  ),
729
  )
730
 
731
- grads = jax.lax.pmean(grads, "batch")
732
  state = state.apply_gradients(
733
  grads=grads,
734
  dropout_rng=new_dropout_rng,
@@ -740,7 +779,6 @@ def main():
740
  "loss": loss,
741
  "learning_rate": learning_rate_fn(state.step),
742
  }
743
- metrics = jax.lax.pmean(metrics, axis_name="batch")
744
 
745
  return state, metrics
746
 
@@ -752,12 +790,20 @@ def main():
752
 
753
  # summarize metrics
754
  metrics = {"loss": loss}
755
- metrics = jax.lax.pmean(metrics, axis_name="batch")
756
  return metrics
757
 
758
  # Create parallel version of the train and eval step
759
- p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
760
- p_eval_step = jax.pmap(eval_step, "batch")
 
 
 
 
 
 
 
 
 
761
 
762
  logger.info("***** Running training *****")
763
  logger.info(f" Num examples = {len_train_dataset}")
@@ -792,9 +838,6 @@ def main():
792
  }
793
  )
794
 
795
- # replicate state on each device
796
- state = state.replicate()
797
-
798
  def run_evaluation():
799
  # ======================== Evaluating ==============================
800
  eval_metrics = []
@@ -819,13 +862,11 @@ def main():
819
  eval_metrics.append(metrics)
820
 
821
  # normalize eval metrics
822
- eval_metrics = get_metrics(eval_metrics)
823
  eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
824
 
825
  # log metrics
826
- metrics_logger.log(
827
- eval_metrics, step=unreplicate(state.step), prefix="eval"
828
- )
829
 
830
  # Print metrics and update progress bar
831
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
@@ -836,7 +877,7 @@ def main():
836
 
837
  def run_save_model(state, eval_metrics=None):
838
  if jax.process_index() == 0:
839
- params = jax.device_get(unreplicate(state.params))
840
  # save model locally
841
  model.save_pretrained(
842
  training_args.output_dir,
@@ -847,11 +888,11 @@ def main():
847
  tokenizer.save_pretrained(training_args.output_dir)
848
 
849
  # save state
850
- opt_state = unreplicate(state.opt_state)
851
  with (Path(training_args.output_dir) / "opt_state.msgpack").open("wb") as f:
852
  f.write(to_bytes(opt_state))
853
  state_dict = {
854
- k: jax.device_get(unreplicate(getattr(state, k))).item()
855
  for k in ["step", "epoch", "train_time", "train_samples"]
856
  }
857
  with (Path(training_args.output_dir) / "training_state.json").open(
@@ -912,63 +953,64 @@ def main():
912
  last_time = time.perf_counter()
913
  train_metrics = None
914
 
915
- for epoch in epochs:
916
- state.replace(epoch=jax_utils.replicate(epoch))
917
- # ======================== Training ================================
918
- metrics_logger.log({"train/epoch": epoch}, step=unreplicate(state.step))
919
-
920
- # Generate an epoch by shuffling sampling indices from the train dataset
921
- train_loader = dataset.dataloader(
922
- "train",
923
- training_args.per_device_train_batch_size,
924
- training_args.gradient_accumulation_steps,
925
- epoch,
926
- )
927
- # train
928
- for batch in tqdm(
929
- train_loader,
930
- desc="Training...",
931
- position=1,
932
- leave=False,
933
- total=steps_per_epoch,
934
- ):
 
935
 
936
- # calculate delta time (we have a lag of one step but it's ok)
937
- new_time = time.perf_counter()
938
- delta_time = new_time - last_time
939
- last_time = new_time
940
 
941
- # train step
942
- state, train_metrics = p_train_step(
943
- state, batch, jax_utils.replicate(delta_time)
944
- )
945
- step = unreplicate(state.step)
946
 
947
- if step % training_args.logging_steps == 0 and jax.process_index() == 0:
948
- all_metrics = metrics_logger.get_all_train_metrics(train_metrics, state)
949
- metrics_logger.log(all_metrics, step=step, prefix="train")
 
 
950
 
951
- eval_metrics = None
952
- if training_args.eval_steps and step % training_args.eval_steps == 0:
953
- eval_metrics = run_evaluation()
954
 
955
- if step % training_args.save_steps == 0:
956
- run_save_model(state, eval_metrics)
957
 
958
- # log final train metrics
959
- if train_metrics is not None:
960
- all_metrics = metrics_logger.get_all_train_metrics(train_metrics, state)
961
- metrics_logger.log(all_metrics, step=step, prefix="train")
962
 
963
- epochs.write(
964
- f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"
965
- )
966
 
967
- # Final evaluation
968
- eval_metrics = run_evaluation()
969
 
970
- # save checkpoint after each epoch
971
- run_save_model(state, eval_metrics)
972
 
973
 
974
  if __name__ == "__main__":
 
30
  import datasets
31
  import jax
32
  import jax.numpy as jnp
33
+ import numpy as np
34
  import optax
35
  import transformers
36
  import wandb
37
  from datasets import Dataset
38
  from distributed_shampoo import GraftingType, distributed_shampoo
39
+ from flax.core.frozen_dict import freeze
 
40
  from flax.serialization import from_bytes, to_bytes
41
  from flax.training import train_state
42
+ from flax.training.common_utils import onehot, stack_forest
43
+ from jax.experimental import PartitionSpec, maps
44
+ from jax.experimental.pjit import pjit
45
  from tqdm import tqdm
46
+ from transformers import HfArgumentParser
47
 
48
  from dalle_mini.data import Dataset
49
+ from dalle_mini.model import (
50
+ DalleBart,
51
+ DalleBartConfig,
52
+ DalleBartTokenizer,
53
+ set_partitions,
54
+ )
55
 
56
  logger = logging.getLogger(__name__)
57
 
 
230
  "help": 'The optimizer to use. Can be "distributed_shampoo" (default), "adam" or "adafactor"'
231
  },
232
  )
 
233
  beta1: float = field(
234
  default=0.9,
235
  metadata={"help": "Beta1 for Adam & Distributed Shampoo."},
 
338
  metadata={"help": "Verify that TPU is not in use."},
339
  )
340
 
341
+ mp_devices: Optional[int] = field(
342
+ default=1,
343
+ metadata={
344
+ "help": "Number of devices required for model parallelism. The other dimension of available devices is used for data parallelism."
345
+ },
346
+ )
347
+
348
  def __post_init__(self):
349
  assert self.optim in [
350
  "distributed_shampoo",
 
353
  ], f"Selected optimizer not supported: {self.optim}"
354
  if self.per_device_eval_batch_size is None:
355
  self.per_device_eval_batch_size = self.per_device_train_batch_size
 
 
 
356
  if (
357
  os.path.exists(self.output_dir)
358
  and os.listdir(self.output_dir)
 
363
  f"Output directory ({self.output_dir}) already exists and is not empty."
364
  "Use --overwrite_output_dir to overcome."
365
  )
366
+ assert (
367
+ jax.device_count() % self.mp_devices == 0
368
+ ), f"Number of available devices ({jax.device_count()} must be divisible by number of devices used for model parallelism ({self.mp_devices})."
369
+ self.dp_devices = jax.device_count() // self.mp_devices
370
 
371
 
372
  class TrainState(train_state.TrainState):
 
375
  train_time: float = 0.0 # total time the model trained
376
  train_samples: int = 0 # number of samples seen
377
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
 
379
  class MetricsLogger:
380
  def __init__(self, state):
 
383
 
384
  def get_all_train_metrics(self, train_metrics, state):
385
  """Make a dict of training metrics to be logged"""
386
+ metrics = train_metrics
387
  # get state parameters
388
  state_dict = {
389
+ k.split("_")[-1]: getattr(state, k)
390
  for k in ["epoch", "train_time", "train_samples"]
391
  }
392
  # timing metrics
393
+ new_step = int(state.step)
394
  new_time = time.perf_counter()
395
  if new_step > self.step:
396
  time_per_step = (new_time - self.time) / (new_step - self.step)
 
479
  dtype=getattr(jnp, model_args.dtype),
480
  abstract_init=True,
481
  )
 
 
482
 
483
  # load tokenizer
484
  tokenizer = DalleBartTokenizer.from_pretrained(
 
502
  dtype=getattr(jnp, model_args.dtype),
503
  abstract_init=True,
504
  )
 
 
505
  else:
506
  model = DalleBart(
507
  config,
 
511
 
512
  # Load tokenizer
513
  if model_args.tokenizer_name is not None:
514
+ tokenizer = DalleBartTokenizer.from_pretrained(
515
  model_args.tokenizer_name, use_fast=True
516
  )
517
  else:
 
589
 
590
  learning_rate_fn = create_learning_rate_fn()
591
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
592
  # create adam optimizer
593
  if training_args.optim == "distributed_shampoo":
594
  # parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
 
 
595
  optimizer = distributed_shampoo(
596
  learning_rate_fn,
597
  block_size=training_args.block_size,
 
599
  beta2=training_args.beta2,
600
  diagonal_epsilon=1e-10,
601
  matrix_epsilon=1e-8,
 
602
  start_preconditioning_step=training_args.warmup_steps,
603
  preconditioning_compute_steps=training_args.preconditioning_compute_steps,
604
  statistics_compute_steps=1,
 
621
  b1=training_args.beta1,
622
  b2=training_args.beta2,
623
  eps=training_args.adam_epsilon,
 
 
624
  )
625
  elif training_args.optim == "adafactor":
626
  # We use the default parameters here to initialize adafactor,
627
  # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
628
  optimizer = optax.adafactor(
629
  learning_rate=learning_rate_fn,
 
 
630
  clipping_threshold=training_args.max_grad_norm,
631
  )
632
 
633
+ # get opt_state shape without actual init
634
+ opt_state_shape = jax.eval_shape(lambda x: optimizer.init(x), model.params)
635
+
636
+ # get PartitionSpec for model params
637
+ param_spec = set_partitions(model.params)
638
+
639
+ # create PartitionSpec for opt_state
640
+ def opt_state_spec_per_leaf(x):
641
+ if training_args.optim in ["adam", "adafactor"]:
642
+ if isinstance(x, dict):
643
+ # variables with same structure as params
644
+ return param_spec
645
+ else:
646
+ # other variables such as count
647
+ return None
648
+ else:
649
+ # TODO: create spec for Distributed Shampoo
650
+ raise NotImplementedError
651
+
652
+ opt_state_spec = jax.tree_map(
653
+ opt_state_spec_per_leaf,
654
+ opt_state_shape,
655
+ # return None spec for empty elements
656
+ is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)),
657
+ )
658
+
659
+ # create a mesh
660
+ mesh_shape = (training_args.dp_devices, training_args.mp_devices)
661
+ devices = np.asarray(jax.devices()).reshape(*mesh_shape)
662
+ mesh = maps.Mesh(devices, ("batch", "mp"))
663
+
664
+ # Create state spec
665
+ state_spec = TrainState(
666
+ params=param_spec,
667
+ opt_state=opt_state_spec,
668
+ dropout_rng=None,
669
+ step=None,
670
+ epoch=None,
671
+ train_time=None,
672
+ train_samples=None,
673
  apply_fn=model.__call__,
 
674
  tx=optimizer,
 
675
  )
676
+
677
+ opt_state, attr_state = None, None
678
  if training_args.resume_from_checkpoint is not None:
679
+ # restore opt_state
680
+ with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
681
+ opt_state = from_bytes(opt_state_shape, f.read())
682
+ # need to freeze dict for pjit
683
+ opt_state = jax.tree_map(
684
+ lambda x: freeze(x) if isinstance(x, dict) else x,
685
+ opt_state,
686
+ is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)),
687
+ )
688
+ # restore other attributes
689
+ with (Path(artifact_dir) / "training_state.json").open("r") as f:
690
+ attr_state = json.load(f)
691
+
692
+ # create training state
693
+ def init_state(params, opt_state):
694
+ if training_args.resume_from_checkpoint is None:
695
+ state = TrainState.create(
696
+ apply_fn=model.__call__,
697
+ tx=optimizer,
698
+ params=freeze(params),
699
+ dropout_rng=dropout_rng,
700
+ )
701
+ else:
702
+ state = TrainState(
703
+ apply_fn=model.__call__,
704
+ tx=optimizer,
705
+ params=freeze(params),
706
+ opt_state=opt_state,
707
+ dropout_rng=dropout_rng,
708
+ **attr_state,
709
+ )
710
+ return state
711
+
712
+ with maps.mesh(mesh.devices, mesh.axis_names):
713
+ state = pjit(
714
+ init_state,
715
+ in_axis_resources=(param_spec, opt_state_spec),
716
+ out_axis_resources=state_spec,
717
+ donate_argnums=(0, 1),
718
+ )(freeze(model.params), opt_state)
719
+
720
+ # free memory from large parameters
721
+ del model._params, opt_state
722
 
723
  # label smoothed cross entropy
724
  def loss_fn(logits, labels):
 
729
  # Define gradient update step fn
730
  def train_step(state, batch, delta_time):
731
  dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
732
+ # use a different rng per node
733
+ dropout_rng = jax.random.fold_in(dropout_rng, jax.process_index())
734
 
735
  def compute_loss(params, minibatch):
736
  labels = minibatch.pop("labels")
 
768
  ),
769
  )
770
 
 
771
  state = state.apply_gradients(
772
  grads=grads,
773
  dropout_rng=new_dropout_rng,
 
779
  "loss": loss,
780
  "learning_rate": learning_rate_fn(state.step),
781
  }
 
782
 
783
  return state, metrics
784
 
 
790
 
791
  # summarize metrics
792
  metrics = {"loss": loss}
 
793
  return metrics
794
 
795
  # Create parallel version of the train and eval step
796
+ p_train_step = pjit(
797
+ train_step,
798
+ in_axis_resources=(state_spec, PartitionSpec("batch", None), None),
799
+ out_axis_resources=(state_spec, None),
800
+ donate_argnums=(0,),
801
+ )
802
+ p_eval_step = pjit(
803
+ eval_step,
804
+ in_axis_resources=(param_spec, PartitionSpec("batch", None)),
805
+ out_axis_resources=None,
806
+ )
807
 
808
  logger.info("***** Running training *****")
809
  logger.info(f" Num examples = {len_train_dataset}")
 
838
  }
839
  )
840
 
 
 
 
841
  def run_evaluation():
842
  # ======================== Evaluating ==============================
843
  eval_metrics = []
 
862
  eval_metrics.append(metrics)
863
 
864
  # normalize eval metrics
865
+ eval_metrics = stack_forest(eval_metrics)
866
  eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
867
 
868
  # log metrics
869
+ metrics_logger.log(eval_metrics, step=state.step, prefix="eval")
 
 
870
 
871
  # Print metrics and update progress bar
872
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
 
877
 
878
  def run_save_model(state, eval_metrics=None):
879
  if jax.process_index() == 0:
880
+ params = jax.device_get(state.params)
881
  # save model locally
882
  model.save_pretrained(
883
  training_args.output_dir,
 
888
  tokenizer.save_pretrained(training_args.output_dir)
889
 
890
  # save state
891
+ opt_state = jax.device_get(state.opt_state)
892
  with (Path(training_args.output_dir) / "opt_state.msgpack").open("wb") as f:
893
  f.write(to_bytes(opt_state))
894
  state_dict = {
895
+ k: jax.device_get(getattr(state, k)).item()
896
  for k in ["step", "epoch", "train_time", "train_samples"]
897
  }
898
  with (Path(training_args.output_dir) / "training_state.json").open(
 
953
  last_time = time.perf_counter()
954
  train_metrics = None
955
 
956
+ with maps.mesh(mesh.devices, mesh.axis_names):
957
+ for epoch in epochs:
958
+ state.replace(epoch=epoch)
959
+ # ======================== Training ================================
960
+ metrics_logger.log({"train/epoch": epoch}, step=state.step)
961
+
962
+ # Generate an epoch by shuffling sampling indices from the train dataset
963
+ train_loader = dataset.dataloader(
964
+ "train",
965
+ training_args.per_device_train_batch_size,
966
+ training_args.gradient_accumulation_steps,
967
+ epoch,
968
+ )
969
+ # train
970
+ for batch in tqdm(
971
+ train_loader,
972
+ desc="Training...",
973
+ position=1,
974
+ leave=False,
975
+ total=steps_per_epoch,
976
+ ):
977
 
978
+ # calculate delta time (we have a lag of one step but it's ok)
979
+ new_time = time.perf_counter()
980
+ delta_time = new_time - last_time
981
+ last_time = new_time
982
 
983
+ # train step
984
+ state, train_metrics = p_train_step(state, batch, delta_time)
985
+ step = state.step
 
 
986
 
987
+ if step % training_args.logging_steps == 0 and jax.process_index() == 0:
988
+ all_metrics = metrics_logger.get_all_train_metrics(
989
+ train_metrics, state
990
+ )
991
+ metrics_logger.log(all_metrics, step=step, prefix="train")
992
 
993
+ eval_metrics = None
994
+ if training_args.eval_steps and step % training_args.eval_steps == 0:
995
+ eval_metrics = run_evaluation()
996
 
997
+ if step % training_args.save_steps == 0:
998
+ run_save_model(state, eval_metrics)
999
 
1000
+ # log final train metrics
1001
+ if train_metrics is not None:
1002
+ all_metrics = metrics_logger.get_all_train_metrics(train_metrics, state)
1003
+ metrics_logger.log(all_metrics, step=step, prefix="train")
1004
 
1005
+ epochs.write(
1006
+ f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"
1007
+ )
1008
 
1009
+ # Final evaluation
1010
+ eval_metrics = run_evaluation()
1011
 
1012
+ # save checkpoint after each epoch
1013
+ run_save_model(state, eval_metrics)
1014
 
1015
 
1016
  if __name__ == "__main__":