zdou0830's picture
desco
749745d
raw
history blame
No virus
15.3 kB
"""
FBNet model builder
"""
from __future__ import absolute_import, division, print_function, unicode_literals
import copy
import logging
import math
from collections import OrderedDict
import torch
import torch.nn as nn
from torch.nn import BatchNorm2d, SyncBatchNorm
from maskrcnn_benchmark.layers import Conv2d, interpolate
from maskrcnn_benchmark.layers import NaiveSyncBatchNorm2d, FrozenBatchNorm2d
from maskrcnn_benchmark.layers.misc import _NewEmptyTensorOp
logger = logging.getLogger(__name__)
def _py2_round(x):
return math.floor(x + 0.5) if x >= 0.0 else math.ceil(x - 0.5)
def _get_divisible_by(num, divisible_by, min_val):
ret = int(num)
if divisible_by > 0 and num % divisible_by != 0:
ret = int((_py2_round(num / divisible_by) or min_val) * divisible_by)
return ret
class Identity(nn.Module):
def __init__(self, C_in, C_out, stride):
super(Identity, self).__init__()
self.conv = (
ConvBNRelu(
C_in,
C_out,
kernel=1,
stride=stride,
pad=0,
no_bias=1,
use_relu="relu",
bn_type="bn",
)
if C_in != C_out or stride != 1
else None
)
def forward(self, x):
if self.conv:
out = self.conv(x)
else:
out = x
return out
class CascadeConv3x3(nn.Sequential):
def __init__(self, C_in, C_out, stride):
assert stride in [1, 2]
ops = [
Conv2d(C_in, C_in, 3, stride, 1, bias=False),
BatchNorm2d(C_in),
nn.ReLU(inplace=True),
Conv2d(C_in, C_out, 3, 1, 1, bias=False),
BatchNorm2d(C_out),
]
super(CascadeConv3x3, self).__init__(*ops)
self.res_connect = (stride == 1) and (C_in == C_out)
def forward(self, x):
y = super(CascadeConv3x3, self).forward(x)
if self.res_connect:
y += x
return y
class Shift(nn.Module):
def __init__(self, C, kernel_size, stride, padding):
super(Shift, self).__init__()
self.C = C
kernel = torch.zeros((C, 1, kernel_size, kernel_size), dtype=torch.float32)
ch_idx = 0
assert stride in [1, 2]
self.stride = stride
self.padding = padding
self.kernel_size = kernel_size
self.dilation = 1
hks = kernel_size // 2
ksq = kernel_size**2
for i in range(kernel_size):
for j in range(kernel_size):
if i == hks and j == hks:
num_ch = C // ksq + C % ksq
else:
num_ch = C // ksq
kernel[ch_idx : ch_idx + num_ch, 0, i, j] = 1
ch_idx += num_ch
self.register_parameter("bias", None)
self.kernel = nn.Parameter(kernel, requires_grad=False)
def forward(self, x):
if x.numel() > 0:
return nn.functional.conv2d(
x,
self.kernel,
self.bias,
(self.stride, self.stride),
(self.padding, self.padding),
self.dilation,
self.C, # groups
)
output_shape = [
(i + 2 * p - (di * (k - 1) + 1)) // d + 1
for i, p, di, k, d in zip(
x.shape[-2:],
(self.padding, self.dilation),
(self.dilation, self.dilation),
(self.kernel_size, self.kernel_size),
(self.stride, self.stride),
)
]
output_shape = [x.shape[0], self.C] + output_shape
return _NewEmptyTensorOp.apply(x, output_shape)
class ShiftBlock5x5(nn.Sequential):
def __init__(self, C_in, C_out, expansion, stride):
assert stride in [1, 2]
self.res_connect = (stride == 1) and (C_in == C_out)
C_mid = _get_divisible_by(C_in * expansion, 8, 8)
ops = [
# pw
Conv2d(C_in, C_mid, 1, 1, 0, bias=False),
BatchNorm2d(C_mid),
nn.ReLU(inplace=True),
# shift
Shift(C_mid, 5, stride, 2),
# pw-linear
Conv2d(C_mid, C_out, 1, 1, 0, bias=False),
BatchNorm2d(C_out),
]
super(ShiftBlock5x5, self).__init__(*ops)
def forward(self, x):
y = super(ShiftBlock5x5, self).forward(x)
if self.res_connect:
y += x
return y
class ChannelShuffle(nn.Module):
def __init__(self, groups):
super(ChannelShuffle, self).__init__()
self.groups = groups
def forward(self, x):
"""Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]"""
N, C, H, W = x.size()
g = self.groups
assert C % g == 0, "Incompatible group size {} for input channel {}".format(g, C)
return x.view(N, g, int(C / g), H, W).permute(0, 2, 1, 3, 4).contiguous().view(N, C, H, W)
class ConvBNRelu(nn.Sequential):
def __init__(
self, input_depth, output_depth, kernel, stride, pad, no_bias, use_relu, bn_type, group=1, *args, **kwargs
):
super(ConvBNRelu, self).__init__()
assert use_relu in ["relu", None]
if isinstance(bn_type, (list, tuple)):
assert len(bn_type) == 2
assert bn_type[0] == "gn"
gn_group = bn_type[1]
bn_type = bn_type[0]
assert bn_type in ["bn", "nsbn", "sbn", "af", "gn", None]
assert stride in [1, 2, 4]
op = Conv2d(
input_depth,
output_depth,
kernel_size=kernel,
stride=stride,
padding=pad,
bias=not no_bias,
groups=group,
*args,
**kwargs
)
nn.init.kaiming_normal_(op.weight, mode="fan_out", nonlinearity="relu")
if op.bias is not None:
nn.init.constant_(op.bias, 0.0)
self.add_module("conv", op)
if bn_type == "bn":
bn_op = BatchNorm2d(output_depth)
elif bn_type == "sbn":
bn_op = SyncBatchNorm(output_depth)
elif bn_type == "nsbn":
bn_op = NaiveSyncBatchNorm2d(output_depth)
elif bn_type == "gn":
bn_op = nn.GroupNorm(num_groups=gn_group, num_channels=output_depth)
elif bn_type == "af":
bn_op = FrozenBatchNorm2d(output_depth)
if bn_type is not None:
self.add_module("bn", bn_op)
if use_relu == "relu":
self.add_module("relu", nn.ReLU(inplace=True))
class SEModule(nn.Module):
reduction = 4
def __init__(self, C):
super(SEModule, self).__init__()
mid = max(C // self.reduction, 8)
conv1 = Conv2d(C, mid, 1, 1, 0)
conv2 = Conv2d(mid, C, 1, 1, 0)
self.op = nn.Sequential(nn.AdaptiveAvgPool2d(1), conv1, nn.ReLU(inplace=True), conv2, nn.Sigmoid())
def forward(self, x):
return x * self.op(x)
class Upsample(nn.Module):
def __init__(self, scale_factor, mode, align_corners=None):
super(Upsample, self).__init__()
self.scale = scale_factor
self.mode = mode
self.align_corners = align_corners
def forward(self, x):
return interpolate(x, scale_factor=self.scale, mode=self.mode, align_corners=self.align_corners)
def _get_upsample_op(stride):
assert (
stride in [1, 2, 4]
or stride in [-1, -2, -4]
or (isinstance(stride, tuple) and all(x in [-1, -2, -4] for x in stride))
)
scales = stride
ret = None
if isinstance(stride, tuple) or stride < 0:
scales = [-x for x in stride] if isinstance(stride, tuple) else -stride
stride = 1
ret = Upsample(scale_factor=scales, mode="nearest", align_corners=None)
return ret, stride
class IRFBlock(nn.Module):
def __init__(
self,
input_depth,
output_depth,
expansion,
stride,
bn_type="bn",
kernel=3,
width_divisor=1,
shuffle_type=None,
pw_group=1,
se=False,
cdw=False,
dw_skip_bn=False,
dw_skip_relu=False,
):
super(IRFBlock, self).__init__()
assert kernel in [1, 3, 5, 7], kernel
self.use_res_connect = stride == 1 and input_depth == output_depth
self.output_depth = output_depth
mid_depth = int(input_depth * expansion)
mid_depth = _get_divisible_by(mid_depth, width_divisor, width_divisor)
# pw
self.pw = ConvBNRelu(
input_depth,
mid_depth,
kernel=1,
stride=1,
pad=0,
no_bias=1,
use_relu="relu",
bn_type=bn_type,
group=pw_group,
)
# negative stride to do upsampling
self.upscale, stride = _get_upsample_op(stride)
# dw
if kernel == 1:
self.dw = nn.Sequential()
elif cdw:
dw1 = ConvBNRelu(
mid_depth,
mid_depth,
kernel=kernel,
stride=stride,
pad=(kernel // 2),
group=mid_depth,
no_bias=1,
use_relu="relu",
bn_type=bn_type,
)
dw2 = ConvBNRelu(
mid_depth,
mid_depth,
kernel=kernel,
stride=1,
pad=(kernel // 2),
group=mid_depth,
no_bias=1,
use_relu="relu" if not dw_skip_relu else None,
bn_type=bn_type if not dw_skip_bn else None,
)
self.dw = nn.Sequential(OrderedDict([("dw1", dw1), ("dw2", dw2)]))
else:
self.dw = ConvBNRelu(
mid_depth,
mid_depth,
kernel=kernel,
stride=stride,
pad=(kernel // 2),
group=mid_depth,
no_bias=1,
use_relu="relu" if not dw_skip_relu else None,
bn_type=bn_type if not dw_skip_bn else None,
)
# pw-linear
self.pwl = ConvBNRelu(
mid_depth,
output_depth,
kernel=1,
stride=1,
pad=0,
no_bias=1,
use_relu=None,
bn_type=bn_type,
group=pw_group,
)
self.shuffle_type = shuffle_type
if shuffle_type is not None:
self.shuffle = ChannelShuffle(pw_group)
self.se4 = SEModule(output_depth) if se else nn.Sequential()
self.output_depth = output_depth
def forward(self, x):
y = self.pw(x)
if self.shuffle_type == "mid":
y = self.shuffle(y)
if self.upscale is not None:
y = self.upscale(y)
y = self.dw(y)
y = self.pwl(y)
if self.use_res_connect:
y += x
y = self.se4(y)
return y
skip = lambda C_in, C_out, stride, **kwargs: Identity(C_in, C_out, stride)
basic_block = lambda C_in, C_out, stride, **kwargs: CascadeConv3x3(C_in, C_out, stride)
# layer search 2
ir_k3_e1 = lambda C_in, C_out, stride, **kwargs: IRFBlock(C_in, C_out, 1, stride, kernel=3, **kwargs)
ir_k3_e3 = lambda C_in, C_out, stride, **kwargs: IRFBlock(C_in, C_out, 3, stride, kernel=3, **kwargs)
ir_k3_e6 = lambda C_in, C_out, stride, **kwargs: IRFBlock(C_in, C_out, 6, stride, kernel=3, **kwargs)
ir_k3_s4 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
C_in, C_out, 4, stride, kernel=3, shuffle_type="mid", pw_group=4, **kwargs
)
ir_k5_e1 = lambda C_in, C_out, stride, **kwargs: IRFBlock(C_in, C_out, 1, stride, kernel=5, **kwargs)
ir_k5_e3 = lambda C_in, C_out, stride, **kwargs: IRFBlock(C_in, C_out, 3, stride, kernel=5, **kwargs)
ir_k5_e6 = lambda C_in, C_out, stride, **kwargs: IRFBlock(C_in, C_out, 6, stride, kernel=5, **kwargs)
ir_k5_s4 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
C_in, C_out, 4, stride, kernel=5, shuffle_type="mid", pw_group=4, **kwargs
)
# layer search se
ir_k3_e1_se = lambda C_in, C_out, stride, **kwargs: IRFBlock(C_in, C_out, 1, stride, kernel=3, se=True, **kwargs)
ir_k3_e3_se = lambda C_in, C_out, stride, **kwargs: IRFBlock(C_in, C_out, 3, stride, kernel=3, se=True, **kwargs)
ir_k3_e6_se = lambda C_in, C_out, stride, **kwargs: IRFBlock(C_in, C_out, 6, stride, kernel=3, se=True, **kwargs)
ir_k3_s4_se = lambda C_in, C_out, stride, **kwargs: IRFBlock(
C_in, C_out, 4, stride, kernel=3, shuffle_type=mid, pw_group=4, se=True, **kwargs
)
ir_k5_e1_se = lambda C_in, C_out, stride, **kwargs: IRFBlock(C_in, C_out, 1, stride, kernel=5, se=True, **kwargs)
ir_k5_e3_se = lambda C_in, C_out, stride, **kwargs: IRFBlock(C_in, C_out, 3, stride, kernel=5, se=True, **kwargs)
ir_k5_e6_se = lambda C_in, C_out, stride, **kwargs: IRFBlock(C_in, C_out, 6, stride, kernel=5, se=True, **kwargs)
ir_k5_s4_se = lambda C_in, C_out, stride, **kwargs: IRFBlock(
C_in, C_out, 4, stride, kernel=5, shuffle_type="mid", pw_group=4, se=True, **kwargs
)
# layer search 3 (in addition to layer search 2)
ir_k3_s2 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
C_in, C_out, 1, stride, kernel=3, shuffle_type="mid", pw_group=2, **kwargs
)
ir_k5_s2 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
C_in, C_out, 1, stride, kernel=5, shuffle_type="mid", pw_group=2, **kwargs
)
ir_k3_s2_se = lambda C_in, C_out, stride, **kwargs: IRFBlock(
C_in, C_out, 1, stride, kernel=3, shuffle_type="mid", pw_group=2, se=True, **kwargs
)
ir_k5_s2_se = lambda C_in, C_out, stride, **kwargs: IRFBlock(
C_in, C_out, 1, stride, kernel=5, shuffle_type="mid", pw_group=2, se=True, **kwargs
)
# layer search 4 (in addition to layer search 3)
ir_k33_e1 = lambda C_in, C_out, stride, **kwargs: IRFBlock(C_in, C_out, 1, stride, kernel=3, cdw=True, **kwargs)
ir_k33_e3 = lambda C_in, C_out, stride, **kwargs: IRFBlock(C_in, C_out, 3, stride, kernel=3, cdw=True, **kwargs)
ir_k33_e6 = lambda C_in, C_out, stride, **kwargs: IRFBlock(C_in, C_out, 6, stride, kernel=3, cdw=True, **kwargs)
# layer search 5 (in addition to layer search 4)
ir_k7_e1 = lambda C_in, C_out, stride, **kwargs: IRFBlock(C_in, C_out, 1, stride, kernel=7, **kwargs)
ir_k7_e3 = lambda C_in, C_out, stride, **kwargs: IRFBlock(C_in, C_out, 3, stride, kernel=7, **kwargs)
ir_k7_e6 = lambda C_in, C_out, stride, **kwargs: IRFBlock(C_in, C_out, 6, stride, kernel=7, **kwargs)
ir_k7_sep_e1 = lambda C_in, C_out, stride, **kwargs: IRFBlock(C_in, C_out, 1, stride, kernel=7, cdw=True, **kwargs)
ir_k7_sep_e3 = lambda C_in, C_out, stride, **kwargs: IRFBlock(C_in, C_out, 3, stride, kernel=7, cdw=True, **kwargs)
ir_k7_sep_e6 = lambda C_in, C_out, stride, **kwargs: IRFBlock(C_in, C_out, 6, stride, kernel=7, cdw=True, **kwargs)