boris commited on
Commit
604a65d
1 Parent(s): 0b87452

feat: update params

Browse files
Files changed (1) hide show
  1. tools/train/train.py +3 -3
tools/train/train.py CHANGED
@@ -569,7 +569,7 @@ def main():
569
  elif training_args.shampoo:
570
  # parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
571
  # Notes:
572
- # - mask for weight decay is not implemented so we don't use it
573
  optimizer = distributed_shampoo(
574
  learning_rate_fn,
575
  block_size=1024, # recommended default for large LM is 1536
@@ -578,8 +578,8 @@ def main():
578
  diagonal_epsilon=1e-10,
579
  matrix_epsilon=1e-8,
580
  weight_decay=0.0,
581
- start_preconditioning_step=51,
582
- preconditioning_compute_steps=50,
583
  statistics_compute_steps=1,
584
  best_effort_shape_interpretation=True,
585
  graft_type=GraftingType.RMSPROP_NORMALIZED,
 
569
  elif training_args.shampoo:
570
  # parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
571
  # Notes:
572
+ # - mask for weight decay is not implemented but we don't use it anyway
573
  optimizer = distributed_shampoo(
574
  learning_rate_fn,
575
  block_size=1024, # recommended default for large LM is 1536
 
578
  diagonal_epsilon=1e-10,
579
  matrix_epsilon=1e-8,
580
  weight_decay=0.0,
581
+ start_preconditioning_step=1001,
582
+ preconditioning_compute_steps=10,
583
  statistics_compute_steps=1,
584
  best_effort_shape_interpretation=True,
585
  graft_type=GraftingType.RMSPROP_NORMALIZED,