|
""" Lookahead Optimizer Wrapper. |
|
Implementation modified from: https://github.com/alphadl/lookahead.pytorch |
|
Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610 |
|
|
|
Hacked together by / Copyright 2020 Ross Wightman |
|
""" |
|
import torch |
|
from torch.optim.optimizer import Optimizer |
|
from collections import defaultdict |
|
|
|
|
|
class Lookahead(Optimizer): |
|
def __init__(self, base_optimizer, alpha=0.5, k=6): |
|
if not 0.0 <= alpha <= 1.0: |
|
raise ValueError(f'Invalid slow update rate: {alpha}') |
|
if not 1 <= k: |
|
raise ValueError(f'Invalid lookahead steps: {k}') |
|
defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0) |
|
self.base_optimizer = base_optimizer |
|
self.param_groups = self.base_optimizer.param_groups |
|
self.defaults = base_optimizer.defaults |
|
self.defaults.update(defaults) |
|
self.state = defaultdict(dict) |
|
|
|
for name, default in defaults.items(): |
|
for group in self.param_groups: |
|
group.setdefault(name, default) |
|
|
|
def update_slow(self, group): |
|
for fast_p in group["params"]: |
|
if fast_p.grad is None: |
|
continue |
|
param_state = self.state[fast_p] |
|
if 'slow_buffer' not in param_state: |
|
param_state['slow_buffer'] = torch.empty_like(fast_p.data) |
|
param_state['slow_buffer'].copy_(fast_p.data) |
|
slow = param_state['slow_buffer'] |
|
slow.add_(group['lookahead_alpha'], fast_p.data - slow) |
|
fast_p.data.copy_(slow) |
|
|
|
def sync_lookahead(self): |
|
for group in self.param_groups: |
|
self.update_slow(group) |
|
|
|
def step(self, closure=None): |
|
|
|
loss = self.base_optimizer.step(closure) |
|
for group in self.param_groups: |
|
group['lookahead_step'] += 1 |
|
if group['lookahead_step'] % group['lookahead_k'] == 0: |
|
self.update_slow(group) |
|
return loss |
|
|
|
def state_dict(self): |
|
fast_state_dict = self.base_optimizer.state_dict() |
|
slow_state = { |
|
(id(k) if isinstance(k, torch.Tensor) else k): v |
|
for k, v in self.state.items() |
|
} |
|
fast_state = fast_state_dict['state'] |
|
param_groups = fast_state_dict['param_groups'] |
|
return { |
|
'state': fast_state, |
|
'slow_state': slow_state, |
|
'param_groups': param_groups, |
|
} |
|
|
|
def load_state_dict(self, state_dict): |
|
fast_state_dict = { |
|
'state': state_dict['state'], |
|
'param_groups': state_dict['param_groups'], |
|
} |
|
self.base_optimizer.load_state_dict(fast_state_dict) |
|
|
|
|
|
|
|
slow_state_new = False |
|
if 'slow_state' not in state_dict: |
|
print('Loading state_dict from optimizer without Lookahead applied.') |
|
state_dict['slow_state'] = defaultdict(dict) |
|
slow_state_new = True |
|
slow_state_dict = { |
|
'state': state_dict['slow_state'], |
|
'param_groups': state_dict['param_groups'], |
|
} |
|
super(Lookahead, self).load_state_dict(slow_state_dict) |
|
self.param_groups = self.base_optimizer.param_groups |
|
if slow_state_new: |
|
|
|
for name, default in self.defaults.items(): |
|
for group in self.param_groups: |
|
group.setdefault(name, default) |
|
|