from functools import partial from typing import Callable def linear_warm_up( step: int, warm_up_steps: int, reduce_lr_steps: int ) -> float: r"""Get linear warm up scheduler for LambdaLR. Args: step (int): global step warm_up_steps (int): steps for warm up reduce_lr_steps (int): reduce learning rate by a factor of 0.9 #reduce_lr_steps step .. code-block: python >>> lr_lambda = partial(linear_warm_up, warm_up_steps=1000, reduce_lr_steps=10000) >>> from torch.optim.lr_scheduler import LambdaLR >>> LambdaLR(optimizer, lr_lambda) Returns: lr_scale (float): learning rate scaler """ if step <= warm_up_steps: lr_scale = step / warm_up_steps else: lr_scale = 0.9 ** (step // reduce_lr_steps) return lr_scale def constant_warm_up( step: int, warm_up_steps: int, reduce_lr_steps: int ) -> float: r"""Get constant warm up scheduler for LambdaLR. Args: step (int): global step warm_up_steps (int): steps for warm up reduce_lr_steps (int): reduce learning rate by a factor of 0.9 #reduce_lr_steps step .. code-block: python >>> lr_lambda = partial(constant_warm_up, warm_up_steps=1000, reduce_lr_steps=10000) >>> from torch.optim.lr_scheduler import LambdaLR >>> LambdaLR(optimizer, lr_lambda) Returns: lr_scale (float): learning rate scaler """ if 0 <= step < warm_up_steps: lr_scale = 0.001 elif warm_up_steps <= step < 2 * warm_up_steps: lr_scale = 0.01 elif 2 * warm_up_steps <= step < 3 * warm_up_steps: lr_scale = 0.1 else: lr_scale = 1 return lr_scale def get_lr_lambda( lr_lambda_type: str, **kwargs ) -> Callable: r"""Get learning scheduler. Args: lr_lambda_type (str), e.g., "constant_warm_up" | "linear_warm_up" Returns: lr_lambda_func (Callable) """ if lr_lambda_type == "constant_warm_up": lr_lambda_func = partial( constant_warm_up, warm_up_steps=kwargs["warm_up_steps"], reduce_lr_steps=kwargs["reduce_lr_steps"], ) elif lr_lambda_type == "linear_warm_up": lr_lambda_func = partial( linear_warm_up, warm_up_steps=kwargs["warm_up_steps"], reduce_lr_steps=kwargs["reduce_lr_steps"], ) else: raise NotImplementedError return lr_lambda_func