test
Browse files
events.out.tfevents.1667934713.t1v-n-98890786-w-0.1515223.0.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3be7414f5d7c27bbf6cb1a7f2c46ddb7ce0631e60a6a60eed140fd0ab4103ed0
|
3 |
+
size 40
|
run_mlm_flax_stream.py
CHANGED
@@ -450,7 +450,7 @@ if __name__ == "__main__":
|
|
450 |
# Store some constant
|
451 |
num_epochs = int(training_args.num_train_epochs)
|
452 |
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() * jax.process_count()
|
453 |
-
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() *
|
454 |
|
455 |
# define number steps per stream epoch
|
456 |
num_train_steps = data_args.num_train_steps
|
|
|
450 |
# Store some constant
|
451 |
num_epochs = int(training_args.num_train_epochs)
|
452 |
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() * jax.process_count()
|
453 |
+
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() * jax.process_count()
|
454 |
|
455 |
# define number steps per stream epoch
|
456 |
num_train_steps = data_args.num_train_steps
|