m3hrdadfi commited on
Commit
3733ce3
1 Parent(s): 51b14d7

Add log info

Browse files
Files changed (1) hide show
  1. src/run_clm_flax.py +3 -1
src/run_clm_flax.py CHANGED
@@ -64,6 +64,8 @@ from data_utils import (
64
  normalizer
65
  )
66
 
 
 
67
  logger = logging.getLogger(__name__)
68
 
69
  # Cache the result
@@ -366,7 +368,7 @@ def main():
366
  # dataset = dataset.map(normalizer)
367
  # logger.info(f"Preprocessed dataset kept {len(dataset)} out of {len(raw_dataset)}")
368
  dataset = raw_dataset
369
-
370
  # Load pretrained model and tokenizer
371
 
372
  # Distributed training:
 
64
  normalizer
65
  )
66
 
67
+ print(jax.devices())
68
+
69
  logger = logging.getLogger(__name__)
70
 
71
  # Cache the result
 
368
  # dataset = dataset.map(normalizer)
369
  # logger.info(f"Preprocessed dataset kept {len(dataset)} out of {len(raw_dataset)}")
370
  dataset = raw_dataset
371
+
372
  # Load pretrained model and tokenizer
373
 
374
  # Distributed training: