boris commited on
Commit
5b533b5
·
1 Parent(s): ed93c8a

feat(train): handle multi-hosts

Browse files
Files changed (1) hide show
  1. tools/train/train.py +95 -71
tools/train/train.py CHANGED
@@ -389,15 +389,19 @@ def main():
389
  )
390
 
391
  # Set up wandb run
392
- wandb.init(
393
- entity="dalle-mini",
394
- project="dalle-mini",
395
- job_type="Seq2Seq",
396
- config=parser.parse_args(),
397
- )
 
398
 
399
  if training_args.resume_from_checkpoint is not None:
400
- artifact = wandb.run.use_artifact(training_args.resume_from_checkpoint)
 
 
 
401
  artifact_dir = artifact.download()
402
 
403
  # load model
@@ -462,14 +466,23 @@ def main():
462
 
463
  # Store some constant
464
  num_epochs = int(training_args.num_train_epochs)
 
465
  train_batch_size = (
466
- int(training_args.per_device_train_batch_size) * jax.device_count()
 
 
 
 
 
 
 
 
467
  )
468
- batch_size_per_update = train_batch_size * training_args.gradient_accumulation_steps
469
- eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
470
  len_train_dataset, len_eval_dataset = dataset.length
471
  steps_per_epoch = (
472
- len_train_dataset // train_batch_size if len_train_dataset is not None else None
 
 
473
  )
474
  num_train_steps = (
475
  steps_per_epoch * num_epochs if steps_per_epoch is not None else None
@@ -568,7 +581,7 @@ def main():
568
  grads=grads,
569
  dropout_rng=new_dropout_rng,
570
  train_time=state.train_time + delta_time,
571
- train_samples=state.train_samples + train_batch_size,
572
  )
573
 
574
  metrics = {
@@ -600,6 +613,7 @@ def main():
600
  logger.info(
601
  f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
602
  )
 
603
  logger.info(
604
  f" Total train batch size (w. parallel, distributed & gradient accumulation) = {batch_size_per_update}"
605
  )
@@ -608,19 +622,20 @@ def main():
608
  range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
609
  )
610
 
611
- # set default x-axis as 'train/step'
612
- wandb_log({}, step=state.step)
613
- wandb.define_metric("*", step_metric="train/step")
614
-
615
- # add interesting config parameters
616
- wandb.config.update(
617
- {
618
- "len_train_dataset": len_train_dataset,
619
- "len_eval_dataset": len_eval_dataset,
620
- "batch_size_per_update": batch_size_per_update,
621
- "num_params": num_params,
622
- }
623
- )
 
624
 
625
  # replicate state on each device
626
  state = state.replicate()
@@ -688,52 +703,61 @@ def main():
688
  f,
689
  )
690
 
691
- # save to W&B
692
- if training_args.log_model:
693
- # save some space
694
- c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
695
- c.cleanup(wandb.util.from_human_size("10GB"))
696
-
697
- metadata = dict(state_dict)
698
- metadata["num_params"] = num_params
699
- if eval_metrics is not None:
700
- metadata["eval"] = eval_metrics
701
- artifact = wandb.Artifact(
702
- name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
703
- )
704
- artifact.add_file(
705
- str(Path(training_args.output_dir) / "flax_model.msgpack")
706
- )
707
- artifact.add_file(str(Path(training_args.output_dir) / "config.json"))
708
- artifact.add_file(
709
- str(Path(training_args.output_dir) / "tokenizer.json")
710
- )
711
- artifact.add_file(
712
- str(Path(training_args.output_dir) / "tokenizer_config.json")
713
- )
714
- artifact.add_file(str(Path(training_args.output_dir) / "vocab.json"))
715
- artifact.add_file(str(Path(training_args.output_dir) / "merges.txt"))
716
- artifact.add_file(
717
- str(Path(training_args.output_dir) / "special_tokens_map.json")
718
- )
719
- artifact.add_file(
720
- str(Path(training_args.output_dir) / "opt_state.msgpack")
721
- )
722
- artifact.add_file(
723
- str(Path(training_args.output_dir) / "training_state.json")
724
- )
725
-
726
- wandb.run.log_artifact(artifact)
727
-
728
- # save to the hub
729
- if training_args.push_to_hub:
730
- model.save_pretrained(
731
- training_args.output_dir,
732
- params=params,
733
- push_to_hub=training_args.push_to_hub,
734
- commit_message=f"Saving weights and logs at step {unreplicate(state.step)+1}",
735
- temp_dir=True, # avoid issues with being in a repository
736
- )
 
 
 
 
 
 
 
 
 
737
 
738
  # init variables
739
  last_time = time.perf_counter()
 
389
  )
390
 
391
  # Set up wandb run
