File size: 9,295 Bytes
eadd7b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
import math

from mmcv import Config
from mmcv.runner import build_optimizer as mm_build_optimizer, OPTIMIZER_BUILDERS, DefaultOptimizerConstructor, \
    OPTIMIZERS
from mmcv.utils import _BatchNorm, _InstanceNorm
from torch.nn import GroupNorm, LayerNorm

from .logger import get_root_logger

from typing import Tuple, Optional, Callable

import torch
from torch.optim.optimizer import Optimizer
from came_pytorch import CAME


def auto_scale_lr(effective_bs, optimizer_cfg, rule='linear', base_batch_size=256):
    assert rule in ['linear', 'sqrt']
    logger = get_root_logger()
    # scale by world size
    if rule == 'sqrt':
        scale_ratio = math.sqrt(effective_bs / base_batch_size)
    elif rule == 'linear':
        scale_ratio = effective_bs / base_batch_size
    optimizer_cfg['lr'] *= scale_ratio
    logger.info(f'Automatically adapt lr to {optimizer_cfg["lr"]:.5f} (using {rule} scaling rule).')
    return scale_ratio


@OPTIMIZER_BUILDERS.register_module()
class MyOptimizerConstructor(DefaultOptimizerConstructor):

    def add_params(self, params, module, prefix='', is_dcn_module=None):
        """Add all parameters of module to the params list.

        The parameters of the given module will be added to the list of param
        groups, with specific rules defined by paramwise_cfg.

        Args:
            params (list[dict]): A list of param groups, it will be modified
                in place.
            module (nn.Module): The module to be added.
            prefix (str): The prefix of the module

        """
        # get param-wise options
        custom_keys = self.paramwise_cfg.get('custom_keys', {})
        # first sort with alphabet order and then sort with reversed len of str
        # sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True)

        bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', 1.)
        bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', 1.)
        norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', 1.)
        bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False)

        # special rules for norm layers and depth-wise conv layers
        is_norm = isinstance(module,
                             (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm))

        for name, param in module.named_parameters(recurse=False):
            base_lr = self.base_lr
            if name == 'bias' and not (is_norm or is_dcn_module):
                base_lr *= bias_lr_mult

            # apply weight decay policies
            base_wd = self.base_wd
            if self.base_wd is not None:
                # norm decay
                if is_norm:
                    base_wd *= norm_decay_mult
                # bias lr and decay
                elif name == 'bias' and not is_dcn_module:
                    # TODO: current bias_decay_mult will have affect on DCN
                    base_wd *= bias_decay_mult

            param_group = {'params': [param]}
            if not param.requires_grad:
                param_group['requires_grad'] = False
                params.append(param_group)
                continue
            if bypass_duplicate and self._is_in(param_group, params):
                logger = get_root_logger()
                logger.warn(f'{prefix} is duplicate. It is skipped since '
                            f'bypass_duplicate={bypass_duplicate}')
                continue
            # if the parameter match one of the custom keys, ignore other rules
            is_custom = False
            for key in custom_keys:
                if isinstance(key, tuple):
                    scope, key_name = key
                else:
                    scope, key_name = None, key
                if scope is not None and scope not in f'{prefix}':
                    continue
                if key_name in f'{prefix}.{name}':
                    is_custom = True
                    if 'lr_mult' in custom_keys[key]:
                        # if 'base_classes' in f'{prefix}.{name}' or 'attn_base' in f'{prefix}.{name}':
                        #     param_group['lr'] = self.base_lr
                        # else:
                        param_group['lr'] = self.base_lr * custom_keys[key]['lr_mult']
                    elif 'lr' not in param_group:
                        param_group['lr'] = base_lr
                    if self.base_wd is not None:
                        if 'decay_mult' in custom_keys[key]:
                            param_group['weight_decay'] = self.base_wd * custom_keys[key]['decay_mult']
                        elif 'weight_decay' not in param_group:
                            param_group['weight_decay'] = base_wd

            if not is_custom:
                # bias_lr_mult affects all bias parameters
                # except for norm.bias dcn.conv_offset.bias
                if base_lr != self.base_lr:
                    param_group['lr'] = base_lr
                if base_wd != self.base_wd:
                    param_group['weight_decay'] = base_wd
            params.append(param_group)

        for child_name, child_mod in module.named_children():
            child_prefix = f'{prefix}.{child_name}' if prefix else child_name
            self.add_params(
                params,
                child_mod,
                prefix=child_prefix,
                is_dcn_module=is_dcn_module)


