boris commited on
Commit
833a2d5
1 Parent(s): b66b951
Files changed (1) hide show
  1. seq2seq/run_seq2seq_flax.py +2 -2
seq2seq/run_seq2seq_flax.py CHANGED
@@ -216,7 +216,7 @@ class DataTrainingArguments:
216
  default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
217
  )
218
  log_interval: Optional[int] = field(
219
- default=500,
220
  metadata={
221
  "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
222
  "value if set."
@@ -753,7 +753,7 @@ def main():
753
 
754
  if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
755
  for k, v in unreplicate(train_metric).items():
756
- wandb.log(f{'train/{k}': jax.device_get(v), step=global_step)
757
 
758
  train_time += time.time() - train_start
759
 
 
216
  default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
217
  )
218
  log_interval: Optional[int] = field(
219
+ default=5,
220
  metadata={
221
  "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
222
  "value if set."
 
753
 
754
  if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
755
  for k, v in unreplicate(train_metric).items():
756
+ wandb.log(f{'train/{k}': jax.device_get(v)}, step=global_step)
757
 
758
  train_time += time.time() - train_start
759