Spaces:
Running
Running
feat: no need for default values
Browse files
dev/seq2seq/run_seq2seq_flax.py
CHANGED
@@ -280,9 +280,9 @@ class DataTrainingArguments:
|
|
280 |
|
281 |
|
282 |
class TrainState(train_state.TrainState):
|
283 |
-
dropout_rng: jnp.ndarray
|
284 |
-
grad_accum: jnp.ndarray
|
285 |
-
optimizer_step: int
|
286 |
|
287 |
def replicate(self):
|
288 |
return jax_utils.replicate(self).replace(
|
|
|
280 |
|
281 |
|
282 |
class TrainState(train_state.TrainState):
|
283 |
+
dropout_rng: jnp.ndarray
|
284 |
+
grad_accum: jnp.ndarray
|
285 |
+
optimizer_step: int
|
286 |
|
287 |
def replicate(self):
|
288 |
return jax_utils.replicate(self).replace(
|