pere commited on
Commit
a07d183
2 Parent(s): 2d0b847 6bece68
events.out.tfevents.1667932958.t1v-n-98890786-w-2.1496415.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:148055659b4645e86d9078ef467300ea5d08d7889d96ea2ef876fada84c50b90
3
+ size 40
run_mlm_flax_stream.py CHANGED
@@ -450,7 +450,7 @@ if __name__ == "__main__":
450
  # Store some constant
451
  num_epochs = int(training_args.num_train_epochs)
452
  train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() * jax.process_count()
453
- eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() * jax_process_count()
454
 
455
  # define number steps per stream epoch
456
  num_train_steps = data_args.num_train_steps
 
450
  # Store some constant
451
  num_epochs = int(training_args.num_train_epochs)
452
  train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() * jax.process_count()
453
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() * jax.process_count()
454
 
455
  # define number steps per stream epoch
456
  num_train_steps = data_args.num_train_steps