Spaces:
Runtime error
Runtime error
File size: 5,052 Bytes
2b7bf83 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import torch.optim as optim
from collections import Counter
class WarmupScheduler(optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, warmup_epochs, initial_lr, max_lr, milestones, gamma=0.1, last_epoch=-1):
assert warmup_epochs < milestones[0]
self.warmup_epochs = warmup_epochs
self.milestones = Counter(milestones)
self.gamma = gamma
initial_lrs = self._format_param("initial_lr", optimizer, initial_lr)
max_lrs = self._format_param("max_lr", optimizer, max_lr)
if last_epoch == -1:
for idx, group in enumerate(optimizer.param_groups):
group["initial_lr"] = initial_lrs[idx]
group["max_lr"] = max_lrs[idx]
super(WarmupScheduler, self).__init__(optimizer, last_epoch)
def get_lr(self):
# if not self._get_lr_called_within_step:
# warnings.warn("To get the last learning rate computed by the scheduler, "
# "please use `get_last_lr()`.", DeprecationWarning)
if self.last_epoch <= self.warmup_epochs:
pct = self.last_epoch / self.warmup_epochs
return [
(group["max_lr"] - group["initial_lr"]) * pct + group["initial_lr"]
for group in self.optimizer.param_groups]
else:
if self.last_epoch not in self.milestones:
return [group['lr'] for group in self.optimizer.param_groups]
return [group['lr'] * self.gamma ** self.milestones[self.last_epoch]
for group in self.optimizer.param_groups]
@staticmethod
def _format_param(name, optimizer, param):
"""Return correctly formatted lr/momentum for each param group."""
if isinstance(param, (list, tuple)):
if len(param) != len(optimizer.param_groups):
raise ValueError("expected {} values for {}, got {}".format(
len(optimizer.param_groups), name, len(param)))
return param
else:
return [param] * len(optimizer.param_groups)
class WarmupScheduler_noUseMilestones(optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, warmup_epochs, initial_lr, max_lr, milestones, gamma=0.1, last_epoch=-1):
assert warmup_epochs < milestones[0]
self.warmup_epochs = warmup_epochs
self.milestones = Counter(milestones)
self.gamma = gamma
initial_lrs = self._format_param("initial_lr", optimizer, initial_lr)
max_lrs = self._format_param("max_lr", optimizer, max_lr)
if last_epoch == -1:
for idx, group in enumerate(optimizer.param_groups):
group["initial_lr"] = initial_lrs[idx]
group["max_lr"] = max_lrs[idx]
super(WarmupScheduler_noUseMilestones, self).__init__(optimizer, last_epoch)
def get_lr(self):
# if not self._get_lr_called_within_step:
# warnings.warn("To get the last learning rate computed by the scheduler, "
# "please use `get_last_lr()`.", DeprecationWarning)
if self.last_epoch <= self.warmup_epochs:
pct = self.last_epoch / self.warmup_epochs
return [
(group["max_lr"] - group["initial_lr"]) * pct + group["initial_lr"]
for group in self.optimizer.param_groups]
else:
# if self.last_epoch not in self.milestones:
return [group['lr'] for group in self.optimizer.param_groups]
# return [group['lr'] * self.gamma ** self.milestones[self.last_epoch]
# for group in self.optimizer.param_groups]
@staticmethod
def _format_param(name, optimizer, param):
"""Return correctly formatted lr/momentum for each param group."""
if isinstance(param, (list, tuple)):
if len(param) != len(optimizer.param_groups):
raise ValueError("expected {} values for {}, got {}".format(
len(optimizer.param_groups), name, len(param)))
return param
else:
return [param] * len(optimizer.param_groups)
if __name__ == '__main__':
import torch
model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))]
optimizer = optim.SGD(model, 0.1)
scheduler = WarmupScheduler(optimizer, 5, 0.05, 0.1, [6, 14], 0.5)
for epoch in range(1, 12):
optimizer.zero_grad()
print(epoch, optimizer.param_groups[0]['lr'])
optimizer.step()
scheduler.step()
checkpoint_dict = {
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict()
}
optimizer = optim.SGD(model, 0.1)
scheduler = WarmupScheduler(optimizer, 5, 0.05, 0.1, [6, 14], 0.5)
optimizer.load_state_dict(checkpoint_dict["optimizer"])
scheduler.load_state_dict(checkpoint_dict["scheduler"])
for epoch in range(12, 20):
optimizer.zero_grad()
print(epoch, optimizer.param_groups[0]['lr'])
optimizer.step()
scheduler.step() |