|
|
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
|
|
class FrozenBatchNorm2d(nn.Module): |
|
""" |
|
BatchNorm2d where the batch statistics and the affine parameters are fixed. |
|
|
|
It contains non-trainable buffers called |
|
"weight" and "bias", "running_mean", "running_var", |
|
initialized to perform identity transformation. |
|
|
|
The pre-trained backbone models from Caffe2 only contain "weight" and "bias", |
|
which are computed from the original four parameters of BN. |
|
The affine transform `x * weight + bias` will perform the equivalent |
|
computation of `(x - running_mean) / sqrt(running_var) * weight + bias`. |
|
When loading a backbone model from Caffe2, "running_mean" and "running_var" |
|
will be left unchanged as identity transformation. |
|
|
|
Other pre-trained backbone models may contain all 4 parameters. |
|
|
|
The forward is implemented by `F.batch_norm(..., training=False)`. |
|
""" |
|
|
|
def __init__(self, num_features, eps=1e-5): |
|
super().__init__() |
|
self.num_features = num_features |
|
self.eps = eps |
|
self.register_buffer("weight", torch.ones(num_features)) |
|
self.register_buffer("bias", torch.zeros(num_features)) |
|
self.register_buffer("running_mean", torch.zeros(num_features)) |
|
self.register_buffer("running_var", torch.ones(num_features) - eps) |
|
|
|
def forward(self, x): |
|
if x.requires_grad: |
|
|
|
|
|
scale = self.weight * (self.running_var + self.eps).rsqrt() |
|
bias = self.bias - self.running_mean * scale |
|
if x.dim() == 5: |
|
scale = scale.reshape(1, -1, 1, 1, 1) |
|
bias = bias.reshape(1, -1, 1, 1, 1) |
|
else: |
|
scale = scale.reshape(1, -1, 1, 1) |
|
bias = bias.reshape(1, -1, 1, 1) |
|
|
|
out_dtype = x.dtype |
|
return x * scale.to(out_dtype) + bias.to(out_dtype) |
|
else: |
|
|
|
|
|
return F.batch_norm( |
|
x, |
|
self.running_mean, |
|
self.running_var, |
|
self.weight, |
|
self.bias, |
|
training=False, |
|
eps=self.eps, |
|
) |
|
|
|
def _load_from_state_dict( |
|
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs |
|
): |
|
num_batches_tracked_key = prefix + 'num_batches_tracked' |
|
if num_batches_tracked_key in state_dict: |
|
del state_dict[num_batches_tracked_key] |
|
version = local_metadata.get("version", None) |
|
|
|
if version is None or version < 2: |
|
|
|
|
|
if prefix + "running_mean" not in state_dict: |
|
state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean) |
|
if prefix + "running_var" not in state_dict: |
|
state_dict[prefix + "running_var"] = torch.ones_like(self.running_var) |
|
|
|
super()._load_from_state_dict( |
|
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs |
|
) |
|
|
|
def __repr__(self): |
|
return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps) |
|
|