Kaori1707 commited on
Commit
72db49b
·
1 Parent(s): 9e31c0e

Upload 21 files

Browse files
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.jpg filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ from torchvision.transforms import Compose
5
+ import cv2
6
+ from dpt.models import DPTDepthModel
7
+ from dpt.transforms import Resize, NormalizeImage, PrepareForNet
8
+ import os
9
+
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ print("device: %s" % device)
12
+ default_models = {
13
+ "dpt_hybrid": "weights/dpt_hybrid-midas-501f0c75.pt",
14
+ }
15
+ torch.backends.cudnn.enabled = True
16
+ torch.backends.cudnn.benchmark = True
17
+ net_w = net_h = 384
18
+ model = DPTDepthModel(
19
+ path=default_models["dpt_hybrid"],
20
+ backbone="vitb_rn50_384",
21
+ non_negative=True,
22
+ enable_attention_hooks=False,
23
+ )
24
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
25
+ transform = Compose(
26
+ [
27
+ Resize(
28
+ net_w,
29
+ net_h,
30
+ resize_target=None,
31
+ keep_aspect_ratio=True,
32
+ ensure_multiple_of=32,
33
+ resize_method="minimal",
34
+ image_interpolation_method=cv2.INTER_CUBIC,
35
+ ),
36
+ normalization,
37
+ PrepareForNet(),
38
+ ]
39
+ )
40
+
41
+ model.eval()
42
+ model.to(device)
43
+
44
+ def write_depth(depth, bits=1, absolute_depth=False):
45
+ """Write depth map to pfm and png file.
46
+
47
+ Args:
48
+ path (str): filepath without extension
49
+ depth (array): depth
50
+ """
51
+ # write_pfm(path + ".pfm", depth.astype(np.float32))
52
+
53
+ if absolute_depth:
54
+ out = depth
55
+ else:
56
+ depth_min = depth.min()
57
+ depth_max = depth.max()
58
+
59
+ max_val = (2 ** (8 * bits)) - 1
60
+
61
+ if depth_max - depth_min > np.finfo("float").eps:
62
+ out = max_val * (depth - depth_min) / (depth_max - depth_min)
63
+ else:
64
+ out = np.zeros(depth.shape, dtype=depth.dtype)
65
+
66
+ if bits == 1:
67
+ return out.astype("uint8")
68
+ elif bits == 2:
69
+ return out.astype("uint16")
70
+
71
+
72
+ def DPT(image):
73
+ img_input = transform({"image": image})["image"]
74
+ # compute
75
+ with torch.no_grad():
76
+ sample = torch.from_numpy(img_input).to(device).unsqueeze(0)
77
+
78
+ prediction = model.forward(sample)
79
+ prediction = (
80
+ torch.nn.functional.interpolate(
81
+ prediction.unsqueeze(1),
82
+ size=image.shape[:2],
83
+ mode="bicubic",
84
+ align_corners=False,
85
+ )
86
+ .squeeze()
87
+ .cpu()
88
+ .numpy()
89
+ )
90
+
91
+ depth_img = write_depth(prediction, bits=2)
92
+ return depth_img
93
+
94
+ title = " AISeed AI Application Demo "
95
+ description = "# A Demo of Deep Learning for Depth Estimation"
96
+ example_list = [["examples/" + example] for example in os.listdir("examples")]
97
+
98
+ with gr.Blocks() as demo:
99
+ demo.title = title
100
+ gr.Markdown(description)
101
+ with gr.Row():
102
+ im = gr.Image(label="Input Image")
103
+ im_2 = gr.Image(label="Depth Image")
104
+ with gr.Column():
105
+
106
+ btn1 = gr.Button(value="Depth Estimator")
107
+ btn1.click(DPT, inputs=[im], outputs=[im_2])
108
+ gr.Examples(examples=example_list,
109
+ inputs=[im],
110
+ outputs=[im_2],
111
+ fn=DPT)
112
+
113
+ if __name__ == "__main__":
114
+ demo.launch()
dpt/__init__.py ADDED
File without changes
dpt/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (147 Bytes). View file
 
dpt/__pycache__/base_model.cpython-38.pyc ADDED
Binary file (673 Bytes). View file
 
dpt/__pycache__/blocks.cpython-38.pyc ADDED
Binary file (6.72 kB). View file
 
dpt/__pycache__/models.cpython-38.pyc ADDED
Binary file (3.82 kB). View file
 
dpt/__pycache__/transforms.cpython-38.pyc ADDED
Binary file (5.68 kB). View file
 
dpt/__pycache__/vit.cpython-38.pyc ADDED
Binary file (11.2 kB). View file
 
