boris commited on
Commit
89cf9ea
1 Parent(s): ddcbc6a

feat: add more config of distributed_shampoo

Browse files
Files changed (1) hide show
  1. tools/train/train.py +13 -7
tools/train/train.py CHANGED
@@ -220,15 +220,15 @@ class TrainingArguments:
220
  },
221
  )
222
  weight_decay: float = field(
223
- default=None, metadata={"help": "Weight decay if we apply some."}
224
  )
225
  beta1: float = field(
226
  default=0.9,
227
- metadata={"help": "Beta1 for adam & distributed_shampoo optimizers"},
228
  )
229
  beta2: float = field(
230
  default=0.999,
231
- metadata={"help": "Beta2 for adam & distributed_shampoo optimizers"},
232
  )
233
  adam_epsilon: float = field(
234
  default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}
@@ -236,13 +236,19 @@ class TrainingArguments:
236
  max_grad_norm: float = field(
237
  default=1.0, metadata={"help": "Max gradient norm for Adafactor."}
238
  )
 
 
 
239
  preconditioning_compute_steps: int = field(
240
  default=10, metadata={"help": "Number of steps to update preconditioner."}
241
  )
 
 
 
242
  optim_quantized: bool = field(
243
  default=False,
244
  metadata={
245
- "help": "Whether to quantize optimizer (only supported with distributed_shampoo)."
246
  },
247
  )
248
 
@@ -594,7 +600,7 @@ def main():
594
  # - mask for weight decay is not implemented
595
  optimizer = distributed_shampoo(
596
  learning_rate_fn,
597
- block_size=1024, # recommended default for large LM is 1536
598
  beta1=training_args.beta1,
599
  beta2=training_args.beta2,
600
  diagonal_epsilon=1e-10,
@@ -602,7 +608,7 @@ def main():
602
  weight_decay=training_args.weight_decay
603
  if training_args.weight_decay is not None
604
  else 0.0,
605
- start_preconditioning_step=1001,
606
  preconditioning_compute_steps=training_args.preconditioning_compute_steps,
607
  statistics_compute_steps=1,
608
  best_effort_shape_interpretation=True,
@@ -612,7 +618,7 @@ def main():
612
  batch_axis_name="batch",
613
  inverse_failure_threshold=0.1,
614
  moving_average_for_momentum=True,
615
- skip_preconditioning_dim_size_gt=4096,
616
  clip_by_scaled_gradient_norm=None,
617
  precision=jax.lax.Precision.HIGHEST,
618
  best_effort_memory_usage_reduction=training_args.optim_quantized,
 
220
  },
221
  )
222
  weight_decay: float = field(
223
+ default=None, metadata={"help": "Weight decay."}
224
  )
225
  beta1: float = field(
226
  default=0.9,
227
+ metadata={"help": "Beta1 for Adam & Distributed Shampoo."},
228
  )
229
  beta2: float = field(
230
  default=0.999,
231
+ metadata={"help": "Beta2 for for Adam & Distributed Shampoo."},
232
  )
233
  adam_epsilon: float = field(
234
  default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}
 
236
  max_grad_norm: float = field(
237
  default=1.0, metadata={"help": "Max gradient norm for Adafactor."}
238
  )
239
+ block_size: int = field(
240
+ default=1024, metadata={"help": "Chunked size for large layers with Distributed Shampoo."}
241
+ )
242
  preconditioning_compute_steps: int = field(
243
  default=10, metadata={"help": "Number of steps to update preconditioner."}
244
  )
245
+ skip_preconditioning_dim_size_gt: int = field(
246
+ default=4096, metadata={"help": "Max size for preconditioning with Distributed Shampoo."}
247
+ )
248
  optim_quantized: bool = field(
249
  default=False,
250
  metadata={
251
+ "help": "Whether to quantize optimizer (only supported with Distributed Shampoo)."
252
  },
253
  )
254
 
 
600
  # - mask for weight decay is not implemented
601
  optimizer = distributed_shampoo(
602
  learning_rate_fn,
603
+ block_size=training_args.block_size,
604
  beta1=training_args.beta1,
605
  beta2=training_args.beta2,
606
  diagonal_epsilon=1e-10,
 
608
  weight_decay=training_args.weight_decay
609
  if training_args.weight_decay is not None
610
  else 0.0,
611
+ start_preconditioning_step=training_args.warmup_steps,
612
  preconditioning_compute_steps=training_args.preconditioning_compute_steps,
613
  statistics_compute_steps=1,
614
  best_effort_shape_interpretation=True,
 
618
  batch_axis_name="batch",
619
  inverse_failure_threshold=0.1,
620
  moving_average_for_momentum=True,
621
+ skip_preconditioning_dim_size_gt=training_args.skip_preconditioning_dim_size_gt,
622
  clip_by_scaled_gradient_norm=None,
623
  precision=jax.lax.Precision.HIGHEST,
624
  best_effort_memory_usage_reduction=training_args.optim_quantized,