Spaces:
Running
Running
feat(train): handle multi-hosts
Browse files- tools/train/train.py +95 -71
tools/train/train.py
CHANGED
@@ -389,15 +389,19 @@ def main():
|
|
389 |
)
|
390 |
|
391 |
# Set up wandb run
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
|
|
398 |
|
399 |
if training_args.resume_from_checkpoint is not None:
|
400 |
-
|
|
|
|
|
|
|
401 |
artifact_dir = artifact.download()
|
402 |
|
403 |
# load model
|
@@ -462,14 +466,23 @@ def main():
|
|
462 |
|
463 |
# Store some constant
|
464 |
num_epochs = int(training_args.num_train_epochs)
|
|
|
465 |
train_batch_size = (
|
466 |
-
int(training_args.per_device_train_batch_size) * jax.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
467 |
)
|
468 |
-
batch_size_per_update = train_batch_size * training_args.gradient_accumulation_steps
|
469 |
-
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
470 |
len_train_dataset, len_eval_dataset = dataset.length
|
471 |
steps_per_epoch = (
|
472 |
-
len_train_dataset // train_batch_size
|
|
|
|
|
473 |
)
|
474 |
num_train_steps = (
|
475 |
steps_per_epoch * num_epochs if steps_per_epoch is not None else None
|
@@ -568,7 +581,7 @@ def main():
|
|
568 |
grads=grads,
|
569 |
dropout_rng=new_dropout_rng,
|
570 |
train_time=state.train_time + delta_time,
|
571 |
-
train_samples=state.train_samples + train_batch_size,
|
572 |
)
|
573 |
|
574 |
metrics = {
|
@@ -600,6 +613,7 @@ def main():
|
|
600 |
logger.info(
|
601 |
f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
|
602 |
)
|
|
|
603 |
logger.info(
|
604 |
f" Total train batch size (w. parallel, distributed & gradient accumulation) = {batch_size_per_update}"
|
605 |
)
|
@@ -608,19 +622,20 @@ def main():
|
|
608 |
range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
|
609 |
)
|
610 |
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
623 |
-
|
|
|
624 |
|
625 |
# replicate state on each device
|
626 |
state = state.replicate()
|
@@ -688,52 +703,61 @@ def main():
|
|
688 |
f,
|
689 |
)
|
690 |
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
|
706 |
-
|
707 |
-
|
708 |
-
|
709 |
-
|
710 |
-
|
711 |
-
|
712 |
-
|
713 |
-
|
714 |
-
|
715 |
-
|
716 |
-
|
717 |
-
|
718 |
-
|
719 |
-
|
720 |
-
|
721 |
-
|
722 |
-
|
723 |
-
|
724 |
-
|
725 |
-
|
726 |
-
|
727 |
-
|
728 |
-
|
729 |
-
|
730 |
-
|
731 |
-
|
732 |
-
|
733 |
-
|
734 |
-
|
735 |
-
|
736 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
737 |
|
738 |
# init variables
|
739 |
last_time = time.perf_counter()
|
|
|
389 |
)
|
390 |
|
391 |
# Set up wandb run
|
392 |
+
if jax.process_index() == 0:
|
393 |
+
wandb.init(
|
394 |
+
entity="dalle-mini",
|
395 |
+
project="dalle-mini",
|
396 |
+
job_type="Seq2Seq",
|
397 |
+
config=parser.parse_args(),
|
398 |
+
)
|
399 |
|
400 |
if training_args.resume_from_checkpoint is not None:
|
401 |
+
if jax.process_index() == 0:
|
402 |
+
artifact = wandb.run.use_artifact(training_args.resume_from_checkpoint)
|
403 |
+
else:
|
404 |
+
artifact = wandb.Api().artifact(training_args.resume_from_checkpoint)
|
405 |
artifact_dir = artifact.download()
|
406 |
|
407 |
# load model
|
|
|
466 |
|
467 |
# Store some constant
|
468 |
num_epochs = int(training_args.num_train_epochs)
|
469 |
+
# batch size per node
|
470 |
train_batch_size = (
|
471 |
+
int(training_args.per_device_train_batch_size) * jax.local_device_count()
|
472 |
+
)
|
473 |
+
batch_size_per_update = (
|
474 |
+
train_batch_size
|
475 |
+
* training_args.gradient_accumulation_steps
|
476 |
+
* jax.process_count()
|
477 |
+
)
|
478 |
+
eval_batch_size = (
|
479 |
+
int(training_args.per_device_eval_batch_size) * jax.local_device_count()
|
480 |
)
|
|
|
|
|
481 |
len_train_dataset, len_eval_dataset = dataset.length
|
482 |
steps_per_epoch = (
|
483 |
+
len_train_dataset // (train_batch_size * jax.process_count())
|
484 |
+
if len_train_dataset is not None
|
485 |
+
else None
|
486 |
)
|
487 |
num_train_steps = (
|
488 |
steps_per_epoch * num_epochs if steps_per_epoch is not None else None
|
|
|
581 |
grads=grads,
|
582 |
dropout_rng=new_dropout_rng,
|
583 |
train_time=state.train_time + delta_time,
|
584 |
+
train_samples=state.train_samples + train_batch_size * jax.process_count(),
|
585 |
)
|
586 |
|
587 |
metrics = {
|
|
|
613 |
logger.info(
|
614 |
f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
|
615 |
)
|
616 |
+
logger.info(f" Number of devices = {jax.device_count()}")
|
617 |
logger.info(
|
618 |
f" Total train batch size (w. parallel, distributed & gradient accumulation) = {batch_size_per_update}"
|
619 |
)
|
|
|
622 |
range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
|
623 |
)
|
624 |
|
625 |
+
if jax.process_index() == 0:
|
626 |
+
# set default x-axis as 'train/step'
|
627 |
+
wandb_log({}, step=state.step)
|
628 |
+
wandb.define_metric("*", step_metric="train/step")
|
629 |
+
|
630 |
+
# add interesting config parameters
|
631 |
+
wandb.config.update(
|
632 |
+
{
|
633 |
+
"len_train_dataset": len_train_dataset,
|
634 |
+
"len_eval_dataset": len_eval_dataset,
|
635 |
+
"batch_size_per_update": batch_size_per_update,
|
636 |
+
"num_params": num_params,
|
637 |
+
}
|
638 |
+
)
|
639 |
|
640 |
# replicate state on each device
|
641 |
state = state.replicate()
|
|
|
703 |
f,
|
704 |
)
|
705 |
|
706 |
+
if jax.process_index() == 0:
|
707 |
+
# save to W&B
|
708 |
+
if training_args.log_model:
|
709 |
+
# save some space
|
710 |
+
c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
|
711 |
+
c.cleanup(wandb.util.from_human_size("10GB"))
|
712 |
+
|
713 |
+
metadata = dict(state_dict)
|
714 |
+
metadata["num_params"] = num_params
|
715 |
+
if eval_metrics is not None:
|
716 |
+
metadata["eval"] = eval_metrics
|
717 |
+
artifact = wandb.Artifact(
|
718 |
+
name=f"model-{wandb.run.id}",
|
719 |
+
type="bart_model",
|
720 |
+
metadata=metadata,
|
721 |
+
)
|
722 |
+
artifact.add_file(
|
723 |
+
str(Path(training_args.output_dir) / "flax_model.msgpack")
|
724 |
+
)
|
725 |
+
artifact.add_file(
|
726 |
+
str(Path(training_args.output_dir) / "config.json")
|
727 |
+
)
|
728 |
+
artifact.add_file(
|
729 |
+
str(Path(training_args.output_dir) / "tokenizer.json")
|
730 |
+
)
|
731 |
+
artifact.add_file(
|
732 |
+
str(Path(training_args.output_dir) / "tokenizer_config.json")
|
733 |
+
)
|
734 |
+
artifact.add_file(
|
735 |
+
str(Path(training_args.output_dir) / "vocab.json")
|
736 |
+
)
|
737 |
+
artifact.add_file(
|
738 |
+
str(Path(training_args.output_dir) / "merges.txt")
|
739 |
+
)
|
740 |
+
artifact.add_file(
|
741 |
+
str(Path(training_args.output_dir) / "special_tokens_map.json")
|
742 |
+
)
|
743 |
+
artifact.add_file(
|
744 |
+
str(Path(training_args.output_dir) / "opt_state.msgpack")
|
745 |
+
)
|
746 |
+
artifact.add_file(
|
747 |
+
str(Path(training_args.output_dir) / "training_state.json")
|
748 |
+
)
|
749 |
+
|
750 |
+
wandb.run.log_artifact(artifact)
|
751 |
+
|
752 |
+
# save to the hub
|
753 |
+
if training_args.push_to_hub:
|
754 |
+
model.save_pretrained(
|
755 |
+
training_args.output_dir,
|
756 |
+
params=params,
|
757 |
+
push_to_hub=training_args.push_to_hub,
|
758 |
+
commit_message=f"Saving weights and logs at step {unreplicate(state.step)+1}",
|
759 |
+
temp_dir=True, # avoid issues with being in a repository
|
760 |
+
)
|
761 |
|
762 |
# init variables
|
763 |
last_time = time.perf_counter()
|