boris commited on
Commit
8ba598c
1 Parent(s): 06f1345

fix: wandb logging with sync_tensorboard

Browse files
Files changed (1) hide show
  1. seq2seq/run_seq2seq_flax.py +2 -1
seq2seq/run_seq2seq_flax.py CHANGED
@@ -755,7 +755,8 @@ def main():
755
 
756
  if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
757
  for k, v in unreplicate(train_metric).items():
758
- wandb.log({f"train/{k}": jax.device_get(v)}, step=global_step)
 
759
 
760
  train_time += time.time() - train_start
761
 
 
755
 
756
  if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
757
  for k, v in unreplicate(train_metric).items():
758
+ wandb.log({"train/step": global_step})
759
+ wandb.log({f"train/{k}": jax.device_get(v)})
760
 
761
  train_time += time.time() - train_start
762