boris commited on
Commit
47e006f
·
1 Parent(s): 5faf0fd

fix: state.step type

Browse files
Files changed (1) hide show
  1. dev/seq2seq/run_seq2seq_flax.py +14 -12
dev/seq2seq/run_seq2seq_flax.py CHANGED
@@ -416,7 +416,7 @@ def wandb_log(metrics, step=None, prefix=None):
416
  f"{prefix}/{k}" if prefix is not None else k: v for k, v in metrics.items()
417
  }
418
  if step is not None:
419
- log_metrics["train/step"] = unreplicate(step)
420
  wandb.log(log_metrics)
421
 
422
 
@@ -846,7 +846,7 @@ def main():
846
  f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
847
  )
848
  logger.info(
849
- f" Total train batch size (w. parallel & distributed) = {batch_size_per_update}"
850
  )
851
  logger.info(f" Total global steps = {total_steps}")
852
  logger.info(f" Total optimization steps = {total_optimization_steps}")
@@ -854,7 +854,7 @@ def main():
854
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
855
 
856
  # set default x-axis as 'train/step'
857
- wandb_log({}, step=state.step)
858
  wandb.define_metric("*", step_metric="train/step")
859
 
860
  # add interesting config parameters
@@ -893,7 +893,7 @@ def main():
893
  eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
894
 
895
  # log metrics
896
- wandb_log(eval_metrics, step=state.step, prefix="eval")
897
 
898
  # Print metrics and update progress bar
899
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
@@ -956,7 +956,7 @@ def main():
956
  )
957
  # save some space
958
  c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
959
- c.cleanup(wandb.util.from_human_size("5GB"))
960
 
961
  wandb.run.log_artifact(artifact)
962
 
@@ -972,7 +972,8 @@ def main():
972
 
973
  for epoch in epochs:
974
  # ======================== Training ================================
975
- wandb_log({"train/epoch": epoch}, step=state.step)
 
976
 
977
  # Create sampling rng
978
  rng, input_rng = jax.random.split(rng)
@@ -994,19 +995,20 @@ def main():
994
  total=steps_per_epoch,
995
  ):
996
  state, train_metric = p_train_step(state, batch)
 
997
 
998
- if state.step % data_args.log_interval == 0 and jax.process_index() == 0:
999
  # log metrics
1000
- wandb_log(unreplicate(train_metric), step=state.step, prefix="train")
1001
 
1002
- if training_args.eval_steps and state.step % training_args.eval_steps == 0:
1003
  run_evaluation()
1004
 
1005
- if state.step % data_args.save_model_steps == 0:
1006
- run_save_model(state, state.step, epoch)
1007
 
1008
  # log final train metrics
1009
- wandb_log(unreplicate(train_metric), step=state.step, prefix="train")
1010
 
1011
  train_metric = unreplicate(train_metric)
1012
  epochs.write(
 
416
  f"{prefix}/{k}" if prefix is not None else k: v for k, v in metrics.items()
417
  }
418
  if step is not None:
419
+ log_metrics["train/step"] = step
420
  wandb.log(log_metrics)
421
 
422
 
 
846
  f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
847
  )
848
  logger.info(
849
+ f" Total train batch size (w. parallel, distributed & gradient accumulation) = {batch_size_per_update}"
850
  )
851
  logger.info(f" Total global steps = {total_steps}")
852
  logger.info(f" Total optimization steps = {total_optimization_steps}")
 
854
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
855
 
856
  # set default x-axis as 'train/step'
857
+ wandb_log({}, step=unreplicate(state.step))
858
  wandb.define_metric("*", step_metric="train/step")
859
 
860
  # add interesting config parameters
 
893
  eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
894
 
895
  # log metrics
896
+ wandb_log(eval_metrics, step=unreplicate(state.step), prefix="eval")
897
 
898
  # Print metrics and update progress bar
899
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
 
956
  )
957
  # save some space
958
  c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
959
+ c.cleanup("5GB")
960
 
961
  wandb.run.log_artifact(artifact)
962
 
 
972
 
973
  for epoch in epochs:
974
  # ======================== Training ================================
975
+ step = unreplicate(state.step)
976
+ wandb_log({"train/epoch": epoch}, step=step)
977
 
978
  # Create sampling rng
979
  rng, input_rng = jax.random.split(rng)
 
995
  total=steps_per_epoch,
996
  ):
997
  state, train_metric = p_train_step(state, batch)
998
+ step = unreplicate(state.step)
999
 
1000
+ if step % data_args.log_interval == 0 and jax.process_index() == 0:
1001
  # log metrics
1002
+ wandb_log(unreplicate(train_metric), step=step, prefix="train")
1003
 
1004
+ if training_args.eval_steps and step % training_args.eval_steps == 0:
1005
  run_evaluation()
1006
 
1007
+ if step % data_args.save_model_steps == 0:
1008
+ run_save_model(state, step, epoch)
1009
 
1010
  # log final train metrics
1011
+ wandb_log(unreplicate(train_metric), step=step, prefix="train")
1012
 
1013
  train_metric = unreplicate(train_metric)
1014
  epochs.write(