boris commited on
Commit
14abe8c
1 Parent(s): e4401dd

feat(train): another 25% faster

Browse files
Files changed (1) hide show
  1. tools/train/train.py +21 -21
tools/train/train.py CHANGED
@@ -36,10 +36,10 @@ import transformers
36
  import wandb
37
  from datasets import Dataset
38
  from distributed_shampoo import GraftingType, distributed_shampoo
39
- from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
40
  from flax.serialization import from_bytes, to_bytes
41
  from flax.training import train_state
42
- from flax.training.common_utils import onehot, stack_forest
43
  from jax.experimental import PartitionSpec, maps
44
  from jax.experimental.pjit import pjit, with_sharding_constraint
45
  from tqdm import tqdm
@@ -382,7 +382,7 @@ class TrainState(train_state.TrainState):
382
 
383
  class MetricsLogger:
384
  def __init__(self, state):
385
- self.step = state.step
386
  self.time = time.perf_counter()
387
 
388
  def get_all_train_metrics(self, train_metrics, state):
@@ -792,8 +792,7 @@ def main():
792
 
793
  def compute_loss(params, minibatch, dropout_rng):
794
  # minibatch has dim (batch_size, ...)
795
- minibatch = unfreeze(minibatch)
796
- labels = minibatch.pop("labels")
797
  logits = state.apply_fn(
798
  **minibatch, params=params, dropout_rng=dropout_rng, train=True
799
  )[0]
@@ -883,14 +882,10 @@ def main():
883
 
884
  # Define eval fn
885
  def eval_step(params, batch):
886
- batch = unfreeze(batch)
887
- labels = batch.pop("labels")
888
  logits = model(**batch, params=params, train=False)[0]
889
  loss = loss_fn(logits, labels)
890
-
891
- # summarize metrics
892
- metrics = {"loss": loss}
893
- return metrics
894
 
895
  # Create parallel version of the train and eval step
896
  p_train_step = pjit(
@@ -940,7 +935,6 @@ def main():
940
 
941
  def run_evaluation():
942
  # ======================== Evaluating ==============================
943
- eval_metrics = []
944
  if training_args.do_eval:
945
  eval_loader = dataset.dataloader("eval", eval_batch_size)
946
  eval_steps = (
@@ -948,6 +942,7 @@ def main():
948
  if len_eval_dataset is not None
949
  else None
950
  )
 
951
  for batch in tqdm(
952
  eval_loader,
953
  desc="Evaluating...",
@@ -955,13 +950,15 @@ def main():
955
  leave=False,
956
  total=eval_steps,
957
  ):
958
- # TODO: make this more efficient once training loop is fast
959
- metrics = p_eval_step(state.params, freeze(batch))
960
- eval_metrics.append(metrics)
 
961
 
962
- # normalize eval metrics
963
- eval_metrics = stack_forest(eval_metrics)
964
- eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
 
965
 
966
  # log metrics
967
  metrics_logger.log(eval_metrics, step=state.step, prefix="eval")
@@ -1050,6 +1047,7 @@ def main():
1050
  # init variables
1051
  last_time = time.perf_counter()
1052
  train_metrics = None
 
1053
 
1054
  with maps.mesh(mesh.devices, mesh.axis_names):
1055
  for epoch in epochs:
@@ -1088,10 +1086,12 @@ def main():
1088
  ),
1089
  batch,
1090
  )
 
 
1091
 
1092
  # train step
1093
- state, train_metrics = p_train_step(state, freeze(batch), delta_time)
1094
- step = int(state.step)
1095
 
1096
  if step % training_args.logging_steps == 0 and jax.process_index() == 0:
1097
  all_metrics = metrics_logger.get_all_train_metrics(
@@ -1100,7 +1100,7 @@ def main():
1100
  metrics_logger.log(all_metrics, step=step, prefix="train")
1101
 
1102
  eval_metrics = None
1103
- if training_args.eval_steps and step % training_args.eval_steps == 0:
1104
  eval_metrics = run_evaluation()
1105
 
1106
  if step % training_args.save_steps == 0:
 
36
  import wandb
37
  from datasets import Dataset
38
  from distributed_shampoo import GraftingType, distributed_shampoo
39
+ from flax.core.frozen_dict import FrozenDict, freeze
40
  from flax.serialization import from_bytes, to_bytes
41
  from flax.training import train_state
42
+ from flax.training.common_utils import onehot
43
  from jax.experimental import PartitionSpec, maps
44
  from jax.experimental.pjit import pjit, with_sharding_constraint
45
  from tqdm import tqdm
 
382
 
383
  class MetricsLogger:
384
  def __init__(self, state):
385
+ self.step = int(state.step)
386
  self.time = time.perf_counter()
387
 
388
  def get_all_train_metrics(self, train_metrics, state):
 
792
 
793
  def compute_loss(params, minibatch, dropout_rng):
794
  # minibatch has dim (batch_size, ...)
795
+ minibatch, labels = minibatch.pop("labels")
 
796
  logits = state.apply_fn(
797
  **minibatch, params=params, dropout_rng=dropout_rng, train=True
798
  )[0]
 
882
 
883
  # Define eval fn
884
  def eval_step(params, batch):
885
+ batch, labels = batch.pop("labels")
 
886
  logits = model(**batch, params=params, train=False)[0]
887
  loss = loss_fn(logits, labels)
888
+ return loss
 
 
 
889
 
890
  # Create parallel version of the train and eval step
891
  p_train_step = pjit(
 
935
 
936
  def run_evaluation():
937
  # ======================== Evaluating ==============================
 
938
  if training_args.do_eval:
939
  eval_loader = dataset.dataloader("eval", eval_batch_size)
940
  eval_steps = (
 
942
  if len_eval_dataset is not None
943
  else None
944
  )
945
+ eval_loss = []
946
  for batch in tqdm(
947
  eval_loader,
948
  desc="Evaluating...",
 
950
  leave=False,
951
  total=eval_steps,
952
  ):
953
+ # freeze batch to pass safely to JAX transforms
954
+ batch = freeze(batch)
955
+ # accumulate losses async
956
+ eval_loss.append(p_eval_step(state.params, batch))
957
 
958
+ # get the mean of the loss
959
+ eval_loss = jnp.stack(eval_loss)
960
+ eval_loss = jnp.mean(eval_loss)
961
+ eval_metrics = {"loss": eval_loss}
962
 
963
  # log metrics
964
  metrics_logger.log(eval_metrics, step=state.step, prefix="eval")
 
1047
  # init variables
1048
  last_time = time.perf_counter()
1049
  train_metrics = None
1050
+ step = int(state.step)
1051
 
1052
  with maps.mesh(mesh.devices, mesh.axis_names):
1053
  for epoch in epochs:
 
1086
  ),
1087
  batch,
1088
  )
1089
+ # freeze batch to pass safely to jax transforms
1090
+ batch = freeze(batch)
1091
 
1092
  # train step
1093
+ state, train_metrics = p_train_step(state, batch, delta_time)
1094
+ step += 1
1095
 
1096
  if step % training_args.logging_steps == 0 and jax.process_index() == 0:
1097
  all_metrics = metrics_logger.get_all_train_metrics(
 
1100
  metrics_logger.log(all_metrics, step=step, prefix="train")
1101
 
1102
  eval_metrics = None
1103
+ if step % training_args.eval_steps == 0:
1104
  eval_metrics = run_evaluation()
1105
 
1106
  if step % training_args.save_steps == 0: