Spaces:
Running
Running
feat: add more config of distributed_shampoo
Browse files- tools/train/train.py +13 -7
tools/train/train.py
CHANGED
@@ -220,15 +220,15 @@ class TrainingArguments:
|
|
220 |
},
|
221 |
)
|
222 |
weight_decay: float = field(
|
223 |
-
default=None, metadata={"help": "Weight decay
|
224 |
)
|
225 |
beta1: float = field(
|
226 |
default=0.9,
|
227 |
-
metadata={"help": "Beta1 for
|
228 |
)
|
229 |
beta2: float = field(
|
230 |
default=0.999,
|
231 |
-
metadata={"help": "Beta2 for
|
232 |
)
|
233 |
adam_epsilon: float = field(
|
234 |
default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}
|
@@ -236,13 +236,19 @@ class TrainingArguments:
|
|
236 |
max_grad_norm: float = field(
|
237 |
default=1.0, metadata={"help": "Max gradient norm for Adafactor."}
|
238 |
)
|
|
|
|
|
|
|
239 |
preconditioning_compute_steps: int = field(
|
240 |
default=10, metadata={"help": "Number of steps to update preconditioner."}
|
241 |
)
|
|
|
|
|
|
|
242 |
optim_quantized: bool = field(
|
243 |
default=False,
|
244 |
metadata={
|
245 |
-
"help": "Whether to quantize optimizer (only supported with
|
246 |
},
|
247 |
)
|
248 |
|
@@ -594,7 +600,7 @@ def main():
|
|
594 |
# - mask for weight decay is not implemented
|
595 |
optimizer = distributed_shampoo(
|
596 |
learning_rate_fn,
|
597 |
-
block_size=
|
598 |
beta1=training_args.beta1,
|
599 |
beta2=training_args.beta2,
|
600 |
diagonal_epsilon=1e-10,
|
@@ -602,7 +608,7 @@ def main():
|
|
602 |
weight_decay=training_args.weight_decay
|
603 |
if training_args.weight_decay is not None
|
604 |
else 0.0,
|
605 |
-
start_preconditioning_step=
|
606 |
preconditioning_compute_steps=training_args.preconditioning_compute_steps,
|
607 |
statistics_compute_steps=1,
|
608 |
best_effort_shape_interpretation=True,
|
@@ -612,7 +618,7 @@ def main():
|
|
612 |
batch_axis_name="batch",
|
613 |
inverse_failure_threshold=0.1,
|
614 |
moving_average_for_momentum=True,
|
615 |
-
skip_preconditioning_dim_size_gt=
|
616 |
clip_by_scaled_gradient_norm=None,
|
617 |
precision=jax.lax.Precision.HIGHEST,
|
618 |
best_effort_memory_usage_reduction=training_args.optim_quantized,
|
|
|
220 |
},
|
221 |
)
|
222 |
weight_decay: float = field(
|
223 |
+
default=None, metadata={"help": "Weight decay."}
|
224 |
)
|
225 |
beta1: float = field(
|
226 |
default=0.9,
|
227 |
+
metadata={"help": "Beta1 for Adam & Distributed Shampoo."},
|
228 |
)
|
229 |
beta2: float = field(
|
230 |
default=0.999,
|
231 |
+
metadata={"help": "Beta2 for for Adam & Distributed Shampoo."},
|
232 |
)
|
233 |
adam_epsilon: float = field(
|
234 |
default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}
|
|
|
236 |
max_grad_norm: float = field(
|
237 |
default=1.0, metadata={"help": "Max gradient norm for Adafactor."}
|
238 |
)
|
239 |
+
block_size: int = field(
|
240 |
+
default=1024, metadata={"help": "Chunked size for large layers with Distributed Shampoo."}
|
241 |
+
)
|
242 |
preconditioning_compute_steps: int = field(
|
243 |
default=10, metadata={"help": "Number of steps to update preconditioner."}
|
244 |
)
|
245 |
+
skip_preconditioning_dim_size_gt: int = field(
|
246 |
+
default=4096, metadata={"help": "Max size for preconditioning with Distributed Shampoo."}
|
247 |
+
)
|
248 |
optim_quantized: bool = field(
|
249 |
default=False,
|
250 |
metadata={
|
251 |
+
"help": "Whether to quantize optimizer (only supported with Distributed Shampoo)."
|
252 |
},
|
253 |
)
|
254 |
|
|
|
600 |
# - mask for weight decay is not implemented
|
601 |
optimizer = distributed_shampoo(
|
602 |
learning_rate_fn,
|
603 |
+
block_size=training_args.block_size,
|
604 |
beta1=training_args.beta1,
|
605 |
beta2=training_args.beta2,
|
606 |
diagonal_epsilon=1e-10,
|
|
|
608 |
weight_decay=training_args.weight_decay
|
609 |
if training_args.weight_decay is not None
|
610 |
else 0.0,
|
611 |
+
start_preconditioning_step=training_args.warmup_steps,
|
612 |
preconditioning_compute_steps=training_args.preconditioning_compute_steps,
|
613 |
statistics_compute_steps=1,
|
614 |
best_effort_shape_interpretation=True,
|
|
|
618 |
batch_axis_name="batch",
|
619 |
inverse_failure_threshold=0.1,
|
620 |
moving_average_for_momentum=True,
|
621 |
+
skip_preconditioning_dim_size_gt=training_args.skip_preconditioning_dim_size_gt,
|
622 |
clip_by_scaled_gradient_norm=None,
|
623 |
precision=jax.lax.Precision.HIGHEST,
|
624 |
best_effort_memory_usage_reduction=training_args.optim_quantized,
|