boris commited on
Commit
adbdff9
1 Parent(s): 23389f6

feat: refactor TrainingArguments

Browse files
Files changed (1) hide show
  1. tools/train/train.py +63 -42
tools/train/train.py CHANGED
@@ -65,7 +65,7 @@ class ModelArguments:
65
  config_name: Optional[str] = field(
66
  default=None,
67
  metadata={
68
- "help": "Pretrained config name or path if not the same as model_name"
69
  },
70
  )
71
  tokenizer_name: Optional[str] = field(
@@ -77,7 +77,7 @@ class ModelArguments:
77
  dtype: Optional[str] = field(
78
  default="float32",
79
  metadata={
80
- "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
81
  },
82
  )
83
 
@@ -106,11 +106,15 @@ class DataTrainingArguments:
106
  )
107
  train_file: Optional[str] = field(
108
  default=None,
109
- metadata={"help": "The input training data file (glob acceptable)."},
 
 
110
  )
111
  validation_file: Optional[str] = field(
112
  default=None,
113
- metadata={"help": "An optional input evaluation data file (glob acceptable)."},
 
 
114
  )
115
  # data loading should not be a bottleneck so we use "streaming" mode by default
116
  streaming: Optional[bool] = field(
@@ -132,15 +136,13 @@ class DataTrainingArguments:
132
  max_train_samples: Optional[int] = field(
133
  default=None,
134
  metadata={
135
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
136
- "value if set."
137
  },
138
  )
139
  max_eval_samples: Optional[int] = field(
140
  default=None,
141
  metadata={
142
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
143
- "value if set."
144
  },
145
  )
146
  preprocessing_num_workers: Optional[int] = field(
@@ -191,42 +193,42 @@ class TrainingArguments:
191
 
192
  do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
193
  do_eval: bool = field(
194
- default=False, metadata={"help": "Whether to run eval on the dev set."}
195
  )
196
 
197
  per_device_train_batch_size: int = field(
198
- default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
199
  )
200
  per_device_eval_batch_size: int = field(
201
- default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
202
  )
203
 
204
  gradient_accumulation_steps: int = field(
205
  default=1,
206
  metadata={
207
- "help": "Number of updates steps to accumulate before performing a backward/update pass."
208
  },
209
  )
210
 
211
  learning_rate: float = field(
212
  default=5e-5, metadata={"help": "The initial learning rate."}
213
  )
214
- adafactor: bool = field(
215
- default=False,
216
- metadata={"help": "Use Adafactor instead of AdamW."},
217
- )
218
- distributed_shampoo: bool = field(
219
- default=False,
220
- metadata={"help": "Use Distributed Shampoo optimizer instead of AdamW."},
221
  )
222
  weight_decay: float = field(
223
  default=None, metadata={"help": "Weight decay if we apply some."}
224
  )
225
- adam_beta1: float = field(
226
- default=0.9, metadata={"help": "Beta1 for AdamW optimizer"}
 
227
  )
228
- adam_beta2: float = field(
229
- default=0.999, metadata={"help": "Beta2 for AdamW optimizer"}
 
230
  )
231
  adam_epsilon: float = field(
232
  default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}
@@ -234,6 +236,16 @@ class TrainingArguments:
234
  max_grad_norm: float = field(
235
  default=1.0, metadata={"help": "Max gradient norm for Adafactor."}
236
  )
 
 
 
 
 
 
 
 
 
 
237
  use_decay: bool = field(
238
  default=False,
239
  metadata={"help": "Whether to use decay in the learning rate scheduler."},
@@ -272,6 +284,13 @@ class TrainingArguments:
272
  metadata={"help": "Reference to a wandb artifact for resuming training."},
273
  )
274
 
 
 
 
 
 
 
 
275
 
276
  class TrainState(train_state.TrainState):
277
  dropout_rng: jnp.ndarray = None
@@ -551,29 +570,22 @@ def main():
551
  return traverse_util.unflatten_dict(flat_mask)
552
 
553
  # create adam optimizer
554
- if training_args.adafactor:
555
- # We use the default parameters here to initialize adafactor,
556
- # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
557
- optimizer = optax.adafactor(
558
- learning_rate=learning_rate_fn,
559
- weight_decay_rate=training_args.weight_decay,
560
- weight_decay_mask=decay_mask_fn,
561
- clipping_threshold=training_args.max_grad_norm,
562
- )
563
- elif training_args.distributed_shampoo:
564
  # parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
565
  # Notes:
566
- # - mask for weight decay is not implemented but we don't use it anyway
567
  optimizer = distributed_shampoo(
568
  learning_rate_fn,
569
  block_size=1024, # recommended default for large LM is 1536
570
- beta1=0.9,
571
- beta2=0.999,
572
  diagonal_epsilon=1e-10,
573
  matrix_epsilon=1e-8,
574
- weight_decay=0.0,
 
 
575
  start_preconditioning_step=1001,
576
- preconditioning_compute_steps=10,
577
  statistics_compute_steps=1,
578
  best_effort_shape_interpretation=True,
579
  graft_type=GraftingType.RMSPROP_NORMALIZED,
@@ -585,20 +597,29 @@ def main():
585
  skip_preconditioning_dim_size_gt=4096,
586
  clip_by_scaled_gradient_norm=None,
587
  precision=jax.lax.Precision.HIGHEST,
588
- best_effort_memory_usage_reduction=False,
589
  )
590
 
591
- else:
592
  optimizer = optax.adamw(
593
  learning_rate=learning_rate_fn,
594
- b1=training_args.adam_beta1,
595
- b2=training_args.adam_beta2,
596
  eps=training_args.adam_epsilon,
597
  weight_decay=training_args.weight_decay
598
  if training_args.weight_decay is not None
599
  else 0.0,
600
  mask=decay_mask_fn,
601
  )
 
 
 
 
 
 
 
 
 
602
 
603
  # add gradient accumulation
604
  if training_args.gradient_accumulation_steps > 1:
 
65
  config_name: Optional[str] = field(
66
  default=None,
67
  metadata={
68
+ "help": "Pretrained config name or path if not the same as model_name_or_path"
69
  },
70
  )
71
  tokenizer_name: Optional[str] = field(
 
77
  dtype: Optional[str] = field(
78
  default="float32",
79
  metadata={
80
+ "help": "Floating-point format in which the computations will be performed (not the model weights). Choose one of `[float32, float16, bfloat16]`."
81
  },
82
  )
83
 
 
106
  )
107
  train_file: Optional[str] = field(
108
  default=None,
109
+ metadata={
110
+ "help": "The input training data file (glob & braceexpand acceptable)."
111
+ },
112
  )
113
  validation_file: Optional[str] = field(
114
  default=None,
115
+ metadata={
116
+ "help": "An optional input evaluation data file (glob & braceexpand acceptable)."
117
+ },
118
  )
119
  # data loading should not be a bottleneck so we use "streaming" mode by default
120
  streaming: Optional[bool] = field(
 
136
  max_train_samples: Optional[int] = field(
137
  default=None,
138
  metadata={
139
+ "help": "For debugging purposes or quicker training, truncate the number of training examples."
 
140
  },
141
  )
142
  max_eval_samples: Optional[int] = field(
143
  default=None,
144
  metadata={
145
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples."
 
146
  },
147
  )
148
  preprocessing_num_workers: Optional[int] = field(
 
193
 
194
  do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
195
  do_eval: bool = field(
196
+ default=False, metadata={"help": "Whether to run eval on the validation set."}
197
  )
198
 
199
  per_device_train_batch_size: int = field(
200
+ default=8, metadata={"help": "Batch size per GPU/TPU/CPU for training."}
201
  )
202
  per_device_eval_batch_size: int = field(
203
+ default=8, metadata={"help": "Batch size per GPU/TPU/CPU for evaluation."}
204
  )
205
 
206
  gradient_accumulation_steps: int = field(
207
  default=1,
208
  metadata={
209
+ "help": "Number of updates steps to accumulate before performing an update pass."
210
  },
211
  )
212
 
213
  learning_rate: float = field(
214
  default=5e-5, metadata={"help": "The initial learning rate."}
215
  )
216
+ optim: str = field(
217
+ default="distributed_shampoo",
218
+ metadata={
219
+ "help": 'The optimizer to use. Can be "distributed_shampoo" (default), "adam" or "adafactor"'
220
+ },
 
 
221
  )
222
  weight_decay: float = field(
223
  default=None, metadata={"help": "Weight decay if we apply some."}
224
  )
225
+ beta1: float = field(
226
+ default=0.9,
227
+ metadata={"help": "Beta1 for adam & distributed_shampoo optimizers"},
228
  )
229
+ beta2: float = field(
230
+ default=0.999,
231
+ metadata={"help": "Beta2 for adam & distributed_shampoo optimizers"},
232
  )
233
  adam_epsilon: float = field(
234
  default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}
 
236
  max_grad_norm: float = field(
237
  default=1.0, metadata={"help": "Max gradient norm for Adafactor."}
238
  )
239
+ preconditioning_compute_steps: int = field(
240
+ default=10, metadata={"help": "Number of steps to update preconditioner."}
241
+ )
242
+ optim_quantized: bool = field(
243
+ default=False,
244
+ metadat={
245
+ "help": "Whether to quantize optimizer (only supported with distributed_shampoo)."
246
+ },
247
+ )
248
+
249
  use_decay: bool = field(
250
  default=False,
251
  metadata={"help": "Whether to use decay in the learning rate scheduler."},
 
284
  metadata={"help": "Reference to a wandb artifact for resuming training."},
285
  )
286
 
287
+ def __post_init__(self):
288
+ assert self.optim in [
289
+ "distributed_shampoo",
290
+ "adam",
291
+ "adafactor",
292
+ ], f"Selected optimizer not supported: {self.optim}"
293
+
294
 
295
  class TrainState(train_state.TrainState):
296
  dropout_rng: jnp.ndarray = None
 
570
  return traverse_util.unflatten_dict(flat_mask)
571
 
572
  # create adam optimizer
573
+ if training_args.optim == "distributed_shampoo":
 
 
 
 
 
 
 
 
 
574
  # parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
575
  # Notes:
576
+ # - mask for weight decay is not implemented
577
  optimizer = distributed_shampoo(
578
  learning_rate_fn,
579
  block_size=1024, # recommended default for large LM is 1536
580
+ beta1=training_args.beta1,
581
+ beta2=training_args.beta2,
582
  diagonal_epsilon=1e-10,
583
  matrix_epsilon=1e-8,
584
+ weight_decay=training_args.weight_decay
585
+ if training_args.weight_decay is not None
586
+ else 0.0,
587
  start_preconditioning_step=1001,
588
+ preconditioning_compute_steps=training_args.preconditioning_compute_steps,
589
  statistics_compute_steps=1,
590
  best_effort_shape_interpretation=True,
591
  graft_type=GraftingType.RMSPROP_NORMALIZED,
 
597
  skip_preconditioning_dim_size_gt=4096,
598
  clip_by_scaled_gradient_norm=None,
599
  precision=jax.lax.Precision.HIGHEST,
600
+ best_effort_memory_usage_reduction=training_args.optim_quantized,
601
  )
602
 
603
+ elif training_args.optim == "adam":
604
  optimizer = optax.adamw(
605
  learning_rate=learning_rate_fn,
606
+ b1=training_args.beta1,
607
+ b2=training_args.beta2,
608
  eps=training_args.adam_epsilon,
609
  weight_decay=training_args.weight_decay
610
  if training_args.weight_decay is not None
611
  else 0.0,
612
  mask=decay_mask_fn,
613
  )
614
+ elif training_args.optim == "adafactor":
615
+ # We use the default parameters here to initialize adafactor,
616
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
617
+ optimizer = optax.adafactor(
618
+ learning_rate=learning_rate_fn,
619
+ weight_decay_rate=training_args.weight_decay,
620
+ weight_decay_mask=decay_mask_fn,
621
+ clipping_threshold=training_args.max_grad_norm,
622
+ )
623
 
624
  # add gradient accumulation
625
  if training_args.gradient_accumulation_steps > 1: