File size: 2,470 Bytes
5381499
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import math
from torch.optim.lr_scheduler import _LRScheduler


class WarmUpScheduler(_LRScheduler):
    def __init__(
        self,
        optimizer,
        warmup_steps: int,
        feature_size: int,
        factor: float = 1.0,
        last_epoch=-1,
    ):
        self.warmup_steps = warmup_steps
        self.feature_size = feature_size
        self.factor = factor
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        lr = self._compute_lr()
        return [lr] * len(self.base_lrs)

    def _compute_lr(self):
        if self.last_epoch == 0:
            return 0.0

        lr = (self.feature_size ** (-0.5)) * min(
            self.last_epoch ** (-0.5), self.last_epoch * self.warmup_steps ** (-1.5)
        )

        return lr * self.factor


class TriStateScheduler(_LRScheduler):
    def __init__(
        self,
        optimizer,
        total_steps: int,
        warmup_steps: int,
        constant_steps: int,
        factor: float = 0.3,
        last_epoch: int = -1,
    ):
        self.warmup_steps = warmup_steps
        self.constant_steps = constant_steps
        self.total_steps = total_steps
        self.factor = factor

        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        if not hasattr(self, "eta_min"):
            self.eta_max = self.base_lrs.copy()
            self.eta_min = [eta_max * self.factor for eta_max in self.eta_max]

        return [
            self._compute_lr(group["lr"], eta_min, eta_max)
            for group, eta_min, eta_max in zip(
                self.optimizer.param_groups, self.eta_min, self.eta_max
            )
        ]

    def _compute_lr(self, prev_lr: float, eta_min: float, eta_max: float):
        # first stage
        if self.last_epoch <= self.warmup_steps:
            lr = eta_max - 0.5 * (eta_max - eta_min) * (
                1 + math.cos(math.pi * self.last_epoch / self.warmup_steps)
            )
        # second stage
        elif self.last_epoch <= self.warmup_steps + self.constant_steps:
            lr = prev_lr
        else:
            # third stage
            decay_steps = self.total_steps - self.warmup_steps - self.constant_steps
            k = self.last_epoch - self.warmup_steps - self.constant_steps
            lr = eta_min + 0.5 * (eta_max - eta_min) * (
                1 + math.cos(math.pi * k / decay_steps)
            )

        return lr

    def state_dict(self) -> dict:
        return super().state_dict()