File size: 1,762 Bytes
2f044c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
from torch.optim.lr_scheduler import LRScheduler


class LinearSchedulerWithWarmup(LRScheduler):
    def __init__(
        self,
        optimizer: torch.optim.Optimizer,
        num_warmup_steps: int,
        num_training_steps: int,
        last_epoch: int = -1,
        verbose: bool = False,
        **kwargs,
    ):
        self.num_warmup_steps = num_warmup_steps
        self.num_training_steps = num_training_steps
        super().__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        def scheduler_fn(current_step):
            if current_step < self.num_warmup_steps:
                return current_step / max(1, self.num_warmup_steps)
            return max(
                0.0,
                float(self.num_training_steps - current_step)
                / float(max(1, self.num_training_steps - self.num_warmup_steps)),
            )

        return [base_lr * scheduler_fn(self.last_epoch) for base_lr in self.base_lrs]


class LinearScheduler(LRScheduler):
    def __init__(
        self,
        optimizer: torch.optim.Optimizer,
        num_training_steps: int,
        last_epoch: int = -1,
        verbose: bool = False,
        **kwargs,
    ):
        self.num_training_steps = num_training_steps
        super().__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        def scheduler_fn(current_step):
            # if current_step < self.num_warmup_steps:
            #     return current_step / max(1, self.num_warmup_steps)
            return max(
                0.0,
                float(self.num_training_steps - current_step)
                / float(max(1, self.num_training_steps)),
            )

        return [base_lr * scheduler_fn(self.last_epoch) for base_lr in self.base_lrs]