File size: 8,117 Bytes
e34aada
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import numpy as np
from utils.commons.hparams import hparams


class NoneSchedule(object):
    def __init__(self, optimizer, lr):
        self.optimizer = optimizer
        self.constant_lr = lr
        self.step(0)

    def step(self, num_updates):
        self.lr = self.constant_lr
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.lr
        return self.lr

    def get_lr(self):
        return self.optimizer.param_groups[0]['lr']

    def get_last_lr(self):
        return self.get_lr()


class RSQRTSchedule(NoneSchedule):
    def __init__(self, optimizer, lr, warmup_updates, hidden_size):
        self.optimizer = optimizer
        self.constant_lr = lr
        self.warmup_updates = warmup_updates
        self.hidden_size = hidden_size
        self.lr = lr
        for param_group in optimizer.param_groups:
            param_group['lr'] = self.lr
        self.step(0)

    def step(self, num_updates):
        constant_lr = self.constant_lr
        warmup = min(num_updates / self.warmup_updates, 1.0)
        rsqrt_decay = max(self.warmup_updates, num_updates) ** -0.5
        rsqrt_hidden = self.hidden_size ** -0.5
        self.lr = max(constant_lr * warmup * rsqrt_decay * rsqrt_hidden, 1e-7)
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.lr
        return self.lr


class WarmupSchedule(NoneSchedule):
    def __init__(self, optimizer, lr, warmup_updates):
        self.optimizer = optimizer
        self.constant_lr = self.lr = lr
        self.warmup_updates = warmup_updates
        for param_group in optimizer.param_groups:
            param_group['lr'] = self.lr
        self.step(0)

    def step(self, num_updates):
        constant_lr = self.constant_lr
        warmup = min(num_updates / self.warmup_updates, 1.0)
        self.lr = max(constant_lr * warmup, 1e-7)
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.lr
        return self.lr


class ExponentialSchedule(NoneSchedule):
    def __init__(self, optimizer, lr, warmup_updates):
        self.optimizer = optimizer
        self.constant_lr = self.lr = lr
        self.warmup_updates = warmup_updates
        for param_group in optimizer.param_groups:
            param_group['lr'] = self.lr
        self.step(0)

    def step(self, num_updates):
        constant_lr = self.constant_lr
        if self.warmup_updates > 0 and num_updates <= self.warmup_updates:
            warmup = min(num_updates / self.warmup_updates, 1.0)
            self.lr = max(constant_lr * warmup, 1e-7)
        else:
            new_lrate = constant_lr * (0.1 ** (num_updates / 250_000)) # decay by 0.1x for every 250k steps
            self.lr = max(new_lrate, hparams.get("min_lr", 1e-6))
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.lr
        return self.lr


class ExponentialScheduleWithAudattNet(NoneSchedule):
    """
    Default Scheduler in AD-NeRF
    for audatt net, since it starts at 20_0000 steps, we need to enlarge its lr
    in optimizer, we set param_groups[1] to optimize audatt net
    """
    def __init__(self, optimizer, lr, warmup_updates=0):
        self.optimizer = optimizer
        self.constant_lr = self.lr = lr
        self.warmup_updates = warmup_updates
        optimizer.param_groups[0]['lr'] = self.lr
        optimizer.param_groups[1]['lr'] = self.lr * 5
        self.step(0)

    def step(self, num_updates):
        constant_lr = self.constant_lr
        if self.warmup_updates > 0 and num_updates <= self.warmup_updates:
            warmup = min(num_updates / self.warmup_updates, 1.0)
            self.lr = max(constant_lr * warmup, 1e-7)
        else:
            new_lrate = constant_lr * (0.1 ** (num_updates / 250_000)) # decay by 0.1x for every 250k steps
            self.lr = max(new_lrate, 1e-7)

        self.optimizer.param_groups[0]['lr'] = self.lr
        self.optimizer.param_groups[1]['lr'] = self.lr * 5
        return self.lr

