Tej3 commited on
Commit
54d726d
·
1 Parent(s): 1f8842a

Adding Application, models and ckpt files

Browse files
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from models.pretrained_decv2 import enc_dec_model
4
+ from models.densenet_v2 import Densenet
5
+ from models.unet_resnet18 import ResNet18UNet
6
+ from models.unet_resnet50 import UNetWithResnet50Encoder
7
+ import numpy as np
8
+ import cv2
9
+
10
+ # kb cropping
11
+ def cropping(img):
12
+ h_im, w_im = img.shape[:2]
13
+
14
+ margin_top = int(h_im - 352)
15
+ margin_left = int((w_im - 1216) / 2)
16
+
17
+ img = img[margin_top: margin_top + 352,
18
+ margin_left: margin_left + 1216]
19
+
20
+ return img
21
+
22
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
23
+ print(DEVICE)
24
+ CWD = "."
25
+ CKPT_FILE_NAMES = {
26
+ 'Indoor':{
27
+ 'Resnet_enc':'resnet_nyu_best.ckpt',
28
+ 'Unet':'resnet18_unet_epoch_08_model_kitti_and_nyu.ckpt',
29
+ 'Densenet_enc':'densenet_epoch_15_model.ckpt'
30
+ },
31
+ 'Outdoor':{
32
+ 'Resnet_enc':'resnet_encdecmodel_epoch_05_model_nyu_and_kitti.ckpt',
33
+ 'Unet':'resnet50_unet_epoch_02_model_nyuandkitti.ckpt',
34
+ 'Densenet_enc':'densenet_nyu_then_kitti_epoch_10_model.ckpt'
35
+ }
36
+ }
37
+ MODEL_CLASSES = {
38
+ 'Indoor': {
39
+ 'Resnet_enc':enc_dec_model,
40
+ 'Unet':ResNet18UNet,
41
+ 'Densenet_enc':Densenet
42
+ },
43
+
44
+ 'Outdoor': {
45
+ 'Resnet_enc':enc_dec_model,
46
+ 'Unet':UNetWithResnet50Encoder,
47
+ 'Densenet_enc':Densenet
48
+ },
49
+
50
+ }
51
+
52
+ def load_model(ckpt, model, optimizer=None):
53
+ ckpt_dict = torch.load(ckpt, map_location='cpu')
54
+ # keep backward compatibility
55
+ if 'model' not in ckpt_dict and 'optimizer' not in ckpt_dict:
56
+ state_dict = ckpt_dict
57
+ else:
58
+ state_dict = ckpt_dict['model']
59
+ weights = {}
60
+ for key, value in state_dict.items():
61
+ if key.startswith('module.'):
62
+ weights[key[len('module.'):]] = value
63
+ else:
64
+ weights[key] = value
65
+
66
+ model.load_state_dict(weights)
67
+
68
+ if optimizer is not None:
69
+ optimizer_state = ckpt_dict['optimizer']
70
+ optimizer.load_state_dict(optimizer_state)
71
+
72
+
73
+ def predict(location, model_name, img):
74
+ ckpt_dir = f"{CWD}/ckpt/{CKPT_FILE_NAMES[location][model_name]}"
75
+ if location == 'nyu':
76
+ max_depth = 10
77
+ else:
78
+ max_depth = 80
79
+ model = MODEL_CLASSES[location][model_name](max_depth).to(DEVICE)
80
+ load_model(ckpt_dir,model)
81
+ # print(img.shape)
82
+ # assert False
83
+ if img.shape == (375,1242,3):
84
+ img = cropping(img)
85
+ img = torch.tensor(img).permute(2, 0, 1).float().to(DEVICE)
86
+ input_RGB = img.unsqueeze(0)
87
+ print(input_RGB.shape)
88
+ with torch.no_grad():
89
+ pred = model(input_RGB)
90
+ pred_d = pred['pred_d']
91
+ pred_d_numpy = pred_d.squeeze().cpu().numpy()
92
+ # pred_d_numpy = (pred_d_numpy - pred_d_numpy.mean())/pred_d_numpy.std()
93
+ pred_d_numpy = np.clip((pred_d_numpy / pred_d_numpy[15:,:].max()) * 255, 0,255)
94
+ # pred_d_numpy = (pred_d_numpy / pred_d_numpy.max()) * 255
95
+ pred_d_numpy = pred_d_numpy.astype(np.uint8)
96
+ pred_d_color = cv2.applyColorMap(pred_d_numpy, cv2.COLORMAP_RAINBOW)
97
+ pred_d_color = cv2.cvtColor(pred_d_color, cv2.COLOR_BGR2RGB)
98
+ # del model
99
+ return pred_d_color
100
+
101
+ with gr.Blocks() as demo:
102
+ gr.Markdown("# Monocular Depth Estimation")
103
+ with gr.Row():
104
+ location = gr.Radio(choices=['Indoor', 'Outdoor'],value='Indoor', label = "Select Location Type")
105
+ model_name = gr.Radio(['Unet', 'Resnet_enc', 'Densenet_enc'],value="Densenet_enc" ,label="Select model")
106
+ with gr.Row():
107
+ with gr.Column():
108
+ input_image = gr.Image(label = "Input Image for Depth Estimation")
109
+ with gr.Column():
110
+ output_depth_map = gr.Image(label = "Depth prediction Heatmap")
111
+ with gr.Row():
112
+ predict_btn = gr.Button("Generate Depthmap")
113
+ predict_btn.click(fn=predict, inputs=[location, model_name, input_image], outputs=output_depth_map)
114
+ with gr.Row():
115
+ gr.Examples(['./demo_data/Bathroom.jpg', './demo_data/Bedroom.jpg', './demo_data/Bookstore.jpg', './demo_data/Classroom.jpg', './demo_data/Computerlab.jpg', './demo_data/kitti_1.png'], inputs=input_image)
116
+ demo.launch()
ckpt/densenet_epoch_15_model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e50a3c6bb7a24e3ece8f323dc759e0822145e81e2770dad0d00e12ac306c37c
3
+ size 1748720589
ckpt/densenet_nyu_then_kitti_epoch_10_model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3bd5ad153cd4363c061d6ca666899b1fe8a2c425fc26423081907cc144d204f5
3
+ size 1748720589
ckpt/nyudepthv2_swin_base.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a1c748dc3e0add9ee18b43dcfa1f2c8d5734d3e523ab7872398f203d5d36b605
3
+ size 493044547
ckpt/resnet18_unet_epoch_08_model_kitti_and_nyu.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae171c2ab1a22570395a284eca5ecd392b653ab4bbf0f9b5edd6a9dbdbd8d2fc
3
+ size 215834813
ckpt/resnet50_unet_epoch_02_model_nyuandkitti.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a26301214906364877b1b71bc1dca3e40b4013948a36b8ea1eb8e99cb56ce49
3
+ size 1774319297
ckpt/resnet_encdecmodel_epoch_05_model_nyu_and_kitti.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa4f12fc178424a4205f241fa9081a8769fe10c36bcc7839dda53bacaa3676d1
3
+ size 174548419
ckpt/resnet_nyu_best.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7fa8bbc121457976cb1c4d2b1396f517befe46a1e578afe53fa9c6ce920ffe48
3
+ size 210256970
demo_data/Bathroom.jpg ADDED
demo_data/Bedroom.jpg ADDED
demo_data/Bookstore.jpg ADDED
demo_data/Classroom.jpg ADDED
demo_data/Computerlab.jpg ADDED
demo_data/kitti_1.png ADDED
demo_data/kitti_2.png ADDED
demo_data/kitti_3.png ADDED
models/densenet_v2.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision
4
+ import torch.nn.functional as F
5
+ from torchinfo import summary
6
+ from math import sqrt
7
+ # torch.autograd.set_detect_anomaly(True)
8
+
9
+ class attention_gate(nn.Module):
10
+ def __init__(self, in_c, out_c):
11
+ super().__init__()
12
+
13
+ self.Wg = nn.Sequential(
14
+ nn.Conv2d(in_c[0], out_c, kernel_size=1, padding=0),
15
+ nn.BatchNorm2d(out_c)
16
+ )
17
+ self.Ws = nn.Sequential(
18
+ nn.Conv2d(in_c[1], out_c, kernel_size=1, padding=0),
19
+ nn.BatchNorm2d(out_c)
20
+ )
21
+ self.relu = nn.ReLU(inplace=True)
22
+ self.output = nn.Sequential(
23
+ nn.Conv2d(out_c, out_c, kernel_size=1, padding=0),
24
+ nn.Sigmoid()
25
+ )
26
+
27
+ def forward(self, g, s):
28
+ Wg = self.Wg(g)
29
+ Ws = self.Ws(s)
30
+ out = self.relu(Wg + Ws)
31
+ out = self.output(out)
32
+ return out
33
+
34
+ class Conv_Block(nn.Module):
35
+ def __init__(self, in_c, out_c, activation_fn=nn.LeakyReLU):
36
+ super().__init__()
37
+
38
+ self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
39
+ self.bn1 = nn.BatchNorm2d(out_c)
40
+
41
+ self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
42
+ self.bn2 = nn.BatchNorm2d(out_c)
43
+
44
+ self.activfn = activation_fn()
45
+
46
+ self.dropout = nn.Dropout(0.25)
47
+
48
+ def forward(self, inputs):
49
+
50
+ x = self.conv1(inputs)
51
+ x = self.bn1(x)
52
+ x = self.activfn(x)
53
+ # x = self.dropout(x)
54
+
55
+ x = self.conv2(x)
56
+ x = self.bn2(x)
57
+ x = self.activfn(x)
58
+ # x = self.dropout(x)
59
+
60
+ return x
61
+
62
+ class Encoder_Block(nn.Module):
63
+ def __init__(self, in_c, out_c):
64
+ super().__init__()
65
+
66
+ self.conv = Conv_Block(in_c, out_c)
67
+ self.pool = nn.MaxPool2d((2, 2))
68
+
69
+ def forward(self, inputs):
70
+ x = self.conv(inputs)
71
+ p = self.pool(x)
72
+
73
+ return x, p
74
+
75
+ class Enc_Dec_Model(nn.Module):
76
+ def __init__(self):
77
+ super(Enc_Dec_Model, self).__init__()
78
+ self.encoder1 = Encoder_Block(3, 64)
79
+ self.encoder2 = Encoder_Block(64, 128)
80
+ self.encoder3 = Encoder_Block(128, 256)
81
+ """ Bottleneck """
82
+ self.bottleneck = Conv_Block(256, 512)
83
+
84
+ """ Decoder """
85
+ self.d1 = Decoder_Block([512, 256], 256)
86
+ self.d2 = Decoder_Block([256, 128], 128)
87
+ self.d3 = Decoder_Block([128, 64], 64)
88
+
89
+ """ Classifier """
90
+ self.outputs = nn.Conv2d(64, 1, kernel_size=1, padding=0)
91
+
92
+ def forward(self, x):
93
+
94
+ """ Encoder """
95
+ s1, p1 = self.encoder1(x)
96
+ s2, p2 = self.encoder2(p1)
97
+ s3, p3 = self.encoder3(p2)
98
+
99
+ """ Bottleneck """
100
+ b = self.bottleneck(p3)
101
+
102
+ """ Decoder """
103
+ d1 = self.d1(b, s3)
104
+ d2 = self.d2(d1, s2)
105
+ d3 = self.d3(d2, s1)
106
+
107
+ """ Classifier """
108
+ outputs = self.outputs(d3)
109
+ out_depth = torch.sigmoid(outputs)
110
+ return out_depth
111
+
112
+ class Decoder(nn.Module):
113
+ def __init__(self):
114
+ super(Decoder, self).__init__()
115
+
116
+ """ Decoder """
117
+ self.d1 = Decoder_Block(1920, 2048)
118
+ self.d2 = Decoder_Block(2048, 1024)
119
+ self.d3 = Decoder_Block(1024, 512)
120
+ self.d4 = Decoder_Block(512, 256)
121
+ self.d5 = Decoder_Block(256, 128)
122
+ # self.d6 = Decoder_Block(128, 64)
123
+
124
+ """ Classifier """
125
+ self.outputs = nn.Conv2d(128, 1, kernel_size=1, padding=0)
126
+
127
+ def forward(self, x):
128
+ """ Decoder """
129
+ # b = self.MHA2(b)
130
+ x = self.d1(x)
131
+ x = self.d2(x)
132
+ x = self.d3(x)
133
+ x = self.d4(x)
134
+ x = self.d5(x)
135
+ # x = self.d6(x)
136
+
137
+ """ Classifier """
138
+ outputs = self.outputs(x)
139
+ out_depth = torch.sigmoid(outputs)
140
+ return out_depth
141
+
142
+ class Decoder_Block(nn.Module):
143
+ def __init__(self, in_c, out_c, activation_fn=nn.LeakyReLU):
144
+ super().__init__()
145
+
146
+ self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0)
147
+ self.conv = Conv_Block(out_c, out_c, activation_fn)
148
+
149
+ def forward(self, inputs):
150
+ x = self.up(inputs)
151
+ x = self.conv(x)
152
+
153
+ return x
154
+
155
+
156
+ class Densenet(nn.Module):
157
+ def __init__(self, max_depth) -> None:
158
+ super().__init__()
159
+ self.densenet = torchvision.models.densenet201(weights=torchvision.models.DenseNet201_Weights.DEFAULT)
160
+ for param in self.densenet.features.parameters():
161
+ param.requires_grad = False
162
+
163
+ self.densenet = torch.nn.Sequential(*(list(self.densenet.children())[:-1]))
164
+ self.decoder = Decoder()
165
+ # self.enc_dec_model = Enc_Dec_Model()
166
+ self.max_depth = max_depth
167
+
168
+ def forward(self, x):
169
+ x = self.densenet(x)
170
+ x = self.decoder(x)
171
+ # x = self.enc_dec_model(x)
172
+ x = x*self.max_depth
173
+ # print(x.shape)
174
+ return {'pred_d':x}
175
+
176
+ if __name__ == "__main__":
177
+ model = Densenet(max_depth=10).cuda()
178
+ print(model)
179
+ summary(model, input_size=(64,3,448,448))
models/pretrained_decv2.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision
4
+ from torchinfo import summary
5
+
6
+ class conv_block(nn.Module):
7
+ def __init__(self, in_c, out_c, act):
8
+ super().__init__()
9
+
10
+ self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
11
+ self.bn1 = nn.BatchNorm2d(out_c)
12
+
13
+ self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
14
+ self.bn2 = nn.BatchNorm2d(out_c)
15
+
16
+ if act == 'relu':
17
+ self.activation = nn.ReLU()
18
+ elif act == 'sigmoid':
19
+ self.activation = nn.Sigmoid()
20
+ else:
21
+ self.activation = nn.Identity()
22
+
23
+ # self.relu = nn.ReLU()
24
+
25
+ def forward(self, inputs):
26
+ x = self.conv1(inputs)
27
+ x = self.bn1(x)
28
+ x = self.activation(x)
29
+
30
+ x = self.conv2(x)
31
+ x = self.bn2(x)
32
+ x = self.activation(x)
33
+
34
+ return x
35
+ class Decoder_block(nn.Module):
36
+ def __init__(self, in_channel, out_channel, kernel, stride, padding=1, out_padding=1, act = 'relu') -> None:
37
+ super().__init__()
38
+ self.upsample = nn.ConvTranspose2d(in_channels=in_channel,\
39
+ out_channels=out_channel,\
40
+ kernel_size=kernel,\
41
+ stride=stride,\
42
+ padding=padding,
43
+ output_padding=out_padding)
44
+ if act == 'relu':
45
+ self.activation = nn.ReLU()
46
+ elif act == 'sigmoid':
47
+ self.activation = nn.Sigmoid()
48
+ else:
49
+ self.activation = nn.Identity()
50
+
51
+ def forward(self, x):
52
+ return self.activation(self.upsample(x))
53
+
54
+ class Decoder(nn.Module):
55
+ def __init__(self, num_layers, channels, kernels, strides, activations) -> None:
56
+ super().__init__()
57
+ assert len(channels) -1 == len(kernels) and len(strides) == len(kernels) and num_layers == len(strides)
58
+ assert num_layers == len(activations)
59
+ self.layers = []
60
+ for i in range(num_layers):
61
+ self.layers.append(Decoder_block(in_channel=channels[i],\
62
+ out_channel=channels[i+1],\
63
+ kernel=kernels[i],\
64
+ stride=strides[i],\
65
+ act=activations[i]))
66
+ self.layers.append(conv_block(in_c=channels[i+1],out_c=channels[i+1], act= activations[i]))
67
+ self.model = nn.Sequential(*self.layers)
68
+ def forward(self, x):
69
+ return self.model(x)
70
+
71
+ class enc_dec_model(nn.Module):
72
+ def __init__(self, max_depth=10, backbone='resnet', unfreeze = False) -> None:
73
+ super().__init__()
74
+ if backbone == 'resnet':
75
+ self.encoder = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
76
+ num_layers=5
77
+ channels=[2048,256,128,64,32,1]
78
+ kernels=[3,3,3,3,3]
79
+ strides = [2,2,2,2,2]
80
+ activations=['relu', 'relu', 'relu' ,'relu', 'sigmoid']
81
+ if unfreeze:
82
+ for param in self.encoder.parameters():
83
+ param.requires_grad = True
84
+ else:
85
+ for param in self.encoder.parameters():
86
+ param.requires_grad = False
87
+ for i, child in enumerate(self.encoder.children()):
88
+ if i == 7:
89
+ for j, child2 in enumerate(child.children()):
90
+ if j == 2:
91
+ # print("count:", j)
92
+ # print(child2)
93
+ for param in child2.parameters():
94
+ param.requires_grad = True
95
+ if i>=8:
96
+ # print("count:", i)
97
+ # print(child)
98
+ for param in child.parameters():
99
+ param.requires_grad = True
100
+ # input(":")
101
+ self.encoder = torch.nn.Sequential(*(list(self.encoder.children())[:-2]))
102
+ # self.bridge = nn.Conv2d(2048, 2048, 1, 1)
103
+
104
+ self.decoder = Decoder(num_layers=num_layers,\
105
+ channels=channels,\
106
+ kernels=kernels,\
107
+ strides = strides,\
108
+ activations=activations)
109
+ self.max_depth = max_depth
110
+ def forward(self, x):
111
+ x = self.encoder(x)
112
+ # x = self.bridge(x)
113
+ # print(x)
114
+ x = self.decoder(x)
115
+ # print(x)
116
+ x = x*self.max_depth
117
+ return {'pred_d':x}
118
+
119
+ if __name__ == "__main__":
120
+ # model = Decoder(num_layers=5,\
121
+ # channels=[2048,256,128,64,32,1],\
122
+ # kernels=[3,3,3,3,3],\
123
+ # strides = [2,2,2,2,2])
124
+ model = enc_dec_model(unfreeze=True).cuda()
125
+ print(model)
126
+ summary(model, input_size=(64,3,448,448))
models/unet_resnet18.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from torchinfo import summary
3
+ import torchvision.models
4
+ import torch
5
+
6
+
7
+ def convrelu(in_channels, out_channels, kernel, padding):
8
+ return nn.Sequential(
9
+ nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
10
+ nn.ReLU(inplace=True),
11
+ )
12
+
13
+
14
+ class ResNet18UNet(nn.Module):
15
+ def __init__(self, max_depth, n_class=1):
16
+ super().__init__()
17
+
18
+ self.base_model = torchvision.models.resnet18(pretrained=True)
19
+ self.base_layers = list(self.base_model.children())
20
+
21
+ self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2)
22
+ self.layer0_1x1 = convrelu(64, 64, 1, 0)
23
+ self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4)
24
+ self.layer1_1x1 = convrelu(64, 64, 1, 0)
25
+ self.layer2 = self.base_layers[5] # size=(N, 128, x.H/8, x.W/8)
26
+ self.layer2_1x1 = convrelu(128, 128, 1, 0)
27
+ self.layer3 = self.base_layers[6] # size=(N, 256, x.H/16, x.W/16)
28
+ self.layer3_1x1 = convrelu(256, 256, 1, 0)
29
+ self.layer4 = self.base_layers[7] # size=(N, 512, x.H/32, x.W/32)
30
+ self.layer4_1x1 = convrelu(512, 512, 1, 0)
31
+
32
+ self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
33
+
34
+ self.conv_up3 = convrelu(256 + 512, 512, 3, 1)
35
+ self.conv_up2 = convrelu(128 + 512, 256, 3, 1)
36
+ self.conv_up1 = convrelu(64 + 256, 256, 3, 1)
37
+ self.conv_up0 = convrelu(64 + 256, 128, 3, 1)
38
+
39
+ self.conv_original_size0 = convrelu(3, 64, 3, 1)
40
+ self.conv_original_size1 = convrelu(64, 64, 3, 1)
41
+ self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1)
42
+
43
+ self.conv_last = nn.Conv2d(64, n_class, 1)
44
+
45
+ self.max_depth = max_depth
46
+
47
+ def forward(self, input):
48
+ x_original = self.conv_original_size0(input)
49
+ x_original = self.conv_original_size1(x_original)
50
+
51
+ layer0 = self.layer0(input)
52
+ layer1 = self.layer1(layer0)
53
+ layer2 = self.layer2(layer1)
54
+ layer3 = self.layer3(layer2)
55
+ layer4 = self.layer4(layer3)
56
+
57
+ layer4 = self.layer4_1x1(layer4)
58
+ x = self.upsample(layer4)
59
+ layer3 = self.layer3_1x1(layer3)
60
+ x = torch.cat([x, layer3], dim=1)
61
+ x = self.conv_up3(x)
62
+
63
+ x = self.upsample(x)
64
+ layer2 = self.layer2_1x1(layer2)
65
+ print(x.shape)
66
+ print(layer2.shape)
67
+ x = torch.cat([x, layer2], dim=1)
68
+ x = self.conv_up2(x)
69
+
70
+ x = self.upsample(x)
71
+ layer1 = self.layer1_1x1(layer1)
72
+ x = torch.cat([x, layer1], dim=1)
73
+ x = self.conv_up1(x)
74
+
75
+ x = self.upsample(x)
76
+ layer0 = self.layer0_1x1(layer0)
77
+ x = torch.cat([x, layer0], dim=1)
78
+ x = self.conv_up0(x)
79
+
80
+ x = self.upsample(x)
81
+ x = torch.cat([x, x_original], dim=1)
82
+ x = self.conv_original_size2(x)
83
+
84
+ out = self.conv_last(x)
85
+
86
+ out_depth = torch.sigmoid(out) * self.max_depth
87
+
88
+ return {'pred_d': out_depth}
89
+
90
+ if __name__ == "__main__":
91
+ model = ResNet18UNet(max_depth=10).cuda()
92
+ # print(model)
93
+ summary(model, input_size=(1,3,256,256))
94
+
models/unet_resnet50.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchinfo import summary
4
+ import torchvision
5
+ resnet = torchvision.models.resnet.resnet50(pretrained=True)
6
+
7
+
8
+ class ConvBlock(nn.Module):
9
+ """
10
+ Helper module that consists of a Conv -> BN -> ReLU
11
+ """
12
+
13
+ def __init__(self, in_channels, out_channels, padding=1, kernel_size=3, stride=1, with_nonlinearity=True):
14
+ super().__init__()
15
+ self.conv = nn.Conv2d(in_channels, out_channels, padding=padding, kernel_size=kernel_size, stride=stride)
16
+ self.bn = nn.BatchNorm2d(out_channels)
17
+ self.relu = nn.ReLU()
18
+ self.with_nonlinearity = with_nonlinearity
19
+
20
+ def forward(self, x):
21
+ x = self.conv(x)
22
+ x = self.bn(x)
23
+ if self.with_nonlinearity:
24
+ x = self.relu(x)
25
+ return x
26
+
27
+
28
+ class Bridge(nn.Module):
29
+ """
30
+ This is the middle layer of the UNet which just consists of some
31
+ """
32
+
33
+ def __init__(self, in_channels, out_channels):
34
+ super().__init__()
35
+ self.bridge = nn.Sequential(
36
+ ConvBlock(in_channels, out_channels),
37
+ ConvBlock(out_channels, out_channels)
38
+ )
39
+
40
+ def forward(self, x):
41
+ return self.bridge(x)
42
+
43
+
44
+ class UpBlockForUNetWithResNet50(nn.Module):
45
+ """
46
+ Up block that encapsulates one up-sampling step which consists of Upsample -> ConvBlock -> ConvBlock
47
+ """
48
+
49
+ def __init__(self, in_channels, out_channels, up_conv_in_channels=None, up_conv_out_channels=None,
50
+ upsampling_method="conv_transpose"):
51
+ super().__init__()
52
+
53
+ if up_conv_in_channels == None:
54
+ up_conv_in_channels = in_channels
55
+ if up_conv_out_channels == None:
56
+ up_conv_out_channels = out_channels
57
+
58
+ if upsampling_method == "conv_transpose":
59
+ self.upsample = nn.ConvTranspose2d(up_conv_in_channels, up_conv_out_channels, kernel_size=2, stride=2)
60
+ elif upsampling_method == "bilinear":
61
+ self.upsample = nn.Sequential(
62
+ nn.Upsample(mode='bilinear', scale_factor=2),
63
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)
64
+ )
65
+ self.conv_block_1 = ConvBlock(in_channels, out_channels)
66
+ self.conv_block_2 = ConvBlock(out_channels, out_channels)
67
+
68
+ def forward(self, up_x, down_x):
69
+ """
70
+
71
+ :param up_x: this is the output from the previous up block
72
+ :param down_x: this is the output from the down block
73
+ :return: upsampled feature map
74
+ """
75
+ x = self.upsample(up_x)
76
+ print(x.shape)
77
+ print(down_x.shape)
78
+ x = torch.cat([x, down_x], 1)
79
+ x = self.conv_block_1(x)
80
+ x = self.conv_block_2(x)
81
+ return x
82
+
83
+
84
+ class UNetWithResnet50Encoder(nn.Module):
85
+ DEPTH = 6
86
+
87
+ def __init__(self, max_depth, n_classes=1):
88
+ super().__init__()
89
+ resnet = torchvision.models.resnet.resnet50(pretrained=True)
90
+ down_blocks = []
91
+ up_blocks = []
92
+ self.input_block = nn.Sequential(*list(resnet.children()))[:3]
93
+ self.input_pool = list(resnet.children())[3]
94
+ for bottleneck in list(resnet.children()):
95
+ if isinstance(bottleneck, nn.Sequential):
96
+ down_blocks.append(bottleneck)
97
+ self.down_blocks = nn.ModuleList(down_blocks)
98
+ self.bridge = Bridge(2048, 2048)
99
+ up_blocks.append(UpBlockForUNetWithResNet50(2048, 1024))
100
+ up_blocks.append(UpBlockForUNetWithResNet50(1024, 512))
101
+ up_blocks.append(UpBlockForUNetWithResNet50(512, 256))
102
+ up_blocks.append(UpBlockForUNetWithResNet50(in_channels=128 + 64, out_channels=128,
103
+ up_conv_in_channels=256, up_conv_out_channels=128))
104
+ up_blocks.append(UpBlockForUNetWithResNet50(in_channels=64 + 3, out_channels=64,
105
+ up_conv_in_channels=128, up_conv_out_channels=64))
106
+
107
+ self.up_blocks = nn.ModuleList(up_blocks)
108
+
109
+ self.out = nn.Conv2d(64, n_classes, kernel_size=1, stride=1)
110
+
111
+ self.max_depth = max_depth
112
+
113
+ def forward(self, x, with_output_feature_map=False):
114
+ pre_pools = dict()
115
+ pre_pools[f"layer_0"] = x
116
+ x = self.input_block(x)
117
+ pre_pools[f"layer_1"] = x
118
+ x = self.input_pool(x)
119
+
120
+ for i, block in enumerate(self.down_blocks, 2):
121
+ x = block(x)
122
+ if i == (UNetWithResnet50Encoder.DEPTH - 1):
123
+ continue
124
+ pre_pools[f"layer_{i}"] = x
125
+
126
+ x = self.bridge(x)
127
+
128
+ for i, block in enumerate(self.up_blocks, 1):
129
+ key = f"layer_{UNetWithResnet50Encoder.DEPTH - 1 - i}"
130
+ x = block(x, pre_pools[key])
131
+ output_feature_map = x
132
+ x = self.out(x)
133
+ del pre_pools
134
+ # if with_output_feature_map:
135
+ # return x, output_feature_map
136
+ # else:
137
+ # return x
138
+
139
+ out_depth = torch.sigmoid(x) * self.max_depth
140
+
141
+ return {'pred_d': out_depth}
142
+
143
+ # model = UNetWithResnet50Encoder().cuda()
144
+ # inp = torch.rand((2, 3, 512, 512)).cuda()
145
+ # out = model(inp)
146
+
147
+ if __name__ == "__main__":
148
+ model = UNetWithResnet50Encoder(max_depth=10).cuda()
149
+ # print(model)
150
+ summary(model, input_size=(1,3,256,256))