test
Browse files
events.out.tfevents.1667932958.t1v-n-98890786-w-2.1496415.0.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:148055659b4645e86d9078ef467300ea5d08d7889d96ea2ef876fada84c50b90
|
3 |
+
size 40
|
run_mlm_flax_stream.py
CHANGED
@@ -450,7 +450,7 @@ if __name__ == "__main__":
|
|
450 |
# Store some constant
|
451 |
num_epochs = int(training_args.num_train_epochs)
|
452 |
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() * jax.process_count()
|
453 |
-
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() *
|
454 |
|
455 |
# define number steps per stream epoch
|
456 |
num_train_steps = data_args.num_train_steps
|
|
|
450 |
# Store some constant
|
451 |
num_epochs = int(training_args.num_train_epochs)
|
452 |
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() * jax.process_count()
|
453 |
+
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() * jax.process_count()
|
454 |
|
455 |
# define number steps per stream epoch
|
456 |
num_train_steps = data_args.num_train_steps
|