Spaces:
Running
Running
fix: use correct key
Browse files
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -338,7 +338,7 @@ def create_learning_rate_fn(
|
|
338 |
|
339 |
def wandb_log(metrics, step=None, prefix=None):
|
340 |
if jax.process_index() == 0:
|
341 |
-
log_metrics = {f'{prefix}/k' if prefix is not None else k: jax.device_get(v) for k,v in metrics.items()}
|
342 |
if step is not None:
|
343 |
log_metrics = {**log_metrics, 'train/step': step}
|
344 |
wandb.log(log_metrics)
|
|
|
338 |
|
339 |
def wandb_log(metrics, step=None, prefix=None):
|
340 |
if jax.process_index() == 0:
|
341 |
+
log_metrics = {f'{prefix}/{k}' if prefix is not None else k: jax.device_get(v) for k,v in metrics.items()}
|
342 |
if step is not None:
|
343 |
log_metrics = {**log_metrics, 'train/step': step}
|
344 |
wandb.log(log_metrics)
|