acul3 commited on
Commit
44e5f40
1 Parent(s): a1ae28e

add support for v3-32

Browse files
Files changed (1) hide show
  1. run_mlm_flax_stream.py +10 -1
run_mlm_flax_stream.py CHANGED
@@ -551,6 +551,10 @@ if __name__ == "__main__":
551
  # define number steps per stream epoch
552
  num_train_steps = data_args.num_train_steps
553
 
 
 
 
 
554
  # Create learning rate schedule
555
  warmup_fn = optax.linear_schedule(
556
  init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
@@ -714,8 +718,13 @@ if __name__ == "__main__":
714
  # process input samples
715
  model_inputs = data_collator(samples, pad_to_multiple_of=16)
716
 
 
 
 
 
 
717
  # Model forward
718
- model_inputs = shard(model_inputs.data)
719
  state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
720
 
721
  train_metrics.append(train_metric)
 
551
  # define number steps per stream epoch
552
  num_train_steps = data_args.num_train_steps
553
 
554
+ num_of_hosts = jax.process_count()
555
+ current_host_idx = jax.process_index()
556
+
557
+
558
  # Create learning rate schedule
559
  warmup_fn = optax.linear_schedule(
560
  init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
 
718
  # process input samples
719
  model_inputs = data_collator(samples, pad_to_multiple_of=16)
720
 
721
+ local_host_model_inputs = {
722
+ key: np.split(model_inputs.data[key], num_of_hosts, axis=0)[current_host_idx]
723
+ for key, value in model_inputs.data.items()
724
+ }
725
+
726
  # Model forward
727
+ model_inputs = shard(local_host_model_inputs)
728
  state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
729
 
730
  train_metrics.append(train_metric)