zdou0830's picture
desco
749745d
raw
history blame
No virus
17.6 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from .deform_conv import DeformConv2d
def add_conv(in_ch, out_ch, ksize, stride, leaky=True):
"""
Add a conv2d / batchnorm / leaky ReLU block.
Args:
in_ch (int): number of input channels of the convolution layer.
out_ch (int): number of output channels of the convolution layer.
ksize (int): kernel size of the convolution layer.
stride (int): stride of the convolution layer.
Returns:
stage (Sequential) : Sequential layers composing a convolution block.
"""
stage = nn.Sequential()
pad = (ksize - 1) // 2
stage.add_module(
"conv",
nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=ksize, stride=stride, padding=pad, bias=False),
)
stage.add_module("batch_norm", nn.BatchNorm2d(out_ch))
if leaky:
stage.add_module("leaky", nn.LeakyReLU(0.1))
else:
stage.add_module("relu6", nn.ReLU6(inplace=True))
return stage
class upsample(nn.Module):
__constants__ = ["size", "scale_factor", "mode", "align_corners", "name"]
def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
super(upsample, self).__init__()
self.name = type(self).__name__
self.size = size
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
def forward(self, input):
return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners)
def extra_repr(self):
if self.scale_factor is not None:
info = "scale_factor=" + str(self.scale_factor)
else:
info = "size=" + str(self.size)
info += ", mode=" + self.mode
return info
class SPPLayer(nn.Module):
def __init__(self):
super(SPPLayer, self).__init__()
def forward(self, x):
x_1 = x
x_2 = F.max_pool2d(x, 5, stride=1, padding=2)
x_3 = F.max_pool2d(x, 9, stride=1, padding=4)
x_4 = F.max_pool2d(x, 13, stride=1, padding=6)
out = torch.cat((x_1, x_2, x_3, x_4), dim=1)
return out
class DropBlock(nn.Module):
def __init__(self, block_size=7, keep_prob=0.9):
super(DropBlock, self).__init__()
self.block_size = block_size
self.keep_prob = keep_prob
self.gamma = None
self.kernel_size = (block_size, block_size)
self.stride = (1, 1)
self.padding = (block_size // 2, block_size // 2)
def reset(self, block_size, keep_prob):
self.block_size = block_size
self.keep_prob = keep_prob
self.gamma = None
self.kernel_size = (block_size, block_size)
self.stride = (1, 1)
self.padding = (block_size // 2, block_size // 2)
def calculate_gamma(self, x):
return (
(1 - self.keep_prob) * x.shape[-1] ** 2 / (self.block_size**2 * (x.shape[-1] - self.block_size + 1) ** 2)
)
def forward(self, x):
if not self.training or self.keep_prob == 1: # set keep_prob=1 to turn off dropblock
return x
if self.gamma is None:
self.gamma = self.calculate_gamma(x)
if x.type() == "torch.cuda.HalfTensor": # TODO: not fully support for FP16 now
FP16 = True
x = x.float()
else:
FP16 = False
p = torch.ones_like(x) * (self.gamma)
mask = 1 - torch.nn.functional.max_pool2d(torch.bernoulli(p), self.kernel_size, self.stride, self.padding)
out = mask * x * (mask.numel() / mask.sum())
if FP16:
out = out.half()
return out
class resblock(nn.Module):
"""
Sequential residual blocks each of which consists of \
two convolution layers.
Args:
ch (int): number of input and output channels.
nblocks (int): number of residual blocks.
shortcut (bool): if True, residual tensor addition is enabled.
"""
def __init__(self, ch, nblocks=1, shortcut=True):
super().__init__()
self.shortcut = shortcut
self.module_list = nn.ModuleList()
for i in range(nblocks):
resblock_one = nn.ModuleList()
resblock_one.append(add_conv(ch, ch // 2, 1, 1))
resblock_one.append(add_conv(ch // 2, ch, 3, 1))
self.module_list.append(resblock_one)
def forward(self, x):
for module in self.module_list:
h = x
for res in module:
h = res(h)
x = x + h if self.shortcut else h
return x
class RFBblock(nn.Module):
def __init__(self, in_ch, residual=False):
super(RFBblock, self).__init__()
inter_c = in_ch // 4
self.branch_0 = nn.Sequential(
nn.Conv2d(in_channels=in_ch, out_channels=inter_c, kernel_size=1, stride=1, padding=0),
)
self.branch_1 = nn.Sequential(
nn.Conv2d(in_channels=in_ch, out_channels=inter_c, kernel_size=1, stride=1, padding=0),
nn.Conv2d(in_channels=inter_c, out_channels=inter_c, kernel_size=3, stride=1, padding=1),
)
self.branch_2 = nn.Sequential(
nn.Conv2d(in_channels=in_ch, out_channels=inter_c, kernel_size=1, stride=1, padding=0),
nn.Conv2d(in_channels=inter_c, out_channels=inter_c, kernel_size=3, stride=1, padding=1),
nn.Conv2d(in_channels=inter_c, out_channels=inter_c, kernel_size=3, stride=1, dilation=2, padding=2),
)
self.branch_3 = nn.Sequential(
nn.Conv2d(in_channels=in_ch, out_channels=inter_c, kernel_size=1, stride=1, padding=0),
nn.Conv2d(in_channels=inter_c, out_channels=inter_c, kernel_size=5, stride=1, padding=2),
nn.Conv2d(in_channels=inter_c, out_channels=inter_c, kernel_size=3, stride=1, dilation=3, padding=3),
)
self.residual = residual
def forward(self, x):
x_0 = self.branch_0(x)
x_1 = self.branch_1(x)
x_2 = self.branch_2(x)
x_3 = self.branch_3(x)
out = torch.cat((x_0, x_1, x_2, x_3), 1)
if self.residual:
out += x
return out
class FeatureAdaption(nn.Module):
def __init__(self, in_ch, out_ch, n_anchors, rfb=False, sep=False):
super(FeatureAdaption, self).__init__()
if sep:
self.sep = True
else:
self.sep = False
self.conv_offset = nn.Conv2d(
in_channels=2 * n_anchors,
out_channels=2 * 9 * n_anchors,
groups=n_anchors,
kernel_size=1,
stride=1,
padding=0,
)
self.dconv = DeformConv2d(
in_channels=in_ch, out_channels=out_ch, kernel_size=3, stride=1, padding=1, deformable_groups=n_anchors
)
self.rfb = None
if rfb:
self.rfb = RFBblock(out_ch)
def forward(self, input, wh_pred):
# The RFB block is added behind FeatureAdaption
# For mobilenet, we currently don't support rfb and FeatureAdaption
if self.sep:
return input
if self.rfb is not None:
input = self.rfb(input)
wh_pred_new = wh_pred.detach()
offset = self.conv_offset(wh_pred_new)
out = self.dconv(input, offset)
return out
class ASFFmobile(nn.Module):
def __init__(self, level, rfb=False, vis=False):
super(ASFFmobile, self).__init__()
self.level = level
self.dim = [512, 256, 128]
self.inter_dim = self.dim[self.level]
if level == 0:
self.stride_level_1 = add_conv(256, self.inter_dim, 3, 2, leaky=False)
self.stride_level_2 = add_conv(128, self.inter_dim, 3, 2, leaky=False)
self.expand = add_conv(self.inter_dim, 1024, 3, 1, leaky=False)
elif level == 1:
self.compress_level_0 = add_conv(512, self.inter_dim, 1, 1, leaky=False)
self.stride_level_2 = add_conv(128, self.inter_dim, 3, 2, leaky=False)
self.expand = add_conv(self.inter_dim, 512, 3, 1, leaky=False)
elif level == 2:
self.compress_level_0 = add_conv(512, self.inter_dim, 1, 1, leaky=False)
self.compress_level_1 = add_conv(256, self.inter_dim, 1, 1, leaky=False)
self.expand = add_conv(self.inter_dim, 256, 3, 1, leaky=False)
compress_c = 8 if rfb else 16 # when adding rfb, we use half number of channels to save memory
self.weight_level_0 = add_conv(self.inter_dim, compress_c, 1, 1, leaky=False)
self.weight_level_1 = add_conv(self.inter_dim, compress_c, 1, 1, leaky=False)
self.weight_level_2 = add_conv(self.inter_dim, compress_c, 1, 1, leaky=False)
self.weight_levels = nn.Conv2d(compress_c * 3, 3, kernel_size=1, stride=1, padding=0)
self.vis = vis
def forward(self, x_level_0, x_level_1, x_level_2):
if self.level == 0:
level_0_resized = x_level_0
level_1_resized = self.stride_level_1(x_level_1)
level_2_downsampled_inter = F.max_pool2d(x_level_2, 3, stride=2, padding=1)
level_2_resized = self.stride_level_2(level_2_downsampled_inter)
elif self.level == 1:
level_0_compressed = self.compress_level_0(x_level_0)
level_0_resized = F.interpolate(level_0_compressed, scale_factor=2, mode="nearest")
level_1_resized = x_level_1
level_2_resized = self.stride_level_2(x_level_2)
elif self.level == 2:
level_0_compressed = self.compress_level_0(x_level_0)
level_0_resized = F.interpolate(level_0_compressed, scale_factor=4, mode="nearest")
level_1_compressed = self.compress_level_1(x_level_1)
level_1_resized = F.interpolate(level_1_compressed, scale_factor=2, mode="nearest")
level_2_resized = x_level_2
level_0_weight_v = self.weight_level_0(level_0_resized)
level_1_weight_v = self.weight_level_1(level_1_resized)
level_2_weight_v = self.weight_level_2(level_2_resized)
levels_weight_v = torch.cat((level_0_weight_v, level_1_weight_v, level_2_weight_v), 1)
levels_weight = self.weight_levels(levels_weight_v)
levels_weight = F.softmax(levels_weight, dim=1)
fused_out_reduced = (
level_0_resized * levels_weight[:, 0:1, :, :]
+ level_1_resized * levels_weight[:, 1:2, :, :]
+ level_2_resized * levels_weight[:, 2:, :, :]
)
out = self.expand(fused_out_reduced)
if self.vis:
return out, levels_weight, fused_out_reduced.sum(dim=1)
else:
return out
class ASFF(nn.Module):
def __init__(self, level, rfb=False, vis=False):
super(ASFF, self).__init__()
self.level = level
self.dim = [512, 256, 256]
self.inter_dim = self.dim[self.level]
if level == 0:
self.stride_level_1 = add_conv(256, self.inter_dim, 3, 2)
self.stride_level_2 = add_conv(256, self.inter_dim, 3, 2)
self.expand = add_conv(self.inter_dim, 1024, 3, 1)
elif level == 1:
self.compress_level_0 = add_conv(512, self.inter_dim, 1, 1)
self.stride_level_2 = add_conv(256, self.inter_dim, 3, 2)
self.expand = add_conv(self.inter_dim, 512, 3, 1)
elif level == 2:
self.compress_level_0 = add_conv(512, self.inter_dim, 1, 1)
self.expand = add_conv(self.inter_dim, 256, 3, 1)
compress_c = 8 if rfb else 16 # when adding rfb, we use half number of channels to save memory
self.weight_level_0 = add_conv(self.inter_dim, compress_c, 1, 1)
self.weight_level_1 = add_conv(self.inter_dim, compress_c, 1, 1)
self.weight_level_2 = add_conv(self.inter_dim, compress_c, 1, 1)
self.weight_levels = nn.Conv2d(compress_c * 3, 3, kernel_size=1, stride=1, padding=0)
self.vis = vis
def forward(self, x_level_0, x_level_1, x_level_2):
if self.level == 0:
level_0_resized = x_level_0
level_1_resized = self.stride_level_1(x_level_1)
level_2_downsampled_inter = F.max_pool2d(x_level_2, 3, stride=2, padding=1)
level_2_resized = self.stride_level_2(level_2_downsampled_inter)
elif self.level == 1:
level_0_compressed = self.compress_level_0(x_level_0)
level_0_resized = F.interpolate(level_0_compressed, scale_factor=2, mode="nearest")
level_1_resized = x_level_1
level_2_resized = self.stride_level_2(x_level_2)
elif self.level == 2:
level_0_compressed = self.compress_level_0(x_level_0)
level_0_resized = F.interpolate(level_0_compressed, scale_factor=4, mode="nearest")
level_1_resized = F.interpolate(x_level_1, scale_factor=2, mode="nearest")
level_2_resized = x_level_2
level_0_weight_v = self.weight_level_0(level_0_resized)
level_1_weight_v = self.weight_level_1(level_1_resized)
level_2_weight_v = self.weight_level_2(level_2_resized)
levels_weight_v = torch.cat((level_0_weight_v, level_1_weight_v, level_2_weight_v), 1)
levels_weight = self.weight_levels(levels_weight_v)
levels_weight = F.softmax(levels_weight, dim=1)
fused_out_reduced = (
level_0_resized * levels_weight[:, 0:1, :, :]
+ level_1_resized * levels_weight[:, 1:2, :, :]
+ level_2_resized * levels_weight[:, 2:, :, :]
)
out = self.expand(fused_out_reduced)
if self.vis:
return out, levels_weight, fused_out_reduced.sum(dim=1)
else:
return out
def make_divisible(v, divisor, min_value=None):
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
:param v:
:param divisor:
:param min_value:
:return:
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class ConvBNReLU(nn.Sequential):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
padding = (kernel_size - 1) // 2
super(ConvBNReLU, self).__init__(
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
nn.BatchNorm2d(out_planes),
nn.ReLU6(inplace=True),
)
def add_sepconv(in_ch, out_ch, ksize, stride):
stage = nn.Sequential()
pad = (ksize - 1) // 2
stage.add_module(
"sepconv",
nn.Conv2d(
in_channels=in_ch,
out_channels=in_ch,
kernel_size=ksize,
stride=stride,
padding=pad,
groups=in_ch,
bias=False,
),
)
stage.add_module("sepbn", nn.BatchNorm2d(in_ch))
stage.add_module("seprelu6", nn.ReLU6(inplace=True))
stage.add_module("ptconv", nn.Conv2d(in_ch, out_ch, 1, 1, 0, bias=False))
stage.add_module("ptbn", nn.BatchNorm2d(out_ch))
stage.add_module("ptrelu6", nn.ReLU6(inplace=True))
return stage
class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, expand_ratio):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = int(round(inp * expand_ratio))
self.use_res_connect = self.stride == 1 and inp == oup
layers = []
if expand_ratio != 1:
# pw
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
layers.extend(
[
# dw
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
]
)
self.conv = nn.Sequential(*layers)
def forward(self, x):
if self.use_res_connect:
return x + self.conv(x)
else:
return self.conv(x)
class ressepblock(nn.Module):
def __init__(self, ch, out_ch, in_ch=None, shortcut=True):
super().__init__()
self.shortcut = shortcut
self.module_list = nn.ModuleList()
in_ch = ch // 2 if in_ch == None else in_ch
resblock_one = nn.ModuleList()
resblock_one.append(add_conv(ch, in_ch, 1, 1, leaky=False))
resblock_one.append(add_conv(in_ch, out_ch, 3, 1, leaky=False))
self.module_list.append(resblock_one)
def forward(self, x):
for module in self.module_list:
h = x
for res in module:
h = res(h)
x = x + h if self.shortcut else h
return x