dpt/base_model.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class BaseModel(torch.nn.Module):
5
+ def load(self, path):
6
+ """Load model from file.
7
+
8
+ Args:
9
+ path (str): file path
10
+ """
11
+ parameters = torch.load(path, map_location=torch.device("cpu"))
12
+
13
+ if "optimizer" in parameters:
14
+ parameters = parameters["model"]
15
+
16
+ self.load_state_dict(parameters)
dpt/blocks.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .vit import (
5
+ _make_pretrained_vitb_rn50_384,
6
+ _make_pretrained_vitl16_384,
7
+ _make_pretrained_vitb16_384,
8
+ forward_vit,
9
+ )
10
+
11
+
12
+ def _make_encoder(
13
+ backbone,
14
+ features,
15
+ use_pretrained,
16
+ groups=1,
17
+ expand=False,
18
+ exportable=True,
19
+ hooks=None,
20
+ use_vit_only=False,
21
+ use_readout="ignore",
22
+ enable_attention_hooks=False,
23
+ ):
24
+ if backbone == "vitl16_384":
25
+ pretrained = _make_pretrained_vitl16_384(
26
+ use_pretrained,
27
+ hooks=hooks,
28
+ use_readout=use_readout,
29
+ enable_attention_hooks=enable_attention_hooks,
30
+ )
31
+ scratch = _make_scratch(
32
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
33
+ ) # ViT-L/16 - 85.0% Top1 (backbone)
34
+ elif backbone == "vitb_rn50_384":
35
+ pretrained = _make_pretrained_vitb_rn50_384(
36
+ use_pretrained,
37
+ hooks=hooks,
38
+ use_vit_only=use_vit_only,
39
+ use_readout=use_readout,
40
+ enable_attention_hooks=enable_attention_hooks,
41
+ )
42
+ scratch = _make_scratch(
43
+ [256, 512, 768, 768], features, groups=groups, expand=expand
44
+ ) # ViT-H/16 - 85.0% Top1 (backbone)
45
+ elif backbone == "vitb16_384":
46
+ pretrained = _make_pretrained_vitb16_384(
47
+ use_pretrained,
48
+ hooks=hooks,
49
+ use_readout=use_readout,
50
+ enable_attention_hooks=enable_attention_hooks,
51
+ )
52
+ scratch = _make_scratch(
53
+ [96, 192, 384, 768], features, groups=groups, expand=expand
54
+ ) # ViT-B/16 - 84.6% Top1 (backbone)
55
+ elif backbone == "resnext101_wsl":
56
+ pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
57
+ scratch = _make_scratch(
58
+ [256, 512, 1024, 2048], features, groups=groups, expand=expand
59
+ ) # efficientnet_lite3
60
+ else:
61
+ print(f"Backbone '{backbone}' not implemented")
62
+ assert False
63
+
64
+ return pretrained, scratch
65
+
66
+
67
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
68
+ scratch = nn.Module()
69
+
70
+ out_shape1 = out_shape
71
+ out_shape2 = out_shape
72
+ out_shape3 = out_shape
73
+ out_shape4 = out_shape
74
+ if expand == True:
75
+ out_shape1 = out_shape
76
+ out_shape2 = out_shape * 2
77
+ out_shape3 = out_shape * 4
78
+ out_shape4 = out_shape * 8
79
+
80
+ scratch.layer1_rn = nn.Conv2d(
81
+ in_shape[0],
82
+ out_shape1,
83
+ kernel_size=3,
84
+ stride=1,
85
+ padding=1,
86
+ bias=False,
87
+ groups=groups,
88
+ )
89
+ scratch.layer2_rn = nn.Conv2d(
90
+ in_shape[1],
91
+ out_shape2,
92
+ kernel_size=3,
93
+ stride=1,
94
+ padding=1,
95
+ bias=False,
96
+ groups=groups,
97
+ )
98
+ scratch.layer3_rn = nn.Conv2d(
99
+ in_shape[2],
100
+ out_shape3,
101
+ kernel_size=3,
102
+ stride=1,
103
+ padding=1,
104
+ bias=False,
105
+ groups=groups,
106
+ )
107
+ scratch.layer4_rn = nn.Conv2d(
108
+ in_shape[3],
109
+ out_shape4,
110
+ kernel_size=3,
111
+ stride=1,
112
+ padding=1,
113
+ bias=False,
114
+ groups=groups,
115
+ )
116
+
117
+ return scratch
118
+
119
+
120
+ def _make_resnet_backbone(resnet):
121
+ pretrained = nn.Module()
122
+ pretrained.layer1 = nn.Sequential(
123
+ resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
124
+ )
125
+
126
+ pretrained.layer2 = resnet.layer2
127
+ pretrained.layer3 = resnet.layer3
128
+ pretrained.layer4 = resnet.layer4
129
+
130
+ return pretrained
131
+
132
+
133
+ def _make_pretrained_resnext101_wsl(use_pretrained):
134
+ resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
135
+ return _make_resnet_backbone(resnet)
136
+
137
+
138
+ class Interpolate(nn.Module):
139
+ """Interpolation module."""
140
+
141
+ def __init__(self, scale_factor, mode, align_corners=False):
142
+ """Init.
143
+
144
+ Args:
145
+ scale_factor (float): scaling
146
+ mode (str): interpolation mode
147
+ """
148
+ super(Interpolate, self).__init__()
149
+
150
+ self.interp = nn.functional.interpolate
151
+ self.scale_factor = scale_factor
152
+ self.mode = mode
153
+ self.align_corners = align_corners
154
+
155
+ def forward(self, x):
156
+ """Forward pass.
157
+
158
+ Args:
159
+ x (tensor): input
160
+
161
+ Returns:
162
+ tensor: interpolated data
163
+ """
164
+
165
+ x = self.interp(
166
+ x,
167
+ scale_factor=self.scale_factor,
168
+ mode=self.mode,
169
+ align_corners=self.align_corners,
170
+ )
171
+
172
+ return x
173
+
174
+
175
+ class ResidualConvUnit(nn.Module):
176
+ """Residual convolution module."""
177
+
178
+ def __init__(self, features):
179
+ """Init.
180
+
181
+ Args:
182
+ features (int): number of features
183
+ """
184
+ super().__init__()
185
+
186
+ self.conv1 = nn.Conv2d(
187
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
188
+ )
189
+
190
+ self.conv2 = nn.Conv2d(
191
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
192
+ )
193
+
194
+ self.relu = nn.ReLU(inplace=True)
195
+
196
+ def forward(self, x):
197
+ """Forward pass.
198
+
199
+ Args:
200
+ x (tensor): input
201
+
202
+ Returns:
203
+ tensor: output
204
+ """
205
+ out = self.relu(x)
206
+ out = self.conv1(out)
207
+ out = self.relu(out)
208
+ out = self.conv2(out)
209
+
210
+ return out + x
211
+
212
+
213
+ class FeatureFusionBlock(nn.Module):
214
+ """Feature fusion block."""
215
+
216
+ def __init__(self, features):
217
+ """Init.
218
+
219
+ Args:
220
+ features (int): number of features
221
+ """
222
+ super(FeatureFusionBlock, self).__init__()
223
+
224
+ self.resConfUnit1 = ResidualConvUnit(features)
225
+ self.resConfUnit2 = ResidualConvUnit(features)
226
+
227
+ def forward(self, *xs):
228
+ """Forward pass.
229
+
230
+ Returns:
231
+ tensor: output
232
+ """
233
+ output = xs[0]
234
+
235
+ if len(xs) == 2:
236
+ output += self.resConfUnit1(xs[1])
237
+
238
+ output = self.resConfUnit2(output)
239
+
240
+ output = nn.functional.interpolate(
241
+ output, scale_factor=2, mode="bilinear", align_corners=True
242
+ )
243
+
244
+ return output
245
+
246
+
247
+ class ResidualConvUnit_custom(nn.Module):
248
+ """Residual convolution module."""
249
+
250
+ def __init__(self, features, activation, bn):
251
+ """Init.
252
+
253
+ Args:
254
+ features (int): number of features
255
+ """
256
+ super().__init__()
257
+
258
+ self.bn = bn
259
+
260
+ self.groups = 1
261
+
262
+ self.conv1 = nn.Conv2d(
263
+ features,
264
+ features,
265
+ kernel_size=3,
266
+ stride=1,
267
+ padding=1,
268
+ bias=not self.bn,
269
+ groups=self.groups,
270
+ )
271
+
272
+ self.conv2 = nn.Conv2d(
273
+ features,
274
+ features,
275
+ kernel_size=3,
276
+ stride=1,
277
+ padding=1,
278
+ bias=not self.bn,
279
+ groups=self.groups,
280
+ )
281
+
282
+ if self.bn == True:
283
+ self.bn1 = nn.BatchNorm2d(features)
284
+ self.bn2 = nn.BatchNorm2d(features)
285
+
286
+ self.activation = activation
287
+
288
+ self.skip_add = nn.quantized.FloatFunctional()
289
+
290
+ def forward(self, x):
291
+ """Forward pass.
292
+
293
+ Args:
294
+ x (tensor): input
295
+
296
+ Returns:
297
+ tensor: output
298
+ """
299
+
300
+ out = self.activation(x)
301
+ out = self.conv1(out)
302
+ if self.bn == True:
303
+ out = self.bn1(out)
304
+
305
+ out = self.activation(out)
306
+ out = self.conv2(out)
307
+ if self.bn == True:
308
+ out = self.bn2(out)
309
+
310
+ if self.groups > 1:
311
+ out = self.conv_merge(out)
312
+
313
+ return self.skip_add.add(out, x)
314
+
315
+ # return out + x
316
+
317
+
318
+ class FeatureFusionBlock_custom(nn.Module):
319
+ """Feature fusion block."""
320
+
321
+ def __init__(
322
+ self,
323
+ features,
324
+ activation,
325
+ deconv=False,
326
+ bn=False,
327
+ expand=False,
328
+ align_corners=True,
329
+ ):
330
+ """Init.
331
+
332
+ Args:
333
+ features (int): number of features
334
+ """
335
+ super(FeatureFusionBlock_custom, self).__init__()
336
+
337
+ self.deconv = deconv
338
+ self.align_corners = align_corners
339
+
340
+ self.groups = 1
341
+
342
+ self.expand = expand
343
+ out_features = features
344
+ if self.expand == True:
345
+ out_features = features // 2
346
+
347
+ self.out_conv = nn.Conv2d(
348
+ features,
349
+ out_features,
350
+ kernel_size=1,
351
+ stride=1,
352
+ padding=0,
353
+ bias=True,
354
+ groups=1,
355
+ )
356
+
357
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
358
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
359
+
360
+ self.skip_add = nn.quantized.FloatFunctional()
361
+
362
+ def forward(self, *xs):
363
+ """Forward pass.
364
+
365
+ Returns:
366
+ tensor: output
367
+ """
368
+ output = xs[0]
369
+
370
+ if len(xs) == 2:
371
+ res = self.resConfUnit1(xs[1])
372
+ output = self.skip_add.add(output, res)
373
+ # output += res
374
+
375
+ output = self.resConfUnit2(output)
376
+
377
+ output = nn.functional.interpolate(
378
+ output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
379
+ )
380
+
381
+ output = self.out_conv(output)
382
+
383
+ return output
dpt/midas_net.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2
+ This file contains code that is adapted from
3
+ https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .base_model import BaseModel
9
+ from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
10
+
11
+
12
+ class MidasNet_large(BaseModel):
13
+ """Network for monocular depth estimation."""
14
+
15
+ def __init__(self, path=None, features=256, non_negative=True):
16
+ """Init.
17
+
18
+ Args:
19
+ path (str, optional): Path to saved model. Defaults to None.
20
+ features (int, optional): Number of features. Defaults to 256.
21
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
22
+ """
23
+ print("Loading weights: ", path)
24
+
25
+ super(MidasNet_large, self).__init__()
26
+
27
+ use_pretrained = False if path is None else True
28
+
29
+ self.pretrained, self.scratch = _make_encoder(
30
+ backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained
31
+ )
32
+
33
+ self.scratch.refinenet4 = FeatureFusionBlock(features)
34
+ self.scratch.refinenet3 = FeatureFusionBlock(features)
35
+ self.scratch.refinenet2 = FeatureFusionBlock(features)
36
+ self.scratch.refinenet1 = FeatureFusionBlock(features)
37
+
38
+ self.scratch.output_conv = nn.Sequential(
39
+ nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
40
+ Interpolate(scale_factor=2, mode="bilinear"),
41
+ nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
42
+ nn.ReLU(True),
43
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
44
+ nn.ReLU(True) if non_negative else nn.Identity(),
45
+ )
46
+
47
+ if path:
48
+ self.load(path)
49
+
50
+ def forward(self, x):
51
+ """Forward pass.
52
+
53
+ Args:
54
+ x (tensor): input data (image)
55
+
56
+ Returns:
57
+ tensor: depth
58
+ """
59
+
60
+ layer_1 = self.pretrained.layer1(x)
61
+ layer_2 = self.pretrained.layer2(layer_1)
62
+ layer_3 = self.pretrained.layer3(layer_2)
63
+ layer_4 = self.pretrained.layer4(layer_3)
64
+
65
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
66
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
67
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
68
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
69
+
70
+ path_4 = self.scratch.refinenet4(layer_4_rn)
71
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
72
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
73
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
74
+
75
+ out = self.scratch.output_conv(path_1)
76
+
77
+ return torch.squeeze(out, dim=1)
dpt/models.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .base_model import BaseModel
6
+ from .blocks import (
7
+ FeatureFusionBlock,
8
+ FeatureFusionBlock_custom,
9
+ Interpolate,
10
+ _make_encoder,
11
+ forward_vit,
12
+ )
13
+
14
+
15
+ def _make_fusion_block(features, use_bn):
16
+ return FeatureFusionBlock_custom(
17
+ features,
18
+ nn.ReLU(False),
19
+ deconv=False,
20
+ bn=use_bn,
21
+ expand=False,
22
+ align_corners=True,
23
+ )
24
+
25
+
26
+ class DPT(BaseModel):
27
+ def __init__(
28
+ self,
29
+ head,
30
+ features=256,
31
+ backbone="vitb_rn50_384",
32
+ readout="project",
33
+ channels_last=False,
34
+ use_bn=False,
35
+ enable_attention_hooks=False,
36
+ ):
37
+
38
+ super(DPT, self).__init__()
39
+
40
+ self.channels_last = channels_last
41
+
42
+ hooks = {
43
+ "vitb_rn50_384": [0, 1, 8, 11],
44
+ "vitb16_384": [2, 5, 8, 11],
45
+ "vitl16_384": [5, 11, 17, 23],
46
+ }
47
+
48
+ # Instantiate backbone and reassemble blocks
49
+ self.pretrained, self.scratch = _make_encoder(
50
+ backbone,
51
+ features,
52
+ False, # Set to true of you want to train from scratch, uses ImageNet weights
53
+ groups=1,
54
+ expand=False,
55
+ exportable=False,
56
+ hooks=hooks[backbone],
57
+ use_readout=readout,
58
+ enable_attention_hooks=enable_attention_hooks,
59
+ )
60
+
61
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
62
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
63
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
64
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
65
+
66
+ self.scratch.output_conv = head
67
+
68
+ def forward(self, x):
69
+ if self.channels_last == True:
70
+ x.contiguous(memory_format=torch.channels_last)
71
+
72
+ layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
73
+
74
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
75
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
76
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
77
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
78
+
79
+ path_4 = self.scratch.refinenet4(layer_4_rn)
80
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
81
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
82
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
83
+
84
+ out = self.scratch.output_conv(path_1)
85
+
86
+ return out
87
+
88
+
89
+ class DPTDepthModel(DPT):
90
+ def __init__(
91
+ self, path=None, non_negative=True, scale=1.0, shift=0.0, invert=False, **kwargs
92
+ ):
93
+ features = kwargs["features"] if "features" in kwargs else 256
94
+
95
+ self.scale = scale
96
+ self.shift = shift
97
+ self.invert = invert
98
+
99
+ head = nn.Sequential(
100
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
101
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
102
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
103
+ nn.ReLU(True),
104
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
105
+ nn.ReLU(True) if non_negative else nn.Identity(),
106
+ nn.Identity(),
107
+ )
108
+
109
+ super().__init__(head, **kwargs)
110
+
111
+ if path is not None:
112
+ self.load(path)
113
+
114
+ def forward(self, x):
115
+ inv_depth = super().forward(x).squeeze(dim=1)
116
+
117
+ if self.invert:
118
+ depth = self.scale * inv_depth + self.shift
119
+ depth[depth < 1e-8] = 1e-8
120
+ depth = 1.0 / depth
121
+ return depth
122
+ else:
123
+ return inv_depth
124
+
125
+
126
+ class DPTSegmentationModel(DPT):
127
+ def __init__(self, num_classes, path=None, **kwargs):
128
+
129
+ features = kwargs["features"] if "features" in kwargs else 256
130
+
131
+ kwargs["use_bn"] = True
132
+
133
+ head = nn.Sequential(
134
+ nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False),
135
+ nn.BatchNorm2d(features),
136
+ nn.ReLU(True),
137
+ nn.Dropout(0.1, False),
138
+ nn.Conv2d(features, num_classes, kernel_size=1),
139
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
140
+ )
141
+
142
+ super().__init__(head, **kwargs)
143
+
144
+ self.auxlayer = nn.Sequential(
145
+ nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False),
146
+ nn.BatchNorm2d(features),
147
+ nn.ReLU(True),
148
+ nn.Dropout(0.1, False),
149
+ nn.Conv2d(features, num_classes, kernel_size=1),
150
+ )
151
+
152
+ if path is not None:
153
+ self.load(path)
dpt/transforms.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import math
4
+
5
+
6
+ def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
7
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
8
+
9
+ Args:
10
+ sample (dict): sample
11
+ size (tuple): image size
12
+
13
+ Returns:
14
+ tuple: new size
15
+ """
16
+ shape = list(sample["disparity"].shape)
17
+
18
+ if shape[0] >= size[0] and shape[1] >= size[1]:
19
+ return sample
20
+
21
+ scale = [0, 0]
22
+ scale[0] = size[0] / shape[0]
23
+ scale[1] = size[1] / shape[1]
24
+
25
+ scale = max(scale)
26
+
27
+ shape[0] = math.ceil(scale * shape[0])
28
+ shape[1] = math.ceil(scale * shape[1])
29
+
30
+ # resize
31
+ sample["image"] = cv2.resize(
32
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
33
+ )
34
+
35
+ sample["disparity"] = cv2.resize(
36
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
37
+ )
38
+ sample["mask"] = cv2.resize(
39
+ sample["mask"].astype(np.float32),
40
+ tuple(shape[::-1]),
41
+ interpolation=cv2.INTER_NEAREST,
42
+ )
43
+ sample["mask"] = sample["mask"].astype(bool)
44
+
45
+ return tuple(shape)
46
+
47
+
48
+ class Resize(object):
49
+ """Resize sample to given size (width, height)."""
50
+
51
+ def __init__(
52
+ self,
53
+ width,
54
+ height,
55
+ resize_target=True,
56
+ keep_aspect_ratio=False,
57
+ ensure_multiple_of=1,
58
+ resize_method="lower_bound",
59
+ image_interpolation_method=cv2.INTER_AREA,
60
+ ):
61
+ """Init.
62
+
63
+ Args:
64
+ width (int): desired output width
65
+ height (int): desired output height
66
+ resize_target (bool, optional):
67
+ True: Resize the full sample (image, mask, target).
68
+ False: Resize image only.
69
+ Defaults to True.
70
+ keep_aspect_ratio (bool, optional):
71
+ True: Keep the aspect ratio of the input sample.
72
+ Output sample might not have the given width and height, and
73
+ resize behaviour depends on the parameter 'resize_method'.
74
+ Defaults to False.
75
+ ensure_multiple_of (int, optional):
76
+ Output width and height is constrained to be multiple of this parameter.
77
+ Defaults to 1.
78
+ resize_method (str, optional):
79
+ "lower_bound": Output will be at least as large as the given size.
80
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
81
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
82
+ Defaults to "lower_bound".
83
+ """
84
+ self.__width = width
85
+ self.__height = height
86
+
87
+ self.__resize_target = resize_target
88
+ self.__keep_aspect_ratio = keep_aspect_ratio
89
+ self.__multiple_of = ensure_multiple_of
90
+ self.__resize_method = resize_method
91
+ self.__image_interpolation_method = image_interpolation_method
92
+
93
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
94
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
95
+
96
+ if max_val is not None and y > max_val:
97
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
98
+
99
+ if y < min_val:
100
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
101
+
102
+ return y
103
+
104
+ def get_size(self, width, height):
105
+ # determine new height and width
106
+ scale_height = self.__height / height
107
+ scale_width = self.__width / width
108
+
109
+ if self.__keep_aspect_ratio:
110
+ if self.__resize_method == "lower_bound":
111
+ # scale such that output size is lower bound
112
+ if scale_width > scale_height:
113
+ # fit width
114
+ scale_height = scale_width
115
+ else:
116
+ # fit height
117
+ scale_width = scale_height
118
+ elif self.__resize_method == "upper_bound":
119
+ # scale such that output size is upper bound
120
+ if scale_width < scale_height:
121
+ # fit width
122
+ scale_height = scale_width
123
+ else:
124
+ # fit height
125
+ scale_width = scale_height
126
+ elif self.__resize_method == "minimal":
127
+ # scale as least as possbile
128
+ if abs(1 - scale_width) < abs(1 - scale_height):
129
+ # fit width
130
+ scale_height = scale_width
131
+ else:
132
+ # fit height
133
+ scale_width = scale_height
134
+ else:
135
+ raise ValueError(
136
+ f"resize_method {self.__resize_method} not implemented"
137
+ )
138
+
139
+ if self.__resize_method == "lower_bound":
140
+ new_height = self.constrain_to_multiple_of(
141
+ scale_height * height, min_val=self.__height
142
+ )
143
+ new_width = self.constrain_to_multiple_of(
144
+ scale_width * width, min_val=self.__width
145
+ )
146
+ elif self.__resize_method == "upper_bound":
147
+ new_height = self.constrain_to_multiple_of(
148
+ scale_height * height, max_val=self.__height
149
+ )
150
+ new_width = self.constrain_to_multiple_of(
151
+ scale_width * width, max_val=self.__width
152
+ )
153
+ elif self.__resize_method == "minimal":
154
+ new_height = self.constrain_to_multiple_of(scale_height * height)
155
+ new_width = self.constrain_to_multiple_of(scale_width * width)
156
+ else:
157
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
158
+
159
+ return (new_width, new_height)
160
+
161
+ def __call__(self, sample):
162
+ width, height = self.get_size(
163
+ sample["image"].shape[1], sample["image"].shape[0]
164
+ )
165
+
166
+ # resize sample
167
+ sample["image"] = cv2.resize(
168
+ sample["image"],
169
+ (width, height),
170
+ interpolation=self.__image_interpolation_method,
171
+ )
172
+
173
+ if self.__resize_target:
174
+ if "disparity" in sample:
175
+ sample["disparity"] = cv2.resize(
176
+ sample["disparity"],
177
+ (width, height),
178
+ interpolation=cv2.INTER_NEAREST,
179
+ )
180
+
181
+ if "depth" in sample:
182
+ sample["depth"] = cv2.resize(
183
+ sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
184
+ )
185
+
186
+ sample["mask"] = cv2.resize(
187
+ sample["mask"].astype(np.float32),
188
+ (width, height),
189
+ interpolation=cv2.INTER_NEAREST,
190
+ )
191
+ sample["mask"] = sample["mask"].astype(bool)
192
+
193
+ return sample
194
+
195
+
196
+ class NormalizeImage(object):
197
+ """Normlize image by given mean and std."""
198
+
199
+ def __init__(self, mean, std):
200
+ self.__mean = mean
201
+ self.__std = std
202
+
203
+ def __call__(self, sample):
204
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
205
+
206
+ return sample
207
+
208
+
209
+ class PrepareForNet(object):
210
+ """Prepare sample for usage as network input."""
211
+
212
+ def __init__(self):
213
+ pass
214
+
215
+ def __call__(self, sample):
216
+ image = np.transpose(sample["image"], (2, 0, 1))
217
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
218
+
219
+ if "mask" in sample:
220
+ sample["mask"] = sample["mask"].astype(np.float32)
221
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
222
+
223
+ if "disparity" in sample:
224
+ disparity = sample["disparity"].astype(np.float32)
225
+ sample["disparity"] = np.ascontiguousarray(disparity)
226
+
227
+ if "depth" in sample:
228
+ depth = sample["depth"].astype(np.float32)
229
+ sample["depth"] = np.ascontiguousarray(depth)
230
+
231
+ return sample
dpt/vit.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import timm
4
+ import types
5
+ import math
6
+ import torch.nn.functional as F
7
+
8
+
9
+ activations = {}
10
+
11
+
12
+ def get_activation(name):
13
+ def hook(model, input, output):
14
+ activations[name] = output
15
+
16
+ return hook
17
+
18
+
19
+ attention = {}
20
+
21
+
22
+ def get_attention(name):
23
+ def hook(module, input, output):
24
+ x = input[0]
25
+ B, N, C = x.shape
26
+ qkv = (
27
+ module.qkv(x)
28
+ .reshape(B, N, 3, module.num_heads, C // module.num_heads)
29
+ .permute(2, 0, 3, 1, 4)
30
+ )
31
+ q, k, v = (
32
+ qkv[0],
33
+ qkv[1],
34
+ qkv[2],
35
+ ) # make torchscript happy (cannot use tensor as tuple)
36
+
37
+ attn = (q @ k.transpose(-2, -1)) * module.scale
38
+
39
+ attn = attn.softmax(dim=-1) # [:,:,1,1:]
40
+ attention[name] = attn
41
+
42
+ return hook
43
+
44
+
45
+ def get_mean_attention_map(attn, token, shape):
46
+ attn = attn[:, :, token, 1:]
47
+ attn = attn.unflatten(2, torch.Size([shape[2] // 16, shape[3] // 16])).float()
48
+ attn = torch.nn.functional.interpolate(
49
+ attn, size=shape[2:], mode="bicubic", align_corners=False
50
+ ).squeeze(0)
51
+
52
+ all_attn = torch.mean(attn, 0)
53
+
54
+ return all_attn
55
+
56
+
57
+ class Slice(nn.Module):
58
+ def __init__(self, start_index=1):
59
+ super(Slice, self).__init__()
60
+ self.start_index = start_index
61
+
62
+ def forward(self, x):
63
+ return x[:, self.start_index :]
64
+
65
+
66
+ class AddReadout(nn.Module):
67
+ def __init__(self, start_index=1):
68
+ super(AddReadout, self).__init__()
69
+ self.start_index = start_index
70
+
71
+ def forward(self, x):
72
+ if self.start_index == 2:
73
+ readout = (x[:, 0] + x[:, 1]) / 2
74
+ else:
75
+ readout = x[:, 0]
76
+ return x[:, self.start_index :] + readout.unsqueeze(1)
77
+
78
+
79
+ class ProjectReadout(nn.Module):
80
+ def __init__(self, in_features, start_index=1):
81
+ super(ProjectReadout, self).__init__()
82
+ self.start_index = start_index
83
+
84
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
85
+
86
+ def forward(self, x):
87
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
88
+ features = torch.cat((x[:, self.start_index :], readout), -1)
89
+
90
+ return self.project(features)
91
+
92
+
93
+ class Transpose(nn.Module):
94
+ def __init__(self, dim0, dim1):
95
+ super(Transpose, self).__init__()
96
+ self.dim0 = dim0
97
+ self.dim1 = dim1
98
+
99
+ def forward(self, x):
100
+ x = x.transpose(self.dim0, self.dim1)
101
+ return x
102
+
103
+
104
+ def forward_vit(pretrained, x):
105
+ b, c, h, w = x.shape
106
+
107
+ glob = pretrained.model.forward_flex(x)
108
+
109
+ layer_1 = pretrained.activations["1"]
110
+ layer_2 = pretrained.activations["2"]
111
+ layer_3 = pretrained.activations["3"]
112
+ layer_4 = pretrained.activations["4"]
113
+
114
+ layer_1 = pretrained.act_postprocess1[0:2](layer_1)
115
+ layer_2 = pretrained.act_postprocess2[0:2](layer_2)
116
+ layer_3 = pretrained.act_postprocess3[0:2](layer_3)
117
+ layer_4 = pretrained.act_postprocess4[0:2](layer_4)
118
+
119
+ unflatten = nn.Sequential(
120
+ nn.Unflatten(
121
+ 2,
122
+ torch.Size(
123
+ [
124
+ h // pretrained.model.patch_size[1],
125
+ w // pretrained.model.patch_size[0],
126
+ ]
127
+ ),
128
+ )
129
+ )
130
+
131
+ if layer_1.ndim == 3:
132
+ layer_1 = unflatten(layer_1)
133
+ if layer_2.ndim == 3:
134
+ layer_2 = unflatten(layer_2)
135
+ if layer_3.ndim == 3:
136
+ layer_3 = unflatten(layer_3)
137
+ if layer_4.ndim == 3:
138
+ layer_4 = unflatten(layer_4)
139
+
140
+ layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
141
+ layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
142
+ layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
143
+ layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
144
+
145
+ return layer_1, layer_2, layer_3, layer_4
146
+
147
+
148
+ def _resize_pos_embed(self, posemb, gs_h, gs_w):
149
+ posemb_tok, posemb_grid = (
150
+ posemb[:, : self.start_index],
151
+ posemb[0, self.start_index :],
152
+ )
153
+
154
+ gs_old = int(math.sqrt(len(posemb_grid)))
155
+
156
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
157
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
158
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
159
+
160
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
161
+
162
+ return posemb
163
+
164
+
165
+ def forward_flex(self, x):
166
+ b, c, h, w = x.shape
167
+
168
+ pos_embed = self._resize_pos_embed(
169
+ self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
170
+ )
171
+
172
+ B = x.shape[0]
173
+
174
+ if hasattr(self.patch_embed, "backbone"):
175
+ x = self.patch_embed.backbone(x)
176
+ if isinstance(x, (list, tuple)):
177
+ x = x[-1] # last feature if backbone outputs list/tuple of features
178
+
179
+ x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
180
+
181
+ if getattr(self, "dist_token", None) is not None:
182
+ cls_tokens = self.cls_token.expand(
183
+ B, -1, -1
184
+ ) # stole cls_tokens impl from Phil Wang, thanks
185
+ dist_token = self.dist_token.expand(B, -1, -1)
186
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
187
+ else:
188
+ cls_tokens = self.cls_token.expand(
189
+ B, -1, -1
190
+ ) # stole cls_tokens impl from Phil Wang, thanks
191
+ x = torch.cat((cls_tokens, x), dim=1)
192
+
193
+ x = x + pos_embed
194
+ x = self.pos_drop(x)
195
+
196
+ for blk in self.blocks:
197
+ x = blk(x)
198
+
199
+ x = self.norm(x)
200
+
201
+ return x
202
+
203
+
204
+ def get_readout_oper(vit_features, features, use_readout, start_index=1):
205
+ if use_readout == "ignore":
206
+ readout_oper = [Slice(start_index)] * len(features)
207
+ elif use_readout == "add":
208
+ readout_oper = [AddReadout(start_index)] * len(features)
209
+ elif use_readout == "project":
210
+ readout_oper = [
211
+ ProjectReadout(vit_features, start_index) for out_feat in features
212
+ ]
213
+ else:
214
+ assert (
215
+ False
216
+ ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
217
+
218
+ return readout_oper
219
+
220
+
221
+ def _make_vit_b16_backbone(
222
+ model,
223
+ features=[96, 192, 384, 768],
224
+ size=[384, 384],
225
+ hooks=[2, 5, 8, 11],
226
+ vit_features=768,
227
+ use_readout="ignore",
228
+ start_index=1,
229
+ enable_attention_hooks=False,
230
+ ):
231
+ pretrained = nn.Module()
232
+
233
+ pretrained.model = model
234
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
235
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
236
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
237
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
238
+
239
+ pretrained.activations = activations
240
+
241
+ if enable_attention_hooks:
242
+ pretrained.model.blocks[hooks[0]].attn.register_forward_hook(
243
+ get_attention("attn_1")
244
+ )
245
+ pretrained.model.blocks[hooks[1]].attn.register_forward_hook(
246
+ get_attention("attn_2")
247
+ )
248
+ pretrained.model.blocks[hooks[2]].attn.register_forward_hook(
249
+ get_attention("attn_3")
250
+ )
251
+ pretrained.model.blocks[hooks[3]].attn.register_forward_hook(
252
+ get_attention("attn_4")
253
+ )
254
+ pretrained.attention = attention
255
+
256
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
257
+
258
+ # 32, 48, 136, 384
259
+ pretrained.act_postprocess1 = nn.Sequential(
260
+ readout_oper[0],
261
+ Transpose(1, 2),
262
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
263
+ nn.Conv2d(
264
+ in_channels=vit_features,
265
+ out_channels=features[0],
266
+ kernel_size=1,
267
+ stride=1,
268
+ padding=0,
269
+ ),
270
+ nn.ConvTranspose2d(
271
+ in_channels=features[0],
272
+ out_channels=features[0],
273
+ kernel_size=4,
274
+ stride=4,
275
+ padding=0,
276
+ bias=True,
277
+ dilation=1,
278
+ groups=1,
279
+ ),
280
+ )
281
+
282
+ pretrained.act_postprocess2 = nn.Sequential(
283
+ readout_oper[1],
284
+ Transpose(1, 2),
285
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
286
+ nn.Conv2d(
287
+ in_channels=vit_features,
288
+ out_channels=features[1],
289
+ kernel_size=1,
290
+ stride=1,
291
+ padding=0,
292
+ ),
293
+ nn.ConvTranspose2d(
294
+ in_channels=features[1],
295
+ out_channels=features[1],
296
+ kernel_size=2,
297
+ stride=2,
298
+ padding=0,
299
+ bias=True,
300
+ dilation=1,
301
+ groups=1,
302
+ ),
303
+ )
304
+
305
+ pretrained.act_postprocess3 = nn.Sequential(
306
+ readout_oper[2],
307
+ Transpose(1, 2),
308
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
309
+ nn.Conv2d(
310
+ in_channels=vit_features,
311
+ out_channels=features[2],
312
+ kernel_size=1,
313
+ stride=1,
314
+ padding=0,
315
+ ),
316
+ )
317
+
318
+ pretrained.act_postprocess4 = nn.Sequential(
319
+ readout_oper[3],
320
+ Transpose(1, 2),
321
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
322
+ nn.Conv2d(
323
+ in_channels=vit_features,
324
+ out_channels=features[3],
325
+ kernel_size=1,
326
+ stride=1,
327
+ padding=0,
328
+ ),
329
+ nn.Conv2d(
330
+ in_channels=features[3],
331
+ out_channels=features[3],
332
+ kernel_size=3,
333
+ stride=2,
334
+ padding=1,
335
+ ),
336
+ )
337
+
338
+ pretrained.model.start_index = start_index
339
+ pretrained.model.patch_size = [16, 16]
340
+
341
+ # We inject this function into the VisionTransformer instances so that
342
+ # we can use it with interpolated position embeddings without modifying the library source.
343
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
344
+ pretrained.model._resize_pos_embed = types.MethodType(
345
+ _resize_pos_embed, pretrained.model
346
+ )
347
+
348
+ return pretrained
349
+
350
+
351
+ def _make_vit_b_rn50_backbone(
352
+ model,
353
+ features=[256, 512, 768, 768],
354
+ size=[384, 384],
355
+ hooks=[0, 1, 8, 11],
356
+ vit_features=768,
357
+ use_vit_only=False,
358
+ use_readout="ignore",
359
+ start_index=1,
360
+ enable_attention_hooks=False,
361
+ ):
362
+ pretrained = nn.Module()
363
+
364
+ pretrained.model = model
365
+
366
+ if use_vit_only == True:
367
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
368
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
369
+ else:
370
+ pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
371
+ get_activation("1")
372
+ )
373
+ pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
374
+ get_activation("2")
375
+ )
376
+
377
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
378
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
379
+
380
+ if enable_attention_hooks:
381
+ pretrained.model.blocks[2].attn.register_forward_hook(get_attention("attn_1"))
382
+ pretrained.model.blocks[5].attn.register_forward_hook(get_attention("attn_2"))
383
+ pretrained.model.blocks[8].attn.register_forward_hook(get_attention("attn_3"))
384
+ pretrained.model.blocks[11].attn.register_forward_hook(get_attention("attn_4"))
385
+ pretrained.attention = attention
386
+
387
+ pretrained.activations = activations
388
+
389
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
390
+
391
+ if use_vit_only == True:
392
+ pretrained.act_postprocess1 = nn.Sequential(
393
+ readout_oper[0],
394
+ Transpose(1, 2),
395
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
396
+ nn.Conv2d(
397
+ in_channels=vit_features,
398
+ out_channels=features[0],
399
+ kernel_size=1,
400
+ stride=1,
401
+ padding=0,
402
+ ),
403
+ nn.ConvTranspose2d(
404
+ in_channels=features[0],
405
+ out_channels=features[0],
406
+ kernel_size=4,
407
+ stride=4,
408
+ padding=0,
409
+ bias=True,
410
+ dilation=1,
411
+ groups=1,
412
+ ),
413
+ )
414
+
415
+ pretrained.act_postprocess2 = nn.Sequential(
416
+ readout_oper[1],
417
+ Transpose(1, 2),
418
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
419
+ nn.Conv2d(
420
+ in_channels=vit_features,
421
+ out_channels=features[1],
422
+ kernel_size=1,
423
+ stride=1,
424
+ padding=0,
425
+ ),
426
+ nn.ConvTranspose2d(
427
+ in_channels=features[1],
428
+ out_channels=features[1],
429
+ kernel_size=2,
430
+ stride=2,
431
+ padding=0,
432
+ bias=True,
433
+ dilation=1,
434
+ groups=1,
435
+ ),
436
+ )
437
+ else:
438
+ pretrained.act_postprocess1 = nn.Sequential(
439
+ nn.Identity(), nn.Identity(), nn.Identity()
440
+ )
441
+ pretrained.act_postprocess2 = nn.Sequential(
442
+ nn.Identity(), nn.Identity(), nn.Identity()
443
+ )
444
+
445
+ pretrained.act_postprocess3 = nn.Sequential(
446
+ readout_oper[2],
447
+ Transpose(1, 2),
448
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
449
+ nn.Conv2d(
450
+ in_channels=vit_features,
451
+ out_channels=features[2],
452
+ kernel_size=1,
453
+ stride=1,
454
+ padding=0,
455
+ ),
456
+ )
457
+
458
+ pretrained.act_postprocess4 = nn.Sequential(
459
+ readout_oper[3],
460
+ Transpose(1, 2),
461
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
462
+ nn.Conv2d(
463
+ in_channels=vit_features,
464
+ out_channels=features[3],
465
+ kernel_size=1,
466
+ stride=1,
467
+ padding=0,
468
+ ),
469
+ nn.Conv2d(
470
+ in_channels=features[3],
471
+ out_channels=features[3],
472
+ kernel_size=3,
473
+ stride=2,
474
+ padding=1,
475
+ ),
476
+ )
477
+
478
+ pretrained.model.start_index = start_index
479
+ pretrained.model.patch_size = [16, 16]
480
+
481
+ # We inject this function into the VisionTransformer instances so that
482
+ # we can use it with interpolated position embeddings without modifying the library source.
483
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
484
+
485
+ # We inject this function into the VisionTransformer instances so that
486
+ # we can use it with interpolated position embeddings without modifying the library source.
487
+ pretrained.model._resize_pos_embed = types.MethodType(
488
+ _resize_pos_embed, pretrained.model
489
+ )
490
+
491
+ return pretrained
492
+
493
+
494
+ def _make_pretrained_vitb_rn50_384(
495
+ pretrained,
496
+ use_readout="ignore",
497
+ hooks=None,
498
+ use_vit_only=False,
499
+ enable_attention_hooks=False,
500
+ ):
501
+ model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
502
+
503
+ hooks = [0, 1, 8, 11] if hooks == None else hooks
504
+ return _make_vit_b_rn50_backbone(
505
+ model,
506
+ features=[256, 512, 768, 768],
507
+ size=[384, 384],
508
+ hooks=hooks,
509
+ use_vit_only=use_vit_only,
510
+ use_readout=use_readout,
511
+ enable_attention_hooks=enable_attention_hooks,
512
+ )
513
+
514
+
515
+ def _make_pretrained_vitl16_384(
516
+ pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
517
+ ):
518
+ model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
519
+
520
+ hooks = [5, 11, 17, 23] if hooks == None else hooks
521
+ return _make_vit_b16_backbone(
522
+ model,
523
+ features=[256, 512, 1024, 1024],
524
+ hooks=hooks,
525
+ vit_features=1024,
526
+ use_readout=use_readout,
527
+ enable_attention_hooks=enable_attention_hooks,
528
+ )
529
+
530
+
531
+ def _make_pretrained_vitb16_384(
532
+ pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
533
+ ):
534
+ model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
535
+
536
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
537
+ return _make_vit_b16_backbone(
538
+ model,
539
+ features=[96, 192, 384, 768],
540
+ hooks=hooks,
541
+ use_readout=use_readout,
542
+ enable_attention_hooks=enable_attention_hooks,
543
+ )
544
+
545
+
546
+ def _make_pretrained_deitb16_384(
547
+ pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
548
+ ):
549
+ model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
550
+
551
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
552
+ return _make_vit_b16_backbone(
553
+ model,
554
+ features=[96, 192, 384, 768],
555
+ hooks=hooks,
556
+ use_readout=use_readout,
557
+ enable_attention_hooks=enable_attention_hooks,
558
+ )
559
+
560
+
561
+ def _make_pretrained_deitb16_distil_384(
562
+ pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
563
+ ):
564
+ model = timm.create_model(
565
+ "vit_deit_base_distilled_patch16_384", pretrained=pretrained
566
+ )
567
+
568
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
569
+ return _make_vit_b16_backbone(
570
+ model,
571
+ features=[96, 192, 384, 768],
572
+ hooks=hooks,
573
+ use_readout=use_readout,
574
+ start_index=2,
575
+ enable_attention_hooks=enable_attention_hooks,
576
+ )
examples/1.jpg ADDED
examples/2.jpg ADDED

Git LFS Details

  • SHA256: 52f7466ef97c416d0f4539b58052ba6ef6370e3f6a0193ab1b609fe7ab9fc5a7
  • Pointer size: 132 Bytes
  • Size of remote file: 3.75 MB
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch==1.8.1
2
+ torchvision==0.9.1
3
+ opencv-python==4.5.2.54
4
+ timm==0.4.5
weights/.placeholder ADDED
File without changes
weights/dpt_hybrid-midas-501f0c75.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:501f0c75b3bca7daec6b3682c5054c09b366765aef6fa3a09d03a5cb4b230853
3
+ size 492757791