File size: 6,177 Bytes
a476bbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import math

from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR


class ConstantLRScheduler(_LRScheduler):
    def __init__(
        self,
        optimizer,
        last_epoch: int = -1,
        verbose: bool = False,
        init_lr: float = 0.0,
    ):
        """
        This is an implementation of constant learning rate scheduler.
        Args:
            optimizer: Optimizer

            last_epoch: The index of last epoch. Default: -1

            verbose: If ``True``, prints a message to stdout for each update. Default: ``False``

            init_lr: Initial learning rate
        """

        self.init_lr = init_lr
        super().__init__(optimizer, last_epoch, verbose)

    def state_dict(self):
        state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]}
        return state_dict

    def load_state_dict(self, state_dict):
        self.__dict__.update(state_dict)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            raise RuntimeError(
                "To get the last learning rate computed by the scheduler, use "
                "get_last_lr()"
            )

        return [self.init_lr for group in self.optimizer.param_groups]


class CosineAnnealingLRScheduler(_LRScheduler):
    def __init__(
        self,
        optimizer,
        last_epoch: int = -1,
        verbose: bool = False,
        init_lr: float = 4.0e-5,
        max_lr: float = 4e-4,
        final_lr: float = 4e-5,
        warmup_steps: int = 2000,
        cosine_steps: int = 10000,
    ):
        """
        This is an implementation of cosine annealing learning rate scheduler.
        Args:
            optimizer: Optimizer

            last_epoch: The index of last epoch. Default: -1

            verbose: If ``True``, prints a message to stdout for each update. Default: ``False``

            init_lr: Initial learning rate

            max_lr: Maximum learning rate after warmup

            final_lr: Final learning rate after decay

            warmup_steps: Number of steps for warmup

            cosine_steps: Number of steps for cosine annealing
        """

        self.init_lr = init_lr
        self.max_lr = max_lr
        self.final_lr = final_lr
        self.warmup_steps = warmup_steps
        self.cosine_steps = cosine_steps
        super(CosineAnnealingLRScheduler, self).__init__(optimizer, last_epoch, verbose)

    def state_dict(self):
        state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]}
        return state_dict

    def load_state_dict(self, state_dict):
        self.__dict__.update(state_dict)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            raise RuntimeError(
                "To get the last learning rate computed by the scheduler, use "
                "get_last_lr()"
            )

        step_no = self.last_epoch

        if step_no <= self.warmup_steps:
            lr = self.init_lr + step_no / self.warmup_steps * (
                self.max_lr - self.init_lr
            )

        else:
            lr = self.final_lr + 0.5 * (self.max_lr - self.final_lr) * (
                1
                + math.cos(math.pi * (step_no - self.warmup_steps) / self.cosine_steps)
            )

        return [lr for group in self.optimizer.param_groups]


class Esm2LRScheduler(_LRScheduler):
    def __init__(
        self,
        optimizer,
        last_epoch: int = -1,
        verbose: bool = False,
        init_lr: float = 4e-5,
        max_lr: float = 4e-4,
        final_lr: float = 4e-5,
        warmup_steps: int = 2000,
        start_decay_after_n_steps: int = 500000,
        end_decay_after_n_steps: int = 5000000,
        on_use: bool = True,
    ):
        """
        An implementation of ESM2's learning rate scheduler.
        Args:
            optimizer: Optimizer

            last_epoch: The index of last epoch. Default: -1

            verbose: If ``True``, prints a message to stdout for each update. Default: ``False``

            init_lr: Initial learning rate

            max_lr: Maximum learning rate after warmup

            final_lr: Final learning rate after decay

            warmup_steps: Number of steps for warmup

            start_decay_after_n_steps: Start decay after this number of steps

            end_decay_after_n_steps: End decay after this number of steps

            on_use: Whether to use this scheduler. If ``False``, the scheduler will not change the learning rate
            and will only use the ``init_lr``. Default: ``True``
        """

        self.init_lr = init_lr
        self.max_lr = max_lr
        self.final_lr = final_lr
        self.warmup_steps = warmup_steps
        self.start_decay_after_n_steps = start_decay_after_n_steps
        self.end_decay_after_n_steps = end_decay_after_n_steps
        self.on_use = on_use
        super(Esm2LRScheduler, self).__init__(optimizer, last_epoch, verbose)

    def state_dict(self):
        state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]}
        return state_dict

    def load_state_dict(self, state_dict):
        self.__dict__.update(state_dict)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            raise RuntimeError(
                "To get the last learning rate computed by the scheduler, use "
                "get_last_lr()"
            )

        step_no = self.last_epoch
        if not self.on_use:
            return [base_lr for base_lr in self.base_lrs]

        if step_no <= self.warmup_steps:
            lr = self.init_lr + step_no / self.warmup_steps * (
                self.max_lr - self.init_lr
            )

        elif step_no <= self.start_decay_after_n_steps:
            lr = self.max_lr

        elif step_no <= self.end_decay_after_n_steps:
            portion = (step_no - self.start_decay_after_n_steps) / (
                self.end_decay_after_n_steps - self.start_decay_after_n_steps
            )
            lr = self.max_lr - portion * (self.max_lr - self.final_lr)

        else:
            lr = self.final_lr

        return [lr for group in self.optimizer.param_groups]