|
from abc import ABC |
|
from abc import abstractmethod |
|
from distutils.version import LooseVersion |
|
|
|
import torch |
|
import torch.optim.lr_scheduler as L |
|
|
|
|
|
class AbsScheduler(ABC): |
|
@abstractmethod |
|
def step(self, epoch: int = None): |
|
pass |
|
|
|
@abstractmethod |
|
def state_dict(self): |
|
pass |
|
|
|
@abstractmethod |
|
def load_state_dict(self, state): |
|
pass |
|
|
|
|
|
|
|
class AbsBatchStepScheduler(AbsScheduler): |
|
@abstractmethod |
|
def step(self, epoch: int = None): |
|
pass |
|
|
|
@abstractmethod |
|
def state_dict(self): |
|
pass |
|
|
|
@abstractmethod |
|
def load_state_dict(self, state): |
|
pass |
|
|
|
|
|
class AbsEpochStepScheduler(AbsScheduler): |
|
@abstractmethod |
|
def step(self, epoch: int = None): |
|
pass |
|
|
|
@abstractmethod |
|
def state_dict(self): |
|
pass |
|
|
|
@abstractmethod |
|
def load_state_dict(self, state): |
|
pass |
|
|
|
|
|
class AbsValEpochStepScheduler(AbsEpochStepScheduler): |
|
@abstractmethod |
|
def step(self, val, epoch: int = None): |
|
pass |
|
|
|
@abstractmethod |
|
def state_dict(self): |
|
pass |
|
|
|
@abstractmethod |
|
def load_state_dict(self, state): |
|
pass |
|
|
|
|
|
|
|
|
|
|
|
AbsValEpochStepScheduler.register(L.ReduceLROnPlateau) |
|
for s in [ |
|
L.ReduceLROnPlateau, |
|
L.LambdaLR, |
|
L.StepLR, |
|
L.MultiStepLR, |
|
L.MultiStepLR, |
|
L.ExponentialLR, |
|
L.CosineAnnealingLR, |
|
]: |
|
AbsEpochStepScheduler.register(s) |
|
if LooseVersion(torch.__version__) >= LooseVersion("1.3.0"): |
|
for s in [L.CyclicLR, L.OneCycleLR, L.CosineAnnealingWarmRestarts]: |
|
AbsBatchStepScheduler.register(s) |
|
|