Spaces:
Running
Running
fix: OOM
Browse files
dev/seq2seq/run_seq2seq_flax.py
CHANGED
@@ -801,8 +801,8 @@ def main():
|
|
801 |
p_eval_step = jax.pmap(eval_step, "batch")
|
802 |
|
803 |
# Replicate the train state on each device
|
804 |
-
state = state.replicate()
|
805 |
del model._params
|
|
|
806 |
|
807 |
logger.info("***** Running training *****")
|
808 |
logger.info(f" Num examples = {len_train_dataset}")
|
|
|
801 |
p_eval_step = jax.pmap(eval_step, "batch")
|
802 |
|
803 |
# Replicate the train state on each device
|
|
|
804 |
del model._params
|
805 |
+
state = state.replicate()
|
806 |
|
807 |
logger.info("***** Running training *****")
|
808 |
logger.info(f" Num examples = {len_train_dataset}")
|