conex / espnet /scheduler /pytorch.py
tobiasc's picture
Initial commit
ad16788
raw
history blame contribute delete
801 Bytes
"""PyTorch optimizer schdulers."""
from typing import List
from torch.optim import Optimizer
from espnet.scheduler.scheduler import SchedulerInterface
class PyTorchScheduler:
"""PyTorch optimizer scheduler."""
def __init__(self, schedulers: List[SchedulerInterface], optimizer: Optimizer):
"""Initialize class."""
self.schedulers = schedulers
self.optimizer = optimizer
for s in self.schedulers:
for group in optimizer.param_groups:
group.setdefault("initial_" + s.key, group[s.key])
def step(self, n_iter: int):
"""Update optimizer by scheduling."""
for s in self.schedulers:
for group in self.optimizer.param_groups:
group[s.key] = group["initial_" + s.key] * s.scale(n_iter)