add support for v3-32
Browse files- run_mlm_flax_stream.py +10 -1
run_mlm_flax_stream.py
CHANGED
@@ -551,6 +551,10 @@ if __name__ == "__main__":
|
|
551 |
# define number steps per stream epoch
|
552 |
num_train_steps = data_args.num_train_steps
|
553 |
|
|
|
|
|
|
|
|
|
554 |
# Create learning rate schedule
|
555 |
warmup_fn = optax.linear_schedule(
|
556 |
init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
|
@@ -714,8 +718,13 @@ if __name__ == "__main__":
|
|
714 |
# process input samples
|
715 |
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
716 |
|
|
|
|
|
|
|
|
|
|
|
717 |
# Model forward
|
718 |
-
model_inputs = shard(
|
719 |
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
720 |
|
721 |
train_metrics.append(train_metric)
|
|
|
551 |
# define number steps per stream epoch
|
552 |
num_train_steps = data_args.num_train_steps
|
553 |
|
554 |
+
num_of_hosts = jax.process_count()
|
555 |
+
current_host_idx = jax.process_index()
|
556 |
+
|
557 |
+
|
558 |
# Create learning rate schedule
|
559 |
warmup_fn = optax.linear_schedule(
|
560 |
init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
|
|
|
718 |
# process input samples
|
719 |
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
720 |
|
721 |
+
local_host_model_inputs = {
|
722 |
+
key: np.split(model_inputs.data[key], num_of_hosts, axis=0)[current_host_idx]
|
723 |
+
for key, value in model_inputs.data.items()
|
724 |
+
}
|
725 |
+
|
726 |
# Model forward
|
727 |
+
model_inputs = shard(local_host_model_inputs)
|
728 |
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
729 |
|
730 |
train_metrics.append(train_metric)
|