boris commited on
Commit
b20769d
1 Parent(s): 3fef9c1

fix: use correct key

Browse files
Files changed (1) hide show
  1. seq2seq/run_seq2seq_flax.py +1 -1
seq2seq/run_seq2seq_flax.py CHANGED
@@ -338,7 +338,7 @@ def create_learning_rate_fn(
338
 
339
  def wandb_log(metrics, step=None, prefix=None):
340
  if jax.process_index() == 0:
341
- log_metrics = {f'{prefix}/k' if prefix is not None else k: jax.device_get(v) for k,v in metrics.items()}
342
  if step is not None:
343
  log_metrics = {**log_metrics, 'train/step': step}
344
  wandb.log(log_metrics)
 
338
 
339
  def wandb_log(metrics, step=None, prefix=None):
340
  if jax.process_index() == 0:
341
+ log_metrics = {f'{prefix}/{k}' if prefix is not None else k: jax.device_get(v) for k,v in metrics.items()}
342
  if step is not None:
343
  log_metrics = {**log_metrics, 'train/step': step}
344
  wandb.log(log_metrics)