boris commited on
Commit
4a4820f
1 Parent(s): 272552a

feat: get rid of global_step + log more metrics

Browse files
Files changed (1) hide show
  1. dev/seq2seq/run_seq2seq_flax.py +29 -20
dev/seq2seq/run_seq2seq_flax.py CHANGED
@@ -419,11 +419,10 @@ def create_learning_rate_fn(
419
  def wandb_log(metrics, step=None, prefix=None):
420
  if jax.process_index() == 0:
421
  log_metrics = {
422
- f"{prefix}/{k}" if prefix is not None else k: jax.device_get(v)
423
- for k, v in metrics.items()
424
  }
425
  if step is not None:
426
- log_metrics["train/step"] = step
427
  wandb.log(log_metrics)
428
 
429
 
@@ -512,10 +511,6 @@ def main():
512
  save_code=True,
513
  )
514
 
515
- # set default x-axis as 'train/step'
516
- wandb.define_metric("train/step")
517
- wandb.define_metric("*", step_metric="train/step")
518
-
519
  if model_args.from_checkpoint is not None:
520
  artifact = wandb.run.use_artifact(model_args.from_checkpoint)
521
  artifact_dir = artifact.download()
@@ -867,13 +862,27 @@ def main():
867
  f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
868
  )
869
  logger.info(
870
- f" Total train batch size (w. parallel & distributed) = {train_batch_size * training_args.gradient_accumulation_steps}"
871
  )
872
  logger.info(f" Total global steps = {total_steps}")
873
  logger.info(f" Total optimization steps = {total_optimization_steps}")
874
 
875
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
876
- global_step = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
877
 
878
  def run_evaluation():
879
  # ======================== Evaluating ==============================
@@ -900,7 +909,7 @@ def main():
900
  eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
901
 
902
  # log metrics
903
- wandb_log(eval_metrics, step=global_step, prefix="eval")
904
 
905
  # Print metrics and update progress bar
906
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
@@ -923,6 +932,7 @@ def main():
923
  tokenizer.save_pretrained(training_args.output_dir)
924
 
925
  # save state
 
926
  state = unreplicate(state)
927
  with (Path(training_args.output_dir) / "opt_state.msgpack").open("wb") as f:
928
  f.write(to_bytes(state.opt_state))
@@ -978,7 +988,7 @@ def main():
978
 
979
  for epoch in epochs:
980
  # ======================== Training ================================
981
- wandb_log({"train/epoch": epoch}, step=global_step)
982
 
983
  # Create sampling rng
984
  rng, input_rng = jax.random.split(rng)
@@ -999,21 +1009,20 @@ def main():
999
  leave=False,
1000
  total=steps_per_epoch,
1001
  ):
1002
- global_step += 1
1003
  state, train_metric = p_train_step(state, batch)
1004
 
1005
- if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
1006
  # log metrics
1007
- wandb_log(unreplicate(train_metric), step=global_step, prefix="train")
1008
 
1009
- if training_args.eval_steps and global_step % training_args.eval_steps == 0:
1010
  run_evaluation()
1011
 
1012
- if global_step % data_args.save_model_steps == 0:
1013
- run_save_model(state, global_step, epoch)
1014
 
1015
  # log final train metrics
1016
- wandb_log(unreplicate(train_metric), step=global_step, prefix="train")
1017
 
1018
  train_metric = unreplicate(train_metric)
1019
  epochs.write(
@@ -1023,8 +1032,8 @@ def main():
1023
  # Final evaluation
1024
  eval_metrics = run_evaluation()
1025
 
1026
- # save checkpoint after each epoch and push checkpoint to the hub
1027
- run_save_model(state, global_step, epoch, eval_metrics)
1028
 
1029
 
1030
  if __name__ == "__main__":
 
419
  def wandb_log(metrics, step=None, prefix=None):
420
  if jax.process_index() == 0:
421
  log_metrics = {
422
+ f"{prefix}/{k}" if prefix is not None else k: v for k, v in metrics.items()
 
423
  }
424
  if step is not None:
425
+ log_metrics["train/step"] = unreplicate(step)
426
  wandb.log(log_metrics)
427
 
428
 
 
511
  save_code=True,
512
  )
513
 
 
 
 
 
514
  if model_args.from_checkpoint is not None:
515
  artifact = wandb.run.use_artifact(model_args.from_checkpoint)
516
  artifact_dir = artifact.download()
 
862
  f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
863
  )
864
  logger.info(
865
+ f" Total train batch size (w. parallel & distributed) = {batch_size_per_update}"
866
  )
867
  logger.info(f" Total global steps = {total_steps}")
868
  logger.info(f" Total optimization steps = {total_optimization_steps}")
869
 
870
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
871
+
872
+ # set default x-axis as 'train/step'
873
+ wandb_log({}, step=state.step)
874
+ wandb.define_metric("*", step_metric="train/step")
875
+
876
+ # add interesting config parameters
877
+ wandb.config.update(
878
+ {
879
+ "len_train": len_train_dataset,
880
+ "len_eval": len_eval_dataset,
881
+ "batch_size_per_update": batch_size_per_update,
882
+ "total_steps": total_steps,
883
+ "total_optimization_steps": total_optimization_steps,
884
+ }
885
+ )
886
 
887
  def run_evaluation():
888
  # ======================== Evaluating ==============================
 
909
  eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
910
 
911
  # log metrics
912
+ wandb_log(eval_metrics, step=state.step, prefix="eval")
913
 
914
  # Print metrics and update progress bar
915
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
 
932
  tokenizer.save_pretrained(training_args.output_dir)
933
 
934
  # save state
935
+ # TODO: maybe we should just save the full state object without params
936
  state = unreplicate(state)
937
  with (Path(training_args.output_dir) / "opt_state.msgpack").open("wb") as f:
938
  f.write(to_bytes(state.opt_state))
 
988
 
989
  for epoch in epochs:
990
  # ======================== Training ================================
991
+ wandb_log({"train/epoch": epoch}, step=state.step)
992
 
993
  # Create sampling rng
994
  rng, input_rng = jax.random.split(rng)
 
1009
  leave=False,
1010
  total=steps_per_epoch,
1011
  ):
 
1012
  state, train_metric = p_train_step(state, batch)
1013
 
1014
+ if state.step % data_args.log_interval == 0 and jax.process_index() == 0:
1015
  # log metrics
1016
+ wandb_log(unreplicate(train_metric), step=state.step, prefix="train")
1017
 
1018
+ if training_args.eval_steps and state.step % training_args.eval_steps == 0:
1019
  run_evaluation()
1020
 
1021
+ if state.step % data_args.save_model_steps == 0:
1022
+ run_save_model(state, state.step, epoch)
1023
 
1024
  # log final train metrics
1025
+ wandb_log(unreplicate(train_metric), step=state.step, prefix="train")
1026
 
1027
  train_metric = unreplicate(train_metric)
1028
  epochs.write(
 
1032
  # Final evaluation
1033
  eval_metrics = run_evaluation()
1034
 
1035
+ # save checkpoint after each epoch
1036
+ run_save_model(state, state.step, epoch, eval_metrics)
1037
 
1038
 
1039
  if __name__ == "__main__":