pere commited on
Commit
6bece68
1 Parent(s): dfca229
Files changed (1) hide show
  1. run_mlm_flax_stream.py +0 -4
run_mlm_flax_stream.py CHANGED
@@ -450,11 +450,7 @@ if __name__ == "__main__":
450
  # Store some constant
451
  num_epochs = int(training_args.num_train_epochs)
452
  train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() * jax.process_count()
453
- <<<<<<< HEAD
454
- eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() * jax_process_count()
455
- =======
456
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() * jax.process_count()
457
- >>>>>>> 3e5f60917796ec70bad7ad043cfcb5559f582cd1
458
 
459
  # define number steps per stream epoch
460
  num_train_steps = data_args.num_train_steps
 
450
  # Store some constant
451
  num_epochs = int(training_args.num_train_epochs)
452
  train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() * jax.process_count()
 
 
 
453
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() * jax.process_count()
 
454
 
455
  # define number steps per stream epoch
456
  num_train_steps = data_args.num_train_steps