boris commited on
Commit
9ed6378
1 Parent(s): 061c06b

feat: update defaults

Browse files
Files changed (1) hide show
  1. dev/seq2seq/run_seq2seq_flax.py +4 -4
dev/seq2seq/run_seq2seq_flax.py CHANGED
@@ -253,7 +253,7 @@ class DataTrainingArguments:
253
  metadata={"help": "Overwrite the cached training and evaluation sets"},
254
  )
255
  save_model_steps: Optional[int] = field(
256
- default=3000, # about once every hour in our experiments
257
  metadata={
258
  "help": "For logging the model more frequently. Used only when `log_model` is set."
259
  },
@@ -290,9 +290,9 @@ class DataTrainingArguments:
290
 
291
 
292
  class TrainState(train_state.TrainState):
293
- dropout_rng: jnp.ndarray
294
- grad_accum: jnp.ndarray
295
- optimizer_step: int
296
 
297
  def replicate(self):
298
  return jax_utils.replicate(self).replace(
 
253
  metadata={"help": "Overwrite the cached training and evaluation sets"},
254
  )
255
  save_model_steps: Optional[int] = field(
256
+ default=5000, # about once every 1.5h in our experiments
257
  metadata={
258
  "help": "For logging the model more frequently. Used only when `log_model` is set."
259
  },
 
290
 
291
 
292
  class TrainState(train_state.TrainState):
293
+ dropout_rng: jnp.ndarray = None
294
+ grad_accum: jnp.ndarray = None
295
+ optimizer_step: int = None
296
 
297
  def replicate(self):
298
  return jax_utils.replicate(self).replace(