Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from collections import defaultdict | |
from typing import Callable, List, Optional | |
from mmengine.logging import MMLogger | |
from mmengine.optim import DefaultOptimWrapperConstructor | |
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm | |
from torch import nn | |
from torch.nn import GroupNorm, LayerNorm | |
from mmpretrain.registry import OPTIM_WRAPPER_CONSTRUCTORS | |
class LearningRateDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor): | |
"""Different learning rates are set for different layers of backbone. | |
By default, each parameter share the same optimizer settings, and we | |
provide an argument ``paramwise_cfg`` to specify parameter-wise settings. | |
It is a dict and may contain the following fields: | |
- ``layer_decay_rate`` (float): The learning rate of a parameter will | |
multiply it by multiple times according to the layer depth of the | |
parameter. Usually, it's less than 1, so that the earlier layers will | |
have a lower learning rate. Defaults to 1. | |
- ``bias_decay_mult`` (float): It will be multiplied to the weight | |
decay for all bias parameters (except for those in normalization layers). | |
- ``norm_decay_mult`` (float): It will be multiplied to the weight | |
decay for all weight and bias parameters of normalization layers. | |
- ``flat_decay_mult`` (float): It will be multiplied to the weight | |
decay for all one-dimensional parameters | |
- ``custom_keys`` (dict): Specified parameters-wise settings by keys. If | |
one of the keys in ``custom_keys`` is a substring of the name of one | |
parameter, then the setting of the parameter will be specified by | |
``custom_keys[key]`` and other setting like ``bias_decay_mult`` will be | |
ignored. It should be a dict and may contain fields ``decay_mult``. | |
(The ``lr_mult`` is disabled in this constructor). | |
Example: | |
In the config file, you can use this constructor as below: | |
.. code:: python | |
optim_wrapper = dict( | |
optimizer=dict( | |
type='AdamW', | |
lr=4e-3, | |
weight_decay=0.05, | |
eps=1e-8, | |
betas=(0.9, 0.999)), | |
constructor='LearningRateDecayOptimWrapperConstructor', | |
paramwise_cfg=dict( | |
layer_decay_rate=0.75, # layer-wise lr decay factor | |
norm_decay_mult=0., | |
flat_decay_mult=0., | |
custom_keys={ | |
'.cls_token': dict(decay_mult=0.0), | |
'.pos_embed': dict(decay_mult=0.0) | |
})) | |
""" | |
def add_params(self, | |
params: List[dict], | |
module: nn.Module, | |
prefix: str = '', | |
get_layer_depth: Optional[Callable] = None, | |
**kwargs) -> None: | |
"""Add all parameters of module to the params list. | |
The parameters of the given module will be added to the list of param | |
groups, with specific rules defined by paramwise_cfg. | |
Args: | |
params (List[dict]): A list of param groups, it will be modified | |
in place. | |
module (nn.Module): The module to be added. | |
optimizer_cfg (dict): The configuration of optimizer. | |
prefix (str): The prefix of the module. | |
""" | |
# get param-wise options | |
custom_keys = self.paramwise_cfg.get('custom_keys', {}) | |
# first sort with alphabet order and then sort with reversed len of str | |
sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True) | |
logger = MMLogger.get_current_instance() | |
# The model should have `get_layer_depth` method | |
if get_layer_depth is None and not hasattr(module, 'get_layer_depth'): | |
raise NotImplementedError('The layer-wise learning rate decay need' | |
f' the model {type(module)} has' | |
' `get_layer_depth` method.') | |
else: | |
get_layer_depth = get_layer_depth or module.get_layer_depth | |
bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', None) | |
norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', None) | |
flat_decay_mult = self.paramwise_cfg.get('flat_decay_mult', None) | |
decay_rate = self.paramwise_cfg.get('layer_decay_rate', 1.0) | |
# special rules for norm layers and depth-wise conv layers | |
is_norm = isinstance(module, | |
(_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)) | |
for name, param in module.named_parameters(recurse=False): | |
param_group = {'params': [param]} | |
param_name = prefix + name | |
if not param.requires_grad: | |
continue | |
if self.base_wd is not None: | |
base_wd = self.base_wd | |
custom_key = next( | |
filter(lambda k: k in param_name, sorted_keys), None) | |
# custom parameters decay | |
if custom_key is not None: | |
custom_cfg = custom_keys[custom_key].copy() | |
decay_mult = custom_cfg.pop('decay_mult', 1.) | |
param_group['weight_decay'] = base_wd * decay_mult | |
# add custom settings to param_group | |
param_group.update(custom_cfg) | |
# norm decay | |
elif is_norm and norm_decay_mult is not None: | |
param_group['weight_decay'] = base_wd * norm_decay_mult | |
# bias decay | |
elif name == 'bias' and bias_decay_mult is not None: | |
param_group['weight_decay'] = base_wd * bias_decay_mult | |
# flatten parameters decay | |
elif param.ndim == 1 and flat_decay_mult is not None: | |
param_group['weight_decay'] = base_wd * flat_decay_mult | |
else: | |
param_group['weight_decay'] = base_wd | |
layer_id, max_id = get_layer_depth(param_name) | |
scale = decay_rate**(max_id - layer_id - 1) | |
param_group['lr'] = self.base_lr * scale | |
param_group['lr_scale'] = scale | |
param_group['layer_id'] = layer_id | |
param_group['param_name'] = param_name | |
params.append(param_group) | |
for child_name, child_mod in module.named_children(): | |
child_prefix = f'{prefix}{child_name}.' | |
self.add_params( | |
params, | |
child_mod, | |
prefix=child_prefix, | |
get_layer_depth=get_layer_depth, | |
) | |
if prefix == '': | |
layer_params = defaultdict(list) | |
for param in params: | |
layer_params[param['layer_id']].append(param) | |
for layer_id, layer_params in layer_params.items(): | |
lr_scale = layer_params[0]['lr_scale'] | |
lr = layer_params[0]['lr'] | |
msg = [ | |
f'layer {layer_id} params ' | |
f'(lr={lr:.3g}, lr_scale={lr_scale:.3g}):' | |
] | |
for param in layer_params: | |
msg.append(f'\t{param["param_name"]}: ' | |
f'weight_decay={param["weight_decay"]:.3g}') | |
logger.debug('\n'.join(msg)) | |