Spaces:
Sleeping
Sleeping
import torch.nn as nn | |
from .ops import * | |
class stem(nn.Module): | |
num_layer = 1 | |
def __init__(self, conv, inplanes, planes, stride=1, norm_layer=nn.BatchNorm2d): | |
super(stem, self).__init__() | |
self.conv1 = conv(inplanes, planes, stride) | |
self.bn1 = norm_layer(planes) | |
self.relu = nn.ReLU(inplace=True) | |
def forward(self, x): | |
out = self.conv1(x) | |
out = self.bn1(out) | |
out = self.relu(out) | |
return out | |
class basic(nn.Module): | |
expansion = 1 | |
num_layer = 2 | |
def __init__(self, conv, inplanes, planes, stride=1, midplanes=None, norm_layer=nn.BatchNorm2d): | |
super(basic, self).__init__() | |
midplanes = planes if midplanes is None else midplanes | |
self.conv1 = conv(inplanes, midplanes, stride) | |
self.bn1 = norm_layer(midplanes) | |
self.relu = nn.ReLU(inplace=True) | |
self.conv2 = conv(midplanes, planes) | |
self.bn2 = norm_layer(planes) | |
if stride != 1 or inplanes != planes * self.expansion: | |
self.downsample = nn.Sequential( | |
conv1x1(inplanes, planes, stride), | |
norm_layer(planes), | |
) | |
else: | |
self.downsample = None | |
def forward(self, x): | |
identity = x | |
out = self.conv1(x) | |
out = self.bn1(out) | |
out = self.relu(out) | |
out = self.conv2(out) | |
out = self.bn2(out) | |
if self.downsample is not None: | |
identity = self.downsample(x) | |
out += identity | |
out = self.relu(out) | |
return out | |
class bottleneck(nn.Module): | |
expansion = 4 | |
num_layer = 3 | |
def __init__(self, conv, inplanes, planes, stride=1, midplanes=None, norm_layer=nn.BatchNorm2d): | |
super(bottleneck, self).__init__() | |
midplanes = planes if midplanes is None else midplanes | |
self.conv1 = conv1x1(inplanes, midplanes) | |
self.bn1 = norm_layer(midplanes) | |
self.conv2 = conv(midplanes, midplanes, stride) | |
self.bn2 = norm_layer(midplanes) | |
self.conv3 = conv1x1(midplanes, planes * self.expansion) | |
self.bn3 = norm_layer(planes * self.expansion) | |
self.relu = nn.ReLU(inplace=True) | |
if stride != 1 or inplanes != planes * self.expansion: | |
self.downsample = nn.Sequential( | |
conv1x1(inplanes, planes * self.expansion, stride), | |
norm_layer(planes * self.expansion), | |
) | |
else: | |
self.downsample = None | |
def forward(self, x): | |
identity = x | |
out = self.conv1(x) | |
out = self.bn1(out) | |
out = self.relu(out) | |
out = self.conv2(out) | |
out = self.bn2(out) | |
out = self.relu(out) | |
out = self.conv3(out) | |
out = self.bn3(out) | |
if self.downsample is not None: | |
identity = self.downsample(x) | |
out += identity | |
out = self.relu(out) | |
return out | |
class invert(nn.Module): | |
def __init__(self, conv, inp, oup, stride=1, expand_ratio=1, norm_layer=nn.BatchNorm2d): | |
super(invert, self).__init__() | |
self.stride = stride | |
assert stride in [1, 2] | |
hidden_dim = round(inp * expand_ratio) | |
self.use_res_connect = self.stride == 1 and inp == oup | |
if expand_ratio == 1: | |
self.conv = nn.Sequential( | |
# dw | |
conv(hidden_dim, hidden_dim, stride), | |
norm_layer(hidden_dim), | |
nn.ReLU6(inplace=True), | |
# pw-linear | |
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), | |
norm_layer(oup), | |
) | |
else: | |
self.conv = nn.Sequential( | |
# pw | |
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), | |
norm_layer(hidden_dim), | |
nn.ReLU6(inplace=True), | |
# dw | |
conv(hidden_dim, hidden_dim, stride), | |
norm_layer(hidden_dim), | |
nn.ReLU6(inplace=True), | |
# pw-linear | |
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), | |
norm_layer(oup), | |
) | |
def forward(self, x): | |
if self.use_res_connect: | |
return x + self.conv(x) | |
else: | |
return self.conv(x) | |
invert2 = lambda op, inp, outp, stride, **kwargs: invert(op, inp, outp, stride, expand_ratio=2, **kwargs) | |
invert3 = lambda op, inp, outp, stride, **kwargs: invert(op, inp, outp, stride, expand_ratio=3, **kwargs) | |
invert4 = lambda op, inp, outp, stride, **kwargs: invert(op, inp, outp, stride, expand_ratio=4, **kwargs) | |
invert6 = lambda op, inp, outp, stride, **kwargs: invert(op, inp, outp, stride, expand_ratio=6, **kwargs) | |
def channel_shuffle(x, groups): | |
batchsize, num_channels, height, width = x.data.size() | |
channels_per_group = num_channels // groups | |
# reshape | |
x = x.view(batchsize, groups, channels_per_group, height, width) | |
x = torch.transpose(x, 1, 2).contiguous() | |
# flatten | |
x = x.view(batchsize, -1, height, width) | |
return x | |
class shuffle(nn.Module): | |
expansion = 1 | |
num_layer = 3 | |
def __init__(self, conv, inplanes, outplanes, stride=1, midplanes=None, norm_layer=nn.BatchNorm2d): | |
super(shuffle, self).__init__() | |
inplanes = inplanes // 2 if stride == 1 else inplanes | |
midplanes = outplanes // 2 if midplanes is None else midplanes | |
rightoutplanes = outplanes - inplanes | |
if stride == 2: | |
self.left_branch = nn.Sequential( | |
# dw | |
conv(inplanes, inplanes, stride), | |
norm_layer(inplanes), | |
# pw-linear | |
conv1x1(inplanes, inplanes), | |
norm_layer(inplanes), | |
nn.ReLU(inplace=True), | |
) | |
self.right_branch = nn.Sequential( | |
# pw | |
conv1x1(inplanes, midplanes), | |
norm_layer(midplanes), | |
nn.ReLU(inplace=True), | |
# dw | |
conv(midplanes, midplanes, stride), | |
norm_layer(midplanes), | |
# pw-linear | |
conv1x1(midplanes, rightoutplanes), | |
norm_layer(rightoutplanes), | |
nn.ReLU(inplace=True), | |
) | |
self.reduce = stride == 2 | |
def forward(self, x): | |
if self.reduce: | |
out = torch.cat((self.left_branch(x), self.right_branch(x)), 1) | |
else: | |
x1 = x[:, : (x.shape[1] // 2), :, :] | |
x2 = x[:, (x.shape[1] // 2) :, :, :] | |
out = torch.cat((x1, self.right_branch(x2)), 1) | |
return channel_shuffle(out, 2) | |
class shufflex(nn.Module): | |
expansion = 1 | |
num_layer = 3 | |
def __init__(self, conv, inplanes, outplanes, stride=1, midplanes=None, norm_layer=nn.BatchNorm2d): | |
super(shufflex, self).__init__() | |
inplanes = inplanes // 2 if stride == 1 else inplanes | |
midplanes = outplanes // 2 if midplanes is None else midplanes | |
rightoutplanes = outplanes - inplanes | |
if stride == 2: | |
self.left_branch = nn.Sequential( | |
# dw | |
conv(inplanes, inplanes, stride), | |
norm_layer(inplanes), | |
# pw-linear | |
conv1x1(inplanes, inplanes), | |
norm_layer(inplanes), | |
nn.ReLU(inplace=True), | |
) | |
self.right_branch = nn.Sequential( | |
# dw | |
conv(inplanes, inplanes, stride), | |
norm_layer(inplanes), | |
# pw-linear | |
conv1x1(inplanes, midplanes), | |
norm_layer(midplanes), | |
nn.ReLU(inplace=True), | |
# dw | |
conv(midplanes, midplanes, 1), | |
norm_layer(midplanes), | |
# pw-linear | |
conv1x1(midplanes, midplanes), | |
norm_layer(midplanes), | |
nn.ReLU(inplace=True), | |
# dw | |
conv(midplanes, midplanes, 1), | |
norm_layer(midplanes), | |
# pw-linear | |
conv1x1(midplanes, rightoutplanes), | |
norm_layer(rightoutplanes), | |
nn.ReLU(inplace=True), | |
) | |
self.reduce = stride == 2 | |
def forward(self, x): | |
if self.reduce: | |
out = torch.cat((self.left_branch(x), self.right_branch(x)), 1) | |
else: | |
x1 = x[:, : (x.shape[1] // 2), :, :] | |
x2 = x[:, (x.shape[1] // 2) :, :, :] | |
out = torch.cat((x1, self.right_branch(x2)), 1) | |
return channel_shuffle(out, 2) | |