def build_optimizer(model, optimizer_cfg):
    # default parameter-wise config
    logger = get_root_logger()

    if hasattr(model, 'module'):
        model = model.module
    # set optimizer constructor
    optimizer_cfg.setdefault('constructor', 'MyOptimizerConstructor')
    # parameter-wise setting: cancel weight decay for some specific modules
    custom_keys = dict()
    for name, module in model.named_modules():
        if hasattr(module, 'zero_weight_decay'):
            custom_keys.update({(name, key): dict(decay_mult=0) for key in module.zero_weight_decay})

    paramwise_cfg = Config(dict(cfg=dict(custom_keys=custom_keys)))
    given_cfg = optimizer_cfg.get('paramwise_cfg')
    if given_cfg:
        paramwise_cfg.merge_from_dict(dict(cfg=given_cfg))
    optimizer_cfg['paramwise_cfg'] = paramwise_cfg.cfg
    # build optimizer
    optimizer = mm_build_optimizer(model, optimizer_cfg)

    weight_decay_groups = dict()
    lr_groups = dict()
    for group in optimizer.param_groups:
        if not group.get('requires_grad', True): continue
        lr_groups.setdefault(group['lr'], []).append(group)
        weight_decay_groups.setdefault(group['weight_decay'], []).append(group)

    learnable_count, fix_count = 0, 0
    for p in model.parameters():
        if p.requires_grad:
            learnable_count += 1
        else:
            fix_count += 1
    fix_info = f"{learnable_count} are learnable, {fix_count} are fix"
    lr_info = "Lr group: " + ", ".join([f'{len(group)} params with lr {lr:.5f}' for lr, group in lr_groups.items()])
    wd_info = "Weight decay group: " + ", ".join(
        [f'{len(group)} params with weight decay {wd}' for wd, group in weight_decay_groups.items()])
    opt_info = f"{optimizer.__class__.__name__} Optimizer: total {len(optimizer.param_groups)} param groups, {fix_info}. {lr_info}; {wd_info}."
    logger.info(opt_info)

    return optimizer


@OPTIMIZERS.register_module()
class Lion(Optimizer):
    def __init__(
            self,
            params,
            lr: float = 1e-4,
            betas: Tuple[float, float] = (0.9, 0.99),
            weight_decay: float = 0.0,
    ):
        assert lr > 0.
        assert all([0. <= beta <= 1. for beta in betas])

        defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)

        super().__init__(params, defaults)

    @staticmethod
    def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2):
        # stepweight decay
        p.data.mul_(1 - lr * wd)

        # weight update
        update = exp_avg.clone().lerp_(grad, 1 - beta1).sign_()
        p.add_(update, alpha=-lr)

        # decay the momentum running average coefficient
        exp_avg.lerp_(grad, 1 - beta2)

    @staticmethod
    def exists(val):
        return val is not None

    @torch.no_grad()
    def step(
            self,
            closure: Optional[Callable] = None
    ):

        loss = None
        if self.exists(closure):
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            for p in filter(lambda p: self.exists(p.grad), group['params']):

                grad, lr, wd, beta1, beta2, state = p.grad, group['lr'], group['weight_decay'], *group['betas'], \
                                                    self.state[p]

                # init state - exponential moving average of gradient values
                if len(state) == 0:
                    state['exp_avg'] = torch.zeros_like(p)

                exp_avg = state['exp_avg']

                self.update_fn(
                    p,
                    grad,
                    exp_avg,
                    lr,
                    wd,
                    beta1,
                    beta2
                )

        return loss


@OPTIMIZERS.register_module()
class CAMEWrapper(CAME):
    def __init__(self, *args, **kwargs):

        super().__init__(*args, **kwargs)