|  | """Module for custom LRScheduler class""" | 
					
						
						|  |  | 
					
						
						|  | from torch.optim.lr_scheduler import LRScheduler | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class InterpolatingLogScheduler(LRScheduler): | 
					
						
						|  | """ | 
					
						
						|  | A scheduler that interpolates learning rates in a logarithmic fashion | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, optimizer, num_steps, min_lr, max_lr, last_epoch=-1): | 
					
						
						|  | """A scheduler that interpolates learning rates in a logarithmic fashion | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | - optimizer: pytorch optimizer | 
					
						
						|  | - num_steps: int, the number of steps over which to increase from the min_lr to the max_lr | 
					
						
						|  | - min_lr: float, the minimum learning rate | 
					
						
						|  | - max_lr: float, the maximum learning rate | 
					
						
						|  |  | 
					
						
						|  | Usage: | 
					
						
						|  | fc = nn.Linear(1,1) | 
					
						
						|  | optimizer = optim.Adam(fc.parameters()) | 
					
						
						|  | lr_scheduler = InterpolatingLogScheduler(optimizer, num_steps=400, min_lr=1e-6, max_lr=1e-4) | 
					
						
						|  | """ | 
					
						
						|  | self.num_steps = num_steps | 
					
						
						|  | self.min_lr = min_lr | 
					
						
						|  | self.max_lr = max_lr | 
					
						
						|  | self.q = (max_lr / min_lr) ** ( | 
					
						
						|  | 1 / (num_steps - 1) | 
					
						
						|  | ) | 
					
						
						|  | super().__init__(optimizer, last_epoch) | 
					
						
						|  |  | 
					
						
						|  | def get_lr(self): | 
					
						
						|  | if self.last_epoch <= 0: | 
					
						
						|  | lrs = [self.min_lr for base_lr in self.base_lrs] | 
					
						
						|  | elif self.last_epoch < self.num_steps: | 
					
						
						|  | lrs = [ | 
					
						
						|  | self.min_lr * (self.q ** (self.last_epoch - 1)) | 
					
						
						|  | for base_lr in self.base_lrs | 
					
						
						|  | ] | 
					
						
						|  | else: | 
					
						
						|  | lrs = [self.max_lr for base_lr in self.base_lrs] | 
					
						
						|  |  | 
					
						
						|  | return lrs | 
					
						
						|  |  |