Adding shampoo training
Browse files- run_mlm_flax.py +31 -0
run_mlm_flax.py
CHANGED
@@ -60,6 +60,7 @@ from transformers import (
|
|
60 |
)
|
61 |
from transformers.file_utils import get_full_repo_name
|
62 |
|
|
|
63 |
|
64 |
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
|
65 |
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
@@ -88,6 +89,10 @@ class TrainingArguments:
|
|
88 |
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
|
89 |
)
|
90 |
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
|
|
|
|
|
|
|
|
|
91 |
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
|
92 |
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
|
93 |
adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
|
@@ -629,6 +634,32 @@ def main():
|
|
629 |
optimizer = optax.adafactor(
|
630 |
learning_rate=linear_decay_lr_schedule_fn,
|
631 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
632 |
else:
|
633 |
optimizer = optax.adamw(
|
634 |
learning_rate=linear_decay_lr_schedule_fn,
|
|
|
60 |
)
|
61 |
from transformers.file_utils import get_full_repo_name
|
62 |
|
63 |
+
from distributed_shampoo import distributed_shampoo, GraftingType
|
64 |
|
65 |
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
|
66 |
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
|
|
89 |
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
|
90 |
)
|
91 |
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
|
92 |
+
distributed_shampoo: bool = field(
|
93 |
+
default=False,
|
94 |
+
metadata={"help": "Use Distributed Shampoo optimizer instead of AdamW."},
|
95 |
+
)
|
96 |
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
|
97 |
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
|
98 |
adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
|
|
|
634 |
optimizer = optax.adafactor(
|
635 |
learning_rate=linear_decay_lr_schedule_fn,
|
636 |
)
|
637 |
+
elif training_args.distributed_shampoo:
|
638 |
+
# parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
|
639 |
+
# Notes:
|
640 |
+
# - mask for weight decay is not implemented but we don't use it anyway
|
641 |
+
optimizer = distributed_shampoo(
|
642 |
+
linear_decay_lr_schedule_fn,
|
643 |
+
block_size=1024, # recommended default for large LM is 1536
|
644 |
+
beta1=training_args.adam_beta1, # 0.9,
|
645 |
+
beta2=training_args.adam_beta2, # 0.999,
|
646 |
+
diagonal_epsilon=training_args.adam_epsilon, # 1e-10,
|
647 |
+
matrix_epsilon=1e-8,
|
648 |
+
weight_decay=training_args.weight_decay, # 0.0,
|
649 |
+
start_preconditioning_step=1001,
|
650 |
+
preconditioning_compute_steps=10,
|
651 |
+
statistics_compute_steps=1,
|
652 |
+
best_effort_shape_interpretation=True,
|
653 |
+
graft_type=GraftingType.RMSPROP_NORMALIZED,
|
654 |
+
nesterov=False,
|
655 |
+
exponent_override=0,
|
656 |
+
batch_axis_name="batch",
|
657 |
+
inverse_failure_threshold=0.1,
|
658 |
+
moving_average_for_momentum=True,
|
659 |
+
skip_preconditioning_dim_size_gt=4096,
|
660 |
+
clip_by_scaled_gradient_norm=None,
|
661 |
+
precision=jax.lax.Precision.HIGHEST,
|
662 |
+
)
|
663 |
else:
|
664 |
optimizer = optax.adamw(
|
665 |
learning_rate=linear_decay_lr_schedule_fn,
|