boris commited on
Commit
274ba73
1 Parent(s): 2d07559

refactor(train): cleanup

Browse files
Files changed (1) hide show
  1. tools/train/train.py +51 -31
tools/train/train.py CHANGED
@@ -310,12 +310,40 @@ class TrainingArguments:
310
  metadata={"help": "Reference to a wandb artifact for resuming training."},
311
  )
312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  def __post_init__(self):
314
  assert self.optim in [
315
  "distributed_shampoo",
316
  "adam",
317
  "adafactor",
318
  ], f"Selected optimizer not supported: {self.optim}"
 
 
 
 
 
 
 
 
 
 
319
 
320
 
321
  class TrainState(train_state.TrainState):
@@ -396,17 +424,6 @@ def main():
396
  else:
397
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
398
 
399
- if (
400
- os.path.exists(training_args.output_dir)
401
- and os.listdir(training_args.output_dir)
402
- and training_args.do_train
403
- and not training_args.overwrite_output_dir
404
- ):
405
- raise ValueError(
406
- f"Output directory ({training_args.output_dir}) already exists and is not empty."
407
- "Use --overwrite_output_dir to overcome."
408
- )
409
-
410
  # Make one log on every process with the configuration for debugging.
411
  logging.basicConfig(
412
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -433,14 +450,18 @@ def main():
433
  )
434
 
435
  logger.info(f"Local TPUs: {jax.local_device_count()}")
436
- assert jax.local_device_count() == 8, "TPUs in use, please check running processes"
 
 
 
 
437
 
438
  # Set up wandb run
439
  if jax.process_index() == 0:
440
  wandb.init(
441
- entity="dalle-mini",
442
- project="dalle-mini",
443
- job_type="Seq2Seq",
444
  config=parser.parse_args(),
445
  )
446
 
@@ -520,17 +541,14 @@ def main():
520
  train_batch_size = (
521
  training_args.per_device_train_batch_size * jax.local_device_count()
522
  )
523
- batch_size_per_update = (
524
- train_batch_size
525
- * training_args.gradient_accumulation_steps
526
- * jax.process_count()
527
- )
528
  eval_batch_size = (
529
  training_args.per_device_eval_batch_size * jax.local_device_count()
530
  )
531
  len_train_dataset, len_eval_dataset = dataset.length
532
  steps_per_epoch = (
533
- len_train_dataset // (train_batch_size * jax.process_count())
534
  if len_train_dataset is not None
535
  else None
536
  )
@@ -708,14 +726,12 @@ def main():
708
  grads=grads,
709
  dropout_rng=new_dropout_rng,
710
  train_time=state.train_time + delta_time,
711
- train_samples=state.train_samples + train_batch_size * jax.process_count(),
712
  )
713
 
714
  metrics = {
715
  "loss": loss,
716
- "learning_rate": learning_rate_fn(
717
- state.step // training_args.gradient_accumulation_steps
718
- ),
719
  }
720
  metrics = jax.lax.pmean(metrics, axis_name="batch")
721
 
@@ -733,19 +749,20 @@ def main():
733
  return metrics
734
 
735
  # Create parallel version of the train and eval step
736
- p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
737
- p_eval_step = jax.pmap(eval_step, "batch")
738
 
739
  logger.info("***** Running training *****")
740
  logger.info(f" Num examples = {len_train_dataset}")
741
  logger.info(f" Num Epochs = {num_epochs}")
742
  logger.info(
743
- f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
744
  )
745
  logger.info(f" Number of devices = {jax.device_count()}")
746
  logger.info(
747
- f" Total train batch size (w. parallel, distributed & gradient accumulation) = {batch_size_per_update}"
748
  )
 
749
  logger.info(f" Model parameters = {num_params:,}")
750
  epochs = tqdm(
751
  range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
@@ -762,8 +779,9 @@ def main():
762
  {
763
  "len_train_dataset": len_train_dataset,
764
  "len_eval_dataset": len_eval_dataset,
765
- "batch_size_per_update": batch_size_per_update,
766
  "num_params": num_params,
 
767
  }
768
  )
769
 
@@ -774,7 +792,9 @@ def main():
774
  # ======================== Evaluating ==============================
775
  eval_metrics = []
776
  if training_args.do_eval:
777
- eval_loader = dataset.dataloader("eval", eval_batch_size)
 
 
778
  eval_steps = (
779
  len_eval_dataset // eval_batch_size
780
  if len_eval_dataset is not None
 
310
  metadata={"help": "Reference to a wandb artifact for resuming training."},
311
  )
312
 
