boris commited on
Commit
86c6c90
1 Parent(s): e2400cc
Files changed (1) hide show
  1. dev/seq2seq/run_seq2seq_flax.py +1 -1
dev/seq2seq/run_seq2seq_flax.py CHANGED
@@ -801,8 +801,8 @@ def main():
801
  p_eval_step = jax.pmap(eval_step, "batch")
802
 
803
  # Replicate the train state on each device
804
- state = state.replicate()
805
  del model._params
 
806
 
807
  logger.info("***** Running training *****")
808
  logger.info(f" Num examples = {len_train_dataset}")
 
801
  p_eval_step = jax.pmap(eval_step, "batch")
802
 
803
  # Replicate the train state on each device
 
804
  del model._params
805
+ state = state.replicate()
806
 
807
  logger.info("***** Running training *****")
808
  logger.info(f" Num examples = {len_train_dataset}")