ibaiGorordo's picture
Update model.py
803143d
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.jit import script
class WSConv2d(nn.Conv2d):
def __init___(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(WSConv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
def forward(self, x):
weight = self.weight
weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
weight = weight - weight_mean
std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
# std = torch.sqrt(torch.var(weight.view(weight.size(0),-1),dim=1)+1e-12).view(-1,1,1,1)+1e-5
weight = weight / std.expand_as(weight)
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
def conv_ws(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
return WSConv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
groups=groups, bias=bias)
'''
class Mish(nn.Module):
def __init__(self):
super(Mish, self).__init__()
def forward(self, x):
return x*torch.tanh(F.softplus(x))
'''
@script
def _mish_jit_fwd(x): return x.mul(torch.tanh(F.softplus(x)))
@script
def _mish_jit_bwd(x, grad_output):
x_sigmoid = torch.sigmoid(x)
x_tanh_sp = F.softplus(x).tanh()
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
class MishJitAutoFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return _mish_jit_fwd(x)
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_variables[0]
return _mish_jit_bwd(x, grad_output)
# Cell
def mish(x): return MishJitAutoFn.apply(x)
class Mish(nn.Module):
def __init__(self, inplace: bool = False):
super(Mish, self).__init__()
def forward(self, x):
return MishJitAutoFn.apply(x)
######################################################################################################################
######################################################################################################################
# pre-activation based upsampling conv block
class upConvLayer(nn.Module):
def __init__(self, in_channels, out_channels, scale_factor, norm, act, num_groups):
super(upConvLayer, self).__init__()
conv = conv_ws
if act == 'ELU':
act = nn.ELU()
elif act == 'Mish':
act = Mish()
else:
act = nn.ReLU(True)
self.conv = conv(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1,
bias=False)
if norm == 'GN':
self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=in_channels)
else:
self.norm = nn.BatchNorm2d(in_channels, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
self.act = act
self.scale_factor = scale_factor
def forward(self, x):
x = self.norm(x)
x = self.act(x) # pre-activation
x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear')
x = self.conv(x)
return x
# pre-activation based conv block
class myConv(nn.Module):
def __init__(self, in_ch, out_ch, kSize, stride=1,
padding=0, dilation=1, bias=True, norm='GN', act='ELU', num_groups=32):
super(myConv, self).__init__()
conv = conv_ws
if act == 'ELU':
act = nn.ELU()
elif act == 'Mish':
act = Mish()
else:
act = nn.ReLU(True)
module = []
if norm == 'GN':
module.append(nn.GroupNorm(num_groups=num_groups, num_channels=in_ch))
else:
module.append(nn.BatchNorm2d(in_ch, eps=0.001, momentum=0.1, affine=True, track_running_stats=True))
module.append(act)
module.append(conv(in_ch, out_ch, kernel_size=kSize, stride=stride,
padding=padding, dilation=dilation, groups=1, bias=bias))
self.module = nn.Sequential(*module)
def forward(self, x):
out = self.module(x)
return out
# Deep Feature Fxtractor
class deepFeatureExtractor_ResNext101(nn.Module):
def __init__(self, lv6=False):
super(deepFeatureExtractor_ResNext101, self).__init__()
# after passing ReLU : H/2 x W/2
# after passing Layer1 : H/4 x W/4
# after passing Layer2 : H/8 x W/8
# after passing Layer3 : H/16 x W/16
self.encoder = models.resnext101_32x8d(weights=models.ResNeXt101_32X8D_Weights.DEFAULT)
self.fixList = ['layer1.0', 'layer1.1', '.bn']
self.lv6 = lv6
if lv6 is True:
self.layerList = ['relu', 'layer1', 'layer2', 'layer3', 'layer4']
self.dimList = [64, 256, 512, 1024, 2048]
else:
del self.encoder.layer4
del self.encoder.fc
self.layerList = ['relu', 'layer1', 'layer2', 'layer3']
self.dimList = [64, 256, 512, 1024]
for name, parameters in self.encoder.named_parameters():
if name == 'conv1.weight':
parameters.requires_grad = False
if any(x in name for x in self.fixList):
parameters.requires_grad = False
def forward(self, x):
out_featList = []
feature = x
for k, v in self.encoder._modules.items():
if k == 'avgpool':
break
feature = v(feature)
# feature = v(features[-1])
# features.append(feature)
if any(x in k for x in self.layerList):
out_featList.append(feature)
return out_featList
def freeze_bn(self, enable=False):
""" Adapted from https://discuss.pytorch.org/t/how-to-train-with-frozen-batchnorm/12106/8 """
for module in self.modules():
if isinstance(module, nn.BatchNorm2d):
module.train() if enable else module.eval()
module.weight.requires_grad = enable
module.bias.requires_grad = enable
# ASPP Module
class Dilated_bottleNeck(nn.Module):
def __init__(self, norm, act, in_feat):
super(Dilated_bottleNeck, self).__init__()
conv = conv_ws
# in feat = 1024 in ResNext101 and ResNet101
self.reduction1 = conv(in_feat, in_feat // 2, kernel_size=1, stride=1, bias=False, padding=0)
self.aspp_d3 = nn.Sequential(
myConv(in_feat // 2, in_feat // 4, kSize=1, stride=1, padding=0, dilation=1, bias=False, norm=norm, act=act,
num_groups=(in_feat // 2) // 16),
myConv(in_feat // 4, in_feat // 4, kSize=3, stride=1, padding=3, dilation=3, bias=False, norm=norm, act=act,
num_groups=(in_feat // 4) // 16))
self.aspp_d6 = nn.Sequential(
myConv(in_feat // 2 + in_feat // 4, in_feat // 4, kSize=1, stride=1, padding=0, dilation=1, bias=False,
norm=norm, act=act, num_groups=(in_feat // 2 + in_feat // 4) // 16),
myConv(in_feat // 4, in_feat // 4, kSize=3, stride=1, padding=6, dilation=6, bias=False, norm=norm, act=act,
num_groups=(in_feat // 4) // 16))
self.aspp_d12 = nn.Sequential(
myConv(in_feat, in_feat // 4, kSize=1, stride=1, padding=0, dilation=1, bias=False, norm=norm, act=act,
num_groups=(in_feat) // 16),
myConv(in_feat // 4, in_feat // 4, kSize=3, stride=1, padding=12, dilation=12, bias=False, norm=norm,
act=act, num_groups=(in_feat // 4) // 16))
self.aspp_d18 = nn.Sequential(
myConv(in_feat + in_feat // 4, in_feat // 4, kSize=1, stride=1, padding=0, dilation=1, bias=False,
norm=norm, act=act, num_groups=(in_feat + in_feat // 4) // 16),
myConv(in_feat // 4, in_feat // 4, kSize=3, stride=1, padding=18, dilation=18, bias=False, norm=norm,
act=act, num_groups=(in_feat // 4) // 16))
self.reduction2 = myConv(((in_feat // 4) * 4) + (in_feat // 2), in_feat // 2, kSize=3, stride=1, padding=1,
bias=False, norm=norm, act=act, num_groups=((in_feat // 4) * 4 + (in_feat // 2)) // 16)
def forward(self, x):
x = self.reduction1(x)
d3 = self.aspp_d3(x)
cat1 = torch.cat([x, d3], dim=1)
d6 = self.aspp_d6(cat1)
cat2 = torch.cat([cat1, d6], dim=1)
d12 = self.aspp_d12(cat2)
cat3 = torch.cat([cat2, d12], dim=1)
d18 = self.aspp_d18(cat3)
out = self.reduction2(torch.cat([x, d3, d6, d12, d18], dim=1))
return out # 512 x H/16 x W/16
class Dilated_bottleNeck2(nn.Module):
def __init__(self, norm, act, in_feat):
super(Dilated_bottleNeck2, self).__init__()
conv = conv_ws
# in feat = 1024 in ResNext101 and ResNet101
# self.reduction1 = conv(in_feat, in_feat//2, kernel_size=1, stride = 1, bias=False, padding=0)
self.reduction1 = conv(in_feat, in_feat // 2, kernel_size=3, stride=1, padding=1, bias=False)
self.aspp_d3 = nn.Sequential(
myConv(in_feat // 2, in_feat // 4, kSize=1, stride=1, padding=0, dilation=1, bias=False, norm=norm, act=act,
num_groups=(in_feat // 2) // 16),
myConv(in_feat // 4, in_feat // 4, kSize=3, stride=1, padding=3, dilation=3, bias=False, norm=norm, act=act,
num_groups=(in_feat // 4) // 16))
self.aspp_d6 = nn.Sequential(
myConv(in_feat // 2 + in_feat // 4, in_feat // 4, kSize=1, stride=1, padding=0, dilation=1, bias=False,
norm=norm, act=act, num_groups=(in_feat // 2 + in_feat // 4) // 16),
myConv(in_feat // 4, in_feat // 4, kSize=3, stride=1, padding=6, dilation=6, bias=False, norm=norm, act=act,
num_groups=(in_feat // 4) // 16))
self.aspp_d12 = nn.Sequential(
myConv(in_feat, in_feat // 4, kSize=1, stride=1, padding=0, dilation=1, bias=False, norm=norm, act=act,
num_groups=(in_feat) // 16),
myConv(in_feat // 4, in_feat // 4, kSize=3, stride=1, padding=12, dilation=12, bias=False, norm=norm,
act=act, num_groups=(in_feat // 4) // 16))
self.aspp_d18 = nn.Sequential(
myConv(in_feat + in_feat // 4, in_feat // 4, kSize=1, stride=1, padding=0, dilation=1, bias=False,
norm=norm, act=act, num_groups=(in_feat + in_feat // 4) // 16),
myConv(in_feat // 4, in_feat // 4, kSize=3, stride=1, padding=18, dilation=18, bias=False, norm=norm,
act=act, num_groups=(in_feat // 4) // 16))
self.aspp_d24 = nn.Sequential(
myConv(in_feat + in_feat // 2, in_feat // 4, kSize=1, stride=1, padding=0, dilation=1, bias=False,
norm=norm, act=act, num_groups=(in_feat + in_feat // 2) // 16),
myConv(in_feat // 4, in_feat // 4, kSize=3, stride=1, padding=24, dilation=24, bias=False, norm=norm,
act=act, num_groups=(in_feat // 4) // 16))
self.reduction2 = myConv(((in_feat // 4) * 5) + (in_feat // 2), in_feat // 2, kSize=3, stride=1, padding=1,
bias=False, norm=norm, act=act, num_groups=((in_feat // 4) * 5 + (in_feat // 2)) // 16)
def forward(self, x):
x = self.reduction1(x)
d3 = self.aspp_d3(x)
cat1 = torch.cat([x, d3], dim=1)
d6 = self.aspp_d6(cat1)
cat2 = torch.cat([cat1, d6], dim=1)
d12 = self.aspp_d12(cat2)
cat3 = torch.cat([cat2, d12], dim=1)
d18 = self.aspp_d18(cat3)
cat4 = torch.cat([cat3, d18], dim=1)
d24 = self.aspp_d24(cat4)
out = self.reduction2(torch.cat([x, d3, d6, d12, d18, d24], dim=1))
return out # 512 x H/16 x W/16
class Dilated_bottleNeck_lv6(nn.Module):
def __init__(self, norm, act, in_feat):
super(Dilated_bottleNeck_lv6, self).__init__()
conv = conv_ws
in_feat = in_feat // 2
self.reduction1 = myConv(in_feat * 2, in_feat // 2, kSize=3, stride=1, padding=1, bias=False, norm=norm,
act=act, num_groups=(in_feat) // 16)
self.aspp_d3 = nn.Sequential(
myConv(in_feat // 2, in_feat // 4, kSize=1, stride=1, padding=0, dilation=1, bias=False, norm=norm, act=act,
num_groups=(in_feat // 2) // 16),
myConv(in_feat // 4, in_feat // 4, kSize=3, stride=1, padding=3, dilation=3, bias=False, norm=norm, act=act,
num_groups=(in_feat // 4) // 16))
self.aspp_d6 = nn.Sequential(
myConv(in_feat // 2 + in_feat // 4, in_feat // 4, kSize=1, stride=1, padding=0, dilation=1, bias=False,
norm=norm, act=act, num_groups=(in_feat // 2 + in_feat // 4) // 16),
myConv(in_feat // 4, in_feat // 4, kSize=3, stride=1, padding=6, dilation=6, bias=False, norm=norm, act=act,
num_groups=(in_feat // 4) // 16))
self.aspp_d12 = nn.Sequential(
myConv(in_feat, in_feat // 4, kSize=1, stride=1, padding=0, dilation=1, bias=False, norm=norm, act=act,
num_groups=(in_feat) // 16),
myConv(in_feat // 4, in_feat // 4, kSize=3, stride=1, padding=12, dilation=12, bias=False, norm=norm,
act=act, num_groups=(in_feat // 4) // 16))
self.aspp_d18 = nn.Sequential(
myConv(in_feat + in_feat // 4, in_feat // 4, kSize=1, stride=1, padding=0, dilation=1, bias=False,
norm=norm, act=act, num_groups=(in_feat + in_feat // 4) // 16),
myConv(in_feat // 4, in_feat // 4, kSize=3, stride=1, padding=18, dilation=18, bias=False, norm=norm,
act=act, num_groups=(in_feat // 4) // 16))
self.reduction2 = myConv(((in_feat // 4) * 4) + (in_feat // 2), in_feat, kSize=3, stride=1, padding=1,
bias=False, norm=norm, act=act, num_groups=((in_feat // 4) * 4 + (in_feat // 2)) // 16)
def forward(self, x):
x = self.reduction1(x)
d3 = self.aspp_d3(x)
cat1 = torch.cat([x, d3], dim=1)
d6 = self.aspp_d6(cat1)
cat2 = torch.cat([cat1, d6], dim=1)
d12 = self.aspp_d12(cat2)
cat3 = torch.cat([cat2, d12], dim=1)
d18 = self.aspp_d18(cat3)
out = self.reduction2(torch.cat([x, d3, d6, d12, d18], dim=1))
return out # 512 x H/16 x W/16
# Laplacian Decoder Network
class Lap_decoder_lv5(nn.Module):
def __init__(self, dimList, norm="BN", rank=0, act='ReLU', max_depth=80):
super(Lap_decoder_lv5, self).__init__()
conv = conv_ws
if norm == 'GN':
if rank == 0:
print("==> Norm: GN")
else:
if rank == 0:
print("==> Norm: BN")
if act == 'ELU':
act = 'ELU'
elif act == 'Mish':
act = 'Mish'
else:
act = 'ReLU'
kSize = 3
self.max_depth = max_depth
self.ASPP = Dilated_bottleNeck(norm, act, dimList[3])
self.dimList = dimList
############################################ Pyramid Level 5 ###################################################
# decoder1 out : 1 x H/16 x W/16 (Level 5)
self.decoder1 = nn.Sequential(
myConv(dimList[3] // 2, dimList[3] // 4, kSize, stride=1, padding=kSize // 2, bias=False,
norm=norm, act=act, num_groups=(dimList[3] // 2) // 16),
myConv(dimList[3] // 4, dimList[3] // 8, kSize, stride=1, padding=kSize // 2, bias=False,
norm=norm, act=act, num_groups=(dimList[3] // 4) // 16),
myConv(dimList[3] // 8, dimList[3] // 16, kSize, stride=1, padding=kSize // 2, bias=False,
norm=norm, act=act, num_groups=(dimList[3] // 8) // 16),
myConv(dimList[3] // 16, dimList[3] // 32, kSize, stride=1, padding=kSize // 2, bias=False,
norm=norm, act=act, num_groups=(dimList[3] // 16) // 16),
myConv(dimList[3] // 32, 1, kSize, stride=1, padding=kSize // 2, bias=False,
norm=norm, act=act, num_groups=(dimList[3] // 32) // 16)
)
########################################################################################################################
############################################ Pyramid Level 4 ###################################################
# decoder2 out : 1 x H/8 x W/8 (Level 4)
# decoder2_up : (H/16,W/16)->(H/8,W/8)
self.decoder2_up1 = upConvLayer(dimList[3] // 2, dimList[3] // 4, 2, norm, act, (dimList[3] // 2) // 16)
self.decoder2_reduc1 = myConv(dimList[3] // 4 + dimList[2], dimList[3] // 4 - 4, kSize=1, stride=1, padding=0,
bias=False,
norm=norm, act=act, num_groups=(dimList[3] // 4 + dimList[2]) // 16)
self.decoder2_1 = myConv(dimList[3] // 4, dimList[3] // 4, kSize, stride=1, padding=kSize // 2, bias=False,
norm=norm, act=act, num_groups=(dimList[3] // 4) // 16)
self.decoder2_2 = myConv(dimList[3] // 4, dimList[3] // 8, kSize, stride=1, padding=kSize // 2, bias=False,
norm=norm, act=act, num_groups=(dimList[3] // 4) // 16)
self.decoder2_3 = myConv(dimList[3] // 8, dimList[3] // 16, kSize, stride=1, padding=kSize // 2, bias=False,
norm=norm, act=act, num_groups=(dimList[3] // 8) // 16)
self.decoder2_4 = myConv(dimList[3] // 16, 1, kSize, stride=1, padding=kSize // 2, bias=False,
norm=norm, act=act, num_groups=(dimList[3] // 16) // 16)
########################################################################################################################
############################################ Pyramid Level 3 ###################################################
# decoder2 out2 : 1 x H/4 x W/4 (Level 3)
# decoder2_1_up2 : (H/8,W/8)->(H/4,W/4)
self.decoder2_1_up2 = upConvLayer(dimList[3] // 4, dimList[3] // 8, 2, norm, act, (dimList[3] // 4) // 16)
self.decoder2_1_reduc2 = myConv(dimList[3] // 8 + dimList[1], dimList[3] // 8 - 4, kSize=1, stride=1, padding=0,
bias=False,
norm=norm, act=act, num_groups=(dimList[3] // 8 + dimList[1]) // 16)
self.decoder2_1_1 = myConv(dimList[3] // 8, dimList[3] // 8, kSize, stride=1, padding=kSize // 2, bias=False,
norm=norm, act=act, num_groups=(dimList[3] // 8) // 16)
self.decoder2_1_2 = myConv(dimList[3] // 8, dimList[3] // 16, kSize, stride=1, padding=kSize // 2, bias=False,
norm=norm, act=act, num_groups=(dimList[3] // 8) // 16)
self.decoder2_1_3 = myConv(dimList[3] // 16, 1, kSize, stride=1, padding=kSize // 2, bias=False,
norm=norm, act=act, num_groups=(dimList[3] // 16) // 16)
########################################################################################################################
############################################ Pyramid Level 2 ###################################################
# decoder2 out3 : 1 x H/2 x W/2 (Level 2)
# decoder2_1_1_up3 : (H/4,W/4)->(H/2,W/2)
self.decoder2_1_1_up3 = upConvLayer(dimList[3] // 8, dimList[3] // 16, 2, norm, act, (dimList[3] // 8) // 16)
self.decoder2_1_1_reduc3 = myConv(dimList[3] // 16 + dimList[0], dimList[3] // 16 - 4, kSize=1, stride=1,
padding=0, bias=False,
norm=norm, act=act, num_groups=(dimList[3] // 16 + dimList[0]) // 16)
self.decoder2_1_1_1 = myConv(dimList[3] // 16, dimList[3] // 16, kSize, stride=1, padding=kSize // 2,
bias=False,
norm=norm, act=act, num_groups=(dimList[3] // 16) // 16)
self.decoder2_1_1_2 = myConv(dimList[3] // 16, 1, kSize, stride=1, padding=kSize // 2, bias=False,
norm=norm, act=act, num_groups=(dimList[3] // 16) // 16)
########################################################################################################################
############################################ Pyramid Level 1 ###################################################
# decoder5 out : 1 x H x W (Level 1)
# decoder2_1_1_1_up4 : (H/2,W/2)->(H,W)
self.decoder2_1_1_1_up4 = upConvLayer(dimList[3] // 16, dimList[3] // 16 - 4, 2, norm, act,
(dimList[3] // 16) // 16)
self.decoder2_1_1_1_1 = myConv(dimList[3] // 16, dimList[3] // 16, kSize, stride=1, padding=kSize // 2,
bias=False,
norm=norm, act=act, num_groups=(dimList[3] // 16) // 16)
self.decoder2_1_1_1_2 = myConv(dimList[3] // 16, dimList[3] // 32, kSize, stride=1, padding=kSize // 2,
bias=False,
norm=norm, act=act, num_groups=(dimList[3] // 16) // 16)
self.decoder2_1_1_1_3 = myConv(dimList[3] // 32, 1, kSize, stride=1, padding=kSize // 2, bias=False,
norm=norm, act=act, num_groups=(dimList[3] // 32) // 16)
########################################################################################################################
self.upscale = F.interpolate
def forward(self, x, rgb):
cat1, cat2, cat3, dense_feat = x[0], x[1], x[2], x[3]
rgb_lv6, rgb_lv5, rgb_lv4, rgb_lv3, rgb_lv2, rgb_lv1 = rgb[0], rgb[1], rgb[2], rgb[3], rgb[4], rgb[5]
dense_feat = self.ASPP(dense_feat) # Dense feature for lev 5
# decoder 1 - Pyramid level 5
lap_lv5 = torch.sigmoid(self.decoder1(dense_feat))
lap_lv5_up = self.upscale(lap_lv5, scale_factor=2, mode='bilinear')
# decoder 2 - Pyramid level 4
dec2 = self.decoder2_up1(dense_feat)
dec2 = self.decoder2_reduc1(torch.cat([dec2, cat3], dim=1))
dec2_up = self.decoder2_1(torch.cat([dec2, lap_lv5_up, rgb_lv4], dim=1))
dec2 = self.decoder2_2(dec2_up)
dec2 = self.decoder2_3(dec2)
lap_lv4 = torch.tanh(self.decoder2_4(dec2) + (0.1 * rgb_lv4.mean(dim=1, keepdim=True)))
# if depth range is (0,1), laplacian of image range is (-1,1)
lap_lv4_up = self.upscale(lap_lv4, scale_factor=2, mode='bilinear')
# decoder 2 - Pyramid level 3
dec3 = self.decoder2_1_up2(dec2_up)
dec3 = self.decoder2_1_reduc2(torch.cat([dec3, cat2], dim=1))
dec3_up = self.decoder2_1_1(torch.cat([dec3, lap_lv4_up, rgb_lv3], dim=1))
dec3 = self.decoder2_1_2(dec3_up)
lap_lv3 = torch.tanh(self.decoder2_1_3(dec3) + (0.1 * rgb_lv3.mean(dim=1, keepdim=True)))
# if depth range is (0,1), laplacian of image range is (-1,1)
lap_lv3_up = self.upscale(lap_lv3, scale_factor=2, mode='bilinear')
# decoder 2 - Pyramid level 2
dec4 = self.decoder2_1_1_up3(dec3_up)
dec4 = self.decoder2_1_1_reduc3(torch.cat([dec4, cat1], dim=1))
dec4_up = self.decoder2_1_1_1(torch.cat([dec4, lap_lv3_up, rgb_lv2], dim=1))
lap_lv2 = torch.tanh(self.decoder2_1_1_2(dec4_up) + (0.1 * rgb_lv2.mean(dim=1, keepdim=True)))
# if depth range is (0,1), laplacian of image range is (-1,1)
lap_lv2_up = self.upscale(lap_lv2, scale_factor=2, mode='bilinear')
# decoder 2 - Pyramid level 1
dec5 = self.decoder2_1_1_1_up4(dec4_up)
dec5 = self.decoder2_1_1_1_1(torch.cat([dec5, lap_lv2_up, rgb_lv1], dim=1))
dec5 = self.decoder2_1_1_1_2(dec5)
lap_lv1 = torch.tanh(self.decoder2_1_1_1_3(dec5) + (0.1 * rgb_lv1.mean(dim=1, keepdim=True)))
# if depth range is (0,1), laplacian of image range is (-1,1)
# Laplacian restoration
lap_lv4_img = lap_lv4 + lap_lv5_up
lap_lv3_img = lap_lv3 + self.upscale(lap_lv4_img, scale_factor=2, mode='bilinear')
lap_lv2_img = lap_lv2 + self.upscale(lap_lv3_img, scale_factor=2, mode='bilinear')
final_depth = lap_lv1 + self.upscale(lap_lv2_img, scale_factor=2, mode='bilinear')
final_depth = torch.sigmoid(final_depth)
return [(lap_lv5) * self.max_depth, (lap_lv4) * self.max_depth, (lap_lv3) * self.max_depth,
(lap_lv2) * self.max_depth, (lap_lv1) * self.max_depth], final_depth * self.max_depth
# fit laplacian image range (-80,80), depth image range(0,80)
class Lap_decoder_lv6(nn.Module):
def __init__(self, dimList, norm="BN", rank=0, act='ReLU', max_depth=80):
super(Lap_decoder_lv6, self).__init__()
norm = norm
conv = conv_ws
if norm == 'GN':
if rank == 0:
print("==> Norm: GN")
else:
if rank == 0:
print("==> Norm: BN")
if act == 'ELU':
act = 'ELU'
elif act == 'Mish':
act = 'Mish'
else:
act = 'ReLU'
kSize = 3
self.max_depth = max_depth
self.ASPP = Dilated_bottleNeck_lv6(norm, act, dimList[4])
dimList[4] = dimList[4] // 2
self.dimList = dimList
############################################ Pyramid Level 6 ###################################################
# decoder1 out : 1 x H/32 x W/32 (Level 6)
self.decoder1 = nn.Sequential(
myConv(dimList[4] // 2, dimList[4] // 4, kSize, stride=1, padding=kSize // 2, bias=False,
norm=norm, act=act, num_groups=(dimList[4] // 2) // 16),
myConv(dimList[4] // 4, dimList[4] // 8, kSize, stride=1, padding=kSize // 2, bias=False,
norm=norm, act=act, num_groups=(dimList[4] // 4) // 16),
myConv(dimList[4] // 8, dimList[4] // 16, kSize, stride=1, padding=kSize // 2, bias=False,
norm=norm, act=act, num_groups=(dimList[4] // 8) // 16),
myConv(dimList[4] // 16, dimList[4] // 32, kSize, stride=1, padding=kSize // 2, bias=False,
norm=norm, act=act, num_groups=(dimList[4] // 16) // 16),
myConv(dimList[4] // 32, 1, kSize, stride=1, padding=kSize // 2, bias=False,
norm=norm, act=act, num_groups=(dimList[4] // 32) // 8)
)
########################################################################################################################
############################################ Pyramid Level 5 ###################################################
# decoder2 out : 1 x H/16 x W/16 (Level 5)
# decoder2_up : (H/32,W/32)->(H/16,W/16)
self.decoder2_up1 = upConvLayer(dimList[4] // 2, dimList[4] // 4, 2, norm, act, (dimList[4] // 2) // 16)
self.decoder2_reduc1 = myConv(dimList[4] // 4 + dimList[3], dimList[4] // 4 - 4, kSize=1, stride=1, padding=0,
bias=False,
norm=norm, act=act, num_groups=(dimList[4] // 4 + dimList[3]) // 16)
self.decoder2_1 = myConv(dimList[4] // 4, dimList[4] // 4, kSize, stride=1, padding=kSize // 2, bias=False,
norm=norm, act=act, num_groups=(dimList[4] // 4) // 16)
self.decoder2_2 = myConv(dimList[4] // 4, dimList[4] // 8, kSize, stride=1, padding=kSize // 2, bias=False,
norm=norm, act=act, num_groups=(dimList[4] // 4) // 16)
self.decoder2_3 = myConv(dimList[4] // 8, dimList[4] // 16, kSize, stride=1, padding=kSize // 2, bias=False,
norm=norm, act=act, num_groups=(dimList[4] // 8) // 16)
self.decoder2_4 = myConv(dimList[4] // 16, 1, kSize, stride=1, padding=kSize // 2, bias=False,
norm=norm, act=act, num_groups=(dimList[4] // 16) // 16)
########################################################################################################################
############################################ Pyramid Level 4 ###################################################
# decoder2 out2 : 1 x H/8 x W/8 (Level 4)
# decoder2_1_up2 : (H/16,W/16)->(H/8,W/8)
self.decoder2_1_up2 = upConvLayer(dimList[4] // 4, dimList[4] // 8, 2, norm, act, (dimList[4] // 4) // 16)
self.decoder2_1_reduc2 = myConv(dimList[4] // 8 + dimList[2], dimList[4] // 8 - 4, kSize=1, stride=1, padding=0,
bias=False,
norm=norm, act=act, num_groups=(dimList[4] // 8 + dimList[2]) // 16)
self.decoder2_1_1 = myConv(dimList[4] // 8, dimList[4] // 8, kSize, stride=1, padding=kSize // 2, bias=False,
norm=norm, act=act, num_groups=(dimList[4] // 8) // 16)
self.decoder2_1_2 = myConv(dimList[4] // 8, dimList[4] // 16, kSize, stride=1, padding=kSize // 2, bias=False,
norm=norm, act=act, num_groups=(dimList[4] // 8) // 16)
self.decoder2_1_3 = myConv(dimList[4] // 16, 1, kSize, stride=1, padding=kSize // 2, bias=False,
norm=norm, act=act, num_groups=(dimList[4] // 16) // 16)
########################################################################################################################
############################################ Pyramid Level 3 ###################################################
# decoder2 out3 : 1 x H/4 x W/4 (Level 3)
# decoder2_1_1_up3 : (H/8,W/8)->(H/4,W/4)
self.decoder2_1_1_up3 = upConvLayer(dimList[4] // 8, dimList[4] // 16, 2, norm, act, (dimList[4] // 8) // 16)
self.decoder2_1_1_reduc3 = myConv(dimList[4] // 16 + dimList[1], dimList[4] // 16 - 4, kSize=1, stride=1,
padding=0, bias=False,
norm=norm, act=act, num_groups=(dimList[4] // 16 + dimList[1]) // 8)
self.decoder2_1_1_1 = myConv(dimList[4] // 16, dimList[4] // 16, kSize, stride=1, padding=kSize // 2,
bias=False,
norm=norm, act=act, num_groups=(dimList[4] // 16) // 16)
self.decoder2_1_1_2 = myConv(dimList[4] // 16, 1, kSize, stride=1, padding=kSize // 2, bias=False,
norm=norm, act=act, num_groups=(dimList[4] // 16) // 16)
########################################################################################################################
############################################ Pyramid Level 2 ###################################################
# decoder2 out4 : 1 x H/2 x W/2 (Level 2)
# decoder2_1_1_1_up4 : (H/4,W/4)->(H/2,W/2)
self.decoder2_1_1_1_up4 = upConvLayer(dimList[4] // 16, dimList[4] // 32, 2, norm, act,
(dimList[4] // 16) // 16)
self.decoder2_1_1_1_reduc4 = myConv(dimList[4] // 32 + dimList[0], dimList[4] // 32 - 4, kSize=1, stride=1,
padding=0, bias=False,
norm=norm, act=act, num_groups=(dimList[4] // 32 + dimList[0]) // 8)
self.decoder2_1_1_1_1 = myConv(dimList[4] // 32, dimList[4] // 32, kSize, stride=1, padding=kSize // 2,
bias=False,
norm=norm, act=act, num_groups=(dimList[4] // 32) // 8)
self.decoder2_1_1_1_2 = myConv(dimList[4] // 32, 1, kSize, stride=1, padding=kSize // 2, bias=False,
norm=norm, act=act, num_groups=(dimList[4] // 32) // 8)
########################################################################################################################
############################################ Pyramid Level 1 ###################################################
# decoder5 out : 1 x H x W (Level 1)
# decoder2_1_1_1_1_up5 : (H/2,W/2)->(H,W)
self.decoder2_1_1_1_1_up5 = upConvLayer(dimList[4] // 32, dimList[4] // 32 - 4, 2, norm, act,
(dimList[4] // 32) // 8) # H x W (64 -> 60)
self.decoder2_1_1_1_1_1 = myConv(dimList[4] // 32, dimList[4] // 32, kSize, stride=1, padding=kSize // 2,
bias=False,
norm=norm, act=act, num_groups=(dimList[4] // 32) // 8)
self.decoder2_1_1_1_1_2 = myConv(dimList[4] // 32, dimList[4] // 64, kSize, stride=1, padding=kSize // 2,
bias=False,
norm=norm, act=act, num_groups=(dimList[4] // 32) // 8)
self.decoder2_1_1_1_1_3 = myConv(dimList[4] // 64, 1, kSize, stride=1, padding=kSize // 2, bias=False,
norm=norm, act=act, num_groups=(dimList[4] // 64) // 4)
########################################################################################################################
self.upscale = F.interpolate
def forward(self, x, rgb):
cat1, cat2, cat3, cat4, dense_feat = x[0], x[1], x[2], x[3], x[4]
rgb_lv6, rgb_lv5, rgb_lv4, rgb_lv3, rgb_lv2, rgb_lv1 = rgb[0], rgb[1], rgb[2], rgb[3], rgb[4], rgb[5]
dense_feat = self.ASPP(dense_feat) # Dense feature for lev 6
# decoder 1 - Pyramid level 6
lap_lv6 = torch.sigmoid(self.decoder1(dense_feat))
lap_lv6_up = self.upscale(lap_lv6, scale_factor=2, mode='bilinear')
# decoder 2 - Pyramid level 5
dec2 = self.decoder2_up1(dense_feat)
dec2 = self.decoder2_reduc1(torch.cat([dec2, cat4], dim=1))
dec2_up = self.decoder2_1(torch.cat([dec2, lap_lv6_up, rgb_lv5], dim=1))
dec2 = self.decoder2_2(dec2_up)
dec2 = self.decoder2_3(dec2)
lap_lv5 = torch.tanh(self.decoder2_4(dec2) + (0.1 * rgb_lv5.mean(dim=1, keepdim=True)))
# if depth range is (0,1), laplacian image range is (-1,1)
lap_lv5_up = self.upscale(lap_lv5, scale_factor=2, mode='bilinear')
# decoder 2 - Pyramid level 4
dec3 = self.decoder2_1_up2(dec2_up)
dec3 = self.decoder2_1_reduc2(torch.cat([dec3, cat3], dim=1))
dec3_up = self.decoder2_1_1(torch.cat([dec3, lap_lv5_up, rgb_lv4], dim=1))
dec3 = self.decoder2_1_2(dec3_up)
lap_lv4 = torch.tanh(self.decoder2_1_3(dec3) + (0.1 * rgb_lv4.mean(dim=1, keepdim=True)))
# if depth range is (0,1), laplacian image range is (-1,1)
lap_lv4_up = self.upscale(lap_lv4, scale_factor=2, mode='bilinear')
# decoder 2 - Pyramid level 3
dec4 = self.decoder2_1_1_up3(dec3_up)
dec4 = self.decoder2_1_1_reduc3(torch.cat([dec4, cat2], dim=1))
dec4_up = self.decoder2_1_1_1(torch.cat([dec4, lap_lv4_up, rgb_lv3], dim=1))
lap_lv3 = torch.tanh(self.decoder2_1_1_2(dec4_up) + (0.1 * rgb_lv3.mean(dim=1, keepdim=True)))
# if depth range is (0,1), laplacian image range is (-1,1)
lap_lv3_up = self.upscale(lap_lv3, scale_factor=2, mode='bilinear')
# decoder 2 - Pyramid level 2
dec5 = self.decoder2_1_1_1_up4(dec4_up)
dec5 = self.decoder2_1_1_1_reduc4(torch.cat([dec5, cat1], dim=1))
dec5_up = self.decoder2_1_1_1_1(torch.cat([dec5, lap_lv3_up, rgb_lv2], dim=1))
lap_lv2 = torch.tanh(self.decoder2_1_1_1_2(dec5_up) + (0.1 * rgb_lv2.mean(dim=1, keepdim=True)))
# if depth range is (0,1), laplacian image range is (-1,1)
lap_lv2_up = self.upscale(lap_lv2, scale_factor=2, mode='bilinear')
# decoder 2 - Pyramid level 1
dec6 = self.decoder2_1_1_1_1_up5(dec5_up)
dec6 = self.decoder2_1_1_1_1_1(torch.cat([dec6, lap_lv2_up, rgb_lv1], dim=1))
dec6 = self.decoder2_1_1_1_1_2(dec6)
lap_lv1 = torch.tanh(self.decoder2_1_1_1_1_3(dec6) + (0.1 * rgb_lv1.mean(dim=1, keepdim=True)))
# if depth range is (0,1), laplacian image range is (-1,1)
# Laplacian restoration
lap_lv5_img = lap_lv5 + lap_lv6_up
lap_lv4_img = lap_lv4 + self.upscale(lap_lv5_img, scale_factor=2, mode='bilinear')
lap_lv3_img = lap_lv3 + self.upscale(lap_lv4_img, scale_factor=2, mode='bilinear')
lap_lv2_img = lap_lv2 + self.upscale(lap_lv3_img, scale_factor=2, mode='bilinear')
final_depth = lap_lv1 + self.upscale(lap_lv2_img, scale_factor=2, mode='bilinear')
final_depth = torch.sigmoid(final_depth)
return [(lap_lv6) * self.max_depth, (lap_lv5) * self.max_depth, (lap_lv4) * self.max_depth,
(lap_lv3) * self.max_depth, (lap_lv2) * self.max_depth,
(lap_lv1) * self.max_depth], final_depth * self.max_depth
# fit laplacian image range (-80,80), depth image range(0,80)
# Laplacian Depth Residual Network
class LDRN(nn.Module):
def __init__(self, lv6=False, norm="BN", rank=0, act='ReLU', max_depth=80):
super(LDRN, self).__init__()
self.encoder = deepFeatureExtractor_ResNext101(lv6)
if lv6 is True:
self.decoder = Lap_decoder_lv6(self.encoder.dimList, norm, rank, act, max_depth)
else:
self.decoder = Lap_decoder_lv5(self.encoder.dimList, norm, rank, act, max_depth)
def forward(self, x):
out_featList = self.encoder(x)
rgb_down2 = F.interpolate(x, scale_factor=0.5, mode='bilinear')
rgb_down4 = F.interpolate(rgb_down2, scale_factor=0.5, mode='bilinear')
rgb_down8 = F.interpolate(rgb_down4, scale_factor=0.5, mode='bilinear')
rgb_down16 = F.interpolate(rgb_down8, scale_factor=0.5, mode='bilinear')
rgb_down32 = F.interpolate(rgb_down16, scale_factor=0.5, mode='bilinear')
rgb_up16 = F.interpolate(rgb_down32, rgb_down16.shape[2:], mode='bilinear')
rgb_up8 = F.interpolate(rgb_down16, rgb_down8.shape[2:], mode='bilinear')
rgb_up4 = F.interpolate(rgb_down8, rgb_down4.shape[2:], mode='bilinear')
rgb_up2 = F.interpolate(rgb_down4, rgb_down2.shape[2:], mode='bilinear')
rgb_up = F.interpolate(rgb_down2, x.shape[2:], mode='bilinear')
lap1 = x - rgb_up
lap2 = rgb_down2 - rgb_up2
lap3 = rgb_down4 - rgb_up4
lap4 = rgb_down8 - rgb_up8
lap5 = rgb_down16 - rgb_up16
rgb_list = [rgb_down32, lap5, lap4, lap3, lap2, lap1]
d_res_list, depth = self.decoder(out_featList, rgb_list)
return d_res_list, depth
def train(self, mode=True):
super().train(mode)
self.encoder.freeze_bn()