DRv2 / Model /AttDes /models /resblock.py
Zhonathon's picture
update all file v1
aa7fb02
#代码借鉴于 https://github.com/facebookresearch/detectron2
from torch import nn
import torch
import torch.nn.functional as F
def c2_msra_fill(module: nn.Module) -> None:
"""
Initialize `module.weight` using the "MSRAFill" implemented in Caffe2.
Also initializes `module.bias` to 0.
Args:
module (torch.nn.Module): module to initialize.
"""
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
if module.bias is not None:
# pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[nn.Module,
# torch.Tensor]`.
nn.init.constant_(module.bias, 0)
def get_norm(norm, out_channels):
"""
Args:
norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
or a callable that takes a channel number and returns
the normalization layer as a nn.Module.
Returns:
nn.Module or None: the normalization layer
"""
if norm is None:
return None
if isinstance(norm, str):
if len(norm) == 0:
return None
norm = {
"BN": torch.nn.BatchNorm2d,
# Fixed in https://github.com/pytorch/pytorch/pull/36382
#"SyncBN": NaiveSyncBatchNorm if env.TORCH_VERSION <= (1, 5) else nn.SyncBatchNorm,
"FrozenBN": FrozenBatchNorm2d,
"GN": lambda channels: nn.GroupNorm(32, channels),
# for debugging:
"nnSyncBN": nn.SyncBatchNorm,
#"naiveSyncBN": NaiveSyncBatchNorm,
# expose stats_mode N as an option to caller, required for zero-len inputs
#"naiveSyncBN_N": lambda channels: NaiveSyncBatchNorm(channels, stats_mode="N"),
}[norm]
return norm(out_channels)
class Conv2d(torch.nn.Conv2d):
"""
A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features.
"""
def __init__(self, *args, **kwargs):
"""
Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`:
Args:
norm (nn.Module, optional): a normalization layer
activation (callable(Tensor) -> Tensor): a callable activation function
It assumes that norm layer is used before activation.
"""
norm = kwargs.pop("norm", None)
activation = kwargs.pop("activation", None)
super().__init__(*args, **kwargs)
self.norm = norm
self.activation = activation
def forward(self, x):
# torchscript does not support SyncBatchNorm yet
# https://github.com/pytorch/pytorch/issues/40507
# and we skip these codes in torchscript since:
# 1. currently we only support torchscript in evaluation mode
# 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or
# later version, `Conv2d` in these PyTorch versions has already supported empty inputs.
if not torch.jit.is_scripting():
if x.numel() == 0 and self.training:
# https://github.com/pytorch/pytorch/issues/12013
assert not isinstance(
self.norm, torch.nn.SyncBatchNorm
), "SyncBatchNorm does not support empty inputs!"
x = F.conv2d(
x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
)
if self.norm is not None:
x = self.norm(x)
if self.activation is not None:
x = self.activation(x)
return x
class FrozenBatchNorm2d(nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.
It contains non-trainable buffers called
"weight" and "bias", "running_mean", "running_var",
initialized to perform identity transformation.
The pre-trained backbone models from Caffe2 only contain "weight" and "bias",
which are computed from the original four parameters of BN.
The affine transform `x * weight + bias` will perform the equivalent
computation of `(x - running_mean) / sqrt(running_var) * weight + bias`.
When loading a backbone model from Caffe2, "running_mean" and "running_var"
will be left unchanged as identity transformation.
Other pre-trained backbone models may contain all 4 parameters.
The forward is implemented by `F.batch_norm(..., training=False)`.
"""
_version = 3
def __init__(self, num_features, eps=1e-5):
super().__init__()
self.num_features = num_features
self.eps = eps
self.register_buffer("weight", torch.ones(num_features))
self.register_buffer("bias", torch.zeros(num_features))
self.register_buffer("running_mean", torch.zeros(num_features))
self.register_buffer("running_var", torch.ones(num_features) - eps)
def forward(self, x):
if x.requires_grad:
# When gradients are needed, F.batch_norm will use extra memory
# because its backward op computes gradients for weight/bias as well.
scale = self.weight * (self.running_var + self.eps).rsqrt()
bias = self.bias - self.running_mean * scale
scale = scale.reshape(1, -1, 1, 1)
bias = bias.reshape(1, -1, 1, 1)
out_dtype = x.dtype # may be half
return x * scale.to(out_dtype) + bias.to(out_dtype)
else:
# When gradients are not needed, F.batch_norm is a single fused op
# and provide more optimization opportunities.
return F.batch_norm(
x,
self.running_mean,
self.running_var,
self.weight,
self.bias,
training=False,
eps=self.eps,
)
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
version = local_metadata.get("version", None)
if version is None or version < 2:
# No running_mean/var in early versions
# This will silent the warnings
if prefix + "running_mean" not in state_dict:
state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean)
if prefix + "running_var" not in state_dict:
state_dict[prefix + "running_var"] = torch.ones_like(self.running_var)
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
def __repr__(self):
return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps)
@classmethod
def convert_frozen_batchnorm(cls, module):
"""
Convert all BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.
Args:
module (torch.nn.Module):
Returns:
If module is BatchNorm/SyncBatchNorm, returns a new module.
Otherwise, in-place convert module and return it.
Similar to convert_sync_batchnorm in
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
"""
bn_module = nn.modules.batchnorm
bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm)
res = module
if isinstance(module, bn_module):
res = cls(module.num_features)
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for name, child in module.named_children():
new_child = cls.convert_frozen_batchnorm(child)
if new_child is not child:
res.add_module(name, new_child)
return res
class CNNBlockBase(nn.Module):
"""
A CNN block is assumed to have input channels, output channels and a stride.
The input and output of `forward()` method must be NCHW tensors.
The method can perform arbitrary computation but must match the given
channels and stride specification.
Attribute:
in_channels (int):
out_channels (int):
stride (int):
"""
def __init__(self, in_channels, out_channels, stride):
"""
The `__init__` method of any subclass should also contain these arguments.
Args:
in_channels (int):
out_channels (int):
stride (int):
"""
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.stride = stride
def freeze(self):
"""
Make this block not trainable.
This method sets all parameters to `requires_grad=False`,
and convert all BatchNorm layers to FrozenBatchNorm
Returns:
the block itself
"""
for p in self.parameters():
p.requires_grad = False
FrozenBatchNorm2d.convert_frozen_batchnorm(self)
return self
class BottleneckBlock(CNNBlockBase):
"""
The standard bottleneck residual block used by ResNet-50, 101 and 152
defined in :paper:`ResNet`. It contains 3 conv layers with kernels
1x1, 3x3, 1x1, and a projection shortcut if needed.
"""
def __init__(
self,
in_channels,
out_channels,
#*,
bottleneck_channels,
stride=1,
num_groups=1,
norm="BN",
stride_in_1x1=False,
dilation=1,
):
"""
Args:
bottleneck_channels (int): number of output channels for the 3x3
"bottleneck" conv layers.
num_groups (int): number of groups for the 3x3 conv layer.
norm (str or callable): normalization for all conv layers.
See :func:`layers.get_norm` for supported format.
stride_in_1x1 (bool): when stride>1, whether to put stride in the
first 1x1 convolution or the bottleneck 3x3 convolution.
dilation (int): the dilation rate of the 3x3 conv layer.
"""
super().__init__(in_channels, out_channels, stride)
if in_channels != out_channels:
self.shortcut = Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=stride,
bias=False,
norm=get_norm(norm, out_channels),
)
else:
self.shortcut = None
# The original MSRA ResNet models have stride in the first 1x1 conv
# The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have
# stride in the 3x3 conv
stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
self.conv1 = Conv2d(
in_channels,
bottleneck_channels,
kernel_size=1,
stride=stride_1x1,
bias=False,
norm=get_norm(norm, bottleneck_channels),
)
self.conv2 = Conv2d(
bottleneck_channels,
bottleneck_channels,
kernel_size=3,
stride=stride_3x3,
padding=1 * dilation,
bias=False,
groups=num_groups,
dilation=dilation,
norm=get_norm(norm, bottleneck_channels),
)
self.conv3 = Conv2d(
bottleneck_channels,
out_channels,
kernel_size=1,
bias=False,
norm=get_norm(norm, out_channels),
)
for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
if layer is not None: # shortcut can be None
c2_msra_fill(layer)
# Zero-initialize the last normalization in each residual branch,
# so that at the beginning, the residual branch starts with zeros,
# and each residual block behaves like an identity.
# See Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
# "For BN layers, the learnable scaling coefficient γ is initialized
# to be 1, except for each residual block's last BN
# where γ is initialized to be 0."
# nn.init.constant_(self.conv3.norm.weight, 0)
# TODO this somehow hurts performance when training GN models from scratch.
# Add it as an option when we need to use this code to train a backbone.
def forward(self, x):
out = self.conv1(x)
out = F.relu_(out)
out = self.conv2(out)
out = F.relu_(out)
out = self.conv3(out)
if self.shortcut is not None:
shortcut = self.shortcut(x)
else:
shortcut = x
out += shortcut
out = F.relu_(out)
return out