boris commited on
Commit
fdf7698
1 Parent(s): 15993e3

feat: load data first

Browse files
Files changed (1) hide show
  1. tools/train/train.py +3 -3
tools/train/train.py CHANGED
@@ -375,9 +375,6 @@ def main():
375
  datasets.utils.logging.set_verbosity_error()
376
  transformers.utils.logging.set_verbosity_error()
377
 
378
- logger.info(f"Local TPUs: {jax.local_device_count()}")
379
- assert jax.local_device_count() == 8, "TPUs in use, please check running processes"
380
-
381
  # Set the verbosity to info of the Transformers logger (on main process only):
382
  logger.info(f"Training/evaluation parameters {training_args}")
383
 
@@ -388,6 +385,9 @@ def main():
388
  do_eval=training_args.do_eval,
389
  )
390
 
 
 
 
391
  # Set up wandb run
392
  if jax.process_index() == 0:
393
  wandb.init(
 
375
  datasets.utils.logging.set_verbosity_error()
376
  transformers.utils.logging.set_verbosity_error()
377
 
 
 
 
378
  # Set the verbosity to info of the Transformers logger (on main process only):
379
  logger.info(f"Training/evaluation parameters {training_args}")
380
 
 
385
  do_eval=training_args.do_eval,
386
  )
387
 
388
+ logger.info(f"Local TPUs: {jax.local_device_count()}")
389
+ assert jax.local_device_count() == 8, "TPUs in use, please check running processes"
390
+
391
  # Set up wandb run
392
  if jax.process_index() == 0:
393
  wandb.init(