Create optimization.py
Browse files- optimization.py +756 -0
optimization.py
ADDED
@@ -0,0 +1,756 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""PyTorch optimization for BERT model."""
|
16 |
+
|
17 |
+
import math
|
18 |
+
import warnings
|
19 |
+
from functools import partial
|
20 |
+
from typing import Callable, Iterable, Optional, Tuple, Union
|
21 |
+
|
22 |
+
import torch
|
23 |
+
from torch import nn
|
24 |
+
from torch.optim import Optimizer
|
25 |
+
from torch.optim.lr_scheduler import LambdaLR
|
26 |
+
|
27 |
+
from .trainer_utils import SchedulerType
|
28 |
+
from .utils import logging
|
29 |
+
from .utils.versions import require_version
|
30 |
+
|
31 |
+
|
32 |
+
logger = logging.get_logger(__name__)
|
33 |
+
|
34 |
+
|
35 |
+
def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
|
36 |
+
"""
|
37 |
+
Create a schedule with a constant learning rate, using the learning rate set in optimizer.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
41 |
+
The optimizer for which to schedule the learning rate.
|
42 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
43 |
+
The index of the last epoch when resuming training.
|
44 |
+
|
45 |
+
Return:
|
46 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
47 |
+
"""
|
48 |
+
|
49 |
+
return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
|
50 |
+
|
51 |
+
|
52 |
+
def _get_constant_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int):
|
53 |
+
if current_step < num_warmup_steps:
|
54 |
+
return float(current_step) / float(max(1.0, num_warmup_steps))
|
55 |
+
return 1.0
|
56 |
+
|
57 |
+
|
58 |
+
def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
|
59 |
+
"""
|
60 |
+
Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
|
61 |
+
increases linearly between 0 and the initial lr set in the optimizer.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
65 |
+
The optimizer for which to schedule the learning rate.
|
66 |
+
num_warmup_steps (`int`):
|
67 |
+
The number of steps for the warmup phase.
|
68 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
69 |
+
The index of the last epoch when resuming training.
|
70 |
+
|
71 |
+
Return:
|
72 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
73 |
+
"""
|
74 |
+
|
75 |
+
lr_lambda = partial(_get_constant_schedule_with_warmup_lr_lambda, num_warmup_steps=num_warmup_steps)
|
76 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
|
77 |
+
|
78 |
+
|
79 |
+
def _get_linear_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int):
|
80 |
+
if current_step < num_warmup_steps:
|
81 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
82 |
+
return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
|
83 |
+
|
84 |
+
|
85 |
+
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
|
86 |
+
"""
|
87 |
+
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
|
88 |
+
a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
92 |
+
The optimizer for which to schedule the learning rate.
|
93 |
+
num_warmup_steps (`int`):
|
94 |
+
The number of steps for the warmup phase.
|
95 |
+
num_training_steps (`int`):
|
96 |
+
The total number of training steps.
|
97 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
98 |
+
The index of the last epoch when resuming training.
|
99 |
+
|
100 |
+
Return:
|
101 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
102 |
+
"""
|
103 |
+
|
104 |
+
lr_lambda = partial(
|
105 |
+
_get_linear_schedule_with_warmup_lr_lambda,
|
106 |
+
num_warmup_steps=num_warmup_steps,
|
107 |
+
num_training_steps=num_training_steps,
|
108 |
+
)
|
109 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
110 |
+
|
111 |
+
|
112 |
+
def _get_cosine_schedule_with_warmup_lr_lambda(
|
113 |
+
current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float
|
114 |
+
):
|
115 |
+
if current_step < num_warmup_steps:
|
116 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
117 |
+
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
118 |
+
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
|
119 |
+
|
120 |
+
|
121 |
+
def get_cosine_schedule_with_warmup(
|
122 |
+
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
|
123 |
+
):
|
124 |
+
"""
|
125 |
+
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
126 |
+
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
|
127 |
+
initial lr set in the optimizer.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
131 |
+
The optimizer for which to schedule the learning rate.
|
132 |
+
num_warmup_steps (`int`):
|
133 |
+
The number of steps for the warmup phase.
|
134 |
+
num_training_steps (`int`):
|
135 |
+
The total number of training steps.
|
136 |
+
num_cycles (`float`, *optional*, defaults to 0.5):
|
137 |
+
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
|
138 |
+
following a half-cosine).
|
139 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
140 |
+
The index of the last epoch when resuming training.
|
141 |
+
|
142 |
+
Return:
|
143 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
144 |
+
"""
|
145 |
+
|
146 |
+
lr_lambda = partial(
|
147 |
+
_get_cosine_schedule_with_warmup_lr_lambda,
|
148 |
+
num_warmup_steps=num_warmup_steps,
|
149 |
+
num_training_steps=num_training_steps,
|
150 |
+
num_cycles=num_cycles,
|
151 |
+
)
|
152 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
153 |
+
|
154 |
+
|
155 |
+
def _get_cosine_with_hard_restarts_schedule_with_warmup_lr_lambda(
|
156 |
+
current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: int
|
157 |
+
):
|
158 |
+
if current_step < num_warmup_steps:
|
159 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
160 |
+
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
161 |
+
if progress >= 1.0:
|
162 |
+
return 0.0
|
163 |
+
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
|
164 |
+
|
165 |
+
|
166 |
+
def get_cosine_with_hard_restarts_schedule_with_warmup(
|
167 |
+
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
|
168 |
+
):
|
169 |
+
"""
|
170 |
+
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
171 |
+
initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
|
172 |
+
linearly between 0 and the initial lr set in the optimizer.
|
173 |
+
|
174 |
+
Args:
|
175 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
176 |
+
The optimizer for which to schedule the learning rate.
|
177 |
+
num_warmup_steps (`int`):
|
178 |
+
The number of steps for the warmup phase.
|
179 |
+
num_training_steps (`int`):
|
180 |
+
The total number of training steps.
|
181 |
+
num_cycles (`int`, *optional*, defaults to 1):
|
182 |
+
The number of hard restarts to use.
|
183 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
184 |
+
The index of the last epoch when resuming training.
|
185 |
+
|
186 |
+
Return:
|
187 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
188 |
+
"""
|
189 |
+
|
190 |
+
lr_lambda = partial(
|
191 |
+
_get_cosine_with_hard_restarts_schedule_with_warmup_lr_lambda,
|
192 |
+
num_warmup_steps=num_warmup_steps,
|
193 |
+
num_training_steps=num_training_steps,
|
194 |
+
num_cycles=num_cycles,
|
195 |
+
)
|
196 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
197 |
+
|
198 |
+
|
199 |
+
def _get_polynomial_decay_schedule_with_warmup_lr_lambda(
|
200 |
+
current_step: int,
|
201 |
+
*,
|
202 |
+
num_warmup_steps: int,
|
203 |
+
num_training_steps: int,
|
204 |
+
lr_end: float,
|
205 |
+
power: float,
|
206 |
+
lr_init: int,
|
207 |
+
):
|
208 |
+
if current_step < num_warmup_steps:
|
209 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
210 |
+
elif current_step > num_training_steps:
|
211 |
+
return lr_end / lr_init # as LambdaLR multiplies by lr_init
|
212 |
+
else:
|
213 |
+
lr_range = lr_init - lr_end
|
214 |
+
decay_steps = num_training_steps - num_warmup_steps
|
215 |
+
pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
|
216 |
+
decay = lr_range * pct_remaining**power + lr_end
|
217 |
+
return decay / lr_init # as LambdaLR multiplies by lr_init
|
218 |
+
|
219 |
+
|
220 |
+
def get_polynomial_decay_schedule_with_warmup(
|
221 |
+
optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1
|
222 |
+
):
|
223 |
+
"""
|
224 |
+
Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
|
225 |
+
optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
|
226 |
+
initial lr set in the optimizer.
|
227 |
+
|
228 |
+
Args:
|
229 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
230 |
+
The optimizer for which to schedule the learning rate.
|
231 |
+
num_warmup_steps (`int`):
|
232 |
+
The number of steps for the warmup phase.
|
233 |
+
num_training_steps (`int`):
|
234 |
+
The total number of training steps.
|
235 |
+
lr_end (`float`, *optional*, defaults to 1e-7):
|
236 |
+
The end LR.
|
237 |
+
power (`float`, *optional*, defaults to 1.0):
|
238 |
+
Power factor.
|
239 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
240 |
+
The index of the last epoch when resuming training.
|
241 |
+
|
242 |
+
Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
|
243 |
+
implementation at
|
244 |
+
https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
|
245 |
+
|
246 |
+
Return:
|
247 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
248 |
+
|
249 |
+
"""
|
250 |
+
|
251 |
+
lr_init = optimizer.defaults["lr"]
|
252 |
+
if not (lr_init > lr_end):
|
253 |
+
raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})")
|
254 |
+
|
255 |
+
lr_lambda = partial(
|
256 |
+
_get_polynomial_decay_schedule_with_warmup_lr_lambda,
|
257 |
+
num_warmup_steps=num_warmup_steps,
|
258 |
+
num_training_steps=num_training_steps,
|
259 |
+
lr_end=lr_end,
|
260 |
+
power=power,
|
261 |
+
lr_init=lr_init,
|
262 |
+
)
|
263 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
264 |
+
|
265 |
+
|
266 |
+
def _get_inverse_sqrt_schedule_lr_lambda(current_step: int, *, num_warmup_steps: int, timescale: int = None):
|
267 |
+
if current_step < num_warmup_steps:
|
268 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
269 |
+
shift = timescale - num_warmup_steps
|
270 |
+
decay = 1.0 / math.sqrt((current_step + shift) / timescale)
|
271 |
+
return decay
|
272 |
+
|
273 |
+
|
274 |
+
def get_inverse_sqrt_schedule(
|
275 |
+
optimizer: Optimizer, num_warmup_steps: int, timescale: int = None, last_epoch: int = -1
|
276 |
+
):
|
277 |
+
"""
|
278 |
+
Create a schedule with an inverse square-root learning rate, from the initial lr set in the optimizer, after a
|
279 |
+
warmup period which increases lr linearly from 0 to the initial lr set in the optimizer.
|
280 |
+
|
281 |
+
Args:
|
282 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
283 |
+
The optimizer for which to schedule the learning rate.
|
284 |
+
num_warmup_steps (`int`):
|
285 |
+
The number of steps for the warmup phase.
|
286 |
+
timescale (`int`, *optional*, defaults to `num_warmup_steps`):
|
287 |
+
Time scale.
|
288 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
289 |
+
The index of the last epoch when resuming training.
|
290 |
+
|
291 |
+
Return:
|
292 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
293 |
+
"""
|
294 |
+
# Note: this implementation is adapted from
|
295 |
+
# https://github.com/google-research/big_vision/blob/f071ce68852d56099437004fd70057597a95f6ef/big_vision/utils.py#L930
|
296 |
+
|
297 |
+
if timescale is None:
|
298 |
+
timescale = num_warmup_steps
|
299 |
+
|
300 |
+
lr_lambda = partial(_get_inverse_sqrt_schedule_lr_lambda, num_warmup_steps=num_warmup_steps, timescale=timescale)
|
301 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
|
302 |
+
|
303 |
+
|
304 |
+
TYPE_TO_SCHEDULER_FUNCTION = {
|
305 |
+
SchedulerType.LINEAR: get_linear_schedule_with_warmup,
|
306 |
+
SchedulerType.COSINE: get_cosine_schedule_with_warmup,
|
307 |
+
SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
|
308 |
+
SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
|
309 |
+
SchedulerType.CONSTANT: get_constant_schedule,
|
310 |
+
SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
|
311 |
+
SchedulerType.INVERSE_SQRT: get_inverse_sqrt_schedule,
|
312 |
+
}
|
313 |
+
|
314 |
+
|
315 |
+
def get_scheduler(
|
316 |
+
name: Union[str, SchedulerType],
|
317 |
+
optimizer: Optimizer,
|
318 |
+
num_warmup_steps: Optional[int] = None,
|
319 |
+
num_training_steps: Optional[int] = None,
|
320 |
+
):
|
321 |
+
"""
|
322 |
+
Unified API to get any scheduler from its name.
|
323 |
+
|
324 |
+
Args:
|
325 |
+
name (`str` or `SchedulerType`):
|
326 |
+
The name of the scheduler to use.
|
327 |
+
optimizer (`torch.optim.Optimizer`):
|
328 |
+
The optimizer that will be used during training.
|
329 |
+
num_warmup_steps (`int`, *optional*):
|
330 |
+
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
331 |
+
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
332 |
+
num_training_steps (`int``, *optional*):
|
333 |
+
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
334 |
+
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
335 |
+
"""
|
336 |
+
name = SchedulerType(name)
|
337 |
+
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
338 |
+
if name == SchedulerType.CONSTANT:
|
339 |
+
return schedule_func(optimizer)
|
340 |
+
|
341 |
+
# All other schedulers require `num_warmup_steps`
|
342 |
+
if num_warmup_steps is None:
|
343 |
+
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
344 |
+
|
345 |
+
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
346 |
+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
347 |
+
|
348 |
+
if name == SchedulerType.INVERSE_SQRT:
|
349 |
+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
350 |
+
|
351 |
+
# All other schedulers require `num_training_steps`
|
352 |
+
if num_training_steps is None:
|
353 |
+
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
354 |
+
|
355 |
+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
356 |
+
|
357 |
+
|
358 |
+
class AdamW(Optimizer):
|
359 |
+
"""
|
360 |
+
Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay
|
361 |
+
Regularization](https://arxiv.org/abs/1711.05101).
|
362 |
+
|
363 |
+
Parameters:
|
364 |
+
params (`Iterable[nn.parameter.Parameter]`):
|
365 |
+
Iterable of parameters to optimize or dictionaries defining parameter groups.
|
366 |
+
lr (`float`, *optional*, defaults to 1e-3):
|
367 |
+
The learning rate to use.
|
368 |
+
betas (`Tuple[float,float]`, *optional*, defaults to (0.9, 0.999)):
|
369 |
+
Adam's betas parameters (b1, b2).
|
370 |
+
eps (`float`, *optional*, defaults to 1e-6):
|
371 |
+
Adam's epsilon for numerical stability.
|
372 |
+
weight_decay (`float`, *optional*, defaults to 0):
|
373 |
+
Decoupled weight decay to apply.
|
374 |
+
correct_bias (`bool`, *optional*, defaults to `True`):
|
375 |
+
Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`).
|
376 |
+
no_deprecation_warning (`bool`, *optional*, defaults to `False`):
|
377 |
+
A flag used to disable the deprecation warning (set to `True` to disable the warning).
|
378 |
+
"""
|
379 |
+
|
380 |
+
def __init__(
|
381 |
+
self,
|
382 |
+
params: Iterable[nn.parameter.Parameter],
|
383 |
+
lr: float = 1e-3,
|
384 |
+
betas: Tuple[float, float] = (0.9, 0.999),
|
385 |
+
eps: float = 1e-6,
|
386 |
+
weight_decay: float = 0.0,
|
387 |
+
correct_bias: bool = True,
|
388 |
+
no_deprecation_warning: bool = False,
|
389 |
+
):
|
390 |
+
if not no_deprecation_warning:
|
391 |
+
warnings.warn(
|
392 |
+
"This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch"
|
393 |
+
" implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this"
|
394 |
+
" warning",
|
395 |
+
FutureWarning,
|
396 |
+
)
|
397 |
+
require_version("torch>=1.5.0") # add_ with alpha
|
398 |
+
if lr < 0.0:
|
399 |
+
raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
|
400 |
+
if not 0.0 <= betas[0] < 1.0:
|
401 |
+
raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)")
|
402 |
+
if not 0.0 <= betas[1] < 1.0:
|
403 |
+
raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
|
404 |
+
if not 0.0 <= eps:
|
405 |
+
raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
|
406 |
+
defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "correct_bias": correct_bias}
|
407 |
+
super().__init__(params, defaults)
|
408 |
+
|
409 |
+
def step(self, closure: Callable = None):
|
410 |
+
"""
|
411 |
+
Performs a single optimization step.
|
412 |
+
|
413 |
+
Arguments:
|
414 |
+
closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
|
415 |
+
"""
|
416 |
+
loss = None
|
417 |
+
if closure is not None:
|
418 |
+
loss = closure()
|
419 |
+
|
420 |
+
for group in self.param_groups:
|
421 |
+
for p in group["params"]:
|
422 |
+
if p.grad is None:
|
423 |
+
continue
|
424 |
+
grad = p.grad.data
|
425 |
+
if grad.is_sparse:
|
426 |
+
raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
|
427 |
+
|
428 |
+
state = self.state[p]
|
429 |
+
|
430 |
+
# State initialization
|
431 |
+
if len(state) == 0:
|
432 |
+
state["step"] = 0
|
433 |
+
# Exponential moving average of gradient values
|
434 |
+
state["exp_avg"] = torch.zeros_like(p.data)
|
435 |
+
# Exponential moving average of squared gradient values
|
436 |
+
state["exp_avg_sq"] = torch.zeros_like(p.data)
|
437 |
+
|
438 |
+
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
439 |
+
beta1, beta2 = group["betas"]
|
440 |
+
|
441 |
+
state["step"] += 1
|
442 |
+
|
443 |
+
# Decay the first and second moment running average coefficient
|
444 |
+
# In-place operations to update the averages at the same time
|
445 |
+
exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
|
446 |
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
|
447 |
+
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
448 |
+
|
449 |
+
step_size = group["lr"]
|
450 |
+
if group["correct_bias"]: # No bias correction for Bert
|
451 |
+
bias_correction1 = 1.0 - beta1 ** state["step"]
|
452 |
+
bias_correction2 = 1.0 - beta2 ** state["step"]
|
453 |
+
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
|
454 |
+
|
455 |
+
p.data.addcdiv_(exp_avg, denom, value=-step_size)
|
456 |
+
|
457 |
+
# Just adding the square of the weights to the loss function is *not*
|
458 |
+
# the correct way of using L2 regularization/weight decay with Adam,
|
459 |
+
# since that will interact with the m and v parameters in strange ways.
|
460 |
+
#
|
461 |
+
# Instead we want to decay the weights in a manner that doesn't interact
|
462 |
+
# with the m/v parameters. This is equivalent to adding the square
|
463 |
+
# of the weights to the loss with plain (non-momentum) SGD.
|
464 |
+
# Add weight decay at the end (fixed version)
|
465 |
+
if group["weight_decay"] > 0.0:
|
466 |
+
p.data.add_(p.data, alpha=(-group["lr"] * group["weight_decay"]))
|
467 |
+
|
468 |
+
return loss
|
469 |
+
|
470 |
+
|
471 |
+
class Adafactor(Optimizer):
|
472 |
+
"""
|
473 |
+
AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code:
|
474 |
+
https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
|
475 |
+
|
476 |
+
Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://arxiv.org/abs/1804.04235 Note that
|
477 |
+
this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and
|
478 |
+
`warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
|
479 |
+
`relative_step=False`.
|
480 |
+
|
481 |
+
Arguments:
|
482 |
+
params (`Iterable[nn.parameter.Parameter]`):
|
483 |
+
Iterable of parameters to optimize or dictionaries defining parameter groups.
|
484 |
+
lr (`float`, *optional*):
|
485 |
+
The external learning rate.
|
486 |
+
eps (`Tuple[float, float]`, *optional*, defaults to (1e-30, 1e-3)):
|
487 |
+
Regularization constants for square gradient and parameter scale respectively
|
488 |
+
clip_threshold (`float`, *optional*, defaults 1.0):
|
489 |
+
Threshold of root mean square of final gradient update
|
490 |
+
decay_rate (`float`, *optional*, defaults to -0.8):
|
491 |
+
Coefficient used to compute running averages of square
|
492 |
+
beta1 (`float`, *optional*):
|
493 |
+
Coefficient used for computing running averages of gradient
|
494 |
+
weight_decay (`float`, *optional*, defaults to 0):
|
495 |
+
Weight decay (L2 penalty)
|
496 |
+
scale_parameter (`bool`, *optional*, defaults to `True`):
|
497 |
+
If True, learning rate is scaled by root mean square
|
498 |
+
relative_step (`bool`, *optional*, defaults to `True`):
|
499 |
+
If True, time-dependent learning rate is computed instead of external learning rate
|
500 |
+
warmup_init (`bool`, *optional*, defaults to `False`):
|
501 |
+
Time-dependent learning rate computation depends on whether warm-up initialization is being used
|
502 |
+
|
503 |
+
This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested.
|
504 |
+
|
505 |
+
Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3):
|
506 |
+
|
507 |
+
- Training without LR warmup or clip_threshold is not recommended.
|
508 |
+
|
509 |
+
- use scheduled LR warm-up to fixed LR
|
510 |
+
- use clip_threshold=1.0 (https://arxiv.org/abs/1804.04235)
|
511 |
+
- Disable relative updates
|
512 |
+
- Use scale_parameter=False
|
513 |
+
- Additional optimizer operations like gradient clipping should not be used alongside Adafactor
|
514 |
+
|
515 |
+
Example:
|
516 |
+
|
517 |
+
```python
|
518 |
+
Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3)
|
519 |
+
```
|
520 |
+
|
521 |
+
Others reported the following combination to work well:
|
522 |
+
|
523 |
+
```python
|
524 |
+
Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
|
525 |
+
```
|
526 |
+
|
527 |
+
When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`]
|
528 |
+
scheduler as following:
|
529 |
+
|
530 |
+
```python
|
531 |
+
from transformers.optimization import Adafactor, AdafactorSchedule
|
532 |
+
|
533 |
+
optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
|
534 |
+
lr_scheduler = AdafactorSchedule(optimizer)
|
535 |
+
trainer = Trainer(..., optimizers=(optimizer, lr_scheduler))
|
536 |
+
```
|
537 |
+
|
538 |
+
Usage:
|
539 |
+
|
540 |
+
```python
|
541 |
+
# replace AdamW with Adafactor
|
542 |
+
optimizer = Adafactor(
|
543 |
+
model.parameters(),
|
544 |
+
lr=1e-3,
|
545 |
+
eps=(1e-30, 1e-3),
|
546 |
+
clip_threshold=1.0,
|
547 |
+
decay_rate=-0.8,
|
548 |
+
beta1=None,
|
549 |
+
weight_decay=0.0,
|
550 |
+
relative_step=False,
|
551 |
+
scale_parameter=False,
|
552 |
+
warmup_init=False,
|
553 |
+
)
|
554 |
+
```"""
|
555 |
+
|
556 |
+
def __init__(
|
557 |
+
self,
|
558 |
+
params,
|
559 |
+
lr=None,
|
560 |
+
eps=(1e-30, 1e-3),
|
561 |
+
clip_threshold=1.0,
|
562 |
+
decay_rate=-0.8,
|
563 |
+
beta1=None,
|
564 |
+
weight_decay=0.0,
|
565 |
+
scale_parameter=True,
|
566 |
+
relative_step=True,
|
567 |
+
warmup_init=False,
|
568 |
+
):
|
569 |
+
require_version("torch>=1.5.0") # add_ with alpha
|
570 |
+
if lr is not None and relative_step:
|
571 |
+
raise ValueError("Cannot combine manual `lr` and `relative_step=True` options")
|
572 |
+
if warmup_init and not relative_step:
|
573 |
+
raise ValueError("`warmup_init=True` requires `relative_step=True`")
|
574 |
+
|
575 |
+
defaults = {
|
576 |
+
"lr": lr,
|
577 |
+
"eps": eps,
|
578 |
+
"clip_threshold": clip_threshold,
|
579 |
+
"decay_rate": decay_rate,
|
580 |
+
"beta1": beta1,
|
581 |
+
"weight_decay": weight_decay,
|
582 |
+
"scale_parameter": scale_parameter,
|
583 |
+
"relative_step": relative_step,
|
584 |
+
"warmup_init": warmup_init,
|
585 |
+
}
|
586 |
+
super().__init__(params, defaults)
|
587 |
+
|
588 |
+
@staticmethod
|
589 |
+
def _get_lr(param_group, param_state):
|
590 |
+
rel_step_sz = param_group["lr"]
|
591 |
+
if param_group["relative_step"]:
|
592 |
+
min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2
|
593 |
+
rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"]))
|
594 |
+
param_scale = 1.0
|
595 |
+
if param_group["scale_parameter"]:
|
596 |
+
param_scale = max(param_group["eps"][1], param_state["RMS"])
|
597 |
+
return param_scale * rel_step_sz
|
598 |
+
|
599 |
+
@staticmethod
|
600 |
+
def _get_options(param_group, param_shape):
|
601 |
+
factored = len(param_shape) >= 2
|
602 |
+
use_first_moment = param_group["beta1"] is not None
|
603 |
+
return factored, use_first_moment
|
604 |
+
|
605 |
+
@staticmethod
|
606 |
+
def _rms(tensor):
|
607 |
+
return tensor.norm(2) / (tensor.numel() ** 0.5)
|
608 |
+
|
609 |
+
@staticmethod
|
610 |
+
def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
|
611 |
+
# copy from fairseq's adafactor implementation:
|
612 |
+
# https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505
|
613 |
+
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
|
614 |
+
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
|
615 |
+
return torch.mul(r_factor, c_factor)
|
616 |
+
|
617 |
+
def step(self, closure=None):
|
618 |
+
"""
|
619 |
+
Performs a single optimization step
|
620 |
+
|
621 |
+
Arguments:
|
622 |
+
closure (callable, optional): A closure that reevaluates the model
|
623 |
+
and returns the loss.
|
624 |
+
"""
|
625 |
+
loss = None
|
626 |
+
if closure is not None:
|
627 |
+
loss = closure()
|
628 |
+
|
629 |
+
for group in self.param_groups:
|
630 |
+
for p in group["params"]:
|
631 |
+
if p.grad is None:
|
632 |
+
continue
|
633 |
+
grad = p.grad.data
|
634 |
+
if grad.dtype in {torch.float16, torch.bfloat16}:
|
635 |
+
grad = grad.float()
|
636 |
+
if grad.is_sparse:
|
637 |
+
raise RuntimeError("Adafactor does not support sparse gradients.")
|
638 |
+
|
639 |
+
state = self.state[p]
|
640 |
+
grad_shape = grad.shape
|
641 |
+
|
642 |
+
factored, use_first_moment = self._get_options(group, grad_shape)
|
643 |
+
# State Initialization
|
644 |
+
if len(state) == 0:
|
645 |
+
state["step"] = 0
|
646 |
+
|
647 |
+
if use_first_moment:
|
648 |
+
# Exponential moving average of gradient values
|
649 |
+
state["exp_avg"] = torch.zeros_like(grad)
|
650 |
+
if factored:
|
651 |
+
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
|
652 |
+
state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
|
653 |
+
else:
|
654 |
+
state["exp_avg_sq"] = torch.zeros_like(grad)
|
655 |
+
|
656 |
+
state["RMS"] = 0
|
657 |
+
else:
|
658 |
+
if use_first_moment:
|
659 |
+
state["exp_avg"] = state["exp_avg"].to(grad)
|
660 |
+
if factored:
|
661 |
+
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
|
662 |
+
state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
|
663 |
+
else:
|
664 |
+
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
|
665 |
+
|
666 |
+
p_data_fp32 = p.data
|
667 |
+
if p.data.dtype in {torch.float16, torch.bfloat16}:
|
668 |
+
p_data_fp32 = p_data_fp32.float()
|
669 |
+
|
670 |
+
state["step"] += 1
|
671 |
+
state["RMS"] = self._rms(p_data_fp32)
|
672 |
+
lr = self._get_lr(group, state)
|
673 |
+
|
674 |
+
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
|
675 |
+
update = (grad**2) + group["eps"][0]
|
676 |
+
if factored:
|
677 |
+
exp_avg_sq_row = state["exp_avg_sq_row"]
|
678 |
+
exp_avg_sq_col = state["exp_avg_sq_col"]
|
679 |
+
|
680 |
+
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
|
681 |
+
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
|
682 |
+
|
683 |
+
# Approximation of exponential moving average of square of gradient
|
684 |
+
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
|
685 |
+
update.mul_(grad)
|
686 |
+
else:
|
687 |
+
exp_avg_sq = state["exp_avg_sq"]
|
688 |
+
|
689 |
+
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
|
690 |
+
update = exp_avg_sq.rsqrt().mul_(grad)
|
691 |
+
|
692 |
+
update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
|
693 |
+
update.mul_(lr)
|
694 |
+
|
695 |
+
if use_first_moment:
|
696 |
+
exp_avg = state["exp_avg"]
|
697 |
+
exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
|
698 |
+
update = exp_avg
|
699 |
+
|
700 |
+
if group["weight_decay"] != 0:
|
701 |
+
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
|
702 |
+
|
703 |
+
p_data_fp32.add_(-update)
|
704 |
+
|
705 |
+
if p.data.dtype in {torch.float16, torch.bfloat16}:
|
706 |
+
p.data.copy_(p_data_fp32)
|
707 |
+
|
708 |
+
return loss
|
709 |
+
|
710 |
+
|
711 |
+
class AdafactorSchedule(LambdaLR):
|
712 |
+
"""
|
713 |
+
Since [`~optimization.Adafactor`] performs its own scheduling, if the training loop relies on a scheduler (e.g.,
|
714 |
+
for logging), this class creates a proxy object that retrieves the current lr values from the optimizer.
|
715 |
+
|
716 |
+
It returns `initial_lr` during startup and the actual `lr` during stepping.
|
717 |
+
"""
|
718 |
+
|
719 |
+
def __init__(self, optimizer, initial_lr=0.0):
|
720 |
+
def lr_lambda(_):
|
721 |
+
return initial_lr
|
722 |
+
|
723 |
+
for group in optimizer.param_groups:
|
724 |
+
group["initial_lr"] = initial_lr
|
725 |
+
super().__init__(optimizer, lr_lambda)
|
726 |
+
for group in optimizer.param_groups:
|
727 |
+
del group["initial_lr"]
|
728 |
+
|
729 |
+
def get_lr(self):
|
730 |
+
opt = self.optimizer
|
731 |
+
lrs = [
|
732 |
+
opt._get_lr(group, opt.state[group["params"][0]])
|
733 |
+
for group in opt.param_groups
|
734 |
+
if group["params"][0].grad is not None
|
735 |
+
]
|
736 |
+
if len(lrs) == 0:
|
737 |
+
lrs = self.base_lrs # if called before stepping
|
738 |
+
return lrs
|
739 |
+
|
740 |
+
|
741 |
+
def get_adafactor_schedule(optimizer, initial_lr=0.0):
|
742 |
+
"""
|
743 |
+
Get a proxy schedule for [`~optimization.Adafactor`]
|
744 |
+
|
745 |
+
Args:
|
746 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
747 |
+
The optimizer for which to schedule the learning rate.
|
748 |
+
initial_lr (`float`, *optional*, defaults to 0.0):
|
749 |
+
Initial lr
|
750 |
+
|
751 |
+
Return:
|
752 |
+
[`~optimization.Adafactor`] proxy schedule object.
|
753 |
+
|
754 |
+
|
755 |
+
"""
|
756 |
+
return AdafactorSchedule(optimizer, initial_lr)
|