File size: 4,701 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
"""Schedulers."""
import argparse
from espnet.utils.dynamic_import import dynamic_import
from espnet.utils.fill_missing_args import fill_missing_args
class _PrefixParser:
def __init__(self, parser, prefix):
self.parser = parser
self.prefix = prefix
def add_argument(self, name, **kwargs):
assert name.startswith("--")
self.parser.add_argument(self.prefix + name[2:], **kwargs)
class SchedulerInterface:
"""Scheduler interface."""
alias = ""
def __init__(self, key: str, args: argparse.Namespace):
"""Initialize class."""
self.key = key
prefix = key + "_" + self.alias + "_"
for k, v in vars(args).items():
if k.startswith(prefix):
setattr(self, k[len(prefix) :], v)
def get_arg(self, name):
"""Get argument without prefix."""
return getattr(self.args, f"{self.key}_{self.alias}_{name}")
@classmethod
def add_arguments(cls, key: str, parser: argparse.ArgumentParser):
"""Add arguments for CLI."""
group = parser.add_argument_group(f"{cls.alias} scheduler")
cls._add_arguments(_PrefixParser(parser=group, prefix=f"--{key}-{cls.alias}-"))
return parser
@staticmethod
def _add_arguments(parser: _PrefixParser):
pass
@classmethod
def build(cls, key: str, **kwargs):
"""Initialize this class with python-level args.
Args:
key (str): key of hyper parameter
Returns:
LMinterface: A new instance of LMInterface.
"""
def add(parser):
return cls.add_arguments(key, parser)
kwargs = {f"{key}_{cls.alias}_" + k: v for k, v in kwargs.items()}
args = argparse.Namespace(**kwargs)
args = fill_missing_args(args, add)
return cls(key, args)
def scale(self, n_iter: int) -> float:
"""Scale at `n_iter`.
Args:
n_iter (int): number of current iterations.
Returns:
float: current scale of learning rate.
"""
raise NotImplementedError()
SCHEDULER_DICT = {}
def register_scheduler(cls):
"""Register scheduler."""
SCHEDULER_DICT[cls.alias] = cls.__module__ + ":" + cls.__name__
return cls
def dynamic_import_scheduler(module):
"""Import Scheduler class dynamically.
Args:
module (str): module_name:class_name or alias in `SCHEDULER_DICT`
Returns:
type: Scheduler class
"""
model_class = dynamic_import(module, SCHEDULER_DICT)
assert issubclass(
model_class, SchedulerInterface
), f"{module} does not implement SchedulerInterface"
return model_class
@register_scheduler
class NoScheduler(SchedulerInterface):
"""Scheduler which does nothing."""
alias = "none"
def scale(self, n_iter):
"""Scale of lr."""
return 1.0
@register_scheduler
class NoamScheduler(SchedulerInterface):
"""Warmup + InverseSqrt decay scheduler.
Args:
noam_warmup (int): number of warmup iterations.
"""
alias = "noam"
@staticmethod
def _add_arguments(parser: _PrefixParser):
"""Add scheduler args."""
parser.add_argument(
"--warmup", type=int, default=1000, help="Number of warmup iterations."
)
def __init__(self, key, args):
"""Initialize class."""
super().__init__(key, args)
self.normalize = 1 / (self.warmup * self.warmup ** -1.5)
def scale(self, step):
"""Scale of lr."""
step += 1 # because step starts from 0
return self.normalize * min(step ** -0.5, step * self.warmup ** -1.5)
@register_scheduler
class CyclicCosineScheduler(SchedulerInterface):
"""Cyclic cosine annealing.
Args:
cosine_warmup (int): number of warmup iterations.
cosine_total (int): number of total annealing iterations.
Notes:
Proposed in https://openreview.net/pdf?id=BJYwwY9ll
(and https://arxiv.org/pdf/1608.03983.pdf).
Used in the GPT2 config of Megatron-LM https://github.com/NVIDIA/Megatron-LM
"""
alias = "cosine"
@staticmethod
def _add_arguments(parser: _PrefixParser):
"""Add scheduler args."""
parser.add_argument(
"--warmup", type=int, default=1000, help="Number of warmup iterations."
)
parser.add_argument(
"--total",
type=int,
default=100000,
help="Number of total annealing iterations.",
)
def scale(self, n_iter):
"""Scale of lr."""
import math
return 0.5 * (math.cos(math.pi * (n_iter - self.warmup) / self.total) + 1)
|