test
Browse files
events.out.tfevents.1667932351.t1v-n-98890786-w-1.1781335.0.v2
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:ee2a2884488898757efca40de1cb207cf04f9195b44731a735abe0d026c0f1ca
|
3 |
-
size 40
|
|
|
|
|
|
|
|
run_mlm_flax_stream.py
CHANGED
@@ -450,7 +450,11 @@ if __name__ == "__main__":
|
|
450 |
# Store some constant
|
451 |
num_epochs = int(training_args.num_train_epochs)
|
452 |
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() * jax.process_count()
|
|
|
453 |
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() * jax_process_count()
|
|
|
|
|
|
|
454 |
|
455 |
# define number steps per stream epoch
|
456 |
num_train_steps = data_args.num_train_steps
|
|
|
450 |
# Store some constant
|
451 |
num_epochs = int(training_args.num_train_epochs)
|
452 |
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() * jax.process_count()
|
453 |
+
<<<<<<< HEAD
|
454 |
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() * jax_process_count()
|
455 |
+
=======
|
456 |
+
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() * jax.process_count()
|
457 |
+
>>>>>>> 3e5f60917796ec70bad7ad043cfcb5559f582cd1
|
458 |
|
459 |
# define number steps per stream epoch
|
460 |
num_train_steps = data_args.num_train_steps
|