boris commited on
Commit
3fef9c1
1 Parent(s): 4c5e5a7

fix: log correct metrics

Browse files
Files changed (1) hide show
  1. seq2seq/run_seq2seq_flax.py +5 -2
seq2seq/run_seq2seq_flax.py CHANGED
@@ -340,7 +340,7 @@ def wandb_log(metrics, step=None, prefix=None):
340
  if jax.process_index() == 0:
341
  log_metrics = {f'{prefix}/k' if prefix is not None else k: jax.device_get(v) for k,v in metrics.items()}
342
  if step is not None:
343
- log_metrics = {**metrics, 'train/step': step}
344
  wandb.log(log_metrics)
345
 
346
 
@@ -791,10 +791,13 @@ def main():
791
 
792
  if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
793
  # log metrics
794
- wandb_log(unreplicate(train_metric), step=global_step, prefix='tran')
795
 
796
  if global_step % training_args.eval_steps == 0:
797
  run_evaluation()
 
 
 
798
 
799
  train_time += time.time() - train_start
800
  train_metric = unreplicate(train_metric)
 
340
  if jax.process_index() == 0:
341
  log_metrics = {f'{prefix}/k' if prefix is not None else k: jax.device_get(v) for k,v in metrics.items()}
342
  if step is not None:
343
+ log_metrics = {**log_metrics, 'train/step': step}
344
  wandb.log(log_metrics)
345
 
346
 
 
791
 
792
  if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
793
  # log metrics
794
+ wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
795
 
796
  if global_step % training_args.eval_steps == 0:
797
  run_evaluation()
798
+
799
+ # log final train metrics
800
+ wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
801
 
802
  train_time += time.time() - train_start
803
  train_metric = unreplicate(train_metric)