boris commited on
Commit
e2781bc
1 Parent(s): 8b72ed8

feat(train): refactor learning rate params

Browse files
Files changed (1) hide show
  1. tools/train/train.py +53 -35
tools/train/train.py CHANGED
@@ -246,9 +246,29 @@ class TrainingArguments:
246
  },
247
  )
248
 
249
- use_decay: bool = field(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  default=False,
251
- metadata={"help": "Whether to use decay in the learning rate scheduler."},
 
 
252
  )
253
 
254
  num_train_epochs: float = field(
@@ -321,33 +341,6 @@ class TrainState(train_state.TrainState):
321
  )
322
 
323
 
324
- def create_learning_rate_fn(
325
- num_warmup_steps: int,
326
- learning_rate: float,
327
- use_decay: bool,
328
- num_train_steps: int = None, # used only with `use_decay`, typically train_size // batch_size * num_epochs
329
- ) -> Callable[[int], jnp.array]:
330
- """Returns a linear warmup, linear_decay learning rate function."""
331
- if use_decay:
332
- assert (
333
- num_train_steps is not None
334
- ), "Learning rate with decay requires number of training steps"
335
- warmup_fn = optax.linear_schedule(
336
- init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
337
- )
338
- if not use_decay:
339
- return warmup_fn
340
- decay_fn = optax.linear_schedule(
341
- init_value=learning_rate,
342
- end_value=0,
343
- transition_steps=num_train_steps - num_warmup_steps,
344
- )
345
- schedule_fn = optax.join_schedules(
346
- schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]
347
- )
348
- return schedule_fn
349
-
350
-
351
  class MetricsLogger:
352
  def __init__(self, state):
353
  self.step = state.step
@@ -541,12 +534,37 @@ def main():
541
  num_params = model.num_params
542
 
543
  # Create learning rate schedule
544
- learning_rate_fn = create_learning_rate_fn(
545
- training_args.warmup_steps,
546
- training_args.learning_rate,
547
- training_args.use_decay,
548
- num_train_steps,
549
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
550
 
551
  # We use Optax's "masking" functionality to not apply weight decay
552
  # to bias and LayerNorm scale parameters. decay_mask_fn returns a
 
246
  },
247
  )
248
 
249
+ lr_decay: str = field(
250
+ default=None,
251
+ metadata={
252
+ "help": "Decay to be used in the learning rate scheduler. Can be None (default), linear or exponential."
253
+ },
254
+ )
255
+ lr_transition_steps: int = field(
256
+ default=None,
257
+ metadata={
258
+ "help": "Number of transition steps associated with learning rate decay when using exponential decay."
259
+ },
260
+ )
261
+ lr_decay_rate: float = field(
262
+ default=None,
263
+ metadata={
264
+ "help": "Decay rate associated with learning rate when using exponential decay."
265
+ },
266
+ )
267
+ lr_staircase: bool = field(
268
  default=False,
269
+ metadata={
270
+ "help": "Whether to use staircase or continuous learning rate when using exponential decay."
271
+ },
272
  )
273
 
274
  num_train_epochs: float = field(
 
341
  )
342
 
343
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
  class MetricsLogger:
345
  def __init__(self, state):
346
  self.step = state.step
 
534
  num_params = model.num_params
535
 
536
  # Create learning rate schedule
537
+ def create_learning_rate_fn() -> Callable[[int], jnp.array]:
538
+ """Create the learning rate function."""
539
+ warmup_fn = optax.linear_schedule(
540
+ init_value=0.0,
541
+ end_value=training_args.learning_rate,
542
+ transition_steps=training_args.warmup_steps,
543
+ )
544
+ if training_args.lr_decay is None:
545
+ return warmup_fn
546
+ elif training_args.lr_decay == "linear":
547
+ assert (
548
+ num_train_steps is not None
549
+ ), "linear decay requires knowing the dataset length"
550
+ decay_fn = optax.linear_schedule(
551
+ init_value=training_args.learning_rate,
552
+ end_value=0,
553
+ transition_steps=num_train_steps - training_args.warmup_steps,
554
+ )
555
+ elif training_args.lr_decay == "exponential":
556
+ decay_fn = optax.exponential_decay(
557
+ init_value=training_args.learning_rate,
558
+ transition_steps=training_args.lr_transition_steps,
559
+ decay_rate=training_args.lr_decay_rate,
560
+ staircase=training_args.lr_staircase,
561
+ )
562
+ schedule_fn = optax.join_schedules(
563
+ schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
564
+ )
565
+ return schedule_fn
566
+
567
+ learning_rate_fn = create_learning_rate_fn()
568
 
569
  # We use Optax's "masking" functionality to not apply weight decay
570
  # to bias and LayerNorm scale parameters. decay_mask_fn returns a