313
+ wandb_entity: Optional[str] = field(
314
+ default=None,
315
+ metadata={"help": "The wandb entity to use (for teams)."},
316
+ )
317
+ wandb_project: str = field(
318
+ default="dalle-mini",
319
+ metadata={"help": "The name of the wandb project."},
320
+ )
321
+ wandb_job_type: str = field(
322
+ default="Seq2Seq",
323
+ metadata={"help": "The name of the wandb job type."},
324
+ )
325
+
326
+ assert_TPU_available: bool = field(
327
+ default=False,
328
+ metadata={"help": "Verify that TPU is not in use."},
329
+ )
330
+
331
  def __post_init__(self):
332
  assert self.optim in [
333
  "distributed_shampoo",
334
  "adam",
335
  "adafactor",
336
  ], f"Selected optimizer not supported: {self.optim}"
337
+ if (
338
+ os.path.exists(self.output_dir)
339
+ and os.listdir(self.output_dir)
340
+ and self.do_train
341
+ and not self.overwrite_output_dir
342
+ ):
343
+ raise ValueError(
344
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
345
+ "Use --overwrite_output_dir to overcome."
346
+ )
347
 
348
 
349
  class TrainState(train_state.TrainState):
 
424
  else:
425
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
426
 
 
 
 
 
 
 
 
 
 
 
 
427
  # Make one log on every process with the configuration for debugging.
428
  logging.basicConfig(
429
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
 
450
  )
451
 
452
  logger.info(f"Local TPUs: {jax.local_device_count()}")
453
+ logger.info(f"Global TPUs: {jax.device_count()}")
454
+ if training_args.assert_TPU_available:
455
+ assert (
456
+ jax.local_device_count() == 8
457
+ ), "TPUs in use, please check running processes"
458
 
459
  # Set up wandb run
460
  if jax.process_index() == 0:
461
  wandb.init(
462
+ entity=training_args.wandb_entity,
463
+ project=training_args.wandb_project,
464
+ job_type=training_args.wandb_job_type,
465
  config=parser.parse_args(),
466
  )
467
 
 
541
  train_batch_size = (
542
  training_args.per_device_train_batch_size * jax.local_device_count()
543
  )
544
+ batch_size_per_node = train_batch_size * training_args.gradient_accumulation_steps
545
+ batch_size_per_step = batch_size_per_node * jax.process_count()
 
 
 
546
  eval_batch_size = (
547
  training_args.per_device_eval_batch_size * jax.local_device_count()
548
  )
549
  len_train_dataset, len_eval_dataset = dataset.length
550
  steps_per_epoch = (
551
+ len_train_dataset // batch_size_per_node
552
  if len_train_dataset is not None
553
  else None
554
  )
 
726
  grads=grads,
727
  dropout_rng=new_dropout_rng,
728
  train_time=state.train_time + delta_time,
729
+ train_samples=state.train_samples + batch_size_per_step,
730
  )
731
 
732
  metrics = {
733
  "loss": loss,
734
+ "learning_rate": learning_rate_fn(state.step),
 
 
735
  }
736
  metrics = jax.lax.pmean(metrics, axis_name="batch")
737
 
 
749
  return metrics
750
 
751
  # Create parallel version of the train and eval step
752
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, 1))
753
+ p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(1,))
754
 
755
  logger.info("***** Running training *****")
756
  logger.info(f" Num examples = {len_train_dataset}")
757
  logger.info(f" Num Epochs = {num_epochs}")
758
  logger.info(
759
+ f" Batch size per device = {training_args.per_device_train_batch_size}"
760
  )
761
  logger.info(f" Number of devices = {jax.device_count()}")
762
  logger.info(
763
+ f" Gradient accumulation steps = {training_args.gradient_accumulation_steps}"
764
  )
765
+ logger.info(f" Batch size per update = {batch_size_per_step}")
766
  logger.info(f" Model parameters = {num_params:,}")
767
  epochs = tqdm(
768
  range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
 
779
  {
780
  "len_train_dataset": len_train_dataset,
781
  "len_eval_dataset": len_eval_dataset,
782
+ "batch_size_per_step": batch_size_per_step,
783
  "num_params": num_params,
784
+ "num_devices": jax.device_count(),
785
  }
786
  )
787
 
 
792
  # ======================== Evaluating ==============================
793
  eval_metrics = []
794
  if training_args.do_eval:
795
+ eval_loader = dataset.dataloader(
796
+ "eval", training_args.per_device_eval_batch_size
797
+ )
798
  eval_steps = (
799
  len_eval_dataset // eval_batch_size
800
  if len_eval_dataset is not None