zdou0830's picture
desco
749745d
raw
history blame
8.71 kB
import torch.nn as nn
from .ops import *
class stem(nn.Module):
num_layer = 1
def __init__(self, conv, inplanes, planes, stride=1, norm_layer=nn.BatchNorm2d):
super(stem, self).__init__()
self.conv1 = conv(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
return out
class basic(nn.Module):
expansion = 1
num_layer = 2
def __init__(self, conv, inplanes, planes, stride=1, midplanes=None, norm_layer=nn.BatchNorm2d):
super(basic, self).__init__()
midplanes = planes if midplanes is None else midplanes
self.conv1 = conv(inplanes, midplanes, stride)
self.bn1 = norm_layer(midplanes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv(midplanes, planes)
self.bn2 = norm_layer(planes)
if stride != 1 or inplanes != planes * self.expansion:
self.downsample = nn.Sequential(
conv1x1(inplanes, planes, stride),
norm_layer(planes),
)
else:
self.downsample = None
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class bottleneck(nn.Module):
expansion = 4
num_layer = 3
def __init__(self, conv, inplanes, planes, stride=1, midplanes=None, norm_layer=nn.BatchNorm2d):
super(bottleneck, self).__init__()
midplanes = planes if midplanes is None else midplanes
self.conv1 = conv1x1(inplanes, midplanes)
self.bn1 = norm_layer(midplanes)
self.conv2 = conv(midplanes, midplanes, stride)
self.bn2 = norm_layer(midplanes)
self.conv3 = conv1x1(midplanes, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
if stride != 1 or inplanes != planes * self.expansion:
self.downsample = nn.Sequential(
conv1x1(inplanes, planes * self.expansion, stride),
norm_layer(planes * self.expansion),
)
else:
self.downsample = None
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class invert(nn.Module):
def __init__(self, conv, inp, oup, stride=1, expand_ratio=1, norm_layer=nn.BatchNorm2d):
super(invert, self).__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = round(inp * expand_ratio)
self.use_res_connect = self.stride == 1 and inp == oup
if expand_ratio == 1:
self.conv = nn.Sequential(
# dw
conv(hidden_dim, hidden_dim, stride),
norm_layer(hidden_dim),
nn.ReLU6(inplace=True),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
norm_layer(oup),
)
else:
self.conv = nn.Sequential(
# pw
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
norm_layer(hidden_dim),
nn.ReLU6(inplace=True),
# dw
conv(hidden_dim, hidden_dim, stride),
norm_layer(hidden_dim),
nn.ReLU6(inplace=True),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
norm_layer(oup),
)
def forward(self, x):
if self.use_res_connect:
return x + self.conv(x)
else:
return self.conv(x)
invert2 = lambda op, inp, outp, stride, **kwargs: invert(op, inp, outp, stride, expand_ratio=2, **kwargs)
invert3 = lambda op, inp, outp, stride, **kwargs: invert(op, inp, outp, stride, expand_ratio=3, **kwargs)
invert4 = lambda op, inp, outp, stride, **kwargs: invert(op, inp, outp, stride, expand_ratio=4, **kwargs)
invert6 = lambda op, inp, outp, stride, **kwargs: invert(op, inp, outp, stride, expand_ratio=6, **kwargs)
def channel_shuffle(x, groups):
batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups
# reshape
x = x.view(batchsize, groups, channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batchsize, -1, height, width)
return x
class shuffle(nn.Module):
expansion = 1
num_layer = 3
def __init__(self, conv, inplanes, outplanes, stride=1, midplanes=None, norm_layer=nn.BatchNorm2d):
super(shuffle, self).__init__()
inplanes = inplanes // 2 if stride == 1 else inplanes
midplanes = outplanes // 2 if midplanes is None else midplanes
rightoutplanes = outplanes - inplanes
if stride == 2:
self.left_branch = nn.Sequential(
# dw
conv(inplanes, inplanes, stride),
norm_layer(inplanes),
# pw-linear
conv1x1(inplanes, inplanes),
norm_layer(inplanes),
nn.ReLU(inplace=True),
)
self.right_branch = nn.Sequential(
# pw
conv1x1(inplanes, midplanes),
norm_layer(midplanes),
nn.ReLU(inplace=True),
# dw
conv(midplanes, midplanes, stride),
norm_layer(midplanes),
# pw-linear
conv1x1(midplanes, rightoutplanes),
norm_layer(rightoutplanes),
nn.ReLU(inplace=True),
)
self.reduce = stride == 2
def forward(self, x):
if self.reduce:
out = torch.cat((self.left_branch(x), self.right_branch(x)), 1)
else:
x1 = x[:, : (x.shape[1] // 2), :, :]
x2 = x[:, (x.shape[1] // 2) :, :, :]
out = torch.cat((x1, self.right_branch(x2)), 1)
return channel_shuffle(out, 2)
class shufflex(nn.Module):
expansion = 1
num_layer = 3
def __init__(self, conv, inplanes, outplanes, stride=1, midplanes=None, norm_layer=nn.BatchNorm2d):
super(shufflex, self).__init__()
inplanes = inplanes // 2 if stride == 1 else inplanes
midplanes = outplanes // 2 if midplanes is None else midplanes
rightoutplanes = outplanes - inplanes
if stride == 2:
self.left_branch = nn.Sequential(
# dw
conv(inplanes, inplanes, stride),
norm_layer(inplanes),
# pw-linear
conv1x1(inplanes, inplanes),
norm_layer(inplanes),
nn.ReLU(inplace=True),
)
self.right_branch = nn.Sequential(
# dw
conv(inplanes, inplanes, stride),
norm_layer(inplanes),
# pw-linear
conv1x1(inplanes, midplanes),
norm_layer(midplanes),
nn.ReLU(inplace=True),
# dw
conv(midplanes, midplanes, 1),
norm_layer(midplanes),
# pw-linear
conv1x1(midplanes, midplanes),
norm_layer(midplanes),
nn.ReLU(inplace=True),
# dw
conv(midplanes, midplanes, 1),
norm_layer(midplanes),
# pw-linear
conv1x1(midplanes, rightoutplanes),
norm_layer(rightoutplanes),
nn.ReLU(inplace=True),
)
self.reduce = stride == 2
def forward(self, x):
if self.reduce:
out = torch.cat((self.left_branch(x), self.right_branch(x)), 1)
else:
x1 = x[:, : (x.shape[1] // 2), :, :]
x2 = x[:, (x.shape[1] // 2) :, :, :]
out = torch.cat((x1, self.right_branch(x2)), 1)
return channel_shuffle(out, 2)