anydoor / iseg /coarse_mask_refine_util.py
olfp's picture
Upload 162 files
054c447 verified
raw
history blame
No virus
10.8 kB
"""MobileNet and MobileNetV2."""
'''
Code adopted from https://github.com/LikeLy-Journey/SegmenTron/blob/master/segmentron/models/backbones/mobilenet.py
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
# ============ Basic Blocks ============
class _ConvBNReLU(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
dilation=1, groups=1, relu6=False, norm_layer=nn.BatchNorm2d):
super(_ConvBNReLU, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False)
self.bn = norm_layer(out_channels)
self.relu = nn.ReLU6(True) if relu6 else nn.ReLU(True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class _DepthwiseConv(nn.Module):
"""conv_dw in MobileNet"""
def __init__(self, in_channels, out_channels, stride, norm_layer=nn.BatchNorm2d, **kwargs):
super(_DepthwiseConv, self).__init__()
self.conv = nn.Sequential(
_ConvBNReLU(in_channels, in_channels, 3, stride, 1, groups=in_channels, norm_layer=norm_layer),
_ConvBNReLU(in_channels, out_channels, 1, norm_layer=norm_layer))
def forward(self, x):
return self.conv(x)
class InvertedResidual(nn.Module):
def __init__(self, in_channels, out_channels, stride, expand_ratio, dilation=1, norm_layer=nn.BatchNorm2d):
super(InvertedResidual, self).__init__()
assert stride in [1, 2]
self.use_res_connect = stride == 1 and in_channels == out_channels
layers = list()
inter_channels = int(round(in_channels * expand_ratio))
if expand_ratio != 1:
# pw
layers.append(_ConvBNReLU(in_channels, inter_channels, 1, relu6=True, norm_layer=norm_layer))
layers.extend([
# dw
_ConvBNReLU(inter_channels, inter_channels, 3, stride, dilation, dilation,
groups=inter_channels, relu6=True, norm_layer=norm_layer),
# pw-linear
nn.Conv2d(inter_channels, out_channels, 1, bias=False),
norm_layer(out_channels)])
self.conv = nn.Sequential(*layers)
def forward(self, x):
if self.use_res_connect:
return x + self.conv(x)
else:
return self.conv(x)
# ============ Backbone ============
class MobileNetV2(nn.Module):
def __init__(self, num_classes=1000, norm_layer=nn.BatchNorm2d):
super(MobileNetV2, self).__init__()
output_stride = 8
self.multiplier = 1
if output_stride == 32:
dilations = [1, 1]
elif output_stride == 16:
dilations = [1, 2]
elif output_stride == 8:
dilations = [2, 4]
else:
raise NotImplementedError
inverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1]]
# building first layer
input_channels = int(32 * self.multiplier) if self.multiplier > 1.0 else 32
# last_channels = int(1280 * multiplier) if multiplier > 1.0 else 1280
self.conv1 = _ConvBNReLU(3, input_channels, 3, 2, 1, relu6=True, norm_layer=norm_layer)
# building inverted residual blocks
self.planes = input_channels
self.block1 = self._make_layer(InvertedResidual, self.planes, inverted_residual_setting[0:1],
norm_layer=norm_layer)
self.block2 = self._make_layer(InvertedResidual, self.planes, inverted_residual_setting[1:2],
norm_layer=norm_layer)
self.block3 = self._make_layer(InvertedResidual, self.planes, inverted_residual_setting[2:3],
norm_layer=norm_layer)
self.block4 = self._make_layer(InvertedResidual, self.planes, inverted_residual_setting[3:5],
dilations[0], norm_layer=norm_layer)
self.block5 = self._make_layer(InvertedResidual, self.planes, inverted_residual_setting[5:],
dilations[1], norm_layer=norm_layer)
self.last_inp_channels = self.planes
# building last several layers
# features = list()
# features.append(_ConvBNReLU(input_channels, last_channels, 1, relu6=True, norm_layer=norm_layer))
# features.append(nn.AdaptiveAvgPool2d(1))
# self.features = nn.Sequential(*features)
#
# self.classifier = nn.Sequential(
# nn.Dropout2d(0.2),
# nn.Linear(last_channels, num_classes))
# weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.zeros_(m.bias)
def _make_layer(self, block, planes, inverted_residual_setting, dilation=1, norm_layer=nn.BatchNorm2d):
features = list()
for t, c, n, s in inverted_residual_setting:
out_channels = int(c * self.multiplier)
stride = s if dilation == 1 else 1
features.append(block(planes, out_channels, stride, t, dilation, norm_layer))
planes = out_channels
for i in range(n - 1):
features.append(block(planes, out_channels, 1, t, norm_layer=norm_layer))
planes = out_channels
self.planes = planes
return nn.Sequential(*features)
def forward(self, x, side_feature):
x = self.conv1(x)
x = x + side_feature
x = self.block1(x)
c1 = self.block2(x)
c2 = self.block3(c1)
c3 = self.block4(c2)
c4 = self.block5(c3)
# x = self.features(x)
# x = self.classifier(x.view(x.size(0), x.size(1)))
return c1, c2, c3, c4
def mobilenet_v2(norm_layer=nn.BatchNorm2d):
return MobileNetV2(norm_layer=norm_layer)
# ============ Segmentor ============
class LRASPP(nn.Module):
"""Lite R-ASPP"""
def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d, **kwargs):
super(LRASPP, self).__init__()
self.b0 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
norm_layer(out_channels),
nn.ReLU(True)
)
self.b1 = nn.Sequential(
nn.AdaptiveAvgPool2d((2,2)),
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.Sigmoid(),
)
def forward(self, x):
size = x.size()[2:]
feat1 = self.b0(x)
feat2 = self.b1(x)
feat2 = F.interpolate(feat2, size, mode='bilinear', align_corners=True)
x = feat1 * feat2
return x
class MobileSeg(nn.Module):
def __init__(self, nclass=1, **kwargs):
super(MobileSeg, self).__init__()
self.backbone = mobilenet_v2()
self.lraspp = LRASPP(320,128)
self.fusion_conv1 = nn.Conv2d(128,16,1,1,0)
self.fusion_conv2 = nn.Conv2d(24,16,1,1,0)
self.head = nn.Conv2d(16,nclass,1,1,0)
self.aux_head = nn.Conv2d(16,nclass,1,1,0)
def forward(self, x, side_feature):
x4, _, _, x8 = self.backbone(x, side_feature)
x8 = self.lraspp(x8)
x8 = F.interpolate(x8, x4.size()[2:], mode='bilinear', align_corners=True)
x8 = self.fusion_conv1(x8)
pred_aux = self.aux_head(x8)
x4 = self.fusion_conv2(x4)
x = x4 + x8
pred = self.head(x)
return pred, pred_aux, x
def load_pretrained_weights(self, path_to_weights= ' '):
backbone_state_dict = self.backbone.state_dict()
pretrained_state_dict = torch.load(path_to_weights, map_location='cpu')
ckpt_keys = set(pretrained_state_dict.keys())
own_keys = set(backbone_state_dict.keys())
missing_keys = own_keys - ckpt_keys
unexpected_keys = ckpt_keys - own_keys
print('Loading Mobilnet V2')
print('Missing Keys: ', missing_keys)
print('Unexpected Keys: ', unexpected_keys)
backbone_state_dict.update(pretrained_state_dict)
self.backbone.load_state_dict(backbone_state_dict, strict= False)
class ScaleLayer(nn.Module):
def __init__(self, init_value=1.0, lr_mult=1):
super().__init__()
self.lr_mult = lr_mult
self.scale = nn.Parameter(
torch.full((1,), init_value / lr_mult, dtype=torch.float32)
)
def forward(self, x):
scale = torch.abs(self.scale * self.lr_mult)
return x * scale
# ============ Interactive Segmentor ============
class BaselineModel(nn.Module):
def __init__(self, backbone_lr_mult=0.1,
norm_layer=nn.BatchNorm2d, **kwargs):
super().__init__()
self.feature_extractor = MobileSeg()
side_feature_ch = 32
mt_layers = [
nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(negative_slope=0.2),
nn.Conv2d(in_channels=16, out_channels=side_feature_ch, kernel_size=3, stride=1, padding=1),
ScaleLayer(init_value=0.05, lr_mult=1)
]
self.maps_transform = nn.Sequential(*mt_layers)
def backbone_forward(self, image, coord_features=None):
mask, mask_aux, feature = self.feature_extractor(image, coord_features)
return {'instances': mask, 'instances_aux':mask_aux, 'feature': feature}
def prepare_input(self, image):
prev_mask = torch.zeros_like(image)[:,:1,:,:]
return image, prev_mask
def forward(self, image, coarse_mask):
image, prev_mask = self.prepare_input(image)
coord_features = torch.cat((prev_mask, coarse_mask, coarse_mask * 0.0), dim=1)
click_map = coord_features[:,1:,:,:]
coord_features = self.maps_transform(coord_features)
outputs = self.backbone_forward(image, coord_features)
pred = nn.functional.interpolate(
outputs['instances'],
size=image.size()[2:],
mode='bilinear', align_corners=True
)
outputs['instances'] = torch.sigmoid(pred)
return outputs