Spaces:
Running
Running
File size: 4,284 Bytes
03f6091 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
# -*- coding: utf-8 -*-
r"""
Schedulers
==============
Leraning Rate schedulers used to train Polos models.
"""
from argparse import Namespace
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR
class ConstantPolicy:
"""Policy for updating the LR of the ConstantLR scheduler.
With this class LambdaLR objects became picklable.
"""
def __call__(self, *args, **kwargs):
return 1
class ConstantLR(LambdaLR):
"""
Constant learning rate schedule
Wrapper for the huggingface Constant LR Scheduler.
https://huggingface.co/transformers/v2.1.1/main_classes/optimizer_schedules.html
:param optimizer: torch.optim.Optimizer
:param last_epoch:
"""
def __init__(self, optimizer: Optimizer, last_epoch: int = -1) -> None:
super(ConstantLR, self).__init__(optimizer, ConstantPolicy(), last_epoch)
@classmethod
def from_hparams(
cls, optimizer: Optimizer, hparams: Namespace, **kwargs
) -> LambdaLR:
""" Initializes a constant learning rate scheduler. """
return ConstantLR(optimizer)
class WarmupPolicy:
"""Policy for updating the LR of the WarmupConstant scheduler.
With this class LambdaLR objects became picklable.
"""
def __init__(self, warmup_steps):
self.warmup_steps = warmup_steps
def __call__(self, current_step):
if current_step < self.warmup_steps:
return float(current_step) / float(max(1.0, self.warmup_steps))
return 1.0
class WarmupConstant(LambdaLR):
"""
Warmup Linear scheduler.
1) Linearly increases learning rate from 0 to 1 over warmup_steps
training steps.
2) Keeps the learning rate constant afterwards.
:param optimizer: torch.optim.Optimizer
:param warmup_steps: Linearly increases learning rate from 0 to 1 over warmup_steps.
:param last_epoch:
"""
def __init__(
self, optimizer: Optimizer, warmup_steps: int, last_epoch: int = -1
) -> None:
super(WarmupConstant, self).__init__(
optimizer, WarmupPolicy(warmup_steps), last_epoch
)
@classmethod
def from_hparams(
cls, optimizer: Optimizer, hparams: Namespace, **kwargs
) -> LambdaLR:
""" Initializes a constant learning rate scheduler with warmup period. """
return WarmupConstant(optimizer, hparams.warmup_steps)
class LinearWarmupPolicy:
"""Policy for updating the LR of the LinearWarmup scheduler.
With this class LambdaLR objects became picklable.
"""
def __init__(self, warmup_steps, num_training_steps):
self.num_training_steps = num_training_steps
self.warmup_steps = warmup_steps
def __call__(self, current_step):
if current_step < self.warmup_steps:
return float(current_step) / float(max(1, self.warmup_steps))
return max(
0.0,
float(self.num_training_steps - current_step)
/ float(max(1, self.num_training_steps - self.warmup_steps)),
)
class LinearWarmup(LambdaLR):
"""
Create a schedule with a learning rate that decreases linearly after
linearly increasing during a warmup period.
:param optimizer: torch.optim.Optimizer
:param warmup_steps: Linearly increases learning rate from 0 to 1*learning_rate over warmup_steps.
:param num_training_steps: Linearly decreases learning rate from 1*learning_rate to 0. over remaining
t_total - warmup_steps steps.
:param last_epoch:
"""
def __init__(
self,
optimizer: Optimizer,
warmup_steps: int,
num_training_steps: int,
last_epoch: int = -1,
) -> None:
super(LinearWarmup, self).__init__(
optimizer, LinearWarmupPolicy(warmup_steps, num_training_steps), last_epoch
)
@classmethod
def from_hparams(
cls, optimizer: Optimizer, hparams: Namespace, num_training_steps: int
) -> LambdaLR:
""" Initializes a learning rate scheduler with warmup period and decreasing period. """
return LinearWarmup(optimizer, hparams.warmup_steps, num_training_steps)
str2scheduler = {
"linear_warmup": LinearWarmup,
"constant": ConstantLR,
"warmup_constant": WarmupConstant,
}
|