# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. """Learning rate policy.""" import math def get_lr_at_epoch(cfg, cur_epoch): """ Retrieve the learning rate of the current epoch with the option to perform warm up in the beginning of the training stage. Args: cfg (CfgNode): configs. Details can be found in slowfast/config/defaults.py cur_epoch (float): the number of epoch of the current training stage. """ lr = get_lr_func(cfg.SOLVER.LR_POLICY)(cfg, cur_epoch) # Perform warm up. if cur_epoch < cfg.SOLVER.WARMUP_EPOCHS: lr_start = cfg.SOLVER.WARMUP_START_LR lr_end = get_lr_func(cfg.SOLVER.LR_POLICY)( cfg, cfg.SOLVER.WARMUP_EPOCHS ) alpha = (lr_end - lr_start) / cfg.SOLVER.WARMUP_EPOCHS lr = cur_epoch * alpha + lr_start return lr def lr_func_cosine(cfg, cur_epoch): """ Retrieve the learning rate to specified values at specified epoch with the cosine learning rate schedule. Details can be found in: Ilya Loshchilov, and Frank Hutter SGDR: Stochastic Gradient Descent With Warm Restarts. Args: cfg (CfgNode): configs. Details can be found in slowfast/config/defaults.py cur_epoch (float): the number of epoch of the current training stage. """ assert cfg.SOLVER.COSINE_END_LR < cfg.SOLVER.BASE_LR return ( cfg.SOLVER.COSINE_END_LR + (cfg.SOLVER.BASE_LR - cfg.SOLVER.COSINE_END_LR) * (math.cos(math.pi * cur_epoch / cfg.SOLVER.MAX_EPOCH) + 1.0) * 0.5 ) def lr_func_steps_with_relative_lrs(cfg, cur_epoch): """ Retrieve the learning rate to specified values at specified epoch with the steps with relative learning rate schedule. Args: cfg (CfgNode): configs. Details can be found in slowfast/config/defaults.py cur_epoch (float): the number of epoch of the current training stage. """ ind = get_step_index(cfg, cur_epoch) return cfg.SOLVER.LRS[ind] * cfg.SOLVER.BASE_LR def get_step_index(cfg, cur_epoch): """ Retrieves the lr step index for the given epoch. Args: cfg (CfgNode): configs. Details can be found in slowfast/config/defaults.py cur_epoch (float): the number of epoch of the current training stage. """ steps = cfg.SOLVER.STEPS + [cfg.SOLVER.MAX_EPOCH] for ind, step in enumerate(steps): # NoQA if cur_epoch < step: break return ind - 1 def get_lr_func(lr_policy): """ Given the configs, retrieve the specified lr policy function. Args: lr_policy (string): the learning rate policy to use for the job. """ policy = "lr_func_" + lr_policy if policy not in globals(): raise NotImplementedError("Unknown LR policy: {}".format(lr_policy)) else: return globals()[policy]