""" File copied from https://github.com/nicola-decao/diffmask/blob/master/diffmask/optim/lookahead.py """ import torch import torch.optim as optim from collections import defaultdict from torch import Tensor from torch.optim.optimizer import Optimizer from typing import Iterable, Optional, Union _params_type = Union[Iterable[Tensor], Iterable[dict]] class Lookahead(Optimizer): """Lookahead optimizer: https://arxiv.org/abs/1907.08610""" # noinspection PyMissingConstructor def __init__(self, base_optimizer: Optimizer, alpha: float = 0.5, k: int = 6): if not 0.0 <= alpha <= 1.0: raise ValueError(f"Invalid slow update rate: {alpha}") if not 1 <= k: raise ValueError(f"Invalid lookahead steps: {k}") defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0) self.base_optimizer = base_optimizer self.param_groups = self.base_optimizer.param_groups self.defaults = base_optimizer.defaults self.defaults.update(defaults) self.state = defaultdict(dict) # manually add our defaults to the param groups for name, default in defaults.items(): for group in self.param_groups: group.setdefault(name, default) def update_slow(self, group: dict): for fast_p in group["params"]: if fast_p.grad is None: continue param_state = self.state[fast_p] if "slow_buffer" not in param_state: param_state["slow_buffer"] = torch.empty_like(fast_p.data) param_state["slow_buffer"].copy_(fast_p.data) slow = param_state["slow_buffer"] slow.add_(fast_p.data - slow, alpha=group["lookahead_alpha"]) fast_p.data.copy_(slow) def sync_lookahead(self): for group in self.param_groups: self.update_slow(group) def step(self, closure: Optional[callable] = None) -> Optional[float]: # print(self.k) # assert id(self.param_groups) == id(self.base_optimizer.param_groups) loss = self.base_optimizer.step(closure) for group in self.param_groups: group["lookahead_step"] += 1 if group["lookahead_step"] % group["lookahead_k"] == 0: self.update_slow(group) return loss def state_dict(self) -> dict: fast_state_dict = self.base_optimizer.state_dict() slow_state = { (id(k) if isinstance(k, torch.Tensor) else k): v for k, v in self.state.items() } fast_state = fast_state_dict["state"] param_groups = fast_state_dict["param_groups"] return { "state": fast_state, "slow_state": slow_state, "param_groups": param_groups, } def load_state_dict(self, state_dict: dict): fast_state_dict = { "state": state_dict["state"], "param_groups": state_dict["param_groups"], } self.base_optimizer.load_state_dict(fast_state_dict) # We want to restore the slow state, but share param_groups reference # with base_optimizer. This is a bit redundant but least code slow_state_new = False if "slow_state" not in state_dict: print("Loading state_dict from optimizer without Lookahead applied.") state_dict["slow_state"] = defaultdict(dict) slow_state_new = True slow_state_dict = { "state": state_dict["slow_state"], "param_groups": state_dict[ "param_groups" ], # this is pointless but saves code } super(Lookahead, self).load_state_dict(slow_state_dict) self.param_groups = ( self.base_optimizer.param_groups ) # make both ref same container if slow_state_new: # reapply defaults to catch missing lookahead specific ones for name, default in self.defaults.items(): for group in self.param_groups: group.setdefault(name, default) def LookaheadAdam( params: _params_type, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0, amsgrad: bool = False, lalpha: float = 0.5, k: int = 6, ): return Lookahead( torch.optim.Adam(params, lr, betas, eps, weight_decay, amsgrad), lalpha, k ) def LookaheadRAdam( params: _params_type, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0, lalpha: float = 0.5, k: int = 6, ): return Lookahead(optim.RAdam(params, lr, betas, eps, weight_decay), lalpha, k) def LookaheadRMSprop( params: _params_type, lr: float = 1e-2, alpha: float = 0.99, eps: float = 1e-08, weight_decay: float = 0, momentum: float = 0, centered: bool = False, lalpha: float = 0.5, k: int = 6, ): return Lookahead( torch.optim.RMSprop(params, lr, alpha, eps, weight_decay, momentum, centered), lalpha, k, )