Spaces:
Running
Running
refactor(train): cleanup
Browse files- tools/train/train.py +51 -31
tools/train/train.py
CHANGED
@@ -310,12 +310,40 @@ class TrainingArguments:
|
|
310 |
metadata={"help": "Reference to a wandb artifact for resuming training."},
|
311 |
)
|
312 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
313 |
def __post_init__(self):
|
314 |
assert self.optim in [
|
315 |
"distributed_shampoo",
|
316 |
"adam",
|
317 |
"adafactor",
|
318 |
], f"Selected optimizer not supported: {self.optim}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
319 |
|
320 |
|
321 |
class TrainState(train_state.TrainState):
|
@@ -396,17 +424,6 @@ def main():
|
|
396 |
else:
|
397 |
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
398 |
|
399 |
-
if (
|
400 |
-
os.path.exists(training_args.output_dir)
|
401 |
-
and os.listdir(training_args.output_dir)
|
402 |
-
and training_args.do_train
|
403 |
-
and not training_args.overwrite_output_dir
|
404 |
-
):
|
405 |
-
raise ValueError(
|
406 |
-
f"Output directory ({training_args.output_dir}) already exists and is not empty."
|
407 |
-
"Use --overwrite_output_dir to overcome."
|
408 |
-
)
|
409 |
-
|
410 |
# Make one log on every process with the configuration for debugging.
|
411 |
logging.basicConfig(
|
412 |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
@@ -433,14 +450,18 @@ def main():
|
|
433 |
)
|
434 |
|
435 |
logger.info(f"Local TPUs: {jax.local_device_count()}")
|
436 |
-
|
|
|
|
|
|
|
|
|
437 |
|
438 |
# Set up wandb run
|
439 |
if jax.process_index() == 0:
|
440 |
wandb.init(
|
441 |
-
entity=
|
442 |
-
project=
|
443 |
-
job_type=
|
444 |
config=parser.parse_args(),
|
445 |
)
|
446 |
|
@@ -520,17 +541,14 @@ def main():
|
|
520 |
train_batch_size = (
|
521 |
training_args.per_device_train_batch_size * jax.local_device_count()
|
522 |
)
|
523 |
-
|
524 |
-
|
525 |
-
* training_args.gradient_accumulation_steps
|
526 |
-
* jax.process_count()
|
527 |
-
)
|
528 |
eval_batch_size = (
|
529 |
training_args.per_device_eval_batch_size * jax.local_device_count()
|
530 |
)
|
531 |
len_train_dataset, len_eval_dataset = dataset.length
|
532 |
steps_per_epoch = (
|
533 |
-
len_train_dataset //
|
534 |
if len_train_dataset is not None
|
535 |
else None
|
536 |
)
|
@@ -708,14 +726,12 @@ def main():
|
|
708 |
grads=grads,
|
709 |
dropout_rng=new_dropout_rng,
|
710 |
train_time=state.train_time + delta_time,
|
711 |
-
train_samples=state.train_samples +
|
712 |
)
|
713 |
|
714 |
metrics = {
|
715 |
"loss": loss,
|
716 |
-
"learning_rate": learning_rate_fn(
|
717 |
-
state.step // training_args.gradient_accumulation_steps
|
718 |
-
),
|
719 |
}
|
720 |
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
721 |
|
@@ -733,19 +749,20 @@ def main():
|
|
733 |
return metrics
|
734 |
|
735 |
# Create parallel version of the train and eval step
|
736 |
-
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
737 |
-
p_eval_step = jax.pmap(eval_step, "batch")
|
738 |
|
739 |
logger.info("***** Running training *****")
|
740 |
logger.info(f" Num examples = {len_train_dataset}")
|
741 |
logger.info(f" Num Epochs = {num_epochs}")
|
742 |
logger.info(
|
743 |
-
f"
|
744 |
)
|
745 |
logger.info(f" Number of devices = {jax.device_count()}")
|
746 |
logger.info(
|
747 |
-
f"
|
748 |
)
|
|
|
749 |
logger.info(f" Model parameters = {num_params:,}")
|
750 |
epochs = tqdm(
|
751 |
range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
|
@@ -762,8 +779,9 @@ def main():
|
|
762 |
{
|
763 |
"len_train_dataset": len_train_dataset,
|
764 |
"len_eval_dataset": len_eval_dataset,
|
765 |
-
"
|
766 |
"num_params": num_params,
|
|
|
767 |
}
|
768 |
)
|
769 |
|
@@ -774,7 +792,9 @@ def main():
|
|
774 |
# ======================== Evaluating ==============================
|
775 |
eval_metrics = []
|
776 |
if training_args.do_eval:
|
777 |
-
eval_loader = dataset.dataloader(
|
|
|
|
|
778 |
eval_steps = (
|
779 |
len_eval_dataset // eval_batch_size
|
780 |
if len_eval_dataset is not None
|
|
|
310 |
metadata={"help": "Reference to a wandb artifact for resuming training."},
|
311 |
)
|
312 |
|
313 |
+
wandb_entity: Optional[str] = field(
|
314 |
+
default=None,
|
315 |
+
metadata={"help": "The wandb entity to use (for teams)."},
|
316 |
+
)
|
317 |
+
wandb_project: str = field(
|
318 |
+
default="dalle-mini",
|
319 |
+
metadata={"help": "The name of the wandb project."},
|
320 |
+
)
|
321 |
+
wandb_job_type: str = field(
|
322 |
+
default="Seq2Seq",
|
323 |
+
metadata={"help": "The name of the wandb job type."},
|
324 |
+
)
|
325 |
+
|
326 |
+
assert_TPU_available: bool = field(
|
327 |
+
default=False,
|
328 |
+
metadata={"help": "Verify that TPU is not in use."},
|
329 |
+
)
|
330 |
+
|
331 |
def __post_init__(self):
|
332 |
assert self.optim in [
|
333 |
"distributed_shampoo",
|
334 |
"adam",
|
335 |
"adafactor",
|
336 |
], f"Selected optimizer not supported: {self.optim}"
|
337 |
+
if (
|
338 |
+
os.path.exists(self.output_dir)
|
339 |
+
and os.listdir(self.output_dir)
|
340 |
+
and self.do_train
|
341 |
+
and not self.overwrite_output_dir
|
342 |
+
):
|
343 |
+
raise ValueError(
|
344 |
+
f"Output directory ({training_args.output_dir}) already exists and is not empty."
|
345 |
+
"Use --overwrite_output_dir to overcome."
|
346 |
+
)
|
347 |
|
348 |
|
349 |
class TrainState(train_state.TrainState):
|
|
|
424 |
else:
|
425 |
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
426 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
427 |
# Make one log on every process with the configuration for debugging.
|
428 |
logging.basicConfig(
|
429 |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
|
450 |
)
|
451 |
|
452 |
logger.info(f"Local TPUs: {jax.local_device_count()}")
|
453 |
+
logger.info(f"Global TPUs: {jax.device_count()}")
|
454 |
+
if training_args.assert_TPU_available:
|
455 |
+
assert (
|
456 |
+
jax.local_device_count() == 8
|
457 |
+
), "TPUs in use, please check running processes"
|
458 |
|
459 |
# Set up wandb run
|
460 |
if jax.process_index() == 0:
|
461 |
wandb.init(
|
462 |
+
entity=training_args.wandb_entity,
|
463 |
+
project=training_args.wandb_project,
|
464 |
+
job_type=training_args.wandb_job_type,
|
465 |
config=parser.parse_args(),
|
466 |
)
|
467 |
|
|
|
541 |
train_batch_size = (
|
542 |
training_args.per_device_train_batch_size * jax.local_device_count()
|
543 |
)
|
544 |
+
batch_size_per_node = train_batch_size * training_args.gradient_accumulation_steps
|
545 |
+
batch_size_per_step = batch_size_per_node * jax.process_count()
|
|
|
|
|
|
|
546 |
eval_batch_size = (
|
547 |
training_args.per_device_eval_batch_size * jax.local_device_count()
|
548 |
)
|
549 |
len_train_dataset, len_eval_dataset = dataset.length
|
550 |
steps_per_epoch = (
|
551 |
+
len_train_dataset // batch_size_per_node
|
552 |
if len_train_dataset is not None
|
553 |
else None
|
554 |
)
|
|
|
726 |
grads=grads,
|
727 |
dropout_rng=new_dropout_rng,
|
728 |
train_time=state.train_time + delta_time,
|
729 |
+
train_samples=state.train_samples + batch_size_per_step,
|
730 |
)
|
731 |
|
732 |
metrics = {
|
733 |
"loss": loss,
|
734 |
+
"learning_rate": learning_rate_fn(state.step),
|
|
|
|
|
735 |
}
|
736 |
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
737 |
|
|
|
749 |
return metrics
|
750 |
|
751 |
# Create parallel version of the train and eval step
|
752 |
+
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, 1))
|
753 |
+
p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(1,))
|
754 |
|
755 |
logger.info("***** Running training *****")
|
756 |
logger.info(f" Num examples = {len_train_dataset}")
|
757 |
logger.info(f" Num Epochs = {num_epochs}")
|
758 |
logger.info(
|
759 |
+
f" Batch size per device = {training_args.per_device_train_batch_size}"
|
760 |
)
|
761 |
logger.info(f" Number of devices = {jax.device_count()}")
|
762 |
logger.info(
|
763 |
+
f" Gradient accumulation steps = {training_args.gradient_accumulation_steps}"
|
764 |
)
|
765 |
+
logger.info(f" Batch size per update = {batch_size_per_step}")
|
766 |
logger.info(f" Model parameters = {num_params:,}")
|
767 |
epochs = tqdm(
|
768 |
range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
|
|
|
779 |
{
|
780 |
"len_train_dataset": len_train_dataset,
|
781 |
"len_eval_dataset": len_eval_dataset,
|
782 |
+
"batch_size_per_step": batch_size_per_step,
|
783 |
"num_params": num_params,
|
784 |
+
"num_devices": jax.device_count(),
|
785 |
}
|
786 |
)
|
787 |
|
|
|
792 |
# ======================== Evaluating ==============================
|
793 |
eval_metrics = []
|
794 |
if training_args.do_eval:
|
795 |
+
eval_loader = dataset.dataloader(
|
796 |
+
"eval", training_args.per_device_eval_batch_size
|
797 |
+
)
|
798 |
eval_steps = (
|
799 |
len_eval_dataset // eval_batch_size
|
800 |
if len_eval_dataset is not None
|