File size: 2,006 Bytes
f670afc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import torch
from torch.optim.optimizer import Optimizer, required


class Madam(Optimizer):
    r"""MADAM optimizer implementation (https://arxiv.org/abs/2006.14560)"""
    def __init__(self, params, lr=required, scale=3.0,
                 g_bound=None, momentum=0):
        self.scale = scale
        self.g_bound = g_bound
        defaults = dict(lr=lr, momentum=momentum)
        super(Madam, self).__init__(params, defaults)

    def step(self, closure=None):
        r"""Performs a single optimization step.

        Args:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                state = self.state[p]
                if len(state) == 0:
                    state['max'] = self.scale * (p * p).mean().sqrt().item()
                    state['step'] = 0
                    state['exp_avg_sq'] = torch.zeros_like(p)

                state['step'] += 1
                bias_correction = 1 - 0.999 ** state['step']
                state['exp_avg_sq'] = 0.999 * state[
                    'exp_avg_sq'] + 0.001 * p.grad.data ** 2
                g_normed = \
                    p.grad.data / (state['exp_avg_sq'] / bias_correction).sqrt()
                g_normed[torch.isnan(g_normed)] = 0
                if self.g_bound is not None:
                    g_normed.clamp_(-self.g_bound, self.g_bound)

                p.data *= torch.exp(
                    -group['lr'] * g_normed * torch.sign(p.data))
                p.data.clamp_(-state['max'], state['max'])

        return loss