versae commited on
Commit
182f272
1 Parent(s): f072d39

Adding shampoo training

Browse files
Files changed (1) hide show
  1. run_mlm_flax.py +31 -0
run_mlm_flax.py CHANGED
@@ -60,6 +60,7 @@ from transformers import (
60
  )
61
  from transformers.file_utils import get_full_repo_name
62
 
 
63
 
64
  MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
65
  MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
@@ -88,6 +89,10 @@ class TrainingArguments:
88
  default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
89
  )
90
  learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
 
 
 
 
91
  weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
92
  adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
93
  adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
@@ -629,6 +634,32 @@ def main():
629
  optimizer = optax.adafactor(
630
  learning_rate=linear_decay_lr_schedule_fn,
631
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
632
  else:
633
  optimizer = optax.adamw(
634
  learning_rate=linear_decay_lr_schedule_fn,
 
60
  )
61
  from transformers.file_utils import get_full_repo_name
62
 
63
+ from distributed_shampoo import distributed_shampoo, GraftingType
64
 
65
  MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
66
  MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
 
89
  default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
90
  )
91
  learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
92
+ distributed_shampoo: bool = field(
93
+ default=False,
94
+ metadata={"help": "Use Distributed Shampoo optimizer instead of AdamW."},
95
+ )
96
  weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
97
  adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
98
  adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
 
634
  optimizer = optax.adafactor(
635
  learning_rate=linear_decay_lr_schedule_fn,
636
  )
637
+ elif training_args.distributed_shampoo:
638
+ # parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
639
+ # Notes:
640
+ # - mask for weight decay is not implemented but we don't use it anyway
641
+ optimizer = distributed_shampoo(
642
+ linear_decay_lr_schedule_fn,
643
+ block_size=1024, # recommended default for large LM is 1536
644
+ beta1=training_args.adam_beta1, # 0.9,
645
+ beta2=training_args.adam_beta2, # 0.999,
646
+ diagonal_epsilon=training_args.adam_epsilon, # 1e-10,
647
+ matrix_epsilon=1e-8,
648
+ weight_decay=training_args.weight_decay, # 0.0,
649
+ start_preconditioning_step=1001,
650
+ preconditioning_compute_steps=10,
651
+ statistics_compute_steps=1,
652
+ best_effort_shape_interpretation=True,
653
+ graft_type=GraftingType.RMSPROP_NORMALIZED,
654
+ nesterov=False,
655
+ exponent_override=0,
656
+ batch_axis_name="batch",
657
+ inverse_failure_threshold=0.1,
658
+ moving_average_for_momentum=True,
659
+ skip_preconditioning_dim_size_gt=4096,
660
+ clip_by_scaled_gradient_norm=None,
661
+ precision=jax.lax.Precision.HIGHEST,
662
+ )
663
  else:
664
  optimizer = optax.adamw(
665
  learning_rate=linear_decay_lr_schedule_fn,