# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch from torch.optim import Optimizer from typing import List, Optional, Tuple, Union def calc_lr(step, dim_embed, warmup_steps): return dim_embed ** (-0.5) * min(step ** (-0.5), step * warmup_steps ** (-1.5)) # The function is modified from # https://github.com/lifeiteng/vall-e/blob/9c69096d603ce13174fb5cb025f185e2e9b36ac7/valle/modules/scheduler.py class NoamScheduler(torch.optim.lr_scheduler._LRScheduler): def __init__( self, base_lr: float, optimizer: torch.optim.Optimizer, dim_embed: int, warmup_steps: int, last_epoch: int = -1, verbose: bool = False, ) -> None: self.dim_embed = dim_embed self.base_lr = base_lr self.warmup_steps = warmup_steps self.num_param_groups = len(optimizer.param_groups) super().__init__(optimizer, last_epoch, verbose) def get_lr(self) -> float: lr = self.base_lr * calc_lr(self._step_count, self.dim_embed, self.warmup_steps) return [lr] * self.num_param_groups def set_step(self, step: int): self._step_count = step class LRScheduler(object): """ Base-class for learning rate schedulers where the learning-rate depends on both the batch and the epoch. """ def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: group.setdefault("base_lr", group["lr"]) self.base_lrs = [group["base_lr"] for group in optimizer.param_groups] self.epoch = 0 self.batch = 0 def state_dict(self): """Returns the state of the scheduler as a :class:`dict`. It contains an entry for every variable in self.__dict__ which is not the optimizer. """ return { "base_lrs": self.base_lrs, "epoch": self.epoch, "batch": self.batch, } def load_state_dict(self, state_dict): """Loads the schedulers state. Args: state_dict (dict): scheduler state. Should be an object returned from a call to :meth:`state_dict`. """ self.__dict__.update(state_dict) def get_last_lr(self) -> List[float]: """Return last computed learning rate by current scheduler. Will be a list of float.""" return self._last_lr def get_lr(self): # Compute list of learning rates from self.epoch and self.batch and # self.base_lrs; this must be overloaded by the user. # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ] raise NotImplementedError def step_batch(self, batch: Optional[int] = None) -> None: # Step the batch index, or just set it. If `batch` is specified, it # must be the batch index from the start of training, i.e. summed over # all epochs. # You can call this in any order; if you don't provide 'batch', it should # of course be called once per batch. if batch is not None: self.batch = batch else: self.batch = self.batch + 1 self._set_lrs() def step_epoch(self, epoch: Optional[int] = None): # Step the epoch index, or just set it. If you provide the 'epoch' arg, # you should call this at the start of the epoch; if you don't provide the 'epoch' # arg, you should call it at the end of the epoch. if epoch is not None: self.epoch = epoch else: self.epoch = self.epoch + 1 self._set_lrs() def _set_lrs(self): values = self.get_lr() assert len(values) == len(self.optimizer.param_groups) for i, data in enumerate(zip(self.optimizer.param_groups, values)): param_group, lr = data param_group["lr"] = lr self._last_lr = [group["lr"] for group in self.optimizer.param_groups] class Eden(LRScheduler): """ Eden scheduler. The basic formula (before warmup) is: lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 * (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches and then stays constant at 1. E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam Args: optimizer: the optimizer to change the learning rates on lr_batches: the number of batches after which we start significantly decreasing the learning rate, suggest 5000. lr_epochs: the number of epochs after which we start significantly decreasing the learning rate, suggest 6 if you plan to do e.g. 20 to 40 epochs, but may need smaller number if dataset is huge and you will do few epochs. """ def __init__( self, optimizer: Optimizer, lr_batches: Union[int, float], lr_epochs: Union[int, float], warmup_batches: Union[int, float] = 500.0, verbose: bool = False, ): super(Eden, self).__init__(optimizer, verbose) self.lr_batches = lr_batches self.lr_epochs = lr_epochs self.warmup_batches = warmup_batches def get_lr(self): factor = ( (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 ) ** -0.25 * ( ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 ) warmup_factor = ( 1.0 if self.batch >= self.warmup_batches else 0.5 + 0.5 * (self.batch / self.warmup_batches) ) return [x * factor * warmup_factor for x in self.base_lrs]