AudioSep / optimizers /lr_schedulers.py
badayvedat's picture
Initial commit
ae29df4
from functools import partial
from typing import Callable
def linear_warm_up(
step: int,
warm_up_steps: int,
reduce_lr_steps: int
) -> float:
r"""Get linear warm up scheduler for LambdaLR.
Args:
step (int): global step
warm_up_steps (int): steps for warm up
reduce_lr_steps (int): reduce learning rate by a factor of 0.9 #reduce_lr_steps step
.. code-block: python
>>> lr_lambda = partial(linear_warm_up, warm_up_steps=1000, reduce_lr_steps=10000)
>>> from torch.optim.lr_scheduler import LambdaLR
>>> LambdaLR(optimizer, lr_lambda)
Returns:
lr_scale (float): learning rate scaler
"""
if step <= warm_up_steps:
lr_scale = step / warm_up_steps
else:
lr_scale = 0.9 ** (step // reduce_lr_steps)
return lr_scale
def constant_warm_up(
step: int,
warm_up_steps: int,
reduce_lr_steps: int
) -> float:
r"""Get constant warm up scheduler for LambdaLR.
Args:
step (int): global step
warm_up_steps (int): steps for warm up
reduce_lr_steps (int): reduce learning rate by a factor of 0.9 #reduce_lr_steps step
.. code-block: python
>>> lr_lambda = partial(constant_warm_up, warm_up_steps=1000, reduce_lr_steps=10000)
>>> from torch.optim.lr_scheduler import LambdaLR
>>> LambdaLR(optimizer, lr_lambda)
Returns:
lr_scale (float): learning rate scaler
"""
if 0 <= step < warm_up_steps:
lr_scale = 0.001
elif warm_up_steps <= step < 2 * warm_up_steps:
lr_scale = 0.01
elif 2 * warm_up_steps <= step < 3 * warm_up_steps:
lr_scale = 0.1
else:
lr_scale = 1
return lr_scale
def get_lr_lambda(
lr_lambda_type: str,
**kwargs
) -> Callable:
r"""Get learning scheduler.
Args:
lr_lambda_type (str), e.g., "constant_warm_up" | "linear_warm_up"
Returns:
lr_lambda_func (Callable)
"""
if lr_lambda_type == "constant_warm_up":
lr_lambda_func = partial(
constant_warm_up,
warm_up_steps=kwargs["warm_up_steps"],
reduce_lr_steps=kwargs["reduce_lr_steps"],
)
elif lr_lambda_type == "linear_warm_up":
lr_lambda_func = partial(
linear_warm_up,
warm_up_steps=kwargs["warm_up_steps"],
reduce_lr_steps=kwargs["reduce_lr_steps"],
)
else:
raise NotImplementedError
return lr_lambda_func