pere commited on
Commit
dfca229
2 Parent(s): 8751206 3e5f609
events.out.tfevents.1667932351.t1v-n-98890786-w-1.1781335.0.v2 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:ee2a2884488898757efca40de1cb207cf04f9195b44731a735abe0d026c0f1ca
3
- size 40
 
 
 
 
run_mlm_flax_stream.py CHANGED
@@ -450,7 +450,11 @@ if __name__ == "__main__":
450
  # Store some constant
451
  num_epochs = int(training_args.num_train_epochs)
452
  train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() * jax.process_count()
 
453
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() * jax_process_count()
 
 
 
454
 
455
  # define number steps per stream epoch
456
  num_train_steps = data_args.num_train_steps
 
450
  # Store some constant
451
  num_epochs = int(training_args.num_train_epochs)
452
  train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() * jax.process_count()
453
+ <<<<<<< HEAD
454
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() * jax_process_count()
455
+ =======
456
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() * jax.process_count()
457
+ >>>>>>> 3e5f60917796ec70bad7ad043cfcb5559f582cd1
458
 
459
  # define number steps per stream epoch
460
  num_train_steps = data_args.num_train_steps