Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
""" | |
Variant of the resnet module that takes cfg as an argument. | |
Example usage. Strings may be specified in the config file. | |
model = ResNet( | |
"StemWithFixedBatchNorm", | |
"BottleneckWithFixedBatchNorm", | |
"ResNet50StagesTo4", | |
) | |
OR: | |
model = ResNet( | |
"StemWithGN", | |
"BottleneckWithGN", | |
"ResNet50StagesTo4", | |
) | |
Custom implementations may be written in user code and hooked in via the | |
`register_*` functions. | |
""" | |
from collections import namedtuple | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from torch.nn import BatchNorm2d, SyncBatchNorm | |
from maskrcnn_benchmark.layers import FrozenBatchNorm2d, NaiveSyncBatchNorm2d | |
from maskrcnn_benchmark.layers import Conv2d, DFConv2d, SELayer | |
from maskrcnn_benchmark.modeling.make_layers import group_norm | |
from maskrcnn_benchmark.utils.registry import Registry | |
# ResNet stage specification | |
StageSpec = namedtuple( | |
"StageSpec", | |
[ | |
"index", # Index of the stage, eg 1, 2, ..,. 5 | |
"block_count", # Number of residual blocks in the stage | |
"return_features", # True => return the last feature map from this stage | |
], | |
) | |
# ----------------------------------------------------------------------------- | |
# Standard ResNet models | |
# ----------------------------------------------------------------------------- | |
# ResNet-50 (including all stages) | |
ResNet50StagesTo5 = tuple( | |
StageSpec(index=i, block_count=c, return_features=r) | |
for (i, c, r) in ((1, 3, False), (2, 4, False), (3, 6, False), (4, 3, True)) | |
) | |
# ResNet-50 up to stage 4 (excludes stage 5) | |
ResNet50StagesTo4 = tuple( | |
StageSpec(index=i, block_count=c, return_features=r) for (i, c, r) in ((1, 3, False), (2, 4, False), (3, 6, True)) | |
) | |
# ResNet-101 (including all stages) | |
ResNet101StagesTo5 = tuple( | |
StageSpec(index=i, block_count=c, return_features=r) | |
for (i, c, r) in ((1, 3, False), (2, 4, False), (3, 23, False), (4, 3, True)) | |
) | |
# ResNet-101 up to stage 4 (excludes stage 5) | |
ResNet101StagesTo4 = tuple( | |
StageSpec(index=i, block_count=c, return_features=r) for (i, c, r) in ((1, 3, False), (2, 4, False), (3, 23, True)) | |
) | |
# ResNet-50-FPN (including all stages) | |
ResNet50FPNStagesTo5 = tuple( | |
StageSpec(index=i, block_count=c, return_features=r) | |
for (i, c, r) in ((1, 3, True), (2, 4, True), (3, 6, True), (4, 3, True)) | |
) | |
# ResNet-101-FPN (including all stages) | |
ResNet101FPNStagesTo5 = tuple( | |
StageSpec(index=i, block_count=c, return_features=r) | |
for (i, c, r) in ((1, 3, True), (2, 4, True), (3, 23, True), (4, 3, True)) | |
) | |
# ResNet-152-FPN (including all stages) | |
ResNet152FPNStagesTo5 = tuple( | |
StageSpec(index=i, block_count=c, return_features=r) | |
for (i, c, r) in ((1, 3, True), (2, 8, True), (3, 36, True), (4, 3, True)) | |
) | |
class ResNet(nn.Module): | |
def __init__(self, cfg): | |
super(ResNet, self).__init__() | |
# If we want to use the cfg in forward(), then we should make a copy | |
# of it and store it for later use: | |
# self.cfg = cfg.clone() | |
# Translate string names to implementations | |
norm_level = None | |
stem_module = _STEM_MODULES[cfg.MODEL.RESNETS.STEM_FUNC] | |
stage_specs = _STAGE_SPECS[cfg.MODEL.BACKBONE.CONV_BODY] | |
transformation_module = _TRANSFORMATION_MODULES[cfg.MODEL.RESNETS.TRANS_FUNC] | |
if cfg.MODEL.BACKBONE.USE_BN: | |
stem_module = StemWithBatchNorm | |
transformation_module = BottleneckWithBatchNorm | |
norm_level = cfg.MODEL.BACKBONE.NORM_LEVEL | |
elif cfg.MODEL.BACKBONE.USE_NSYNCBN: | |
stem_module = StemWithNaiveSyncBatchNorm | |
transformation_module = BottleneckWithNaiveSyncBatchNorm | |
norm_level = cfg.MODEL.BACKBONE.NORM_LEVEL | |
elif cfg.MODEL.BACKBONE.USE_SYNCBN: | |
stem_module = StemWithSyncBatchNorm | |
transformation_module = BottleneckWithSyncBatchNorm | |
norm_level = cfg.MODEL.BACKBONE.NORM_LEVEL | |
# Construct the stem module | |
self.stem = stem_module(cfg) | |
# Constuct the specified ResNet stages | |
num_groups = cfg.MODEL.RESNETS.NUM_GROUPS | |
width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP | |
in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS | |
stage2_bottleneck_channels = num_groups * width_per_group | |
stage2_out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS | |
with_se = cfg.MODEL.RESNETS.WITH_SE | |
self.stages = [] | |
self.out_channels = [] | |
self.return_features = {} | |
for stage_spec in stage_specs: | |
name = "layer" + str(stage_spec.index) | |
stage2_relative_factor = 2 ** (stage_spec.index - 1) | |
bottleneck_channels = stage2_bottleneck_channels * stage2_relative_factor | |
out_channels = stage2_out_channels * stage2_relative_factor | |
stage_with_dcn = cfg.MODEL.RESNETS.STAGE_WITH_DCN[stage_spec.index - 1] | |
if cfg.MODEL.RESNETS.USE_AVG_DOWN: | |
avg_down_stride = 1 if stage_spec.index == 1 else 2 | |
else: | |
avg_down_stride = 0 | |
module = _make_stage( | |
transformation_module, | |
in_channels, | |
bottleneck_channels, | |
out_channels, | |
stage_spec.block_count, | |
num_groups, | |
cfg.MODEL.RESNETS.STRIDE_IN_1X1, | |
first_stride=int(stage_spec.index > 1) + 1, | |
dcn_config={ | |
"stage_with_dcn": stage_with_dcn, | |
"with_modulated_dcn": cfg.MODEL.RESNETS.WITH_MODULATED_DCN, | |
"deformable_groups": cfg.MODEL.RESNETS.DEFORMABLE_GROUPS, | |
}, | |
norm_level=norm_level, | |
with_se=with_se, | |
avg_down_stride=avg_down_stride, | |
) | |
in_channels = out_channels | |
self.add_module(name, module) | |
self.stages.append(name) | |
self.out_channels.append(out_channels) | |
self.return_features[name] = stage_spec.return_features | |
# Optionally freeze (requires_grad=False) parts of the backbone | |
self._freeze_backbone(cfg.MODEL.BACKBONE.FREEZE_CONV_BODY_AT) | |
def _freeze_backbone(self, freeze_at): | |
if freeze_at < 0: | |
return | |
for stage_index in range(freeze_at): | |
if stage_index == 0: | |
m = self.stem # stage 0 is the stem | |
else: | |
m = getattr(self, "layer" + str(stage_index)) | |
for p in m.parameters(): | |
p.requires_grad = False | |
def forward(self, x): | |
outputs = [] | |
x = self.stem(x) | |
for stage_name in self.stages: | |
x = getattr(self, stage_name)(x) | |
if self.return_features[stage_name]: | |
outputs.append(x) | |
return outputs | |
class ResNetHead(nn.Module): | |
def __init__( | |
self, | |
block_module, | |
stages, | |
num_groups=1, | |
width_per_group=64, | |
stride_in_1x1=True, | |
stride_init=None, | |
res2_out_channels=256, | |
dilation=1, | |
dcn_config=None, | |
): | |
super(ResNetHead, self).__init__() | |
stage2_relative_factor = 2 ** (stages[0].index - 1) | |
stage2_bottleneck_channels = num_groups * width_per_group | |
out_channels = res2_out_channels * stage2_relative_factor | |
in_channels = out_channels // 2 | |
bottleneck_channels = stage2_bottleneck_channels * stage2_relative_factor | |
block_module = _TRANSFORMATION_MODULES[block_module] | |
self.stages = [] | |
stride = stride_init | |
for stage in stages: | |
name = "layer" + str(stage.index) | |
if not stride: | |
stride = int(stage.index > 1) + 1 | |
module = _make_stage( | |
block_module, | |
in_channels, | |
bottleneck_channels, | |
out_channels, | |
stage.block_count, | |
num_groups, | |
stride_in_1x1, | |
first_stride=stride, | |
dilation=dilation, | |
dcn_config=dcn_config, | |
) | |
stride = None | |
self.add_module(name, module) | |
self.stages.append(name) | |
self.out_channels = out_channels | |
def forward(self, x): | |
for stage in self.stages: | |
x = getattr(self, stage)(x) | |
return x | |
def _make_stage( | |
transformation_module, | |
in_channels, | |
bottleneck_channels, | |
out_channels, | |
block_count, | |
num_groups, | |
stride_in_1x1, | |
first_stride, | |
dilation=1, | |
dcn_config=None, | |
norm_level=None, | |
**kwargs | |
): | |
blocks = [] | |
stride = first_stride | |
for li in range(block_count): | |
if norm_level is not None: | |
layer_module = BottleneckWithFixedBatchNorm | |
if norm_level >= 1 and li == 0: | |
layer_module = transformation_module | |
if norm_level >= 2 and li == block_count - 1: | |
layer_module = transformation_module | |
if norm_level >= 3: | |
layer_module = transformation_module | |
else: | |
layer_module = transformation_module | |
blocks.append( | |
layer_module( | |
in_channels, | |
bottleneck_channels, | |
out_channels, | |
num_groups, | |
stride_in_1x1, | |
stride, | |
dilation=dilation, | |
dcn_config=dcn_config, | |
**kwargs | |
) | |
) | |
stride = 1 | |
in_channels = out_channels | |
return nn.Sequential(*blocks) | |
class Bottleneck(nn.Module): | |
def __init__( | |
self, | |
in_channels, | |
bottleneck_channels, | |
out_channels, | |
num_groups, | |
stride_in_1x1, | |
stride, | |
dilation, | |
norm_func, | |
dcn_config, | |
with_se=False, | |
avg_down_stride=0, | |
): | |
super(Bottleneck, self).__init__() | |
self.downsample = None | |
if in_channels != out_channels: | |
down_stride = stride if dilation == 1 else 1 | |
if avg_down_stride > 0: | |
self.downsample = nn.Sequential( | |
nn.AvgPool2d( | |
kernel_size=avg_down_stride, stride=avg_down_stride, ceil_mode=True, count_include_pad=False | |
), | |
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False), | |
norm_func(out_channels), | |
) | |
else: | |
self.downsample = nn.Sequential( | |
Conv2d(in_channels, out_channels, kernel_size=1, stride=down_stride, bias=False), | |
norm_func(out_channels), | |
) | |
for modules in [ | |
self.downsample, | |
]: | |
for l in modules.modules(): | |
if isinstance(l, Conv2d): | |
nn.init.kaiming_uniform_(l.weight, a=1) | |
if dilation > 1: | |
stride = 1 # reset to be 1 | |
# The original MSRA ResNet models have stride in the first 1x1 conv | |
# The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have | |
# stride in the 3x3 conv | |
stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride) | |
self.conv1 = Conv2d( | |
in_channels, | |
bottleneck_channels, | |
kernel_size=1, | |
stride=stride_1x1, | |
bias=False, | |
) | |
self.bn1 = norm_func(bottleneck_channels) | |
# TODO: specify init for the above | |
with_dcn = dcn_config.get("stage_with_dcn", False) | |
if with_dcn: | |
deformable_groups = dcn_config.get("deformable_groups", 1) | |
with_modulated_dcn = dcn_config.get("with_modulated_dcn", False) | |
self.conv2 = DFConv2d( | |
bottleneck_channels, | |
bottleneck_channels, | |
with_modulated_dcn=with_modulated_dcn, | |
kernel_size=3, | |
stride=stride_3x3, | |
groups=num_groups, | |
dilation=dilation, | |
deformable_groups=deformable_groups, | |
bias=False, | |
) | |
else: | |
self.conv2 = Conv2d( | |
bottleneck_channels, | |
bottleneck_channels, | |
kernel_size=3, | |
stride=stride_3x3, | |
padding=dilation, | |
bias=False, | |
groups=num_groups, | |
dilation=dilation, | |
) | |
nn.init.kaiming_uniform_(self.conv2.weight, a=1) | |
self.bn2 = norm_func(bottleneck_channels) | |
self.conv3 = Conv2d(bottleneck_channels, out_channels, kernel_size=1, bias=False) | |
self.bn3 = norm_func(out_channels) | |
self.se = SELayer(out_channels) if with_se and not with_dcn else None | |
for l in [ | |
self.conv1, | |
self.conv3, | |
]: | |
nn.init.kaiming_uniform_(l.weight, a=1) | |
def forward(self, x): | |
identity = x | |
out = self.conv1(x) | |
out = self.bn1(out) | |
out = F.relu_(out) | |
out = self.conv2(out) | |
out = self.bn2(out) | |
out = F.relu_(out) | |
out0 = self.conv3(out) | |
out = self.bn3(out0) | |
if self.se: | |
out = self.se(out) | |
if self.downsample is not None: | |
identity = self.downsample(x) | |
out += identity | |
out = F.relu_(out) | |
return out | |
class BaseStem(nn.Module): | |
def __init__(self, cfg, norm_func): | |
super(BaseStem, self).__init__() | |
out_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS | |
self.stem_3x3 = cfg.MODEL.RESNETS.USE_STEM3X3 | |
if self.stem_3x3: | |
self.conv1 = Conv2d(3, out_channels, kernel_size=3, stride=2, padding=1, bias=False) | |
self.bn1 = norm_func(out_channels) | |
self.conv2 = Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False) | |
self.bn2 = norm_func(out_channels) | |
for l in [self.conv1, self.conv2]: | |
nn.init.kaiming_uniform_(l.weight, a=1) | |
else: | |
self.conv1 = Conv2d(3, out_channels, kernel_size=7, stride=2, padding=3, bias=False) | |
self.bn1 = norm_func(out_channels) | |
for l in [ | |
self.conv1, | |
]: | |
nn.init.kaiming_uniform_(l.weight, a=1) | |
def forward(self, x): | |
if self.stem_3x3: | |
x = self.conv1(x) | |
x = self.bn1(x) | |
x = F.relu_(x) | |
x = self.conv2(x) | |
x = self.bn2(x) | |
x = F.relu_(x) | |
else: | |
x = self.conv1(x) | |
x = self.bn1(x) | |
x = F.relu_(x) | |
x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) | |
return x | |
class BottleneckWithFixedBatchNorm(Bottleneck): | |
def __init__( | |
self, | |
in_channels, | |
bottleneck_channels, | |
out_channels, | |
num_groups=1, | |
stride_in_1x1=True, | |
stride=1, | |
dilation=1, | |
dcn_config=None, | |
**kwargs | |
): | |
super(BottleneckWithFixedBatchNorm, self).__init__( | |
in_channels=in_channels, | |
bottleneck_channels=bottleneck_channels, | |
out_channels=out_channels, | |
num_groups=num_groups, | |
stride_in_1x1=stride_in_1x1, | |
stride=stride, | |
dilation=dilation, | |
norm_func=FrozenBatchNorm2d, | |
dcn_config=dcn_config, | |
**kwargs | |
) | |
class StemWithFixedBatchNorm(BaseStem): | |
def __init__(self, cfg): | |
super(StemWithFixedBatchNorm, self).__init__(cfg, norm_func=FrozenBatchNorm2d) | |
class BottleneckWithBatchNorm(Bottleneck): | |
def __init__( | |
self, | |
in_channels, | |
bottleneck_channels, | |
out_channels, | |
num_groups=1, | |
stride_in_1x1=True, | |
stride=1, | |
dilation=1, | |
dcn_config=None, | |
**kwargs | |
): | |
super(BottleneckWithBatchNorm, self).__init__( | |
in_channels=in_channels, | |
bottleneck_channels=bottleneck_channels, | |
out_channels=out_channels, | |
num_groups=num_groups, | |
stride_in_1x1=stride_in_1x1, | |
stride=stride, | |
dilation=dilation, | |
norm_func=BatchNorm2d, | |
dcn_config=dcn_config, | |
**kwargs | |
) | |
class StemWithBatchNorm(BaseStem): | |
def __init__(self, cfg): | |
super(StemWithBatchNorm, self).__init__(cfg, norm_func=BatchNorm2d) | |
class BottleneckWithNaiveSyncBatchNorm(Bottleneck): | |
def __init__( | |
self, | |
in_channels, | |
bottleneck_channels, | |
out_channels, | |
num_groups=1, | |
stride_in_1x1=True, | |
stride=1, | |
dilation=1, | |
dcn_config=None, | |
**kwargs | |
): | |
super(BottleneckWithNaiveSyncBatchNorm, self).__init__( | |
in_channels=in_channels, | |
bottleneck_channels=bottleneck_channels, | |
out_channels=out_channels, | |
num_groups=num_groups, | |
stride_in_1x1=stride_in_1x1, | |
stride=stride, | |
dilation=dilation, | |
norm_func=NaiveSyncBatchNorm2d, | |
dcn_config=dcn_config, | |
**kwargs | |
) | |
class StemWithNaiveSyncBatchNorm(BaseStem): | |
def __init__(self, cfg): | |
super(StemWithNaiveSyncBatchNorm, self).__init__(cfg, norm_func=NaiveSyncBatchNorm2d) | |
class BottleneckWithSyncBatchNorm(Bottleneck): | |
def __init__( | |
self, | |
in_channels, | |
bottleneck_channels, | |
out_channels, | |
num_groups=1, | |
stride_in_1x1=True, | |
stride=1, | |
dilation=1, | |
dcn_config=None, | |
**kwargs | |
): | |
super(BottleneckWithSyncBatchNorm, self).__init__( | |
in_channels=in_channels, | |
bottleneck_channels=bottleneck_channels, | |
out_channels=out_channels, | |
num_groups=num_groups, | |
stride_in_1x1=stride_in_1x1, | |
stride=stride, | |
dilation=dilation, | |
norm_func=SyncBatchNorm, | |
dcn_config=dcn_config, | |
**kwargs | |
) | |
class StemWithSyncBatchNorm(BaseStem): | |
def __init__(self, cfg): | |
super(StemWithSyncBatchNorm, self).__init__(cfg, norm_func=SyncBatchNorm) | |
class BottleneckWithGN(Bottleneck): | |
def __init__( | |
self, | |
in_channels, | |
bottleneck_channels, | |
out_channels, | |
num_groups=1, | |
stride_in_1x1=True, | |
stride=1, | |
dilation=1, | |
dcn_config=None, | |
**kwargs | |
): | |
super(BottleneckWithGN, self).__init__( | |
in_channels=in_channels, | |
bottleneck_channels=bottleneck_channels, | |
out_channels=out_channels, | |
num_groups=num_groups, | |
stride_in_1x1=stride_in_1x1, | |
stride=stride, | |
dilation=dilation, | |
norm_func=group_norm, | |
dcn_config=dcn_config, | |
**kwargs | |
) | |
class StemWithGN(BaseStem): | |
def __init__(self, cfg): | |
super(StemWithGN, self).__init__(cfg, norm_func=group_norm) | |
_TRANSFORMATION_MODULES = Registry( | |
{ | |
"BottleneckWithFixedBatchNorm": BottleneckWithFixedBatchNorm, | |
"BottleneckWithGN": BottleneckWithGN, | |
} | |
) | |
_STEM_MODULES = Registry( | |
{ | |
"StemWithFixedBatchNorm": StemWithFixedBatchNorm, | |
"StemWithGN": StemWithGN, | |
} | |
) | |
_STAGE_SPECS = Registry( | |
{ | |
"R-50-C4": ResNet50StagesTo4, | |
"R-50-C5": ResNet50StagesTo5, | |
"R-50-RETINANET": ResNet50StagesTo5, | |
"R-101-C4": ResNet101StagesTo4, | |
"R-101-C5": ResNet101StagesTo5, | |
"R-101-RETINANET": ResNet101StagesTo5, | |
"R-50-FPN": ResNet50FPNStagesTo5, | |
"R-50-FPN-RETINANET": ResNet50FPNStagesTo5, | |
"R-50-FPN-FCOS": ResNet50FPNStagesTo5, | |
"R-101-FPN": ResNet101FPNStagesTo5, | |
"R-101-FPN-RETINANET": ResNet101FPNStagesTo5, | |
"R-101-FPN-FCOS": ResNet101FPNStagesTo5, | |
"R-152-FPN": ResNet152FPNStagesTo5, | |
} | |
) | |