Spaces:
Running
Running
Adding Application, models and ckpt files
Browse files- app.py +116 -0
- ckpt/densenet_epoch_15_model.ckpt +3 -0
- ckpt/densenet_nyu_then_kitti_epoch_10_model.ckpt +3 -0
- ckpt/nyudepthv2_swin_base.ckpt +3 -0
- ckpt/resnet18_unet_epoch_08_model_kitti_and_nyu.ckpt +3 -0
- ckpt/resnet50_unet_epoch_02_model_nyuandkitti.ckpt +3 -0
- ckpt/resnet_encdecmodel_epoch_05_model_nyu_and_kitti.ckpt +3 -0
- ckpt/resnet_nyu_best.ckpt +3 -0
- demo_data/Bathroom.jpg +0 -0
- demo_data/Bedroom.jpg +0 -0
- demo_data/Bookstore.jpg +0 -0
- demo_data/Classroom.jpg +0 -0
- demo_data/Computerlab.jpg +0 -0
- demo_data/kitti_1.png +0 -0
- demo_data/kitti_2.png +0 -0
- demo_data/kitti_3.png +0 -0
- models/densenet_v2.py +179 -0
- models/pretrained_decv2.py +126 -0
- models/unet_resnet18.py +94 -0
- models/unet_resnet50.py +150 -0
app.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from models.pretrained_decv2 import enc_dec_model
|
4 |
+
from models.densenet_v2 import Densenet
|
5 |
+
from models.unet_resnet18 import ResNet18UNet
|
6 |
+
from models.unet_resnet50 import UNetWithResnet50Encoder
|
7 |
+
import numpy as np
|
8 |
+
import cv2
|
9 |
+
|
10 |
+
# kb cropping
|
11 |
+
def cropping(img):
|
12 |
+
h_im, w_im = img.shape[:2]
|
13 |
+
|
14 |
+
margin_top = int(h_im - 352)
|
15 |
+
margin_left = int((w_im - 1216) / 2)
|
16 |
+
|
17 |
+
img = img[margin_top: margin_top + 352,
|
18 |
+
margin_left: margin_left + 1216]
|
19 |
+
|
20 |
+
return img
|
21 |
+
|
22 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
23 |
+
print(DEVICE)
|
24 |
+
CWD = "."
|
25 |
+
CKPT_FILE_NAMES = {
|
26 |
+
'Indoor':{
|
27 |
+
'Resnet_enc':'resnet_nyu_best.ckpt',
|
28 |
+
'Unet':'resnet18_unet_epoch_08_model_kitti_and_nyu.ckpt',
|
29 |
+
'Densenet_enc':'densenet_epoch_15_model.ckpt'
|
30 |
+
},
|
31 |
+
'Outdoor':{
|
32 |
+
'Resnet_enc':'resnet_encdecmodel_epoch_05_model_nyu_and_kitti.ckpt',
|
33 |
+
'Unet':'resnet50_unet_epoch_02_model_nyuandkitti.ckpt',
|
34 |
+
'Densenet_enc':'densenet_nyu_then_kitti_epoch_10_model.ckpt'
|
35 |
+
}
|
36 |
+
}
|
37 |
+
MODEL_CLASSES = {
|
38 |
+
'Indoor': {
|
39 |
+
'Resnet_enc':enc_dec_model,
|
40 |
+
'Unet':ResNet18UNet,
|
41 |
+
'Densenet_enc':Densenet
|
42 |
+
},
|
43 |
+
|
44 |
+
'Outdoor': {
|
45 |
+
'Resnet_enc':enc_dec_model,
|
46 |
+
'Unet':UNetWithResnet50Encoder,
|
47 |
+
'Densenet_enc':Densenet
|
48 |
+
},
|
49 |
+
|
50 |
+
}
|
51 |
+
|
52 |
+
def load_model(ckpt, model, optimizer=None):
|
53 |
+
ckpt_dict = torch.load(ckpt, map_location='cpu')
|
54 |
+
# keep backward compatibility
|
55 |
+
if 'model' not in ckpt_dict and 'optimizer' not in ckpt_dict:
|
56 |
+
state_dict = ckpt_dict
|
57 |
+
else:
|
58 |
+
state_dict = ckpt_dict['model']
|
59 |
+
weights = {}
|
60 |
+
for key, value in state_dict.items():
|
61 |
+
if key.startswith('module.'):
|
62 |
+
weights[key[len('module.'):]] = value
|
63 |
+
else:
|
64 |
+
weights[key] = value
|
65 |
+
|
66 |
+
model.load_state_dict(weights)
|
67 |
+
|
68 |
+
if optimizer is not None:
|
69 |
+
optimizer_state = ckpt_dict['optimizer']
|
70 |
+
optimizer.load_state_dict(optimizer_state)
|
71 |
+
|
72 |
+
|
73 |
+
def predict(location, model_name, img):
|
74 |
+
ckpt_dir = f"{CWD}/ckpt/{CKPT_FILE_NAMES[location][model_name]}"
|
75 |
+
if location == 'nyu':
|
76 |
+
max_depth = 10
|
77 |
+
else:
|
78 |
+
max_depth = 80
|
79 |
+
model = MODEL_CLASSES[location][model_name](max_depth).to(DEVICE)
|
80 |
+
load_model(ckpt_dir,model)
|
81 |
+
# print(img.shape)
|
82 |
+
# assert False
|
83 |
+
if img.shape == (375,1242,3):
|
84 |
+
img = cropping(img)
|
85 |
+
img = torch.tensor(img).permute(2, 0, 1).float().to(DEVICE)
|
86 |
+
input_RGB = img.unsqueeze(0)
|
87 |
+
print(input_RGB.shape)
|
88 |
+
with torch.no_grad():
|
89 |
+
pred = model(input_RGB)
|
90 |
+
pred_d = pred['pred_d']
|
91 |
+
pred_d_numpy = pred_d.squeeze().cpu().numpy()
|
92 |
+
# pred_d_numpy = (pred_d_numpy - pred_d_numpy.mean())/pred_d_numpy.std()
|
93 |
+
pred_d_numpy = np.clip((pred_d_numpy / pred_d_numpy[15:,:].max()) * 255, 0,255)
|
94 |
+
# pred_d_numpy = (pred_d_numpy / pred_d_numpy.max()) * 255
|
95 |
+
pred_d_numpy = pred_d_numpy.astype(np.uint8)
|
96 |
+
pred_d_color = cv2.applyColorMap(pred_d_numpy, cv2.COLORMAP_RAINBOW)
|
97 |
+
pred_d_color = cv2.cvtColor(pred_d_color, cv2.COLOR_BGR2RGB)
|
98 |
+
# del model
|
99 |
+
return pred_d_color
|
100 |
+
|
101 |
+
with gr.Blocks() as demo:
|
102 |
+
gr.Markdown("# Monocular Depth Estimation")
|
103 |
+
with gr.Row():
|
104 |
+
location = gr.Radio(choices=['Indoor', 'Outdoor'],value='Indoor', label = "Select Location Type")
|
105 |
+
model_name = gr.Radio(['Unet', 'Resnet_enc', 'Densenet_enc'],value="Densenet_enc" ,label="Select model")
|
106 |
+
with gr.Row():
|
107 |
+
with gr.Column():
|
108 |
+
input_image = gr.Image(label = "Input Image for Depth Estimation")
|
109 |
+
with gr.Column():
|
110 |
+
output_depth_map = gr.Image(label = "Depth prediction Heatmap")
|
111 |
+
with gr.Row():
|
112 |
+
predict_btn = gr.Button("Generate Depthmap")
|
113 |
+
predict_btn.click(fn=predict, inputs=[location, model_name, input_image], outputs=output_depth_map)
|
114 |
+
with gr.Row():
|
115 |
+
gr.Examples(['./demo_data/Bathroom.jpg', './demo_data/Bedroom.jpg', './demo_data/Bookstore.jpg', './demo_data/Classroom.jpg', './demo_data/Computerlab.jpg', './demo_data/kitti_1.png'], inputs=input_image)
|
116 |
+
demo.launch()
|
ckpt/densenet_epoch_15_model.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1e50a3c6bb7a24e3ece8f323dc759e0822145e81e2770dad0d00e12ac306c37c
|
3 |
+
size 1748720589
|
ckpt/densenet_nyu_then_kitti_epoch_10_model.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3bd5ad153cd4363c061d6ca666899b1fe8a2c425fc26423081907cc144d204f5
|
3 |
+
size 1748720589
|
ckpt/nyudepthv2_swin_base.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a1c748dc3e0add9ee18b43dcfa1f2c8d5734d3e523ab7872398f203d5d36b605
|
3 |
+
size 493044547
|
ckpt/resnet18_unet_epoch_08_model_kitti_and_nyu.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ae171c2ab1a22570395a284eca5ecd392b653ab4bbf0f9b5edd6a9dbdbd8d2fc
|
3 |
+
size 215834813
|
ckpt/resnet50_unet_epoch_02_model_nyuandkitti.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9a26301214906364877b1b71bc1dca3e40b4013948a36b8ea1eb8e99cb56ce49
|
3 |
+
size 1774319297
|
ckpt/resnet_encdecmodel_epoch_05_model_nyu_and_kitti.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fa4f12fc178424a4205f241fa9081a8769fe10c36bcc7839dda53bacaa3676d1
|
3 |
+
size 174548419
|
ckpt/resnet_nyu_best.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7fa8bbc121457976cb1c4d2b1396f517befe46a1e578afe53fa9c6ce920ffe48
|
3 |
+
size 210256970
|
demo_data/Bathroom.jpg
ADDED
demo_data/Bedroom.jpg
ADDED
demo_data/Bookstore.jpg
ADDED
demo_data/Classroom.jpg
ADDED
demo_data/Computerlab.jpg
ADDED
demo_data/kitti_1.png
ADDED
demo_data/kitti_2.png
ADDED
demo_data/kitti_3.png
ADDED
models/densenet_v2.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torchvision
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torchinfo import summary
|
6 |
+
from math import sqrt
|
7 |
+
# torch.autograd.set_detect_anomaly(True)
|
8 |
+
|
9 |
+
class attention_gate(nn.Module):
|
10 |
+
def __init__(self, in_c, out_c):
|
11 |
+
super().__init__()
|
12 |
+
|
13 |
+
self.Wg = nn.Sequential(
|
14 |
+
nn.Conv2d(in_c[0], out_c, kernel_size=1, padding=0),
|
15 |
+
nn.BatchNorm2d(out_c)
|
16 |
+
)
|
17 |
+
self.Ws = nn.Sequential(
|
18 |
+
nn.Conv2d(in_c[1], out_c, kernel_size=1, padding=0),
|
19 |
+
nn.BatchNorm2d(out_c)
|
20 |
+
)
|
21 |
+
self.relu = nn.ReLU(inplace=True)
|
22 |
+
self.output = nn.Sequential(
|
23 |
+
nn.Conv2d(out_c, out_c, kernel_size=1, padding=0),
|
24 |
+
nn.Sigmoid()
|
25 |
+
)
|
26 |
+
|
27 |
+
def forward(self, g, s):
|
28 |
+
Wg = self.Wg(g)
|
29 |
+
Ws = self.Ws(s)
|
30 |
+
out = self.relu(Wg + Ws)
|
31 |
+
out = self.output(out)
|
32 |
+
return out
|
33 |
+
|
34 |
+
class Conv_Block(nn.Module):
|
35 |
+
def __init__(self, in_c, out_c, activation_fn=nn.LeakyReLU):
|
36 |
+
super().__init__()
|
37 |
+
|
38 |
+
self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
|
39 |
+
self.bn1 = nn.BatchNorm2d(out_c)
|
40 |
+
|
41 |
+
self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
|
42 |
+
self.bn2 = nn.BatchNorm2d(out_c)
|
43 |
+
|
44 |
+
self.activfn = activation_fn()
|
45 |
+
|
46 |
+
self.dropout = nn.Dropout(0.25)
|
47 |
+
|
48 |
+
def forward(self, inputs):
|
49 |
+
|
50 |
+
x = self.conv1(inputs)
|
51 |
+
x = self.bn1(x)
|
52 |
+
x = self.activfn(x)
|
53 |
+
# x = self.dropout(x)
|
54 |
+
|
55 |
+
x = self.conv2(x)
|
56 |
+
x = self.bn2(x)
|
57 |
+
x = self.activfn(x)
|
58 |
+
# x = self.dropout(x)
|
59 |
+
|
60 |
+
return x
|
61 |
+
|
62 |
+
class Encoder_Block(nn.Module):
|
63 |
+
def __init__(self, in_c, out_c):
|
64 |
+
super().__init__()
|
65 |
+
|
66 |
+
self.conv = Conv_Block(in_c, out_c)
|
67 |
+
self.pool = nn.MaxPool2d((2, 2))
|
68 |
+
|
69 |
+
def forward(self, inputs):
|
70 |
+
x = self.conv(inputs)
|
71 |
+
p = self.pool(x)
|
72 |
+
|
73 |
+
return x, p
|
74 |
+
|
75 |
+
class Enc_Dec_Model(nn.Module):
|
76 |
+
def __init__(self):
|
77 |
+
super(Enc_Dec_Model, self).__init__()
|
78 |
+
self.encoder1 = Encoder_Block(3, 64)
|
79 |
+
self.encoder2 = Encoder_Block(64, 128)
|
80 |
+
self.encoder3 = Encoder_Block(128, 256)
|
81 |
+
""" Bottleneck """
|
82 |
+
self.bottleneck = Conv_Block(256, 512)
|
83 |
+
|
84 |
+
""" Decoder """
|
85 |
+
self.d1 = Decoder_Block([512, 256], 256)
|
86 |
+
self.d2 = Decoder_Block([256, 128], 128)
|
87 |
+
self.d3 = Decoder_Block([128, 64], 64)
|
88 |
+
|
89 |
+
""" Classifier """
|
90 |
+
self.outputs = nn.Conv2d(64, 1, kernel_size=1, padding=0)
|
91 |
+
|
92 |
+
def forward(self, x):
|
93 |
+
|
94 |
+
""" Encoder """
|
95 |
+
s1, p1 = self.encoder1(x)
|
96 |
+
s2, p2 = self.encoder2(p1)
|
97 |
+
s3, p3 = self.encoder3(p2)
|
98 |
+
|
99 |
+
""" Bottleneck """
|
100 |
+
b = self.bottleneck(p3)
|
101 |
+
|
102 |
+
""" Decoder """
|
103 |
+
d1 = self.d1(b, s3)
|
104 |
+
d2 = self.d2(d1, s2)
|
105 |
+
d3 = self.d3(d2, s1)
|
106 |
+
|
107 |
+
""" Classifier """
|
108 |
+
outputs = self.outputs(d3)
|
109 |
+
out_depth = torch.sigmoid(outputs)
|
110 |
+
return out_depth
|
111 |
+
|
112 |
+
class Decoder(nn.Module):
|
113 |
+
def __init__(self):
|
114 |
+
super(Decoder, self).__init__()
|
115 |
+
|
116 |
+
""" Decoder """
|
117 |
+
self.d1 = Decoder_Block(1920, 2048)
|
118 |
+
self.d2 = Decoder_Block(2048, 1024)
|
119 |
+
self.d3 = Decoder_Block(1024, 512)
|
120 |
+
self.d4 = Decoder_Block(512, 256)
|
121 |
+
self.d5 = Decoder_Block(256, 128)
|
122 |
+
# self.d6 = Decoder_Block(128, 64)
|
123 |
+
|
124 |
+
""" Classifier """
|
125 |
+
self.outputs = nn.Conv2d(128, 1, kernel_size=1, padding=0)
|
126 |
+
|
127 |
+
def forward(self, x):
|
128 |
+
""" Decoder """
|
129 |
+
# b = self.MHA2(b)
|
130 |
+
x = self.d1(x)
|
131 |
+
x = self.d2(x)
|
132 |
+
x = self.d3(x)
|
133 |
+
x = self.d4(x)
|
134 |
+
x = self.d5(x)
|
135 |
+
# x = self.d6(x)
|
136 |
+
|
137 |
+
""" Classifier """
|
138 |
+
outputs = self.outputs(x)
|
139 |
+
out_depth = torch.sigmoid(outputs)
|
140 |
+
return out_depth
|
141 |
+
|
142 |
+
class Decoder_Block(nn.Module):
|
143 |
+
def __init__(self, in_c, out_c, activation_fn=nn.LeakyReLU):
|
144 |
+
super().__init__()
|
145 |
+
|
146 |
+
self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0)
|
147 |
+
self.conv = Conv_Block(out_c, out_c, activation_fn)
|
148 |
+
|
149 |
+
def forward(self, inputs):
|
150 |
+
x = self.up(inputs)
|
151 |
+
x = self.conv(x)
|
152 |
+
|
153 |
+
return x
|
154 |
+
|
155 |
+
|
156 |
+
class Densenet(nn.Module):
|
157 |
+
def __init__(self, max_depth) -> None:
|
158 |
+
super().__init__()
|
159 |
+
self.densenet = torchvision.models.densenet201(weights=torchvision.models.DenseNet201_Weights.DEFAULT)
|
160 |
+
for param in self.densenet.features.parameters():
|
161 |
+
param.requires_grad = False
|
162 |
+
|
163 |
+
self.densenet = torch.nn.Sequential(*(list(self.densenet.children())[:-1]))
|
164 |
+
self.decoder = Decoder()
|
165 |
+
# self.enc_dec_model = Enc_Dec_Model()
|
166 |
+
self.max_depth = max_depth
|
167 |
+
|
168 |
+
def forward(self, x):
|
169 |
+
x = self.densenet(x)
|
170 |
+
x = self.decoder(x)
|
171 |
+
# x = self.enc_dec_model(x)
|
172 |
+
x = x*self.max_depth
|
173 |
+
# print(x.shape)
|
174 |
+
return {'pred_d':x}
|
175 |
+
|
176 |
+
if __name__ == "__main__":
|
177 |
+
model = Densenet(max_depth=10).cuda()
|
178 |
+
print(model)
|
179 |
+
summary(model, input_size=(64,3,448,448))
|
models/pretrained_decv2.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torchvision
|
4 |
+
from torchinfo import summary
|
5 |
+
|
6 |
+
class conv_block(nn.Module):
|
7 |
+
def __init__(self, in_c, out_c, act):
|
8 |
+
super().__init__()
|
9 |
+
|
10 |
+
self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
|
11 |
+
self.bn1 = nn.BatchNorm2d(out_c)
|
12 |
+
|
13 |
+
self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
|
14 |
+
self.bn2 = nn.BatchNorm2d(out_c)
|
15 |
+
|
16 |
+
if act == 'relu':
|
17 |
+
self.activation = nn.ReLU()
|
18 |
+
elif act == 'sigmoid':
|
19 |
+
self.activation = nn.Sigmoid()
|
20 |
+
else:
|
21 |
+
self.activation = nn.Identity()
|
22 |
+
|
23 |
+
# self.relu = nn.ReLU()
|
24 |
+
|
25 |
+
def forward(self, inputs):
|
26 |
+
x = self.conv1(inputs)
|
27 |
+
x = self.bn1(x)
|
28 |
+
x = self.activation(x)
|
29 |
+
|
30 |
+
x = self.conv2(x)
|
31 |
+
x = self.bn2(x)
|
32 |
+
x = self.activation(x)
|
33 |
+
|
34 |
+
return x
|
35 |
+
class Decoder_block(nn.Module):
|
36 |
+
def __init__(self, in_channel, out_channel, kernel, stride, padding=1, out_padding=1, act = 'relu') -> None:
|
37 |
+
super().__init__()
|
38 |
+
self.upsample = nn.ConvTranspose2d(in_channels=in_channel,\
|
39 |
+
out_channels=out_channel,\
|
40 |
+
kernel_size=kernel,\
|
41 |
+
stride=stride,\
|
42 |
+
padding=padding,
|
43 |
+
output_padding=out_padding)
|
44 |
+
if act == 'relu':
|
45 |
+
self.activation = nn.ReLU()
|
46 |
+
elif act == 'sigmoid':
|
47 |
+
self.activation = nn.Sigmoid()
|
48 |
+
else:
|
49 |
+
self.activation = nn.Identity()
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
return self.activation(self.upsample(x))
|
53 |
+
|
54 |
+
class Decoder(nn.Module):
|
55 |
+
def __init__(self, num_layers, channels, kernels, strides, activations) -> None:
|
56 |
+
super().__init__()
|
57 |
+
assert len(channels) -1 == len(kernels) and len(strides) == len(kernels) and num_layers == len(strides)
|
58 |
+
assert num_layers == len(activations)
|
59 |
+
self.layers = []
|
60 |
+
for i in range(num_layers):
|
61 |
+
self.layers.append(Decoder_block(in_channel=channels[i],\
|
62 |
+
out_channel=channels[i+1],\
|
63 |
+
kernel=kernels[i],\
|
64 |
+
stride=strides[i],\
|
65 |
+
act=activations[i]))
|
66 |
+
self.layers.append(conv_block(in_c=channels[i+1],out_c=channels[i+1], act= activations[i]))
|
67 |
+
self.model = nn.Sequential(*self.layers)
|
68 |
+
def forward(self, x):
|
69 |
+
return self.model(x)
|
70 |
+
|
71 |
+
class enc_dec_model(nn.Module):
|
72 |
+
def __init__(self, max_depth=10, backbone='resnet', unfreeze = False) -> None:
|
73 |
+
super().__init__()
|
74 |
+
if backbone == 'resnet':
|
75 |
+
self.encoder = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
|
76 |
+
num_layers=5
|
77 |
+
channels=[2048,256,128,64,32,1]
|
78 |
+
kernels=[3,3,3,3,3]
|
79 |
+
strides = [2,2,2,2,2]
|
80 |
+
activations=['relu', 'relu', 'relu' ,'relu', 'sigmoid']
|
81 |
+
if unfreeze:
|
82 |
+
for param in self.encoder.parameters():
|
83 |
+
param.requires_grad = True
|
84 |
+
else:
|
85 |
+
for param in self.encoder.parameters():
|
86 |
+
param.requires_grad = False
|
87 |
+
for i, child in enumerate(self.encoder.children()):
|
88 |
+
if i == 7:
|
89 |
+
for j, child2 in enumerate(child.children()):
|
90 |
+
if j == 2:
|
91 |
+
# print("count:", j)
|
92 |
+
# print(child2)
|
93 |
+
for param in child2.parameters():
|
94 |
+
param.requires_grad = True
|
95 |
+
if i>=8:
|
96 |
+
# print("count:", i)
|
97 |
+
# print(child)
|
98 |
+
for param in child.parameters():
|
99 |
+
param.requires_grad = True
|
100 |
+
# input(":")
|
101 |
+
self.encoder = torch.nn.Sequential(*(list(self.encoder.children())[:-2]))
|
102 |
+
# self.bridge = nn.Conv2d(2048, 2048, 1, 1)
|
103 |
+
|
104 |
+
self.decoder = Decoder(num_layers=num_layers,\
|
105 |
+
channels=channels,\
|
106 |
+
kernels=kernels,\
|
107 |
+
strides = strides,\
|
108 |
+
activations=activations)
|
109 |
+
self.max_depth = max_depth
|
110 |
+
def forward(self, x):
|
111 |
+
x = self.encoder(x)
|
112 |
+
# x = self.bridge(x)
|
113 |
+
# print(x)
|
114 |
+
x = self.decoder(x)
|
115 |
+
# print(x)
|
116 |
+
x = x*self.max_depth
|
117 |
+
return {'pred_d':x}
|
118 |
+
|
119 |
+
if __name__ == "__main__":
|
120 |
+
# model = Decoder(num_layers=5,\
|
121 |
+
# channels=[2048,256,128,64,32,1],\
|
122 |
+
# kernels=[3,3,3,3,3],\
|
123 |
+
# strides = [2,2,2,2,2])
|
124 |
+
model = enc_dec_model(unfreeze=True).cuda()
|
125 |
+
print(model)
|
126 |
+
summary(model, input_size=(64,3,448,448))
|
models/unet_resnet18.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from torchinfo import summary
|
3 |
+
import torchvision.models
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def convrelu(in_channels, out_channels, kernel, padding):
|
8 |
+
return nn.Sequential(
|
9 |
+
nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
|
10 |
+
nn.ReLU(inplace=True),
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
class ResNet18UNet(nn.Module):
|
15 |
+
def __init__(self, max_depth, n_class=1):
|
16 |
+
super().__init__()
|
17 |
+
|
18 |
+
self.base_model = torchvision.models.resnet18(pretrained=True)
|
19 |
+
self.base_layers = list(self.base_model.children())
|
20 |
+
|
21 |
+
self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2)
|
22 |
+
self.layer0_1x1 = convrelu(64, 64, 1, 0)
|
23 |
+
self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4)
|
24 |
+
self.layer1_1x1 = convrelu(64, 64, 1, 0)
|
25 |
+
self.layer2 = self.base_layers[5] # size=(N, 128, x.H/8, x.W/8)
|
26 |
+
self.layer2_1x1 = convrelu(128, 128, 1, 0)
|
27 |
+
self.layer3 = self.base_layers[6] # size=(N, 256, x.H/16, x.W/16)
|
28 |
+
self.layer3_1x1 = convrelu(256, 256, 1, 0)
|
29 |
+
self.layer4 = self.base_layers[7] # size=(N, 512, x.H/32, x.W/32)
|
30 |
+
self.layer4_1x1 = convrelu(512, 512, 1, 0)
|
31 |
+
|
32 |
+
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
33 |
+
|
34 |
+
self.conv_up3 = convrelu(256 + 512, 512, 3, 1)
|
35 |
+
self.conv_up2 = convrelu(128 + 512, 256, 3, 1)
|
36 |
+
self.conv_up1 = convrelu(64 + 256, 256, 3, 1)
|
37 |
+
self.conv_up0 = convrelu(64 + 256, 128, 3, 1)
|
38 |
+
|
39 |
+
self.conv_original_size0 = convrelu(3, 64, 3, 1)
|
40 |
+
self.conv_original_size1 = convrelu(64, 64, 3, 1)
|
41 |
+
self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1)
|
42 |
+
|
43 |
+
self.conv_last = nn.Conv2d(64, n_class, 1)
|
44 |
+
|
45 |
+
self.max_depth = max_depth
|
46 |
+
|
47 |
+
def forward(self, input):
|
48 |
+
x_original = self.conv_original_size0(input)
|
49 |
+
x_original = self.conv_original_size1(x_original)
|
50 |
+
|
51 |
+
layer0 = self.layer0(input)
|
52 |
+
layer1 = self.layer1(layer0)
|
53 |
+
layer2 = self.layer2(layer1)
|
54 |
+
layer3 = self.layer3(layer2)
|
55 |
+
layer4 = self.layer4(layer3)
|
56 |
+
|
57 |
+
layer4 = self.layer4_1x1(layer4)
|
58 |
+
x = self.upsample(layer4)
|
59 |
+
layer3 = self.layer3_1x1(layer3)
|
60 |
+
x = torch.cat([x, layer3], dim=1)
|
61 |
+
x = self.conv_up3(x)
|
62 |
+
|
63 |
+
x = self.upsample(x)
|
64 |
+
layer2 = self.layer2_1x1(layer2)
|
65 |
+
print(x.shape)
|
66 |
+
print(layer2.shape)
|
67 |
+
x = torch.cat([x, layer2], dim=1)
|
68 |
+
x = self.conv_up2(x)
|
69 |
+
|
70 |
+
x = self.upsample(x)
|
71 |
+
layer1 = self.layer1_1x1(layer1)
|
72 |
+
x = torch.cat([x, layer1], dim=1)
|
73 |
+
x = self.conv_up1(x)
|
74 |
+
|
75 |
+
x = self.upsample(x)
|
76 |
+
layer0 = self.layer0_1x1(layer0)
|
77 |
+
x = torch.cat([x, layer0], dim=1)
|
78 |
+
x = self.conv_up0(x)
|
79 |
+
|
80 |
+
x = self.upsample(x)
|
81 |
+
x = torch.cat([x, x_original], dim=1)
|
82 |
+
x = self.conv_original_size2(x)
|
83 |
+
|
84 |
+
out = self.conv_last(x)
|
85 |
+
|
86 |
+
out_depth = torch.sigmoid(out) * self.max_depth
|
87 |
+
|
88 |
+
return {'pred_d': out_depth}
|
89 |
+
|
90 |
+
if __name__ == "__main__":
|
91 |
+
model = ResNet18UNet(max_depth=10).cuda()
|
92 |
+
# print(model)
|
93 |
+
summary(model, input_size=(1,3,256,256))
|
94 |
+
|
models/unet_resnet50.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torchinfo import summary
|
4 |
+
import torchvision
|
5 |
+
resnet = torchvision.models.resnet.resnet50(pretrained=True)
|
6 |
+
|
7 |
+
|
8 |
+
class ConvBlock(nn.Module):
|
9 |
+
"""
|
10 |
+
Helper module that consists of a Conv -> BN -> ReLU
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self, in_channels, out_channels, padding=1, kernel_size=3, stride=1, with_nonlinearity=True):
|
14 |
+
super().__init__()
|
15 |
+
self.conv = nn.Conv2d(in_channels, out_channels, padding=padding, kernel_size=kernel_size, stride=stride)
|
16 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
17 |
+
self.relu = nn.ReLU()
|
18 |
+
self.with_nonlinearity = with_nonlinearity
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
x = self.conv(x)
|
22 |
+
x = self.bn(x)
|
23 |
+
if self.with_nonlinearity:
|
24 |
+
x = self.relu(x)
|
25 |
+
return x
|
26 |
+
|
27 |
+
|
28 |
+
class Bridge(nn.Module):
|
29 |
+
"""
|
30 |
+
This is the middle layer of the UNet which just consists of some
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(self, in_channels, out_channels):
|
34 |
+
super().__init__()
|
35 |
+
self.bridge = nn.Sequential(
|
36 |
+
ConvBlock(in_channels, out_channels),
|
37 |
+
ConvBlock(out_channels, out_channels)
|
38 |
+
)
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
return self.bridge(x)
|
42 |
+
|
43 |
+
|
44 |
+
class UpBlockForUNetWithResNet50(nn.Module):
|
45 |
+
"""
|
46 |
+
Up block that encapsulates one up-sampling step which consists of Upsample -> ConvBlock -> ConvBlock
|
47 |
+
"""
|
48 |
+
|
49 |
+
def __init__(self, in_channels, out_channels, up_conv_in_channels=None, up_conv_out_channels=None,
|
50 |
+
upsampling_method="conv_transpose"):
|
51 |
+
super().__init__()
|
52 |
+
|
53 |
+
if up_conv_in_channels == None:
|
54 |
+
up_conv_in_channels = in_channels
|
55 |
+
if up_conv_out_channels == None:
|
56 |
+
up_conv_out_channels = out_channels
|
57 |
+
|
58 |
+
if upsampling_method == "conv_transpose":
|
59 |
+
self.upsample = nn.ConvTranspose2d(up_conv_in_channels, up_conv_out_channels, kernel_size=2, stride=2)
|
60 |
+
elif upsampling_method == "bilinear":
|
61 |
+
self.upsample = nn.Sequential(
|
62 |
+
nn.Upsample(mode='bilinear', scale_factor=2),
|
63 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)
|
64 |
+
)
|
65 |
+
self.conv_block_1 = ConvBlock(in_channels, out_channels)
|
66 |
+
self.conv_block_2 = ConvBlock(out_channels, out_channels)
|
67 |
+
|
68 |
+
def forward(self, up_x, down_x):
|
69 |
+
"""
|
70 |
+
|
71 |
+
:param up_x: this is the output from the previous up block
|
72 |
+
:param down_x: this is the output from the down block
|
73 |
+
:return: upsampled feature map
|
74 |
+
"""
|
75 |
+
x = self.upsample(up_x)
|
76 |
+
print(x.shape)
|
77 |
+
print(down_x.shape)
|
78 |
+
x = torch.cat([x, down_x], 1)
|
79 |
+
x = self.conv_block_1(x)
|
80 |
+
x = self.conv_block_2(x)
|
81 |
+
return x
|
82 |
+
|
83 |
+
|
84 |
+
class UNetWithResnet50Encoder(nn.Module):
|
85 |
+
DEPTH = 6
|
86 |
+
|
87 |
+
def __init__(self, max_depth, n_classes=1):
|
88 |
+
super().__init__()
|
89 |
+
resnet = torchvision.models.resnet.resnet50(pretrained=True)
|
90 |
+
down_blocks = []
|
91 |
+
up_blocks = []
|
92 |
+
self.input_block = nn.Sequential(*list(resnet.children()))[:3]
|
93 |
+
self.input_pool = list(resnet.children())[3]
|
94 |
+
for bottleneck in list(resnet.children()):
|
95 |
+
if isinstance(bottleneck, nn.Sequential):
|
96 |
+
down_blocks.append(bottleneck)
|
97 |
+
self.down_blocks = nn.ModuleList(down_blocks)
|
98 |
+
self.bridge = Bridge(2048, 2048)
|
99 |
+
up_blocks.append(UpBlockForUNetWithResNet50(2048, 1024))
|
100 |
+
up_blocks.append(UpBlockForUNetWithResNet50(1024, 512))
|
101 |
+
up_blocks.append(UpBlockForUNetWithResNet50(512, 256))
|
102 |
+
up_blocks.append(UpBlockForUNetWithResNet50(in_channels=128 + 64, out_channels=128,
|
103 |
+
up_conv_in_channels=256, up_conv_out_channels=128))
|
104 |
+
up_blocks.append(UpBlockForUNetWithResNet50(in_channels=64 + 3, out_channels=64,
|
105 |
+
up_conv_in_channels=128, up_conv_out_channels=64))
|
106 |
+
|
107 |
+
self.up_blocks = nn.ModuleList(up_blocks)
|
108 |
+
|
109 |
+
self.out = nn.Conv2d(64, n_classes, kernel_size=1, stride=1)
|
110 |
+
|
111 |
+
self.max_depth = max_depth
|
112 |
+
|
113 |
+
def forward(self, x, with_output_feature_map=False):
|
114 |
+
pre_pools = dict()
|
115 |
+
pre_pools[f"layer_0"] = x
|
116 |
+
x = self.input_block(x)
|
117 |
+
pre_pools[f"layer_1"] = x
|
118 |
+
x = self.input_pool(x)
|
119 |
+
|
120 |
+
for i, block in enumerate(self.down_blocks, 2):
|
121 |
+
x = block(x)
|
122 |
+
if i == (UNetWithResnet50Encoder.DEPTH - 1):
|
123 |
+
continue
|
124 |
+
pre_pools[f"layer_{i}"] = x
|
125 |
+
|
126 |
+
x = self.bridge(x)
|
127 |
+
|
128 |
+
for i, block in enumerate(self.up_blocks, 1):
|
129 |
+
key = f"layer_{UNetWithResnet50Encoder.DEPTH - 1 - i}"
|
130 |
+
x = block(x, pre_pools[key])
|
131 |
+
output_feature_map = x
|
132 |
+
x = self.out(x)
|
133 |
+
del pre_pools
|
134 |
+
# if with_output_feature_map:
|
135 |
+
# return x, output_feature_map
|
136 |
+
# else:
|
137 |
+
# return x
|
138 |
+
|
139 |
+
out_depth = torch.sigmoid(x) * self.max_depth
|
140 |
+
|
141 |
+
return {'pred_d': out_depth}
|
142 |
+
|
143 |
+
# model = UNetWithResnet50Encoder().cuda()
|
144 |
+
# inp = torch.rand((2, 3, 512, 512)).cuda()
|
145 |
+
# out = model(inp)
|
146 |
+
|
147 |
+
if __name__ == "__main__":
|
148 |
+
model = UNetWithResnet50Encoder(max_depth=10).cuda()
|
149 |
+
# print(model)
|
150 |
+
summary(model, input_size=(1,3,256,256))
|