pere commited on
Commit
fbc1fb5
1 Parent(s): e249bce
events.out.tfevents.1667934713.t1v-n-98890786-w-0.1515223.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3be7414f5d7c27bbf6cb1a7f2c46ddb7ce0631e60a6a60eed140fd0ab4103ed0
3
+ size 40
run_mlm_flax_stream.py CHANGED
@@ -450,7 +450,7 @@ if __name__ == "__main__":
450
  # Store some constant
451
  num_epochs = int(training_args.num_train_epochs)
452
  train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() * jax.process_count()
453
- eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() * jax_process_count()
454
 
455
  # define number steps per stream epoch
456
  num_train_steps = data_args.num_train_steps
 
450
  # Store some constant
451
  num_epochs = int(training_args.num_train_epochs)
452
  train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() * jax.process_count()
453
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() * jax.process_count()
454
 
455
  # define number steps per stream epoch
456
  num_train_steps = data_args.num_train_steps