Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
import logging | |
import torch | |
import torch.distributed as dist | |
from fvcore.nn.distributed import differentiable_all_reduce | |
from torch import nn | |
from torch.nn import functional as F | |
from detectron2.utils import comm, env | |
from .wrappers import BatchNorm2d | |
class FrozenBatchNorm2d(nn.Module): | |
""" | |
BatchNorm2d where the batch statistics and the affine parameters are fixed. | |
It contains non-trainable buffers called | |
"weight" and "bias", "running_mean", "running_var", | |
initialized to perform identity transformation. | |
The pre-trained backbone models from Caffe2 only contain "weight" and "bias", | |
which are computed from the original four parameters of BN. | |
The affine transform `x * weight + bias` will perform the equivalent | |
computation of `(x - running_mean) / sqrt(running_var) * weight + bias`. | |
When loading a backbone model from Caffe2, "running_mean" and "running_var" | |
will be left unchanged as identity transformation. | |
Other pre-trained backbone models may contain all 4 parameters. | |
The forward is implemented by `F.batch_norm(..., training=False)`. | |
""" | |
_version = 3 | |
def __init__(self, num_features, eps=1e-5): | |
super().__init__() | |
self.num_features = num_features | |
self.eps = eps | |
self.register_buffer("weight", torch.ones(num_features)) | |
self.register_buffer("bias", torch.zeros(num_features)) | |
self.register_buffer("running_mean", torch.zeros(num_features)) | |
self.register_buffer("running_var", torch.ones(num_features) - eps) | |
def forward(self, x): | |
if x.requires_grad: | |
# When gradients are needed, F.batch_norm will use extra memory | |
# because its backward op computes gradients for weight/bias as well. | |
scale = self.weight * (self.running_var + self.eps).rsqrt() | |
bias = self.bias - self.running_mean * scale | |
scale = scale.reshape(1, -1, 1, 1) | |
bias = bias.reshape(1, -1, 1, 1) | |
out_dtype = x.dtype # may be half | |
return x * scale.to(out_dtype) + bias.to(out_dtype) | |
else: | |
# When gradients are not needed, F.batch_norm is a single fused op | |
# and provide more optimization opportunities. | |
return F.batch_norm( | |
x, | |
self.running_mean, | |
self.running_var, | |
self.weight, | |
self.bias, | |
training=False, | |
eps=self.eps, | |
) | |
def _load_from_state_dict( | |
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs | |
): | |
version = local_metadata.get("version", None) | |
if version is None or version < 2: | |
# when use offline modules, avoid overwriting running mean and var for loaded weights | |
skip_reset = False | |
for k_n in state_dict: # checkpoint weights | |
if 'ignore_others' in k_n: #if 'offline' in k_n: | |
skip_reset = True | |
if not skip_reset: | |
# No running_mean/var in early versions | |
# This will silent the warnings | |
if prefix + "running_mean" not in state_dict: | |
state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean) | |
if prefix + "running_var" not in state_dict: | |
state_dict[prefix + "running_var"] = torch.ones_like(self.running_var) | |
# NOTE: if a checkpoint is trained with BatchNorm and loaded (together with | |
# version number) to FrozenBatchNorm, running_var will be wrong. One solution | |
# is to remove the version number from the checkpoint. | |
if version is not None and version < 3: | |
logger = logging.getLogger(__name__) | |
logger.info("FrozenBatchNorm {} is upgraded to version 3.".format(prefix.rstrip("."))) | |
# In version < 3, running_var are used without +eps. | |
state_dict[prefix + "running_var"] -= self.eps | |
super()._load_from_state_dict( | |
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs | |
) | |
def __repr__(self): | |
return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps) | |
def convert_frozen_batchnorm(cls, module): | |
""" | |
Convert all BatchNorm/SyncBatchNorm in module into FrozenBatchNorm. | |
Args: | |
module (torch.nn.Module): | |
Returns: | |
If module is BatchNorm/SyncBatchNorm, returns a new module. | |
Otherwise, in-place convert module and return it. | |
Similar to convert_sync_batchnorm in | |
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py | |
""" | |
bn_module = nn.modules.batchnorm | |
bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm) | |
res = module | |
if isinstance(module, bn_module): | |
res = cls(module.num_features) | |
if module.affine: | |
res.weight.data = module.weight.data.clone().detach() | |
res.bias.data = module.bias.data.clone().detach() | |
res.running_mean.data = module.running_mean.data | |
res.running_var.data = module.running_var.data | |
res.eps = module.eps | |
else: | |
for name, child in module.named_children(): | |
new_child = cls.convert_frozen_batchnorm(child) | |
if new_child is not child: | |
res.add_module(name, new_child) | |
return res | |
def get_norm(norm, out_channels): | |
""" | |
Args: | |
norm (str or callable): either one of BN, SyncBN, FrozenBN, GN; | |
or a callable that takes a channel number and returns | |
the normalization layer as a nn.Module. | |
Returns: | |
nn.Module or None: the normalization layer | |
""" | |
if norm is None: | |
return None | |
if isinstance(norm, str): | |
if len(norm) == 0: | |
return None | |
norm = { | |
"BN": BatchNorm2d, | |
# Fixed in https://github.com/pytorch/pytorch/pull/36382 | |
"SyncBN": NaiveSyncBatchNorm if env.TORCH_VERSION <= (1, 5) else nn.SyncBatchNorm, | |
"FrozenBN": FrozenBatchNorm2d, | |
"GN": lambda channels: nn.GroupNorm(32, channels), | |
# for debugging: | |
"nnSyncBN": nn.SyncBatchNorm, | |
"naiveSyncBN": NaiveSyncBatchNorm, | |
}[norm] | |
return norm(out_channels) | |
class NaiveSyncBatchNorm(BatchNorm2d): | |
""" | |
In PyTorch<=1.5, ``nn.SyncBatchNorm`` has incorrect gradient | |
when the batch size on each worker is different. | |
(e.g., when scale augmentation is used, or when it is applied to mask head). | |
This is a slower but correct alternative to `nn.SyncBatchNorm`. | |
Note: | |
There isn't a single definition of Sync BatchNorm. | |
When ``stats_mode==""``, this module computes overall statistics by using | |
statistics of each worker with equal weight. The result is true statistics | |
of all samples (as if they are all on one worker) only when all workers | |
have the same (N, H, W). This mode does not support inputs with zero batch size. | |
When ``stats_mode=="N"``, this module computes overall statistics by weighting | |
the statistics of each worker by their ``N``. The result is true statistics | |
of all samples (as if they are all on one worker) only when all workers | |
have the same (H, W). It is slower than ``stats_mode==""``. | |
Even though the result of this module may not be the true statistics of all samples, | |
it may still be reasonable because it might be preferrable to assign equal weights | |
to all workers, regardless of their (H, W) dimension, instead of putting larger weight | |
on larger images. From preliminary experiments, little difference is found between such | |
a simplified implementation and an accurate computation of overall mean & variance. | |
""" | |
def __init__(self, *args, stats_mode="", **kwargs): | |
super().__init__(*args, **kwargs) | |
assert stats_mode in ["", "N"] | |
self._stats_mode = stats_mode | |
def forward(self, input): | |
if comm.get_world_size() == 1 or not self.training: | |
return super().forward(input) | |
B, C = input.shape[0], input.shape[1] | |
half_input = input.dtype == torch.float16 | |
if half_input: | |
# fp16 does not have good enough numerics for the reduction here | |
input = input.float() | |
mean = torch.mean(input, dim=[0, 2, 3]) | |
meansqr = torch.mean(input * input, dim=[0, 2, 3]) | |
if self._stats_mode == "": | |
assert B > 0, 'SyncBatchNorm(stats_mode="") does not support zero batch size.' | |
vec = torch.cat([mean, meansqr], dim=0) | |
vec = differentiable_all_reduce(vec) * (1.0 / dist.get_world_size()) | |
mean, meansqr = torch.split(vec, C) | |
momentum = self.momentum | |
else: | |
if B == 0: | |
vec = torch.zeros([2 * C + 1], device=mean.device, dtype=mean.dtype) | |
vec = vec + input.sum() # make sure there is gradient w.r.t input | |
else: | |
vec = torch.cat( | |
[mean, meansqr, torch.ones([1], device=mean.device, dtype=mean.dtype)], dim=0 | |
) | |
vec = differentiable_all_reduce(vec * B) | |
total_batch = vec[-1].detach() | |
momentum = total_batch.clamp(max=1) * self.momentum # no update if total_batch is 0 | |
mean, meansqr, _ = torch.split(vec / total_batch.clamp(min=1), C) # avoid div-by-zero | |
var = meansqr - mean * mean | |
invstd = torch.rsqrt(var + self.eps) | |
scale = self.weight * invstd | |
bias = self.bias - mean * scale | |
scale = scale.reshape(1, -1, 1, 1) | |
bias = bias.reshape(1, -1, 1, 1) | |
self.running_mean += momentum * (mean.detach() - self.running_mean) | |
self.running_var += momentum * (var.detach() - self.running_var) | |
ret = input * scale + bias | |
if half_input: | |
ret = ret.half() | |
return ret | |