Spaces:
Running
Running
File size: 12,353 Bytes
3b49518 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 |
# --------------------------------------------------------
# Based on timm and MAE-priv code bases
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/BUPT-PRIV/MAE-priv
# --------------------------------------------------------
""" Model / state_dict utils
Hacked together by / Copyright 2020 Ross Wightman
"""
import fnmatch
import torch
from torchvision.ops.misc import FrozenBatchNorm2d
from .model_ema import ModelEma
def unwrap_model(model):
if isinstance(model, ModelEma):
return unwrap_model(model.ema)
else:
return model.module if hasattr(model, 'module') else model
def get_state_dict(model, unwrap_fn=unwrap_model):
return unwrap_fn(model).state_dict()
def avg_sq_ch_mean(model, input, output):
""" calculate average channel square mean of output activations
"""
return torch.mean(output.mean(axis=[0, 2, 3]) ** 2).item()
def avg_ch_var(model, input, output):
""" calculate average channel variance of output activations
"""
return torch.mean(output.var(axis=[0, 2, 3])).item()
def avg_ch_var_residual(model, input, output):
""" calculate average channel variance of output activations
"""
return torch.mean(output.var(axis=[0, 2, 3])).item()
class ActivationStatsHook:
"""Iterates through each of `model`'s modules and matches modules using unix pattern
matching based on `hook_fn_locs` and registers `hook_fn` to the module if there is
a match.
Arguments:
model (nn.Module): model from which we will extract the activation stats
hook_fn_locs (List[str]): List of `hook_fn` locations based on Unix type string
matching with the name of model's modules.
hook_fns (List[Callable]): List of hook functions to be registered at every
module in `layer_names`.
Inspiration from https://docs.fast.ai/callback.hook.html.
Refer to https://gist.github.com/amaarora/6e56942fcb46e67ba203f3009b30d950 for an example
on how to plot Signal Propogation Plots using `ActivationStatsHook`.
"""
def __init__(self, model, hook_fn_locs, hook_fns):
self.model = model
self.hook_fn_locs = hook_fn_locs
self.hook_fns = hook_fns
if len(hook_fn_locs) != len(hook_fns):
raise ValueError("Please provide `hook_fns` for each `hook_fn_locs`, \
their lengths are different.")
self.stats = dict((hook_fn.__name__, []) for hook_fn in hook_fns)
for hook_fn_loc, hook_fn in zip(hook_fn_locs, hook_fns):
self.register_hook(hook_fn_loc, hook_fn)
def _create_hook(self, hook_fn):
def append_activation_stats(module, input, output):
out = hook_fn(module, input, output)
self.stats[hook_fn.__name__].append(out)
return append_activation_stats
def register_hook(self, hook_fn_loc, hook_fn):
for name, module in self.model.named_modules():
if not fnmatch.fnmatch(name, hook_fn_loc):
continue
module.register_forward_hook(self._create_hook(hook_fn))
def extract_spp_stats(
model,
hook_fn_locs,
hook_fns,
input_shape=[8, 3, 224, 224]):
"""Extract average square channel mean and variance of activations during
forward pass to plot Signal Propogation Plots (SPP).
Paper: https://arxiv.org/abs/2101.08692
Example Usage: https://gist.github.com/amaarora/6e56942fcb46e67ba203f3009b30d950
"""
x = torch.normal(0., 1., input_shape)
hook = ActivationStatsHook(model, hook_fn_locs=hook_fn_locs, hook_fns=hook_fns)
_ = model(x)
return hook.stats
def freeze_batch_norm_2d(module):
"""
Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
returned. Otherwise, the module is walked recursively and submodules are converted in place.
Args:
module (torch.nn.Module): Any PyTorch module.
Returns:
torch.nn.Module: Resulting module
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
"""
res = module
if isinstance(module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
res = FrozenBatchNorm2d(module.num_features)
res.num_features = module.num_features
res.affine = module.affine
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for name, child in module.named_children():
new_child = freeze_batch_norm_2d(child)
if new_child is not child:
res.add_module(name, new_child)
return res
def unfreeze_batch_norm_2d(module):
"""
Converts all `FrozenBatchNorm2d` layers of provided module into `BatchNorm2d`. If `module` is itself and instance
of `FrozenBatchNorm2d`, it is converted into `BatchNorm2d` and returned. Otherwise, the module is walked
recursively and submodules are converted in place.
Args:
module (torch.nn.Module): Any PyTorch module.
Returns:
torch.nn.Module: Resulting module
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
"""
res = module
if isinstance(module, FrozenBatchNorm2d):
res = torch.nn.BatchNorm2d(module.num_features)
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for name, child in module.named_children():
new_child = unfreeze_batch_norm_2d(child)
if new_child is not child:
res.add_module(name, new_child)
return res
def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True, mode='freeze'):
"""
Freeze or unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is
done in place.
Args:
root_module (nn.Module, optional): Root module relative to which the `submodules` are referenced.
submodules (list[str]): List of modules for which the parameters will be (un)frozen. They are to be provided as
named modules relative to the root module (accessible via `root_module.named_modules()`). An empty list
means that the whole root module will be (un)frozen. Defaults to []
include_bn_running_stats (bool): Whether to also (un)freeze the running statistics of batch norm 2d layers.
Defaults to `True`.
mode (bool): Whether to freeze ("freeze") or unfreeze ("unfreeze"). Defaults to `"freeze"`.
"""
assert mode in ["freeze", "unfreeze"], '`mode` must be one of "freeze" or "unfreeze"'
if isinstance(root_module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
# Raise assertion here because we can't convert it in place
raise AssertionError(
"You have provided a batch norm layer as the `root module`. Please use "
"`timm.utils.model.freeze_batch_norm_2d` or `timm.utils.model.unfreeze_batch_norm_2d` instead.")
if isinstance(submodules, str):
submodules = [submodules]
named_modules = submodules
submodules = [root_module.get_submodule(m) for m in submodules]
if not len(submodules):
named_modules, submodules = list(zip(*root_module.named_children()))
for n, m in zip(named_modules, submodules):
# (Un)freeze parameters
for p in m.parameters():
p.requires_grad = False if mode == 'freeze' else True
if include_bn_running_stats:
# Helper to add submodule specified as a named_module
def _add_submodule(module, name, submodule):
split = name.rsplit('.', 1)
if len(split) > 1:
module.get_submodule(split[0]).add_module(split[1], submodule)
else:
module.add_module(name, submodule)
# Freeze batch norm
if mode == 'freeze':
res = freeze_batch_norm_2d(m)
# It's possible that `m` is a type of BatchNorm in itself, in which case `unfreeze_batch_norm_2d` won't
# convert it in place, but will return the converted result. In this case `res` holds the converted
# result and we may try to re-assign the named module
if isinstance(m, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
_add_submodule(root_module, n, res)
# Unfreeze batch norm
else:
res = unfreeze_batch_norm_2d(m)
# Ditto. See note above in mode == 'freeze' branch
if isinstance(m, FrozenBatchNorm2d):
_add_submodule(root_module, n, res)
def freeze(root_module, submodules=[], include_bn_running_stats=True):
"""
Freeze parameters of the specified modules and those of all their hierarchical descendants. This is done in place.
Args:
root_module (nn.Module): Root module relative to which `submodules` are referenced.
submodules (list[str]): List of modules for which the parameters will be frozen. They are to be provided as
named modules relative to the root module (accessible via `root_module.named_modules()`). An empty list
means that the whole root module will be frozen. Defaults to `[]`.
include_bn_running_stats (bool): Whether to also freeze the running statistics of `BatchNorm2d` and
`SyncBatchNorm` layers. These will be converted to `FrozenBatchNorm2d` in place. Hint: During fine tuning,
it's good practice to freeze batch norm stats. And note that these are different to the affine parameters
which are just normal PyTorch parameters. Defaults to `True`.
Hint: If you want to freeze batch norm ONLY, use `timm.utils.model.freeze_batch_norm_2d`.
Examples::
>>> model = timm.create_model('resnet18')
>>> # Freeze up to and including layer2
>>> submodules = [n for n, _ in model.named_children()]
>>> print(submodules)
['conv1', 'bn1', 'act1', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'global_pool', 'fc']
>>> freeze(model, submodules[:submodules.index('layer2') + 1])
>>> # Check for yourself that it works as expected
>>> print(model.layer2[0].conv1.weight.requires_grad)
False
>>> print(model.layer3[0].conv1.weight.requires_grad)
True
>>> # Unfreeze
>>> unfreeze(model)
"""
_freeze_unfreeze(root_module, submodules, include_bn_running_stats=include_bn_running_stats, mode="freeze")
def unfreeze(root_module, submodules=[], include_bn_running_stats=True):
"""
Unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is done in place.
Args:
root_module (nn.Module): Root module relative to which `submodules` are referenced.
submodules (list[str]): List of submodules for which the parameters will be (un)frozen. They are to be provided
as named modules relative to the root module (accessible via `root_module.named_modules()`). An empty
list means that the whole root module will be unfrozen. Defaults to `[]`.
include_bn_running_stats (bool): Whether to also unfreeze the running statistics of `FrozenBatchNorm2d` layers.
These will be converted to `BatchNorm2d` in place. Defaults to `True`.
See example in docstring for `freeze`.
"""
_freeze_unfreeze(root_module, submodules, include_bn_running_stats=include_bn_running_stats, mode="unfreeze")
|