|
from utils.hparams import hparams |
|
|
|
|
|
class RSQRTSchedule(object): |
|
def __init__(self, optimizer): |
|
super().__init__() |
|
self.optimizer = optimizer |
|
self.constant_lr = hparams['lr'] |
|
self.warmup_updates = hparams['warmup_updates'] |
|
self.hidden_size = hparams['hidden_size'] |
|
self.lr = hparams['lr'] |
|
for param_group in optimizer.param_groups: |
|
param_group['lr'] = self.lr |
|
self.step(0) |
|
|
|
def step(self, num_updates): |
|
constant_lr = self.constant_lr |
|
warmup = min(num_updates / self.warmup_updates, 1.0) |
|
rsqrt_decay = max(self.warmup_updates, num_updates) ** -0.5 |
|
rsqrt_hidden = self.hidden_size ** -0.5 |
|
self.lr = max(constant_lr * warmup * rsqrt_decay * rsqrt_hidden, 1e-7) |
|
for param_group in self.optimizer.param_groups: |
|
param_group['lr'] = self.lr |
|
return self.lr |
|
|
|
def get_lr(self): |
|
return self.optimizer.param_groups[0]['lr'] |
|
|