test
Browse files- run_mlm_flax_stream.py +0 -4
run_mlm_flax_stream.py
CHANGED
@@ -450,11 +450,7 @@ if __name__ == "__main__":
|
|
450 |
# Store some constant
|
451 |
num_epochs = int(training_args.num_train_epochs)
|
452 |
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() * jax.process_count()
|
453 |
-
<<<<<<< HEAD
|
454 |
-
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() * jax_process_count()
|
455 |
-
=======
|
456 |
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() * jax.process_count()
|
457 |
-
>>>>>>> 3e5f60917796ec70bad7ad043cfcb5559f582cd1
|
458 |
|
459 |
# define number steps per stream epoch
|
460 |
num_train_steps = data_args.num_train_steps
|
|
|
450 |
# Store some constant
|
451 |
num_epochs = int(training_args.num_train_epochs)
|
452 |
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() * jax.process_count()
|
|
|
|
|
|
|
453 |
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() * jax.process_count()
|
|
|
454 |
|
455 |
# define number steps per stream epoch
|
456 |
num_train_steps = data_args.num_train_steps
|