Cosine learning rate schedule - minimum learning rate (#1062)
Browse files* Cosine min lr
* Cosine min lr - warn if using deepspeed
* cosine_min_lr_ratio readme
* chore: lint
---------
Co-authored-by: Wing Lian <wing.lian@gmail.com>
- README.md +1 -0
- src/axolotl/core/trainer_builder.py +20 -1
- src/axolotl/utils/schedulers.py +40 -0
README.md
CHANGED
@@ -755,6 +755,7 @@ early_stopping_patience: 3
|
|
755 |
# Specify a scheduler and kwargs to use with the optimizer
|
756 |
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
|
757 |
lr_scheduler_kwargs:
|
|
|
758 |
|
759 |
# For one_cycle optim
|
760 |
lr_div_factor: # Learning rate div factor
|
|
|
755 |
# Specify a scheduler and kwargs to use with the optimizer
|
756 |
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
|
757 |
lr_scheduler_kwargs:
|
758 |
+
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
|
759 |
|
760 |
# For one_cycle optim
|
761 |
lr_div_factor: # Learning rate div factor
|
src/axolotl/core/trainer_builder.py
CHANGED
@@ -38,7 +38,10 @@ from axolotl.utils.collators import (
|
|
38 |
MambaDataCollator,
|
39 |
)
|
40 |
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
41 |
-
from axolotl.utils.schedulers import
|
|
|
|
|
|
|
42 |
|
43 |
try:
|
44 |
import torch._dynamo # pylint: disable=ungrouped-imports
|
@@ -120,6 +123,10 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|
120 |
default=None,
|
121 |
metadata={"help": "prefetch_factor argument to the dataloader"},
|
122 |
)
|
|
|
|
|
|
|
|
|
123 |
|
124 |
|
125 |
class AxolotlTrainer(Trainer):
|
@@ -159,6 +166,17 @@ class AxolotlTrainer(Trainer):
|
|
159 |
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
160 |
num_training_steps=num_training_steps,
|
161 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
else:
|
163 |
return super().create_scheduler(num_training_steps, optimizer)
|
164 |
return self.lr_scheduler
|
@@ -745,6 +763,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
745 |
training_arguments_kwargs["lr_scheduler_kwargs"] = (
|
746 |
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
747 |
)
|
|
|
748 |
training_arguments_kwargs["weight_decay"] = (
|
749 |
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
750 |
)
|
|
|
38 |
MambaDataCollator,
|
39 |
)
|
40 |
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
41 |
+
from axolotl.utils.schedulers import (
|
42 |
+
get_cosine_schedule_with_min_lr,
|
43 |
+
get_cosine_schedule_with_quadratic_warmup,
|
44 |
+
)
|
45 |
|
46 |
try:
|
47 |
import torch._dynamo # pylint: disable=ungrouped-imports
|
|
|
123 |
default=None,
|
124 |
metadata={"help": "prefetch_factor argument to the dataloader"},
|
125 |
)
|
126 |
+
cosine_min_lr_ratio: Optional[float] = field(
|
127 |
+
default=None,
|
128 |
+
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
|
129 |
+
)
|
130 |
|
131 |
|
132 |
class AxolotlTrainer(Trainer):
|
|
|
166 |
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
167 |
num_training_steps=num_training_steps,
|
168 |
)
|
169 |
+
elif self.args.lr_scheduler_type == "cosine" and self.args.cosine_min_lr_ratio is not None:
|
170 |
+
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
171 |
+
if self.args.deepspeed:
|
172 |
+
LOG.warning("Using cosine scheduler with deepspeed. This may be ignored if a scheduler is set \
|
173 |
+
in the deepspeed JSON")
|
174 |
+
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
|
175 |
+
optimizer,
|
176 |
+
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
177 |
+
num_training_steps=num_training_steps,
|
178 |
+
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
179 |
+
)
|
180 |
else:
|
181 |
return super().create_scheduler(num_training_steps, optimizer)
|
182 |
return self.lr_scheduler
|
|
|
763 |
training_arguments_kwargs["lr_scheduler_kwargs"] = (
|
764 |
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
765 |
)
|
766 |
+
training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio
|
767 |
training_arguments_kwargs["weight_decay"] = (
|
768 |
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
769 |
)
|
src/axolotl/utils/schedulers.py
CHANGED
@@ -100,3 +100,43 @@ def get_cosine_schedule_with_quadratic_warmup(
|
|
100 |
num_cycles=num_cycles,
|
101 |
)
|
102 |
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
num_cycles=num_cycles,
|
101 |
)
|
102 |
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
103 |
+
|
104 |
+
|
105 |
+
def _get_cosine_schedule_with_min_lr_lambda(
|
106 |
+
current_step: int,
|
107 |
+
*,
|
108 |
+
num_warmup_steps: int,
|
109 |
+
num_training_steps: int,
|
110 |
+
min_lr_ratio: float
|
111 |
+
):
|
112 |
+
# Warm up
|
113 |
+
if current_step < num_warmup_steps:
|
114 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
115 |
+
|
116 |
+
# Cosine learning rate decay
|
117 |
+
progress = float(current_step - num_warmup_steps) / float(
|
118 |
+
max(1, num_training_steps - num_warmup_steps)
|
119 |
+
)
|
120 |
+
scaling = 0.5 * (1.0 + math.cos(math.pi * progress))
|
121 |
+
return (1 - min_lr_ratio) * scaling + min_lr_ratio
|
122 |
+
|
123 |
+
|
124 |
+
def get_cosine_schedule_with_min_lr(
|
125 |
+
optimizer: Optimizer,
|
126 |
+
num_warmup_steps: int,
|
127 |
+
num_training_steps: int,
|
128 |
+
min_lr_ratio: float = 0.0,
|
129 |
+
):
|
130 |
+
"""
|
131 |
+
Create a learning rate schedule which has:
|
132 |
+
- linear warmup from 0 -> `max_lr` over `num_warmup_steps`
|
133 |
+
- cosine learning rate annealing from `max_lr` -> `min_lr` over `num_training_steps`
|
134 |
+
"""
|
135 |
+
|
136 |
+
lr_lambda = partial(
|
137 |
+
_get_cosine_schedule_with_min_lr_lambda,
|
138 |
+
num_warmup_steps=num_warmup_steps,
|
139 |
+
num_training_steps=num_training_steps,
|
140 |
+
min_lr_ratio=min_lr_ratio,
|
141 |
+
)
|
142 |
+
return LambdaLR(optimizer, lr_lambda)
|