boris commited on
Commit
7cfe576
1 Parent(s): 5996680

feat: log num_parameters early

Browse files
Files changed (1) hide show
  1. tools/train/train.py +32 -31
tools/train/train.py CHANGED
@@ -558,6 +558,35 @@ def main():
558
  )
559
  num_params = model.num_params
560
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
561
  # Create learning rate schedule
562
  def create_learning_rate_fn() -> Callable[[int], jnp.array]:
563
  """Create the learning rate function."""
@@ -915,42 +944,14 @@ def main():
915
  out_axis_resources=None,
916
  )
917
 
918
- logger.info("***** Running training *****")
919
- logger.info(f" Num examples = {len_train_dataset}")
920
- logger.info(f" Num Epochs = {num_epochs}")
921
- logger.info(
922
- f" Batch size per device = {training_args.per_device_train_batch_size}"
923
- )
924
- logger.info(f" Number of devices = {jax.device_count()}")
925
- logger.info(
926
- f" Gradient accumulation steps = {training_args.gradient_accumulation_steps}"
927
- )
928
- logger.info(f" Batch size per update = {batch_size_per_step}")
929
- logger.info(f" Model parameters = {num_params:,}")
930
- epochs = tqdm(
931
- range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
932
- )
933
-
934
  # init variables
935
  last_time = time.perf_counter()
936
  train_metrics = None
937
  step = int(state.step)
938
  metrics_logger = MetricsLogger(step)
939
-
940
- if jax.process_index() == 0:
941
- # set default x-axis as 'train/step'
942
- wandb.define_metric("*", step_metric="train/step")
943
-
944
- # add interesting config parameters
945
- wandb.config.update(
946
- {
947
- "len_train_dataset": len_train_dataset,
948
- "len_eval_dataset": len_eval_dataset,
949
- "batch_size_per_step": batch_size_per_step,
950
- "num_params": num_params,
951
- "num_devices": jax.device_count(),
952
- }
953
- )
954
 
955
  def run_evaluation():
956
  # ======================== Evaluating ==============================
 
558
  )
559
  num_params = model.num_params
560
 
561
+ logger.info("***** Running training *****")
562
+ logger.info(f" Num examples = {len_train_dataset}")
563
+ logger.info(f" Num Epochs = {num_epochs}")
564
+ logger.info(
565
+ f" Batch size per device = {training_args.per_device_train_batch_size}"
566
+ )
567
+ logger.info(f" Number of devices = {jax.device_count()}")
568
+ logger.info(
569
+ f" Gradient accumulation steps = {training_args.gradient_accumulation_steps}"
570
+ )
571
+ logger.info(f" Batch size per update = {batch_size_per_step}")
572
+ logger.info(f" Model parameters = {num_params:,}")
573
+
574
+ # create wandb run
575
+ if jax.process_index() == 0:
576
+ # set default x-axis as 'train/step'
577
+ wandb.define_metric("*", step_metric="train/step")
578
+
579
+ # add interesting config parameters
580
+ wandb.config.update(
581
+ {
582
+ "len_train_dataset": len_train_dataset,
583
+ "len_eval_dataset": len_eval_dataset,
584
+ "batch_size_per_step": batch_size_per_step,
585
+ "num_params": num_params,
586
+ "num_devices": jax.device_count(),
587
+ }
588
+ )
589
+
590
  # Create learning rate schedule
591
  def create_learning_rate_fn() -> Callable[[int], jnp.array]:
592
  """Create the learning rate function."""
 
944
  out_axis_resources=None,
945
  )
946
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
947
  # init variables
948
  last_time = time.perf_counter()
949
  train_metrics = None
950
  step = int(state.step)
951
  metrics_logger = MetricsLogger(step)
952
+ epochs = tqdm(
953
+ range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
954
+ )
 
 
 
 
 
 
 
 
 
 
 
 
955
 
956
  def run_evaluation():
957
  # ======================== Evaluating ==============================