|
|
|
|
|
from torch import nn |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
def c2_msra_fill(module: nn.Module) -> None: |
|
""" |
|
Initialize `module.weight` using the "MSRAFill" implemented in Caffe2. |
|
Also initializes `module.bias` to 0. |
|
Args: |
|
module (torch.nn.Module): module to initialize. |
|
""" |
|
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") |
|
if module.bias is not None: |
|
|
|
|
|
nn.init.constant_(module.bias, 0) |
|
|
|
def get_norm(norm, out_channels): |
|
""" |
|
Args: |
|
norm (str or callable): either one of BN, SyncBN, FrozenBN, GN; |
|
or a callable that takes a channel number and returns |
|
the normalization layer as a nn.Module. |
|
|
|
Returns: |
|
nn.Module or None: the normalization layer |
|
""" |
|
if norm is None: |
|
return None |
|
if isinstance(norm, str): |
|
if len(norm) == 0: |
|
return None |
|
norm = { |
|
"BN": torch.nn.BatchNorm2d, |
|
|
|
|
|
"FrozenBN": FrozenBatchNorm2d, |
|
"GN": lambda channels: nn.GroupNorm(32, channels), |
|
|
|
"nnSyncBN": nn.SyncBatchNorm, |
|
|
|
|
|
|
|
}[norm] |
|
return norm(out_channels) |
|
class Conv2d(torch.nn.Conv2d): |
|
""" |
|
A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features. |
|
""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
""" |
|
Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`: |
|
|
|
Args: |
|
norm (nn.Module, optional): a normalization layer |
|
activation (callable(Tensor) -> Tensor): a callable activation function |
|
|
|
It assumes that norm layer is used before activation. |
|
""" |
|
norm = kwargs.pop("norm", None) |
|
activation = kwargs.pop("activation", None) |
|
super().__init__(*args, **kwargs) |
|
|
|
self.norm = norm |
|
self.activation = activation |
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
|
|
|
|
|
|
if not torch.jit.is_scripting(): |
|
if x.numel() == 0 and self.training: |
|
|
|
assert not isinstance( |
|
self.norm, torch.nn.SyncBatchNorm |
|
), "SyncBatchNorm does not support empty inputs!" |
|
|
|
x = F.conv2d( |
|
x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups |
|
) |
|
if self.norm is not None: |
|
x = self.norm(x) |
|
if self.activation is not None: |
|
x = self.activation(x) |
|
return x |
|
|
|
|
|
class FrozenBatchNorm2d(nn.Module): |
|
""" |
|
BatchNorm2d where the batch statistics and the affine parameters are fixed. |
|
|
|
It contains non-trainable buffers called |
|
"weight" and "bias", "running_mean", "running_var", |
|
initialized to perform identity transformation. |
|
|
|
The pre-trained backbone models from Caffe2 only contain "weight" and "bias", |
|
which are computed from the original four parameters of BN. |
|
The affine transform `x * weight + bias` will perform the equivalent |
|
computation of `(x - running_mean) / sqrt(running_var) * weight + bias`. |
|
When loading a backbone model from Caffe2, "running_mean" and "running_var" |
|
will be left unchanged as identity transformation. |
|
|
|
Other pre-trained backbone models may contain all 4 parameters. |
|
|
|
The forward is implemented by `F.batch_norm(..., training=False)`. |
|
""" |
|
|
|
_version = 3 |
|
|
|
def __init__(self, num_features, eps=1e-5): |
|
super().__init__() |
|
self.num_features = num_features |
|
self.eps = eps |
|
self.register_buffer("weight", torch.ones(num_features)) |
|
self.register_buffer("bias", torch.zeros(num_features)) |
|
self.register_buffer("running_mean", torch.zeros(num_features)) |
|
self.register_buffer("running_var", torch.ones(num_features) - eps) |
|
|
|
def forward(self, x): |
|
if x.requires_grad: |
|
|
|
|
|
scale = self.weight * (self.running_var + self.eps).rsqrt() |
|
bias = self.bias - self.running_mean * scale |
|
scale = scale.reshape(1, -1, 1, 1) |
|
bias = bias.reshape(1, -1, 1, 1) |
|
out_dtype = x.dtype |
|
return x * scale.to(out_dtype) + bias.to(out_dtype) |
|
else: |
|
|
|
|
|
return F.batch_norm( |
|
x, |
|
self.running_mean, |
|
self.running_var, |
|
self.weight, |
|
self.bias, |
|
training=False, |
|
eps=self.eps, |
|
) |
|
|
|
def _load_from_state_dict( |
|
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs |
|
): |
|
version = local_metadata.get("version", None) |
|
|
|
if version is None or version < 2: |
|
|
|
|
|
if prefix + "running_mean" not in state_dict: |
|
state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean) |
|
if prefix + "running_var" not in state_dict: |
|
state_dict[prefix + "running_var"] = torch.ones_like(self.running_var) |
|
|
|
super()._load_from_state_dict( |
|
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs |
|
) |
|
|
|
def __repr__(self): |
|
return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps) |
|
|
|
@classmethod |
|
def convert_frozen_batchnorm(cls, module): |
|
""" |
|
Convert all BatchNorm/SyncBatchNorm in module into FrozenBatchNorm. |
|
|
|
Args: |
|
module (torch.nn.Module): |
|
|
|
Returns: |
|
If module is BatchNorm/SyncBatchNorm, returns a new module. |
|
Otherwise, in-place convert module and return it. |
|
|
|
Similar to convert_sync_batchnorm in |
|
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py |
|
""" |
|
bn_module = nn.modules.batchnorm |
|
bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm) |
|
res = module |
|
if isinstance(module, bn_module): |
|
res = cls(module.num_features) |
|
if module.affine: |
|
res.weight.data = module.weight.data.clone().detach() |
|
res.bias.data = module.bias.data.clone().detach() |
|
res.running_mean.data = module.running_mean.data |
|
res.running_var.data = module.running_var.data |
|
res.eps = module.eps |
|
else: |
|
for name, child in module.named_children(): |
|
new_child = cls.convert_frozen_batchnorm(child) |
|
if new_child is not child: |
|
res.add_module(name, new_child) |
|
return res |
|
|
|
|
|
class CNNBlockBase(nn.Module): |
|
""" |
|
A CNN block is assumed to have input channels, output channels and a stride. |
|
The input and output of `forward()` method must be NCHW tensors. |
|
The method can perform arbitrary computation but must match the given |
|
channels and stride specification. |
|
|
|
Attribute: |
|
in_channels (int): |
|
out_channels (int): |
|
stride (int): |
|
""" |
|
|
|
def __init__(self, in_channels, out_channels, stride): |
|
""" |
|
The `__init__` method of any subclass should also contain these arguments. |
|
|
|
Args: |
|
in_channels (int): |
|
out_channels (int): |
|
stride (int): |
|
""" |
|
super().__init__() |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.stride = stride |
|
|
|
def freeze(self): |
|
""" |
|
Make this block not trainable. |
|
This method sets all parameters to `requires_grad=False`, |
|
and convert all BatchNorm layers to FrozenBatchNorm |
|
|
|
Returns: |
|
the block itself |
|
""" |
|
for p in self.parameters(): |
|
p.requires_grad = False |
|
FrozenBatchNorm2d.convert_frozen_batchnorm(self) |
|
return self |
|
|
|
class BottleneckBlock(CNNBlockBase): |
|
""" |
|
The standard bottleneck residual block used by ResNet-50, 101 and 152 |
|
defined in :paper:`ResNet`. It contains 3 conv layers with kernels |
|
1x1, 3x3, 1x1, and a projection shortcut if needed. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
|
|
bottleneck_channels, |
|
stride=1, |
|
num_groups=1, |
|
norm="BN", |
|
stride_in_1x1=False, |
|
dilation=1, |
|
): |
|
""" |
|
Args: |
|
bottleneck_channels (int): number of output channels for the 3x3 |
|
"bottleneck" conv layers. |
|
num_groups (int): number of groups for the 3x3 conv layer. |
|
norm (str or callable): normalization for all conv layers. |
|
See :func:`layers.get_norm` for supported format. |
|
stride_in_1x1 (bool): when stride>1, whether to put stride in the |
|
first 1x1 convolution or the bottleneck 3x3 convolution. |
|
dilation (int): the dilation rate of the 3x3 conv layer. |
|
""" |
|
super().__init__(in_channels, out_channels, stride) |
|
|
|
if in_channels != out_channels: |
|
self.shortcut = Conv2d( |
|
in_channels, |
|
out_channels, |
|
kernel_size=1, |
|
stride=stride, |
|
bias=False, |
|
norm=get_norm(norm, out_channels), |
|
) |
|
else: |
|
self.shortcut = None |
|
|
|
|
|
|
|
|
|
stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride) |
|
|
|
self.conv1 = Conv2d( |
|
in_channels, |
|
bottleneck_channels, |
|
kernel_size=1, |
|
stride=stride_1x1, |
|
bias=False, |
|
norm=get_norm(norm, bottleneck_channels), |
|
) |
|
|
|
self.conv2 = Conv2d( |
|
bottleneck_channels, |
|
bottleneck_channels, |
|
kernel_size=3, |
|
stride=stride_3x3, |
|
padding=1 * dilation, |
|
bias=False, |
|
groups=num_groups, |
|
dilation=dilation, |
|
norm=get_norm(norm, bottleneck_channels), |
|
) |
|
|
|
self.conv3 = Conv2d( |
|
bottleneck_channels, |
|
out_channels, |
|
kernel_size=1, |
|
bias=False, |
|
norm=get_norm(norm, out_channels), |
|
) |
|
|
|
for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]: |
|
if layer is not None: |
|
c2_msra_fill(layer) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
out = self.conv1(x) |
|
out = F.relu_(out) |
|
|
|
out = self.conv2(out) |
|
out = F.relu_(out) |
|
|
|
out = self.conv3(out) |
|
|
|
if self.shortcut is not None: |
|
shortcut = self.shortcut(x) |
|
else: |
|
shortcut = x |
|
|
|
out += shortcut |
|
out = F.relu_(out) |
|
return out |
|
|