vision-diffmask / code /utils /optimizer.py
din0s's picture
Add code
d4ab5ac unverified
raw
history blame
No virus
5.14 kB
"""
File copied from
https://github.com/nicola-decao/diffmask/blob/master/diffmask/optim/lookahead.py
"""
import torch
import torch.optim as optim
from collections import defaultdict
from torch import Tensor
from torch.optim.optimizer import Optimizer
from typing import Iterable, Optional, Union
_params_type = Union[Iterable[Tensor], Iterable[dict]]
class Lookahead(Optimizer):
"""Lookahead optimizer: https://arxiv.org/abs/1907.08610"""
# noinspection PyMissingConstructor
def __init__(self, base_optimizer: Optimizer, alpha: float = 0.5, k: int = 6):
if not 0.0 <= alpha <= 1.0:
raise ValueError(f"Invalid slow update rate: {alpha}")
if not 1 <= k:
raise ValueError(f"Invalid lookahead steps: {k}")
defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
self.base_optimizer = base_optimizer
self.param_groups = self.base_optimizer.param_groups
self.defaults = base_optimizer.defaults
self.defaults.update(defaults)
self.state = defaultdict(dict)
# manually add our defaults to the param groups
for name, default in defaults.items():
for group in self.param_groups:
group.setdefault(name, default)
def update_slow(self, group: dict):
for fast_p in group["params"]:
if fast_p.grad is None:
continue
param_state = self.state[fast_p]
if "slow_buffer" not in param_state:
param_state["slow_buffer"] = torch.empty_like(fast_p.data)
param_state["slow_buffer"].copy_(fast_p.data)
slow = param_state["slow_buffer"]
slow.add_(fast_p.data - slow, alpha=group["lookahead_alpha"])
fast_p.data.copy_(slow)
def sync_lookahead(self):
for group in self.param_groups:
self.update_slow(group)
def step(self, closure: Optional[callable] = None) -> Optional[float]:
# print(self.k)
# assert id(self.param_groups) == id(self.base_optimizer.param_groups)
loss = self.base_optimizer.step(closure)
for group in self.param_groups:
group["lookahead_step"] += 1
if group["lookahead_step"] % group["lookahead_k"] == 0:
self.update_slow(group)
return loss
def state_dict(self) -> dict:
fast_state_dict = self.base_optimizer.state_dict()
slow_state = {
(id(k) if isinstance(k, torch.Tensor) else k): v
for k, v in self.state.items()
}
fast_state = fast_state_dict["state"]
param_groups = fast_state_dict["param_groups"]
return {
"state": fast_state,
"slow_state": slow_state,
"param_groups": param_groups,
}
def load_state_dict(self, state_dict: dict):
fast_state_dict = {
"state": state_dict["state"],
"param_groups": state_dict["param_groups"],
}
self.base_optimizer.load_state_dict(fast_state_dict)
# We want to restore the slow state, but share param_groups reference
# with base_optimizer. This is a bit redundant but least code
slow_state_new = False
if "slow_state" not in state_dict:
print("Loading state_dict from optimizer without Lookahead applied.")
state_dict["slow_state"] = defaultdict(dict)
slow_state_new = True
slow_state_dict = {
"state": state_dict["slow_state"],
"param_groups": state_dict[
"param_groups"
], # this is pointless but saves code
}
super(Lookahead, self).load_state_dict(slow_state_dict)
self.param_groups = (
self.base_optimizer.param_groups
) # make both ref same container
if slow_state_new:
# reapply defaults to catch missing lookahead specific ones
for name, default in self.defaults.items():
for group in self.param_groups:
group.setdefault(name, default)
def LookaheadAdam(
params: _params_type,
lr: float = 1e-3,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-08,
weight_decay: float = 0,
amsgrad: bool = False,
lalpha: float = 0.5,
k: int = 6,
):
return Lookahead(
torch.optim.Adam(params, lr, betas, eps, weight_decay, amsgrad), lalpha, k
)
def LookaheadRAdam(
params: _params_type,
lr: float = 1e-3,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0,
lalpha: float = 0.5,
k: int = 6,
):
return Lookahead(optim.RAdam(params, lr, betas, eps, weight_decay), lalpha, k)
def LookaheadRMSprop(
params: _params_type,
lr: float = 1e-2,
alpha: float = 0.99,
eps: float = 1e-08,
weight_decay: float = 0,
momentum: float = 0,
centered: bool = False,
lalpha: float = 0.5,
k: int = 6,
):
return Lookahead(
torch.optim.RMSprop(params, lr, alpha, eps, weight_decay, momentum, centered),
lalpha,
k,
)