class ExponentialScheduleForRADNeRF(NoneSchedule):
    """
    Default Scheduler in RAD-NeRF
    RAD-NeRF has two groups of params with different lr
    for tileGrid embedding, the lr=5e-3
    for other network params, the lr=5e-4
    """
    def __init__(self, optimizer, lr, warmup_updates=0):
        self.optimizer = optimizer
        self.constant_lr = self.lr = lr # 0.0005
        self.warmup_updates = warmup_updates
        self.finetune_lips = hparams['finetune_lips']
        self.finetune_lips_start_iter = hparams['finetune_lips_start_iter']

        optimizer.param_groups[0]['lr'] = self.lr # for Net_params in RAD-NeRF, lr starts from 0.0005
        optimizer.param_groups[1]['lr'] = self.lr * 10 # for tileGrid, lr starts from 0.005
        optimizer.param_groups[2]['lr'] = self.lr * 5 # for Att Net, lr starts from 0.0025
        self.step(0)

    def step(self, num_updates):
        constant_lr = self.constant_lr
        if self.warmup_updates > 0 and num_updates <= self.warmup_updates:
            warmup = min(num_updates / self.warmup_updates, 1.0)
            self.lr = max(constant_lr * warmup, 1e-5)
        else:
            if self.finetune_lips and num_updates > self.finetune_lips_start_iter:
                new_lrate = constant_lr * (0.1 ** (num_updates / 250_000)) # decay by 0.05x for every 200k steps
            else:
                new_lrate = constant_lr * (0.1 ** (num_updates / 250_000)) # decay by 0.1x for every 200k steps

            self.lr = max(new_lrate, 1e-5)

        self.optimizer.param_groups[0]['lr'] = self.lr
        self.optimizer.param_groups[1]['lr'] = self.lr * 10
        self.optimizer.param_groups[2]['lr'] = self.lr * 5
        return self.lr
    

class ExponentialScheduleForRADNeRFTorso(NoneSchedule):
    """
    Default Scheduler in RAD-NeRF
    RAD-NeRF has two groups of params with different lr
    for tileGrid embedding, the lr=5e-3
    for other network params, the lr=5e-4
    """
    def __init__(self, optimizer, lr, warmup_updates=0):
        self.optimizer = optimizer
        self.constant_lr = self.lr = lr # 0.0005
        self.warmup_updates = warmup_updates

        optimizer.param_groups[0]['lr'] = self.lr # for Net_params in RAD-NeRF, lr starts from 0.0005
        optimizer.param_groups[1]['lr'] = self.lr * 10 # for tileGrid, lr starts from 0.005
        self.step(0)

    def step(self, num_updates):
        constant_lr = self.constant_lr
        if self.warmup_updates > 0 and num_updates <= self.warmup_updates:
            warmup = min(num_updates / self.warmup_updates, 1.0)
            self.lr = max(constant_lr * warmup, 1e-5)
        else:
            new_lrate = constant_lr * (0.1 ** (num_updates / 250_000)) # decay by 0.1x for every 200k steps
            self.lr = max(new_lrate, 1e-5)
        self.optimizer.param_groups[0]['lr'] = self.lr
        self.optimizer.param_groups[1]['lr'] = self.lr * 10
        return self.lr
    

class CosineSchedule(NoneSchedule):
    def __init__(self, optimizer, lr, warmup_updates, total_updates):
        self.optimizer = optimizer
        self.constant_lr = lr
        self.warmup_updates = warmup_updates
        self.total_updates = total_updates
        self.lr = lr
        self.assign_learning_rate(self.optimizer, self.lr)
        self.step(0)

    def assign_learning_rate(self, optimizer, new_lr):
        for param_group in optimizer.param_groups:
            param_group["lr"] = new_lr

    def _warmup_lr(self, base_lr, warmup_length, step):
        return base_lr * (step + 1) / warmup_length

    def step(self, num_updates):
        if self.warmup_updates > 0 and num_updates <= self.warmup_updates:
            lr = self._warmup_lr(self.lr, self.warmup_updates, num_updates)
        elif num_updates <= self.total_updates:
            e = num_updates - self.warmup_updates
            es = self.total_updates - self.warmup_updates
            lr = 0.5 * (1 + np.cos(np.pi * e / es)) * self.lr
        else:
            lr = 1e-5
        lr = max(1e-5, lr)
        self.assign_learning_rate(self.optimizer, lr)
        return lr