Spaces:
Running
Running
from collections import OrderedDict | |
import torch | |
import torch.nn as nn | |
class IdentityLayer(nn.Module): | |
def __init__(self): | |
super(IdentityLayer, self).__init__() | |
def forward(self, x): | |
return x | |
def is_zero_layer(): | |
return False | |
class ZeroLayer(nn.Module): | |
def __init__(self, stride): | |
super(ZeroLayer, self).__init__() | |
self.stride = stride | |
def forward(self, x): | |
n, c, h, w = x.shape | |
h //= self.stride[0] | |
w //= self.stride[1] | |
device = x.device | |
padding = torch.zeros(n, c, h, w, device=device, requires_grad=False) | |
return padding | |
def is_zero_layer(): | |
return True | |
def get_flops(self, x): | |
return 0, self.forward(x) | |
def get_same_padding(kernel_size): | |
if isinstance(kernel_size, tuple): | |
assert len(kernel_size) == 2, 'invalid kernel size: %s' % kernel_size | |
p1 = get_same_padding(kernel_size[0]) | |
p2 = get_same_padding(kernel_size[1]) | |
return p1, p2 | |
assert isinstance(kernel_size, | |
int), 'kernel size should be either `int` or `tuple`' | |
assert kernel_size % 2 > 0, 'kernel size should be odd number' | |
return kernel_size // 2 | |
class MBInvertedConvLayer(nn.Module): | |
def __init__(self, | |
in_channels, | |
out_channels, | |
kernel_size=3, | |
stride=(1, 1), | |
expand_ratio=6, | |
mid_channels=None): | |
super(MBInvertedConvLayer, self).__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.kernel_size = kernel_size | |
self.stride = stride | |
self.expand_ratio = expand_ratio | |
self.mid_channels = mid_channels | |
feature_dim = round( | |
self.in_channels * | |
self.expand_ratio) if mid_channels is None else mid_channels | |
if self.expand_ratio == 1: | |
self.inverted_bottleneck = None | |
else: | |
self.inverted_bottleneck = nn.Sequential( | |
OrderedDict([ | |
('conv', | |
nn.Conv2d(self.in_channels, | |
feature_dim, | |
1, | |
1, | |
0, | |
bias=False)), | |
('bn', nn.BatchNorm2d(feature_dim)), | |
('act', nn.ReLU6(inplace=True)), | |
])) | |
pad = get_same_padding(self.kernel_size) | |
self.depth_conv = nn.Sequential( | |
OrderedDict([ | |
('conv', | |
nn.Conv2d(feature_dim, | |
feature_dim, | |
kernel_size, | |
stride, | |
pad, | |
groups=feature_dim, | |
bias=False)), | |
('bn', nn.BatchNorm2d(feature_dim)), | |
('act', nn.ReLU6(inplace=True)), | |
])) | |
self.point_conv = nn.Sequential( | |
OrderedDict([ | |
('conv', | |
nn.Conv2d(feature_dim, out_channels, 1, 1, 0, bias=False)), | |
('bn', nn.BatchNorm2d(out_channels)), | |
])) | |
def forward(self, x): | |
if self.inverted_bottleneck: | |
x = self.inverted_bottleneck(x) | |
x = self.depth_conv(x) | |
x = self.point_conv(x) | |
return x | |
def is_zero_layer(): | |
return False | |
def conv_func_by_name(name): | |
name2ops = { | |
'Identity': lambda in_C, out_C, S: IdentityLayer(), | |
'Zero': lambda in_C, out_C, S: ZeroLayer(stride=S), | |
} | |
name2ops.update({ | |
'3x3_MBConv1': | |
lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 1), | |
'3x3_MBConv2': | |
lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 2), | |
'3x3_MBConv3': | |
lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 3), | |
'3x3_MBConv4': | |
lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 4), | |
'3x3_MBConv5': | |
lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 5), | |
'3x3_MBConv6': | |
lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 6), | |
####################################################################################### | |
'5x5_MBConv1': | |
lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 1), | |
'5x5_MBConv2': | |
lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 2), | |
'5x5_MBConv3': | |
lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 3), | |
'5x5_MBConv4': | |
lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 4), | |
'5x5_MBConv5': | |
lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 5), | |
'5x5_MBConv6': | |
lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 6), | |
####################################################################################### | |
'7x7_MBConv1': | |
lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 1), | |
'7x7_MBConv2': | |
lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 2), | |
'7x7_MBConv3': | |
lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 3), | |
'7x7_MBConv4': | |
lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 4), | |
'7x7_MBConv5': | |
lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 5), | |
'7x7_MBConv6': | |
lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 6), | |
}) | |
return name2ops[name] | |
def build_candidate_ops(candidate_ops, in_channels, out_channels, stride, | |
ops_order): | |
if candidate_ops is None: | |
raise ValueError('please specify a candidate set') | |
name2ops = { | |
'Identity': | |
lambda in_C, out_C, S: IdentityLayer(in_C, out_C, ops_order=ops_order), | |
'Zero': | |
lambda in_C, out_C, S: ZeroLayer(stride=S), | |
} | |
# add MBConv layers | |
name2ops.update({ | |
'3x3_MBConv1': | |
lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 1), | |
'3x3_MBConv2': | |
lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 2), | |
'3x3_MBConv3': | |
lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 3), | |
'3x3_MBConv4': | |
lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 4), | |
'3x3_MBConv5': | |
lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 5), | |
'3x3_MBConv6': | |
lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 3, S, 6), | |
####################################################################################### | |
'5x5_MBConv1': | |
lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 1), | |
'5x5_MBConv2': | |
lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 2), | |
'5x5_MBConv3': | |
lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 3), | |
'5x5_MBConv4': | |
lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 4), | |
'5x5_MBConv5': | |
lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 5), | |
'5x5_MBConv6': | |
lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 5, S, 6), | |
####################################################################################### | |
'7x7_MBConv1': | |
lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 1), | |
'7x7_MBConv2': | |
lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 2), | |
'7x7_MBConv3': | |
lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 3), | |
'7x7_MBConv4': | |
lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 4), | |
'7x7_MBConv5': | |
lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 5), | |
'7x7_MBConv6': | |
lambda in_C, out_C, S: MBInvertedConvLayer(in_C, out_C, 7, S, 6), | |
}) | |
return [ | |
name2ops[name](in_channels, out_channels, stride) | |
for name in candidate_ops | |
] | |
class MobileInvertedResidualBlock(nn.Module): | |
def __init__(self, mobile_inverted_conv, shortcut): | |
super(MobileInvertedResidualBlock, self).__init__() | |
self.mobile_inverted_conv = mobile_inverted_conv | |
self.shortcut = shortcut | |
def forward(self, x): | |
if self.mobile_inverted_conv.is_zero_layer(): | |
res = x | |
elif self.shortcut is None or self.shortcut.is_zero_layer(): | |
res = self.mobile_inverted_conv(x) | |
else: | |
conv_x = self.mobile_inverted_conv(x) | |
skip_x = self.shortcut(x) | |
res = skip_x + conv_x | |
return res | |
class AutoSTREncoder(nn.Module): | |
def __init__(self, | |
in_channels, | |
out_dim=256, | |
with_lstm=True, | |
stride_stages='[(2, 2), (2, 2), (2, 1), (2, 1), (2, 1)]', | |
n_cell_stages=[3, 3, 3, 3, 3], | |
conv_op_ids=[5, 5, 5, 5, 5, 5, 5, 6, 6, 5, 4, 3, 4, 6, 6], | |
**kwargs): | |
super().__init__() | |
self.first_conv = nn.Sequential( | |
nn.Conv2d(in_channels, | |
32, | |
kernel_size=(3, 3), | |
stride=1, | |
padding=1, | |
bias=False), nn.BatchNorm2d(32), nn.ReLU(inplace=True)) | |
stride_stages = eval(stride_stages) | |
width_stages = [32, 64, 128, 256, 512] | |
conv_candidates = [ | |
'5x5_MBConv1', '5x5_MBConv3', '5x5_MBConv6', '3x3_MBConv1', | |
'3x3_MBConv3', '3x3_MBConv6', 'Zero' | |
] | |
assert len(conv_op_ids) == sum(n_cell_stages) | |
blocks = [] | |
input_channel = 32 | |
for width, n_cell, s in zip(width_stages, n_cell_stages, | |
stride_stages): | |
for i in range(n_cell): | |
if i == 0: | |
stride = s | |
else: | |
stride = (1, 1) | |
block_i = len(blocks) | |
conv_op = conv_func_by_name( | |
conv_candidates[conv_op_ids[block_i]])(input_channel, | |
width, stride) | |
if stride == (1, 1) and input_channel == width: | |
shortcut = IdentityLayer() | |
else: | |
shortcut = None | |
inverted_residual_block = MobileInvertedResidualBlock( | |
conv_op, shortcut) | |
blocks.append(inverted_residual_block) | |
input_channel = width | |
self.out_channels = input_channel | |
self.blocks = nn.ModuleList(blocks) | |
# with_lstm = False | |
self.with_lstm = with_lstm | |
if with_lstm: | |
self.rnn = nn.LSTM(input_channel, | |
out_dim // 2, | |
bidirectional=True, | |
num_layers=2, | |
batch_first=True) | |
self.out_channels = out_dim | |
for m in self.modules(): | |
if isinstance(m, nn.Conv2d): | |
nn.init.kaiming_normal_(m.weight, | |
mode='fan_out', | |
nonlinearity='relu') | |
elif isinstance(m, nn.BatchNorm2d): | |
nn.init.constant_(m.weight, 1) | |
nn.init.constant_(m.bias, 0) | |
def forward(self, x): | |
x = self.first_conv(x) | |
for block in self.blocks: | |
x = block(x) | |
cnn_feat = x.squeeze(dim=2) | |
cnn_feat = cnn_feat.transpose(2, 1) | |
if self.with_lstm: | |
rnn_feat, _ = self.rnn(cnn_feat) | |
return rnn_feat | |
else: | |
return cnn_feat | |