File size: 4,355 Bytes
a5f8a35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
r"""
`Lookahead Optimizer: k steps forward, 1 step back <https://arxiv.org/abs/1907.08610>`_.

This implementation is adapted with minimal modifications from the
`authors' implementation <https://github.com/michaelrzhang/lookahead>`_.

If you take it from here, please cite them:

.. code-block:: text

    @inproceedings{zhang2019lookahead,
        title={Lookahead Optimizer: k steps forward, 1 step back},
        author={Zhang, Michael R and Lucas, James and Hinton, Geoffrey and Ba, Jimmy},
        journal={NeurIPS},
        year={2019}
    }
"""
from collections import defaultdict
from typing import Any, Callable, Dict

import torch
from torch.optim.optimizer import Optimizer


class Lookahead(Optimizer):
    r"""
    Implements Lookahead optimizer.

    Parameters
    ----------
    optimizer: torch.optim.Optimizer
        Wrapper inner optimizer. The weights it manages will be the "fast"
        weights.
    k: int, optional (default = 5)
        Number of lookahead steps before updating "slow" weights.
    alpha: float, optional (default = 0.8)
        Linear interpolation factor, 1.0 recovers inner optimizer.
    """

    def __init__(self, optimizer: Optimizer, k: int = 5, alpha: float = 0.8):
        self.optimizer = optimizer
        self.k = k
        self.alpha = alpha

        # Counter for inner optimizer.
        self._k_counter = 0

        # Cache the current optimizer parameters
        self.state: Dict[str, Any] = defaultdict(dict)
        for group in optimizer.param_groups:
            for p in group["params"]:
                param_state = self.state[p]
                param_state["slow_params"] = torch.zeros_like(p.data)
                param_state["slow_params"].copy_(p.data)

    def __getstate__(self):
        return {
            "state": self.state,
            "optimizer": self.optimizer,
            "alpha": self.alpha,
            "k": self.k,
            "_k_counter": self._k_counter,
        }

    @property
    def param_groups(self):
        return self.optimizer.param_groups

    def zero_grad(self):
        r"""Clear all grad buffers at the start of new forward pass."""
        self.optimizer.zero_grad()

    def state_dict(self):
        return self.optimizer.state_dict()

    def load_state_dict(self, state_dict: Dict[str, Any]):
        self.optimizer.load_state_dict(state_dict)

    def step(self, closure: Callable = None):
        r"""
        Perform a single Lookahead optimization step.

        Parameters
        ----------
        closure: Callable, optional (default = None)
            A callable that re-evaluates the model and returns the loss.
        """
        loss = self.optimizer.step(closure)
        self._k_counter += 1

        if self._k_counter >= self.k:
            self._k_counter = 0
            # Lookahead and cache the current optimizer parameters
            for group in self.optimizer.param_groups:
                for p in group["params"]:
                    param_state = self.state[p]
                    p.data.mul_(self.alpha).add_(
                        param_state["slow_params"], alpha=1.0 - self.alpha
                    )
                    param_state["slow_params"].copy_(p.data)
        return loss

    def load_slow_weights(self):
        r"""
        Load slow weights from Lookahead optimizer. Useful for performing
        evaluation on the slow weights (which typically generalize better).

        This method backs up fast weights to load them after evaluation. No
        need to call this method if evaluation happens just after a lookahead
        step.
        """
        for group in self.optimizer.param_groups:
            for p in group["params"]:
                param_state = self.state[p]
                param_state["backup_params"] = torch.zeros_like(p.data)
                param_state["backup_params"].copy_(p.data)
                p.data.copy_(param_state["slow_params"])

    def restore_fast_weights(self):
        r"""
        Restore fast weights for optimization. Call this after evaluation if
        :meth:`load_slow_weights` was called.
        """
        for group in self.optimizer.param_groups:
            for p in group["params"]:
                param_state = self.state[p]
                p.data.copy_(param_state["backup_params"])
                del param_state["backup_params"]