ricdomolm winglian commited on
Commit
04b978b
1 Parent(s): c3e8165

Cosine learning rate schedule - minimum learning rate (#1062)

Browse files

* Cosine min lr

* Cosine min lr - warn if using deepspeed

* cosine_min_lr_ratio readme

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>

README.md CHANGED
@@ -755,6 +755,7 @@ early_stopping_patience: 3
755
  # Specify a scheduler and kwargs to use with the optimizer
756
  lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
757
  lr_scheduler_kwargs:
 
758
 
759
  # For one_cycle optim
760
  lr_div_factor: # Learning rate div factor
 
755
  # Specify a scheduler and kwargs to use with the optimizer
756
  lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
757
  lr_scheduler_kwargs:
758
+ cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
759
 
760
  # For one_cycle optim
761
  lr_div_factor: # Learning rate div factor
src/axolotl/core/trainer_builder.py CHANGED
@@ -38,7 +38,10 @@ from axolotl.utils.collators import (
38
  MambaDataCollator,
39
  )
40
  from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
41
- from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
 
 
 
42
 
43
  try:
44
  import torch._dynamo # pylint: disable=ungrouped-imports
@@ -120,6 +123,10 @@ class AxolotlTrainingArguments(TrainingArguments):
120
  default=None,
121
  metadata={"help": "prefetch_factor argument to the dataloader"},
122
  )
 
 
 
 
123
 
124
 
125
  class AxolotlTrainer(Trainer):
@@ -159,6 +166,17 @@ class AxolotlTrainer(Trainer):
159
  num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
160
  num_training_steps=num_training_steps,
161
  )
 
 
 
 
 
 
 
 
 
 
 
162
  else:
163
  return super().create_scheduler(num_training_steps, optimizer)
164
  return self.lr_scheduler
@@ -745,6 +763,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
745
  training_arguments_kwargs["lr_scheduler_kwargs"] = (
746
  self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
747
  )
 
748
  training_arguments_kwargs["weight_decay"] = (
749
  self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
750
  )
 
38
  MambaDataCollator,
39
  )
40
  from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
41
+ from axolotl.utils.schedulers import (
42
+ get_cosine_schedule_with_min_lr,
43
+ get_cosine_schedule_with_quadratic_warmup,
44
+ )
45
 
46
  try:
47
  import torch._dynamo # pylint: disable=ungrouped-imports
 
123
  default=None,
124
  metadata={"help": "prefetch_factor argument to the dataloader"},
125
  )
126
+ cosine_min_lr_ratio: Optional[float] = field(
127
+ default=None,
128
+ metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
129
+ )
130
 
131
 
132
  class AxolotlTrainer(Trainer):
 
166
  num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
167
  num_training_steps=num_training_steps,
168
  )
169
+ elif self.args.lr_scheduler_type == "cosine" and self.args.cosine_min_lr_ratio is not None:
170
+ assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
171
+ if self.args.deepspeed:
172
+ LOG.warning("Using cosine scheduler with deepspeed. This may be ignored if a scheduler is set \
173
+ in the deepspeed JSON")
174
+ self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
175
+ optimizer,
176
+ num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
177
+ num_training_steps=num_training_steps,
178
+ min_lr_ratio=self.args.cosine_min_lr_ratio,
179
+ )
180
  else:
181
  return super().create_scheduler(num_training_steps, optimizer)
182
  return self.lr_scheduler
 
763
  training_arguments_kwargs["lr_scheduler_kwargs"] = (
764
  self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
765
  )
766
+ training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio
767
  training_arguments_kwargs["weight_decay"] = (
768
  self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
769
  )
src/axolotl/utils/schedulers.py CHANGED
@@ -100,3 +100,43 @@ def get_cosine_schedule_with_quadratic_warmup(
100
  num_cycles=num_cycles,
101
  )
102
  return LambdaLR(optimizer, lr_lambda, last_epoch)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  num_cycles=num_cycles,
101
  )
102
  return LambdaLR(optimizer, lr_lambda, last_epoch)
103
+
104
+
105
+ def _get_cosine_schedule_with_min_lr_lambda(
106
+ current_step: int,
107
+ *,
108
+ num_warmup_steps: int,
109
+ num_training_steps: int,
110
+ min_lr_ratio: float
111
+ ):
112
+ # Warm up
113
+ if current_step < num_warmup_steps:
114
+ return float(current_step) / float(max(1, num_warmup_steps))
115
+
116
+ # Cosine learning rate decay
117
+ progress = float(current_step - num_warmup_steps) / float(
118
+ max(1, num_training_steps - num_warmup_steps)
119
+ )
120
+ scaling = 0.5 * (1.0 + math.cos(math.pi * progress))
121
+ return (1 - min_lr_ratio) * scaling + min_lr_ratio
122
+
123
+
124
+ def get_cosine_schedule_with_min_lr(
125
+ optimizer: Optimizer,
126
+ num_warmup_steps: int,
127
+ num_training_steps: int,
128
+ min_lr_ratio: float = 0.0,
129
+ ):
130
+ """
131
+ Create a learning rate schedule which has:
132
+ - linear warmup from 0 -> `max_lr` over `num_warmup_steps`
133
+ - cosine learning rate annealing from `max_lr` -> `min_lr` over `num_training_steps`
134
+ """
135
+
136
+ lr_lambda = partial(
137
+ _get_cosine_schedule_with_min_lr_lambda,
138
+ num_warmup_steps=num_warmup_steps,
139
+ num_training_steps=num_training_steps,
140
+ min_lr_ratio=min_lr_ratio,
141
+ )
142
+ return LambdaLR(optimizer, lr_lambda)