boris commited on
Commit
87fed1b
1 Parent(s): 3f13951

fix: check local TPU instances only

Browse files
Files changed (1) hide show
  1. tools/train/train.py +1 -1
tools/train/train.py CHANGED
@@ -376,7 +376,7 @@ def main():
376
  transformers.utils.logging.set_verbosity_error()
377
 
378
  logger.info(f"TPUs: {jax.device_count()}")
379
- assert jax.device_count() == 8, "TPUs in use, please check running processes"
380
 
381
  # Set the verbosity to info of the Transformers logger (on main process only):
382
  logger.info(f"Training/evaluation parameters {training_args}")
 
376
  transformers.utils.logging.set_verbosity_error()
377
 
378
  logger.info(f"TPUs: {jax.device_count()}")
379
+ assert jax.local_device_count() == 8, "TPUs in use, please check running processes"
380
 
381
  # Set the verbosity to info of the Transformers logger (on main process only):
382
  logger.info(f"Training/evaluation parameters {training_args}")