boris commited on
Commit
baa52db
1 Parent(s): 5f44d34

feat(train): merge logged dict

Browse files
Files changed (1) hide show
  1. tools/train/train.py +8 -8
tools/train/train.py CHANGED
@@ -797,7 +797,7 @@ def main():
797
 
798
  # init variables
799
  last_time = time.perf_counter()
800
- train_metric = None
801
 
802
  for epoch in epochs:
803
  state.replace(epoch=jax_utils.replicate(epoch))
@@ -821,20 +821,20 @@ def main():
821
  last_time = new_time
822
 
823
  # train step
824
- state, train_metric = p_train_step(
825
  state, batch, jax_utils.replicate(delta_time)
826
  )
827
  step = unreplicate(state.step)
828
 
829
  if step % training_args.logging_steps == 0 and jax.process_index() == 0:
830
  # log metrics
831
- wandb_log(unreplicate(train_metric), step=step, prefix="train")
832
  # log state parameters
833
  state_dict = {
834
  k.split("_")[-1]: unreplicate(getattr(state, k))
835
  for k in ["epoch", "train_time", "train_samples"]
836
  }
837
- wandb_log(state_dict, step=step, prefix="train")
838
 
839
  eval_metrics = None
840
  if training_args.eval_steps and step % training_args.eval_steps == 0:
@@ -844,12 +844,12 @@ def main():
844
  run_save_model(state, eval_metrics)
845
 
846
  # log final train metrics
847
- if train_metric is not None:
848
- train_metric = unreplicate(train_metric)
849
- wandb_log(train_metric, step=step, prefix="train")
850
 
851
  epochs.write(
852
- f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
853
  )
854
 
855
  # Final evaluation
 
797
 
798
  # init variables
799
  last_time = time.perf_counter()
800
+ train_metrics = None
801
 
802
  for epoch in epochs:
803
  state.replace(epoch=jax_utils.replicate(epoch))
 
821
  last_time = new_time
822
 
823
  # train step
824
+ state, train_metrics = p_train_step(
825
  state, batch, jax_utils.replicate(delta_time)
826
  )
827
  step = unreplicate(state.step)
828
 
829
  if step % training_args.logging_steps == 0 and jax.process_index() == 0:
830
  # log metrics
831
+ metrics = unreplicate(train_metrics)
832
  # log state parameters
833
  state_dict = {
834
  k.split("_")[-1]: unreplicate(getattr(state, k))
835
  for k in ["epoch", "train_time", "train_samples"]
836
  }
837
+ wandb_log({**metrics, **state_dict}, step=step, prefix="train")
838
 
839
  eval_metrics = None
840
  if training_args.eval_steps and step % training_args.eval_steps == 0:
 
844
  run_save_model(state, eval_metrics)
845
 
846
  # log final train metrics
847
+ if train_metrics is not None:
848
+ train_metrics = unreplicate(train_metrics)
849
+ wandb_log(train_metrics, step=step, prefix="train")
850
 
851
  epochs.write(
852
+ f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"
853
  )
854
 
855
  # Final evaluation