392
+ if jax.process_index() == 0:
393
+ wandb.init(
394
+ entity="dalle-mini",
395
+ project="dalle-mini",
396
+ job_type="Seq2Seq",
397
+ config=parser.parse_args(),
398
+ )
399
 
400
  if training_args.resume_from_checkpoint is not None:
401
+ if jax.process_index() == 0:
402
+ artifact = wandb.run.use_artifact(training_args.resume_from_checkpoint)
403
+ else:
404
+ artifact = wandb.Api().artifact(training_args.resume_from_checkpoint)
405
  artifact_dir = artifact.download()
406
 
407
  # load model
 
466
 
467
  # Store some constant
468
  num_epochs = int(training_args.num_train_epochs)
469
+ # batch size per node
470
  train_batch_size = (
471
+ int(training_args.per_device_train_batch_size) * jax.local_device_count()
472
+ )
473
+ batch_size_per_update = (
474
+ train_batch_size
475
+ * training_args.gradient_accumulation_steps
476
+ * jax.process_count()
477
+ )
478
+ eval_batch_size = (
479
+ int(training_args.per_device_eval_batch_size) * jax.local_device_count()
480
  )
 
 
481
  len_train_dataset, len_eval_dataset = dataset.length
482
  steps_per_epoch = (
483
+ len_train_dataset // (train_batch_size * jax.process_count())
484
+ if len_train_dataset is not None
485
+ else None
486
  )
487
  num_train_steps = (
488
  steps_per_epoch * num_epochs if steps_per_epoch is not None else None
 
581
  grads=grads,
582
  dropout_rng=new_dropout_rng,
583
  train_time=state.train_time + delta_time,
584
+ train_samples=state.train_samples + train_batch_size * jax.process_count(),
585
  )
586
 
587
  metrics = {
 
613
  logger.info(
614
  f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
615
  )
616
+ logger.info(f" Number of devices = {jax.device_count()}")
617
  logger.info(
618
  f" Total train batch size (w. parallel, distributed & gradient accumulation) = {batch_size_per_update}"
619
  )
 
622
  range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
623
  )
624
 
625
+ if jax.process_index() == 0:
626
+ # set default x-axis as 'train/step'
627
+ wandb_log({}, step=state.step)
628
+ wandb.define_metric("*", step_metric="train/step")
629
+
630
+ # add interesting config parameters
631
+ wandb.config.update(
632
+ {
633
+ "len_train_dataset": len_train_dataset,
634
+ "len_eval_dataset": len_eval_dataset,
635
+ "batch_size_per_update": batch_size_per_update,
636
+ "num_params": num_params,
637
+ }
638
+ )
639
 
640
  # replicate state on each device
641
  state = state.replicate()
 
703
  f,
704
  )
705
 
706
+ if jax.process_index() == 0:
707
+ # save to W&B
708
+ if training_args.log_model:
709
+ # save some space
710
+ c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
711
+ c.cleanup(wandb.util.from_human_size("10GB"))
712
+
713
+ metadata = dict(state_dict)
714
+ metadata["num_params"] = num_params
715
+ if eval_metrics is not None:
716
+ metadata["eval"] = eval_metrics
717
+ artifact = wandb.Artifact(
718
+ name=f"model-{wandb.run.id}",
719
+ type="bart_model",
720
+ metadata=metadata,
721
+ )
722
+ artifact.add_file(
723
+ str(Path(training_args.output_dir) / "flax_model.msgpack")
724
+ )
725
+ artifact.add_file(
726
+ str(Path(training_args.output_dir) / "config.json")
727
+ )
728
+ artifact.add_file(
729
+ str(Path(training_args.output_dir) / "tokenizer.json")
730
+ )
731
+ artifact.add_file(
732
+ str(Path(training_args.output_dir) / "tokenizer_config.json")
733
+ )
734
+ artifact.add_file(
735
+ str(Path(training_args.output_dir) / "vocab.json")
736
+ )
737
+ artifact.add_file(
738
+ str(Path(training_args.output_dir) / "merges.txt")
739
+ )
740
+ artifact.add_file(
741
+ str(Path(training_args.output_dir) / "special_tokens_map.json")
742
+ )
743
+ artifact.add_file(
744
+ str(Path(training_args.output_dir) / "opt_state.msgpack")
745
+ )
746
+ artifact.add_file(
747
+ str(Path(training_args.output_dir) / "training_state.json")
748
+ )
749
+
750
+ wandb.run.log_artifact(artifact)
751
+
752
+ # save to the hub
753
+ if training_args.push_to_hub:
754
+ model.save_pretrained(
755
+ training_args.output_dir,
756
+ params=params,
757
+ push_to_hub=training_args.push_to_hub,
758
+ commit_message=f"Saving weights and logs at step {unreplicate(state.step)+1}",
759
+ temp_dir=True, # avoid issues with being in a repository
760
+ )
761
 
762
  # init variables
763
  last_time = time.perf_counter()