pere commited on
Commit
dfcb939
1 Parent(s): 7d84292

Update run_mlm_flax_stream.py

Browse files
Files changed (1) hide show
  1. run_mlm_flax_stream.py +2 -2
run_mlm_flax_stream.py CHANGED
@@ -449,8 +449,8 @@ if __name__ == "__main__":
449
 
450
  # Store some constant
451
  num_epochs = int(training_args.num_train_epochs)
452
- train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() * jax.process_count()
453
- eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() * jax.process_count()
454
 
455
  # define number steps per stream epoch
456
  num_train_steps = data_args.num_train_steps
449
 
450
  # Store some constant
451
  num_epochs = int(training_args.num_train_epochs)
452
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
453
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
454
 
455
  # define number steps per stream epoch
456
  num_train_steps = data_args.num_train_steps