RSPrompter / mmpretrain /engine /optimizers /layer_decay_optim_wrapper_constructor.py
KyanChen's picture
Upload 303 files
4d0eb62
raw
history blame
7.41 kB
# Copyright (c) OpenMMLab. All rights reserved.
from collections import defaultdict
from typing import Callable, List, Optional
from mmengine.logging import MMLogger
from mmengine.optim import DefaultOptimWrapperConstructor
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm
from torch import nn
from torch.nn import GroupNorm, LayerNorm
from mmpretrain.registry import OPTIM_WRAPPER_CONSTRUCTORS
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
class LearningRateDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor):
"""Different learning rates are set for different layers of backbone.
By default, each parameter share the same optimizer settings, and we
provide an argument ``paramwise_cfg`` to specify parameter-wise settings.
It is a dict and may contain the following fields:
- ``layer_decay_rate`` (float): The learning rate of a parameter will
multiply it by multiple times according to the layer depth of the
parameter. Usually, it's less than 1, so that the earlier layers will
have a lower learning rate. Defaults to 1.
- ``bias_decay_mult`` (float): It will be multiplied to the weight
decay for all bias parameters (except for those in normalization layers).
- ``norm_decay_mult`` (float): It will be multiplied to the weight
decay for all weight and bias parameters of normalization layers.
- ``flat_decay_mult`` (float): It will be multiplied to the weight
decay for all one-dimensional parameters
- ``custom_keys`` (dict): Specified parameters-wise settings by keys. If
one of the keys in ``custom_keys`` is a substring of the name of one
parameter, then the setting of the parameter will be specified by
``custom_keys[key]`` and other setting like ``bias_decay_mult`` will be
ignored. It should be a dict and may contain fields ``decay_mult``.
(The ``lr_mult`` is disabled in this constructor).
Example:
In the config file, you can use this constructor as below:
.. code:: python
optim_wrapper = dict(
optimizer=dict(
type='AdamW',
lr=4e-3,
weight_decay=0.05,
eps=1e-8,
betas=(0.9, 0.999)),
constructor='LearningRateDecayOptimWrapperConstructor',
paramwise_cfg=dict(
layer_decay_rate=0.75, # layer-wise lr decay factor
norm_decay_mult=0.,
flat_decay_mult=0.,
custom_keys={
'.cls_token': dict(decay_mult=0.0),
'.pos_embed': dict(decay_mult=0.0)
}))
"""
def add_params(self,
params: List[dict],
module: nn.Module,
prefix: str = '',
get_layer_depth: Optional[Callable] = None,
**kwargs) -> None:
"""Add all parameters of module to the params list.
The parameters of the given module will be added to the list of param
groups, with specific rules defined by paramwise_cfg.
Args:
params (List[dict]): A list of param groups, it will be modified
in place.
module (nn.Module): The module to be added.
optimizer_cfg (dict): The configuration of optimizer.
prefix (str): The prefix of the module.
"""
# get param-wise options
custom_keys = self.paramwise_cfg.get('custom_keys', {})
# first sort with alphabet order and then sort with reversed len of str
sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True)
logger = MMLogger.get_current_instance()
# The model should have `get_layer_depth` method
if get_layer_depth is None and not hasattr(module, 'get_layer_depth'):
raise NotImplementedError('The layer-wise learning rate decay need'
f' the model {type(module)} has'
' `get_layer_depth` method.')
else:
get_layer_depth = get_layer_depth or module.get_layer_depth
bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', None)
norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', None)
flat_decay_mult = self.paramwise_cfg.get('flat_decay_mult', None)
decay_rate = self.paramwise_cfg.get('layer_decay_rate', 1.0)
# special rules for norm layers and depth-wise conv layers
is_norm = isinstance(module,
(_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm))
for name, param in module.named_parameters(recurse=False):
param_group = {'params': [param]}
param_name = prefix + name
if not param.requires_grad:
continue
if self.base_wd is not None:
base_wd = self.base_wd
custom_key = next(
filter(lambda k: k in param_name, sorted_keys), None)
# custom parameters decay
if custom_key is not None:
custom_cfg = custom_keys[custom_key].copy()
decay_mult = custom_cfg.pop('decay_mult', 1.)
param_group['weight_decay'] = base_wd * decay_mult
# add custom settings to param_group
param_group.update(custom_cfg)
# norm decay
elif is_norm and norm_decay_mult is not None:
param_group['weight_decay'] = base_wd * norm_decay_mult
# bias decay
elif name == 'bias' and bias_decay_mult is not None:
param_group['weight_decay'] = base_wd * bias_decay_mult
# flatten parameters decay
elif param.ndim == 1 and flat_decay_mult is not None:
param_group['weight_decay'] = base_wd * flat_decay_mult
else:
param_group['weight_decay'] = base_wd
layer_id, max_id = get_layer_depth(param_name)
scale = decay_rate**(max_id - layer_id - 1)
param_group['lr'] = self.base_lr * scale
param_group['lr_scale'] = scale
param_group['layer_id'] = layer_id
param_group['param_name'] = param_name
params.append(param_group)
for child_name, child_mod in module.named_children():
child_prefix = f'{prefix}{child_name}.'
self.add_params(
params,
child_mod,
prefix=child_prefix,
get_layer_depth=get_layer_depth,
)
if prefix == '':
layer_params = defaultdict(list)
for param in params:
layer_params[param['layer_id']].append(param)
for layer_id, layer_params in layer_params.items():
lr_scale = layer_params[0]['lr_scale']
lr = layer_params[0]['lr']
msg = [
f'layer {layer_id} params '
f'(lr={lr:.3g}, lr_scale={lr_scale:.3g}):'
]
for param in layer_params:
msg.append(f'\t{param["param_name"]}: '
f'weight_decay={param["weight_decay"]:.3g}')
logger.debug('\n'.join(msg))