Spaces:
Running
Running
feat: update params
Browse files- tools/train/train.py +3 -3
tools/train/train.py
CHANGED
@@ -569,7 +569,7 @@ def main():
|
|
569 |
elif training_args.shampoo:
|
570 |
# parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
|
571 |
# Notes:
|
572 |
-
# - mask for weight decay is not implemented
|
573 |
optimizer = distributed_shampoo(
|
574 |
learning_rate_fn,
|
575 |
block_size=1024, # recommended default for large LM is 1536
|
@@ -578,8 +578,8 @@ def main():
|
|
578 |
diagonal_epsilon=1e-10,
|
579 |
matrix_epsilon=1e-8,
|
580 |
weight_decay=0.0,
|
581 |
-
start_preconditioning_step=
|
582 |
-
preconditioning_compute_steps=
|
583 |
statistics_compute_steps=1,
|
584 |
best_effort_shape_interpretation=True,
|
585 |
graft_type=GraftingType.RMSPROP_NORMALIZED,
|
|
|
569 |
elif training_args.shampoo:
|
570 |
# parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
|
571 |
# Notes:
|
572 |
+
# - mask for weight decay is not implemented but we don't use it anyway
|
573 |
optimizer = distributed_shampoo(
|
574 |
learning_rate_fn,
|
575 |
block_size=1024, # recommended default for large LM is 1536
|
|
|
578 |
diagonal_epsilon=1e-10,
|
579 |
matrix_epsilon=1e-8,
|
580 |
weight_decay=0.0,
|
581 |
+
start_preconditioning_step=1001,
|
582 |
+
preconditioning_compute_steps=10,
|
583 |
statistics_compute_steps=1,
|
584 |
best_effort_shape_interpretation=True,
|
585 |
graft_type=GraftingType.RMSPROP_NORMALIZED,
|