boris commited on
Commit
2d212d8
1 Parent(s): df1fe19

feat(train): different rng per node

Browse files
Files changed (1) hide show
  1. tools/train/train.py +2 -0
tools/train/train.py CHANGED
@@ -727,6 +727,8 @@ def main():
727
  # Define gradient update step fn
728
  def train_step(state, batch, delta_time):
729
  dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
 
 
730
 
731
  def compute_loss(params, minibatch):
732
  labels = minibatch.pop("labels")
 
727
  # Define gradient update step fn
728
  def train_step(state, batch, delta_time):
729
  dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
730
+ # use a different rng per node
731
+ dropout_rng = jax.random.fold_in(dropout_rng, jax.process_index())
732
 
733
  def compute_loss(params, minibatch):
734
  labels = minibatch.pop("labels")