|  | """Module for custom LRScheduler class""" | 
					
						
						|  | import math | 
					
						
						|  | from functools import partial | 
					
						
						|  |  | 
					
						
						|  | from torch.optim import Optimizer | 
					
						
						|  | from torch.optim.lr_scheduler import LambdaLR, LRScheduler | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class InterpolatingLogScheduler(LRScheduler): | 
					
						
						|  | """ | 
					
						
						|  | A scheduler that interpolates learning rates in a logarithmic fashion | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, optimizer, num_steps, min_lr, max_lr, last_epoch=-1): | 
					
						
						|  | """A scheduler that interpolates learning rates in a logarithmic fashion | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | - optimizer: pytorch optimizer | 
					
						
						|  | - num_steps: int, the number of steps over which to increase from the min_lr to the max_lr | 
					
						
						|  | - min_lr: float, the minimum learning rate | 
					
						
						|  | - max_lr: float, the maximum learning rate | 
					
						
						|  |  | 
					
						
						|  | Usage: | 
					
						
						|  | fc = nn.Linear(1,1) | 
					
						
						|  | optimizer = optim.Adam(fc.parameters()) | 
					
						
						|  | lr_scheduler = InterpolatingLogScheduler(optimizer, num_steps=400, min_lr=1e-6, max_lr=1e-4) | 
					
						
						|  | """ | 
					
						
						|  | self.num_steps = num_steps | 
					
						
						|  | self.min_lr = min_lr | 
					
						
						|  | self.max_lr = max_lr | 
					
						
						|  | self.q = (max_lr / min_lr) ** ( | 
					
						
						|  | 1 / (num_steps - 1) | 
					
						
						|  | ) | 
					
						
						|  | super().__init__(optimizer, last_epoch) | 
					
						
						|  |  | 
					
						
						|  | def get_lr(self): | 
					
						
						|  | if self.last_epoch <= 0: | 
					
						
						|  | lrs = [self.min_lr for base_lr in self.base_lrs] | 
					
						
						|  | elif self.last_epoch < self.num_steps: | 
					
						
						|  | lrs = [ | 
					
						
						|  | self.min_lr * (self.q ** (self.last_epoch - 1)) | 
					
						
						|  | for base_lr in self.base_lrs | 
					
						
						|  | ] | 
					
						
						|  | else: | 
					
						
						|  | lrs = [self.max_lr for base_lr in self.base_lrs] | 
					
						
						|  |  | 
					
						
						|  | return lrs | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _get_cosine_schedule_with_quadratic_warmup_lr_lambda( | 
					
						
						|  | current_step: int, | 
					
						
						|  | *, | 
					
						
						|  | num_warmup_steps: int, | 
					
						
						|  | num_training_steps: int, | 
					
						
						|  | num_cycles: float, | 
					
						
						|  | ): | 
					
						
						|  | if current_step < num_warmup_steps: | 
					
						
						|  | return (float(current_step) / float(max(1, num_warmup_steps))) ** 2 | 
					
						
						|  | progress = float(current_step - num_warmup_steps) / float( | 
					
						
						|  | max(1, num_training_steps - num_warmup_steps) | 
					
						
						|  | ) | 
					
						
						|  | return max( | 
					
						
						|  | 0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_cosine_schedule_with_quadratic_warmup( | 
					
						
						|  | optimizer: Optimizer, | 
					
						
						|  | num_warmup_steps: int, | 
					
						
						|  | num_training_steps: int, | 
					
						
						|  | num_cycles: float = 0.5, | 
					
						
						|  | last_epoch: int = -1, | 
					
						
						|  | ): | 
					
						
						|  | """ | 
					
						
						|  | Create a schedule with a learning rate that decreases following the values of the cosine function between the | 
					
						
						|  | initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the | 
					
						
						|  | initial lr set in the optimizer. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | optimizer ([`~torch.optim.Optimizer`]): | 
					
						
						|  | The optimizer for which to schedule the learning rate. | 
					
						
						|  | num_warmup_steps (`int`): | 
					
						
						|  | The number of steps for the warmup phase. | 
					
						
						|  | num_training_steps (`int`): | 
					
						
						|  | The total number of training steps. | 
					
						
						|  | num_cycles (`float`, *optional*, defaults to 0.5): | 
					
						
						|  | The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 | 
					
						
						|  | following a half-cosine). | 
					
						
						|  | last_epoch (`int`, *optional*, defaults to -1): | 
					
						
						|  | The index of the last epoch when resuming training. | 
					
						
						|  |  | 
					
						
						|  | Return: | 
					
						
						|  | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | lr_lambda = partial( | 
					
						
						|  | _get_cosine_schedule_with_quadratic_warmup_lr_lambda, | 
					
						
						|  | num_warmup_steps=num_warmup_steps, | 
					
						
						|  | num_training_steps=num_training_steps, | 
					
						
						|  | num_cycles=num_cycles, | 
					
						
						|  | ) | 
					
						
						|  | return LambdaLR(optimizer, lr_lambda, last_epoch) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _get_cosine_schedule_with_min_lr_lambda( | 
					
						
						|  | current_step: int, | 
					
						
						|  | *, | 
					
						
						|  | num_warmup_steps: int, | 
					
						
						|  | num_training_steps: int, | 
					
						
						|  | min_lr_ratio: float, | 
					
						
						|  | ): | 
					
						
						|  |  | 
					
						
						|  | if current_step < num_warmup_steps: | 
					
						
						|  | return float(current_step) / float(max(1, num_warmup_steps)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | progress = float(current_step - num_warmup_steps) / float( | 
					
						
						|  | max(1, num_training_steps - num_warmup_steps) | 
					
						
						|  | ) | 
					
						
						|  | scaling = 0.5 * (1.0 + math.cos(math.pi * progress)) | 
					
						
						|  | return (1 - min_lr_ratio) * scaling + min_lr_ratio | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_cosine_schedule_with_min_lr( | 
					
						
						|  | optimizer: Optimizer, | 
					
						
						|  | num_warmup_steps: int, | 
					
						
						|  | num_training_steps: int, | 
					
						
						|  | min_lr_ratio: float = 0.0, | 
					
						
						|  | ): | 
					
						
						|  | """ | 
					
						
						|  | Create a learning rate schedule which has: | 
					
						
						|  | - linear warmup from 0 -> `max_lr` over `num_warmup_steps` | 
					
						
						|  | - cosine learning rate annealing from `max_lr` -> `min_lr` over `num_training_steps` | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | lr_lambda = partial( | 
					
						
						|  | _get_cosine_schedule_with_min_lr_lambda, | 
					
						
						|  | num_warmup_steps=num_warmup_steps, | 
					
						
						|  | num_training_steps=num_training_steps, | 
					
						
						|  | min_lr_ratio=min_lr_ratio, | 
					
						
						|  | ) | 
					
						
						|  | return LambdaLR(optimizer, lr_lambda) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _get_cosine_schedule_with_warmup_decay_constant_lr_lambda( | 
					
						
						|  | current_step: int, | 
					
						
						|  | *, | 
					
						
						|  | num_warmup_steps: int, | 
					
						
						|  | num_training_steps: int, | 
					
						
						|  | constant_lr_ratio: float, | 
					
						
						|  | min_lr_ratio: float, | 
					
						
						|  | num_cycles: float, | 
					
						
						|  | ): | 
					
						
						|  | if current_step < num_warmup_steps: | 
					
						
						|  | return float(current_step) / float(max(1, num_warmup_steps)) | 
					
						
						|  |  | 
					
						
						|  | num_constant_steps = int(num_training_steps * constant_lr_ratio) | 
					
						
						|  | current_step = min(current_step, num_constant_steps) | 
					
						
						|  |  | 
					
						
						|  | progress = float(current_step - num_warmup_steps) / float( | 
					
						
						|  | max(1, num_constant_steps - num_warmup_steps) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | return ( | 
					
						
						|  | max( | 
					
						
						|  | 0, | 
					
						
						|  | (1 - min_lr_ratio) | 
					
						
						|  | * 0.5 | 
					
						
						|  | * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)), | 
					
						
						|  | ) | 
					
						
						|  | + min_lr_ratio | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_cosine_schedule_with_warmup_decay_constant( | 
					
						
						|  | optimizer: Optimizer, | 
					
						
						|  | num_warmup_steps: int, | 
					
						
						|  | num_training_steps: int, | 
					
						
						|  | constant_lr_ratio: float, | 
					
						
						|  | min_lr_ratio: float, | 
					
						
						|  | num_cycles: float = 0.5, | 
					
						
						|  | last_epoch: int = -1, | 
					
						
						|  | ): | 
					
						
						|  | """ | 
					
						
						|  | Implementation of Continual Pre-Training of Large Language Models: How to (re)warm your model? (https://arxiv.org/pdf/2308.04014.pdf) | 
					
						
						|  | Create a schedule with a learning rate that decreases following the values of the cosine function between the | 
					
						
						|  | initial lr set in the optimizer to min_lr_ratio until num_training_steps * constant_lr_ratio, after constant_rate returns constant value of min_rate | 
					
						
						|  | , after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | optimizer ([`~torch.optim.Optimizer`]): | 
					
						
						|  | The optimizer for which to schedule the learning rate. | 
					
						
						|  | num_warmup_steps (`int`): | 
					
						
						|  | The number of steps for the warmup phase. | 
					
						
						|  | num_training_steps (`int`): | 
					
						
						|  | The total number of training steps. | 
					
						
						|  | constant_lr_ratio: (`float`): | 
					
						
						|  | The ratio of num_training_steps to decrease by cosine function. | 
					
						
						|  | min_lr_ratio: (`float): | 
					
						
						|  | The ratio of maximum learning rate for cosine function to decay to minimum learning rate. | 
					
						
						|  | num_cycles (`float`, *optional*, defaults to 0.5): | 
					
						
						|  | The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 | 
					
						
						|  | following a half-cosine). | 
					
						
						|  | last_epoch (`int`, *optional*, defaults to -1): | 
					
						
						|  | The index of the last epoch when resuming training. | 
					
						
						|  |  | 
					
						
						|  | Return: | 
					
						
						|  | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | lr_lambda = partial( | 
					
						
						|  | _get_cosine_schedule_with_warmup_decay_constant_lr_lambda, | 
					
						
						|  | num_warmup_steps=num_warmup_steps, | 
					
						
						|  | num_training_steps=num_training_steps, | 
					
						
						|  | constant_lr_ratio=constant_lr_ratio, | 
					
						
						|  | min_lr_ratio=min_lr_ratio, | 
					
						
						|  | num_cycles=num_cycles, | 
					
						
						|  | ) | 
					
						
						|  | return LambdaLR(optimizer, lr_lambda, last_epoch) | 
					
						
						|  |  |