boris commited on
Commit
edae62d
1 Parent(s): 604a65d

fix: shampoo -> distributed shampoo

Browse files
Files changed (1) hide show
  1. tools/train/train.py +4 -4
tools/train/train.py CHANGED
@@ -214,11 +214,11 @@ class TrainingArguments:
214
  )
215
  adafactor: bool = field(
216
  default=False,
217
- metadata={"help": "Whether or not to replace AdamW by Adafactor."},
218
  )
219
- shampoo: bool = field(
220
  default=False,
221
- metadata={"help": "Whether or not to replace AdamW by Adafactor."},
222
  )
223
  weight_decay: float = field(
224
  default=None, metadata={"help": "Weight decay if we apply some."}
@@ -566,7 +566,7 @@ def main():
566
  weight_decay_mask=decay_mask_fn,
567
  clipping_threshold=training_args.max_grad_norm,
568
  )
569
- elif training_args.shampoo:
570
  # parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
571
  # Notes:
572
  # - mask for weight decay is not implemented but we don't use it anyway
 
214
  )
215
  adafactor: bool = field(
216
  default=False,
217
+ metadata={"help": "Use Adafactor instead of AdamW."},
218
  )
219
+ distributed_shampoo: bool = field(
220
  default=False,
221
+ metadata={"help": "Use Distributed Shampoo optimizer instead of AdamW."},
222
  )
223
  weight_decay: float = field(
224
  default=None, metadata={"help": "Weight decay if we apply some."}
 
566
  weight_decay_mask=decay_mask_fn,
567
  clipping_threshold=training_args.max_grad_norm,
568
  )
569
+ elif training_args.distributed_shampoo:
570
  # parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
571
  # Notes:
572
  # - mask for weight decay is not implemented but we don't use it anyway