import math from collections import Counter, defaultdict import torch from torch.optim.lr_scheduler import _LRScheduler class MultiStepLR_Restart(_LRScheduler): def __init__( self, optimizer, milestones, restarts=None, weights=None, gamma=0.1, clear_state=False, last_epoch=-1 ): self.milestones = Counter(milestones) self.gamma = gamma self.clear_state = clear_state self.restarts = restarts if restarts else [0] self.restarts = [v + 1 for v in self.restarts] self.restart_weights = weights if weights else [1] assert len(self.restarts) == len(self.restart_weights), "restarts and their weights do not match." super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch) def get_lr(self): if self.last_epoch in self.restarts: if self.clear_state: self.optimizer.state = defaultdict(dict) weight = self.restart_weights[self.restarts.index(self.last_epoch)] return [group["initial_lr"] * weight for group in self.optimizer.param_groups] if self.last_epoch not in self.milestones: return [group["lr"] for group in self.optimizer.param_groups] return [group["lr"] * self.gamma ** self.milestones[self.last_epoch] for group in self.optimizer.param_groups] class CosineAnnealingLR_Restart(_LRScheduler): def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1): self.T_period = T_period self.T_max = self.T_period[0] # current T period self.eta_min = eta_min self.restarts = restarts if restarts else [0] self.restarts = [v + 1 for v in self.restarts] self.restart_weights = weights if weights else [1] self.last_restart = 0 assert len(self.restarts) == len(self.restart_weights), "restarts and their weights do not match." super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch) def get_lr(self): if self.last_epoch == 0: return self.base_lrs elif self.last_epoch in self.restarts: self.last_restart = self.last_epoch self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1] weight = self.restart_weights[self.restarts.index(self.last_epoch)] return [group["initial_lr"] * weight for group in self.optimizer.param_groups] elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0: return [ group["lr"] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) ] return [ (1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) / (1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) * (group["lr"] - self.eta_min) + self.eta_min for group in self.optimizer.param_groups ] if __name__ == "__main__": optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0, betas=(0.9, 0.99)) ############################## # MultiStepLR_Restart ############################## # Original lr_steps = [200000, 400000, 600000, 800000] restarts = None restart_weights = None # two lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000] restarts = [500000] restart_weights = [1] # four lr_steps = [ 50000, 100000, 150000, 200000, 240000, 300000, 350000, 400000, 450000, 490000, 550000, 600000, 650000, 700000, 740000, 800000, 850000, 900000, 950000, 990000, ] restarts = [250000, 500000, 750000] restart_weights = [1, 1, 1] scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5, clear_state=False) ############################## # Cosine Annealing Restart ############################## # two T_period = [500000, 500000] restarts = [500000] restart_weights = [1] # four T_period = [250000, 250000, 250000, 250000] restarts = [250000, 500000, 750000] restart_weights = [1, 1, 1] scheduler = CosineAnnealingLR_Restart( optimizer, T_period, eta_min=1e-7, restarts=restarts, weights=restart_weights ) ############################## # Draw figure ############################## N_iter = 1000000 lr_l = list(range(N_iter)) for i in range(N_iter): scheduler.step() current_lr = optimizer.param_groups[0]["lr"] lr_l[i] = current_lr import matplotlib as mpl import matplotlib.ticker as mtick from matplotlib import pyplot as plt mpl.style.use("default") import seaborn seaborn.set(style="whitegrid") seaborn.set_context("paper") plt.figure(1) plt.subplot(111) plt.ticklabel_format(style="sci", axis="x", scilimits=(0, 0)) plt.title("Title", fontsize=16, color="k") plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label="learning rate scheme") legend = plt.legend(loc="upper right", shadow=False) ax = plt.gca() labels = ax.get_xticks().tolist() for k, v in enumerate(labels): labels[k] = str(int(v / 1000)) + "K" ax.set_xticklabels(labels) ax.yaxis.set_major_formatter(mtick.FormatStrFormatter("%.1e")) ax.set_ylabel("Learning rate") ax.set_xlabel("Iteration") fig = plt.gcf() plt.show()