boris commited on
Commit
a37cd75
1 Parent(s): 7253e56

feat: no need for default values

Browse files
Files changed (1) hide show
  1. dev/seq2seq/run_seq2seq_flax.py +3 -3
dev/seq2seq/run_seq2seq_flax.py CHANGED
@@ -280,9 +280,9 @@ class DataTrainingArguments:
280
 
281
 
282
  class TrainState(train_state.TrainState):
283
- dropout_rng: jnp.ndarray = None
284
- grad_accum: jnp.ndarray = None
285
- optimizer_step: int = None
286
 
287
  def replicate(self):
288
  return jax_utils.replicate(self).replace(
 
280
 
281
 
282
  class TrainState(train_state.TrainState):
283
+ dropout_rng: jnp.ndarray
284
+ grad_accum: jnp.ndarray
285
+ optimizer_step: int
286
 
287
  def replicate(self):
288
  return jax_utils.replicate(self).replace(