boris commited on
Commit
5faf0fd
2 Parent(s): 708a42c 4a4820f

Merge branch 'main' of https://github.com/borisdayma/dalle-mini

Browse files
Files changed (1) hide show
  1. dev/seq2seq/run_seq2seq_flax.py +29 -20
dev/seq2seq/run_seq2seq_flax.py CHANGED
@@ -413,11 +413,10 @@ def create_learning_rate_fn(
413
  def wandb_log(metrics, step=None, prefix=None):
414
  if jax.process_index() == 0:
415
  log_metrics = {
416
- f"{prefix}/{k}" if prefix is not None else k: jax.device_get(v)
417
- for k, v in metrics.items()
418
  }
419
  if step is not None:
420
- log_metrics["train/step"] = step
421
  wandb.log(log_metrics)
422
 
423
 
@@ -506,10 +505,6 @@ def main():
506
  save_code=True,
507
  )
508
 
509
- # set default x-axis as 'train/step'
510
- wandb.define_metric("train/step")
511
- wandb.define_metric("*", step_metric="train/step")
512
-
513
  if model_args.from_checkpoint is not None:
514
  artifact = wandb.run.use_artifact(model_args.from_checkpoint)
515
  artifact_dir = artifact.download()
@@ -851,13 +846,27 @@ def main():
851
  f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
852
  )
853
  logger.info(
854
- f" Total train batch size (w. parallel & distributed) = {train_batch_size * training_args.gradient_accumulation_steps}"
855
  )
856
  logger.info(f" Total global steps = {total_steps}")
857
  logger.info(f" Total optimization steps = {total_optimization_steps}")
858
 
859
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
860
- global_step = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
861
 
862
  def run_evaluation():
863
  # ======================== Evaluating ==============================
@@ -884,7 +893,7 @@ def main():
884
  eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
885
 
886
  # log metrics
887
- wandb_log(eval_metrics, step=global_step, prefix="eval")
888
 
889
  # Print metrics and update progress bar
890
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
@@ -907,6 +916,7 @@ def main():
907
  tokenizer.save_pretrained(training_args.output_dir)
908
 
909
  # save state
 
910
  state = unreplicate(state)
911
  with (Path(training_args.output_dir) / "opt_state.msgpack").open("wb") as f:
912
  f.write(to_bytes(state.opt_state))
@@ -962,7 +972,7 @@ def main():
962
 
963
  for epoch in epochs:
964
  # ======================== Training ================================
965
- wandb_log({"train/epoch": epoch}, step=global_step)
966
 
967
  # Create sampling rng
968
  rng, input_rng = jax.random.split(rng)
@@ -983,21 +993,20 @@ def main():
983
  leave=False,
984
  total=steps_per_epoch,
985
  ):
986
- global_step += 1
987
  state, train_metric = p_train_step(state, batch)
988
 
989
- if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
990
  # log metrics
991
- wandb_log(unreplicate(train_metric), step=global_step, prefix="train")
992
 
993
- if training_args.eval_steps and global_step % training_args.eval_steps == 0:
994
  run_evaluation()
995
 
996
- if global_step % data_args.save_model_steps == 0:
997
- run_save_model(state, global_step, epoch)
998
 
999
  # log final train metrics
1000
- wandb_log(unreplicate(train_metric), step=global_step, prefix="train")
1001
 
1002
  train_metric = unreplicate(train_metric)
1003
  epochs.write(
@@ -1007,8 +1016,8 @@ def main():
1007
  # Final evaluation
1008
  eval_metrics = run_evaluation()
1009
 
1010
- # save checkpoint after each epoch and push checkpoint to the hub
1011
- run_save_model(state, global_step, epoch, eval_metrics)
1012
 
1013
 
1014
  if __name__ == "__main__":
 
413
  def wandb_log(metrics, step=None, prefix=None):
414
  if jax.process_index() == 0:
415
  log_metrics = {
416
+ f"{prefix}/{k}" if prefix is not None else k: v for k, v in metrics.items()
 
417
  }
418
  if step is not None:
419
+ log_metrics["train/step"] = unreplicate(step)
420
  wandb.log(log_metrics)
421
 
422
 
 
505
  save_code=True,
506
  )
507
 
 
 
 
 
508
  if model_args.from_checkpoint is not None:
509
  artifact = wandb.run.use_artifact(model_args.from_checkpoint)
510
  artifact_dir = artifact.download()
 
846
  f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
847
  )
848
  logger.info(
849
+ f" Total train batch size (w. parallel & distributed) = {batch_size_per_update}"
850
  )
851
  logger.info(f" Total global steps = {total_steps}")
852
  logger.info(f" Total optimization steps = {total_optimization_steps}")
853
 
854
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
855
+
856
+ # set default x-axis as 'train/step'
857
+ wandb_log({}, step=state.step)
858
+ wandb.define_metric("*", step_metric="train/step")
859
+
860
+ # add interesting config parameters
861
+ wandb.config.update(
862
+ {
863
+ "len_train": len_train_dataset,
864
+ "len_eval": len_eval_dataset,
865
+ "batch_size_per_update": batch_size_per_update,
866
+ "total_steps": total_steps,
867
+ "total_optimization_steps": total_optimization_steps,
868
+ }
869
+ )
870
 
871
  def run_evaluation():
872
  # ======================== Evaluating ==============================
 
893
  eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
894
 
895
  # log metrics
896
+ wandb_log(eval_metrics, step=state.step, prefix="eval")
897
 
898
  # Print metrics and update progress bar
899
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
 
916
  tokenizer.save_pretrained(training_args.output_dir)
917
 
918
  # save state
919
+ # TODO: maybe we should just save the full state object without params
920
  state = unreplicate(state)
921
  with (Path(training_args.output_dir) / "opt_state.msgpack").open("wb") as f:
922
  f.write(to_bytes(state.opt_state))
 
972
 
973
  for epoch in epochs:
974
  # ======================== Training ================================
975
+ wandb_log({"train/epoch": epoch}, step=state.step)
976
 
977
  # Create sampling rng
978
  rng, input_rng = jax.random.split(rng)
 
993
  leave=False,
994
  total=steps_per_epoch,
995
  ):
 
996
  state, train_metric = p_train_step(state, batch)
997
 
998
+ if state.step % data_args.log_interval == 0 and jax.process_index() == 0:
999
  # log metrics
1000
+ wandb_log(unreplicate(train_metric), step=state.step, prefix="train")
1001
 
1002
+ if training_args.eval_steps and state.step % training_args.eval_steps == 0:
1003
  run_evaluation()
1004
 
1005
+ if state.step % data_args.save_model_steps == 0:
1006
+ run_save_model(state, state.step, epoch)
1007
 
1008
  # log final train metrics
1009
+ wandb_log(unreplicate(train_metric), step=state.step, prefix="train")
1010
 
1011
  train_metric = unreplicate(train_metric)
1012
  epochs.write(
 
1016
  # Final evaluation
1017
  eval_metrics = run_evaluation()
1018
 
1019
+ # save checkpoint after each epoch
1020
+ run_save_model(state, state.step, epoch, eval_metrics)
1021
 
1022
 
1023
  if __name__ == "__main__":