Vaishanth Ramaraj commited on
Commit
8166792
1 Parent(s): d55206a

initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +2 -0
  2. Readme.md +4 -0
  3. app.py +116 -0
  4. image_segmenter.py +91 -0
  5. midas/__init__.py +0 -0
  6. midas/__pycache__/__init__.cpython-38.pyc +0 -0
  7. midas/__pycache__/base_model.cpython-37.pyc +0 -0
  8. midas/__pycache__/base_model.cpython-38.pyc +0 -0
  9. midas/__pycache__/blocks.cpython-37.pyc +0 -0
  10. midas/__pycache__/blocks.cpython-38.pyc +0 -0
  11. midas/__pycache__/dpt_depth.cpython-37.pyc +0 -0
  12. midas/__pycache__/dpt_depth.cpython-38.pyc +0 -0
  13. midas/__pycache__/midas_net.cpython-37.pyc +0 -0
  14. midas/__pycache__/midas_net.cpython-38.pyc +0 -0
  15. midas/__pycache__/midas_net_custom.cpython-37.pyc +0 -0
  16. midas/__pycache__/midas_net_custom.cpython-38.pyc +0 -0
  17. midas/__pycache__/model_loader.cpython-37.pyc +0 -0
  18. midas/__pycache__/model_loader.cpython-38.pyc +0 -0
  19. midas/__pycache__/transforms.cpython-37.pyc +0 -0
  20. midas/__pycache__/transforms.cpython-38.pyc +0 -0
  21. midas/backbones/__pycache__/beit.cpython-37.pyc +0 -0
  22. midas/backbones/__pycache__/beit.cpython-38.pyc +0 -0
  23. midas/backbones/__pycache__/levit.cpython-37.pyc +0 -0
  24. midas/backbones/__pycache__/levit.cpython-38.pyc +0 -0
  25. midas/backbones/__pycache__/swin.cpython-37.pyc +0 -0
  26. midas/backbones/__pycache__/swin.cpython-38.pyc +0 -0
  27. midas/backbones/__pycache__/swin2.cpython-37.pyc +0 -0
  28. midas/backbones/__pycache__/swin2.cpython-38.pyc +0 -0
  29. midas/backbones/__pycache__/swin_common.cpython-37.pyc +0 -0
  30. midas/backbones/__pycache__/swin_common.cpython-38.pyc +0 -0
  31. midas/backbones/__pycache__/utils.cpython-37.pyc +0 -0
  32. midas/backbones/__pycache__/utils.cpython-38.pyc +0 -0
  33. midas/backbones/__pycache__/vit.cpython-37.pyc +0 -0
  34. midas/backbones/__pycache__/vit.cpython-38.pyc +0 -0
  35. midas/backbones/beit.py +196 -0
  36. midas/backbones/levit.py +106 -0
  37. midas/backbones/next_vit.py +39 -0
  38. midas/backbones/swin.py +13 -0
  39. midas/backbones/swin2.py +34 -0
  40. midas/backbones/swin_common.py +52 -0
  41. midas/backbones/utils.py +249 -0
  42. midas/backbones/vit.py +221 -0
  43. midas/base_model.py +16 -0
  44. midas/blocks.py +439 -0
  45. midas/dpt_depth.py +166 -0
  46. midas/midas_net.py +76 -0
  47. midas/midas_net_custom.py +128 -0
  48. midas/model_loader.py +242 -0
  49. midas/transforms.py +234 -0
  50. monocular_depth_estimator.py +175 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ flagged/
2
+ *.pt
Readme.md ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+
2
+
3
+ 1. Pip install Ultralytics: Yolov8 package
4
+ - pip install ultralytics
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ultralytics import YOLO
2
+ import cv2
3
+ import gradio as gr
4
+ import numpy as np
5
+ import os
6
+ import torch
7
+
8
+ from image_segmenter import ImageSegmenter
9
+ from monocular_depth_estimator import MonocularDepthEstimator
10
+
11
+ # params
12
+ CANCEL_PROCESSING = False
13
+
14
+ img_seg = ImageSegmenter(model_type='n')
15
+ depth_estimator = MonocularDepthEstimator(side_by_side=False)
16
+
17
+ def process_image(image):
18
+ return img_seg.predict(image), depth_estimator.make_prediction(image)
19
+
20
+ def process_video(vid_path=None):
21
+ vid_cap = cv2.VideoCapture(vid_path)
22
+ while vid_cap.isOpened():
23
+ ret, frame = vid_cap.read()
24
+ if ret:
25
+ print("making predictions ....")
26
+ yield cv2.cvtColor(img_seg.predict(frame), cv2.COLOR_BGR2RGB), depth_estimator.make_prediction(frame)
27
+
28
+ return None
29
+
30
+ def update_segmentation_options(options):
31
+ img_seg.is_show_bounding_boxes = True if 'Show Boundary Box' in options else False
32
+ img_seg.is_show_segmentation = True if 'Show Segmentation Region' in options else False
33
+ img_seg.is_show_segmentation_boundary = True if 'Show Segmentation Boundary' in options else False
34
+
35
+ def update_confidence_threshold(thres_val):
36
+ img_seg.confidence_threshold = thres_val/100
37
+
38
+ def cancel():
39
+ CANCEL_PROCESSING = True
40
+
41
+ if __name__ == "__main__":
42
+ # img_1 = cv2.imread("assets/images/bus.jpg")
43
+ # pred_img = image_segmentation(img_1)
44
+ # cv2.imshow("output", pred_img)
45
+ # cv2.waitKey(0)
46
+ # cv2.destroyAllWindows()
47
+
48
+ # gradio gui app
49
+ with gr.Blocks() as my_app:
50
+
51
+ # title
52
+ gr.Markdown(
53
+ """
54
+ # Object segmentation and depth estimation
55
+ Input an image or Video
56
+ """)
57
+
58
+ # tabs
59
+ with gr.Tab("Image"):
60
+ with gr.Row():
61
+ with gr.Column(scale=1):
62
+ img_input = gr.Image()
63
+ options_checkbox_img = gr.CheckboxGroup(["Show Boundary Box", "Show Segmentation Region", "Show Segmentation Boundary"], label="Options")
64
+ conf_thres_img = gr.Slider(1, 100, value=60, label="Confidence Threshold", info="Choose the threshold above which objects should be detected")
65
+ submit_btn_img = gr.Button(value="Predict")
66
+
67
+ with gr.Column(scale=2):
68
+ with gr.Row():
69
+ segmentation_img_output = gr.Image(height=300, label="Segmentation")
70
+ depth_img_output = gr.Image(height=300, label="Depth Estimation")
71
+
72
+ gr.Markdown("## Sample Images")
73
+ gr.Examples(
74
+ examples=[os.path.join(os.path.dirname(__file__), "assets/images/bus.jpg")],
75
+ inputs=img_input,
76
+ outputs=[segmentation_img_output, depth_img_output],
77
+ fn=process_image,
78
+ cache_examples=True,
79
+ )
80
+
81
+ with gr.Tab("Video"):
82
+ with gr.Row():
83
+ with gr.Column(scale=1):
84
+ vid_input = gr.Video()
85
+ options_checkbox_vid = gr.CheckboxGroup(["Show Boundary Box", "Show Segmentation Region", "Show Segmentation Boundary"], label="Options")
86
+ conf_thres_vid = gr.Slider(1, 100, value=60, label="Confidence Threshold", info="Choose the threshold above which objects should be detected")
87
+ with gr.Row():
88
+ cancel_btn = gr.Button(value="Cancel")
89
+ submit_btn_vid = gr.Button(value="Predict")
90
+
91
+ with gr.Column(scale=2):
92
+ with gr.Row():
93
+ segmentation_vid_output = gr.Image(height=400, label="Segmentation")
94
+ depth_vid_output = gr.Image(height=400, label="Depth Estimation")
95
+
96
+ gr.Markdown("## Sample Videos")
97
+ gr.Examples(
98
+ examples=[os.path.join(os.path.dirname(__file__), "assets/videos/input_video.mp4")],
99
+ inputs=vid_input,
100
+ # outputs=vid_output,
101
+ # fn=vid_segmenation,
102
+ )
103
+
104
+ # image tab logic
105
+ submit_btn_img.click(process_image, inputs=img_input, outputs=[segmentation_img_output, depth_img_output])
106
+ options_checkbox_img.change(update_segmentation_options, options_checkbox_img, [])
107
+ conf_thres_img.change(update_confidence_threshold, conf_thres_img, [])
108
+
109
+ # video tab logic
110
+ submit_btn_vid.click(process_video, inputs=vid_input, outputs=[segmentation_vid_output, depth_vid_output])
111
+ cancel_btn.click(cancel, inputs=[], outputs=[])
112
+ options_checkbox_vid.change(update_segmentation_options, options_checkbox_vid, [])
113
+ conf_thres_vid.change(update_confidence_threshold, conf_thres_vid, [])
114
+
115
+
116
+ my_app.queue(concurrency_count=5, max_size=20).launch()
image_segmenter.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from ultralytics import YOLO
4
+ from ultralytics.yolo.utils.ops import scale_image
5
+ import random
6
+ import torch
7
+
8
+ class ImageSegmenter:
9
+ def __init__(self, model_type="n") -> None:
10
+
11
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+ self.model = YOLO('models/yolov8'+ model_type +'-seg.pt')
13
+ self.model.to(self.device)
14
+
15
+ self.is_show_bounding_boxes = False
16
+ self.is_show_segmentation_boundary = False
17
+ self.is_show_segmentation = True
18
+ self.confidence_threshold = 0.5
19
+ self.cls_clr = {}
20
+
21
+ # params
22
+ self.bb_thickness = 2
23
+ self.bb_clr = (255, 0, 0)
24
+
25
+
26
+ def get_cls_clr(self, cls_id):
27
+ if cls_id in self.cls_clr:
28
+ return self.cls_clr[cls_id]
29
+
30
+ # gen rand color
31
+ r = random.randint(50, 200)
32
+ g = random.randint(50, 200)
33
+ b = random.randint(50, 200)
34
+ self.cls_clr[cls_id] = (r, g, b)
35
+ return (r, g, b)
36
+
37
+ def predict(self, image):
38
+ # resizing the image for faster prediction
39
+ image = cv2.resize(image, (480, 640))
40
+ predictions = self.model.predict(image)
41
+
42
+ cls_ids = predictions[0].boxes.cls.cpu().numpy()
43
+ bounding_boxes = predictions[0].boxes.xyxy.int().cpu().numpy()
44
+ cls_conf = predictions[0].boxes.conf.cpu().numpy()
45
+ # segmentation
46
+ if predictions[0].masks:
47
+ seg_mask_boundary = predictions[0].masks.xy
48
+ seg_mask = predictions[0].masks.data.cpu().numpy()
49
+ else:
50
+ seg_mask_boundary, seg_mask = [], np.array([])
51
+
52
+ for id, cls in enumerate(cls_ids):
53
+ cls_clr = self.get_cls_clr(cls)
54
+
55
+ # draw bounding box with class name and score
56
+ if self.is_show_bounding_boxes and cls_conf[id] > self.confidence_threshold:
57
+ (x1, y1, x2, y2) = bounding_boxes[id]
58
+ cls_name = self.model.names[cls]
59
+ cls_confidence = cls_conf[id]
60
+ disp_str = cls_name +' '+ str(round(cls_confidence, 2))
61
+ cv2.rectangle(image, (x1, y1), (x2, y2), cls_clr, self.bb_thickness)
62
+ cv2.rectangle(image, (x1, y1), (x1+(len(disp_str)*18), y1+45), cls_clr, -1)
63
+ cv2.putText(image, disp_str, (x1+10, y1+30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
64
+
65
+
66
+ # draw segmentation boundary
67
+ if len(seg_mask_boundary) and self.is_show_segmentation_boundary and cls_conf[id] > self.confidence_threshold:
68
+ cv2.polylines(image, [np.array(seg_mask_boundary[id], dtype=np.int32)], isClosed=True, color=cls_clr, thickness=2)
69
+
70
+ # draw filled segmentation region
71
+ if seg_mask.any() and self.is_show_segmentation and cls_conf[id] > self.confidence_threshold:
72
+ alpha = 0.8
73
+
74
+ # converting the mask from 1 channel to 3 channels
75
+ colored_mask = np.expand_dims(seg_mask[id], 0).repeat(3, axis=0)
76
+ colored_mask = np.moveaxis(colored_mask, 0, -1)
77
+
78
+ # Resize the mask to match the image size, if necessary
79
+ if image.shape[:2] != seg_mask[id].shape[:2]:
80
+ colored_mask = cv2.resize(colored_mask, (image.shape[1], image.shape[0]))
81
+
82
+ # filling the mased area with class color
83
+ masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=cls_clr)
84
+ image_overlay = masked.filled()
85
+ image = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0)
86
+
87
+ return image
88
+
89
+
90
+
91
+
midas/__init__.py ADDED
File without changes
midas/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (172 Bytes). View file
 
midas/__pycache__/base_model.cpython-37.pyc ADDED
Binary file (680 Bytes). View file
 
midas/__pycache__/base_model.cpython-38.pyc ADDED
Binary file (728 Bytes). View file
 
midas/__pycache__/blocks.cpython-37.pyc ADDED
Binary file (9.34 kB). View file
 
midas/__pycache__/blocks.cpython-38.pyc ADDED
Binary file (9.11 kB). View file
 
midas/__pycache__/dpt_depth.cpython-37.pyc ADDED
Binary file (4.07 kB). View file
 
midas/__pycache__/dpt_depth.cpython-38.pyc ADDED
Binary file (4.12 kB). View file
 
midas/__pycache__/midas_net.cpython-37.pyc ADDED
Binary file (2.57 kB). View file
 
midas/__pycache__/midas_net.cpython-38.pyc ADDED
Binary file (2.63 kB). View file
 
midas/__pycache__/midas_net_custom.cpython-37.pyc ADDED
Binary file (3.7 kB). View file
 
midas/__pycache__/midas_net_custom.cpython-38.pyc ADDED
Binary file (3.75 kB). View file
 
midas/__pycache__/model_loader.cpython-37.pyc ADDED
Binary file (4.9 kB). View file
 
midas/__pycache__/model_loader.cpython-38.pyc ADDED
Binary file (4.98 kB). View file
 
midas/__pycache__/transforms.cpython-37.pyc ADDED
Binary file (5.65 kB). View file
 
midas/__pycache__/transforms.cpython-38.pyc ADDED
Binary file (5.75 kB). View file
 
midas/backbones/__pycache__/beit.cpython-37.pyc ADDED
Binary file (5.57 kB). View file
 
midas/backbones/__pycache__/beit.cpython-38.pyc ADDED
Binary file (5.61 kB). View file
 
midas/backbones/__pycache__/levit.cpython-37.pyc ADDED
Binary file (3.38 kB). View file
 
midas/backbones/__pycache__/levit.cpython-38.pyc ADDED
Binary file (3.49 kB). View file
 
midas/backbones/__pycache__/swin.cpython-37.pyc ADDED
Binary file (522 Bytes). View file
 
midas/backbones/__pycache__/swin.cpython-38.pyc ADDED
Binary file (568 Bytes). View file
 
midas/backbones/__pycache__/swin2.cpython-37.pyc ADDED
Binary file (1.08 kB). View file
 
midas/backbones/__pycache__/swin2.cpython-38.pyc ADDED
Binary file (1.09 kB). View file
 
midas/backbones/__pycache__/swin_common.cpython-37.pyc ADDED
Binary file (1.35 kB). View file
 
midas/backbones/__pycache__/swin_common.cpython-38.pyc ADDED
Binary file (1.41 kB). View file
 
midas/backbones/__pycache__/utils.cpython-37.pyc ADDED
Binary file (5.9 kB). View file
 
midas/backbones/__pycache__/utils.cpython-38.pyc ADDED
Binary file (5.95 kB). View file
 
midas/backbones/__pycache__/vit.cpython-37.pyc ADDED
Binary file (4.56 kB). View file
 
midas/backbones/__pycache__/vit.cpython-38.pyc ADDED
Binary file (4.64 kB). View file
 
midas/backbones/beit.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timm
2
+ import torch
3
+ import types
4
+
5
+ import numpy as np
6
+ import torch.nn.functional as F
7
+
8
+ from .utils import forward_adapted_unflatten, make_backbone_default
9
+ from timm.models.beit import gen_relative_position_index
10
+ from torch.utils.checkpoint import checkpoint
11
+ from typing import Optional
12
+
13
+
14
+ def forward_beit(pretrained, x):
15
+ return forward_adapted_unflatten(pretrained, x, "forward_features")
16
+
17
+
18
+ def patch_embed_forward(self, x):
19
+ """
20
+ Modification of timm.models.layers.patch_embed.py: PatchEmbed.forward to support arbitrary window sizes.
21
+ """
22
+ x = self.proj(x)
23
+ if self.flatten:
24
+ x = x.flatten(2).transpose(1, 2)
25
+ x = self.norm(x)
26
+ return x
27
+
28
+
29
+ def _get_rel_pos_bias(self, window_size):
30
+ """
31
+ Modification of timm.models.beit.py: Attention._get_rel_pos_bias to support arbitrary window sizes.
32
+ """
33
+ old_height = 2 * self.window_size[0] - 1
34
+ old_width = 2 * self.window_size[1] - 1
35
+
36
+ new_height = 2 * window_size[0] - 1
37
+ new_width = 2 * window_size[1] - 1
38
+
39
+ old_relative_position_bias_table = self.relative_position_bias_table
40
+
41
+ old_num_relative_distance = self.num_relative_distance
42
+ new_num_relative_distance = new_height * new_width + 3
43
+
44
+ old_sub_table = old_relative_position_bias_table[:old_num_relative_distance - 3]
45
+
46
+ old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2)
47
+ new_sub_table = F.interpolate(old_sub_table, size=(new_height, new_width), mode="bilinear")
48
+ new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1)
49
+
50
+ new_relative_position_bias_table = torch.cat(
51
+ [new_sub_table, old_relative_position_bias_table[old_num_relative_distance - 3:]])
52
+
53
+ key = str(window_size[1]) + "," + str(window_size[0])
54
+ if key not in self.relative_position_indices.keys():
55
+ self.relative_position_indices[key] = gen_relative_position_index(window_size)
56
+
57
+ relative_position_bias = new_relative_position_bias_table[
58
+ self.relative_position_indices[key].view(-1)].view(
59
+ window_size[0] * window_size[1] + 1,
60
+ window_size[0] * window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
61
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
62
+ return relative_position_bias.unsqueeze(0)
63
+
64
+
65
+ def attention_forward(self, x, resolution, shared_rel_pos_bias: Optional[torch.Tensor] = None):
66
+ """
67
+ Modification of timm.models.beit.py: Attention.forward to support arbitrary window sizes.
68
+ """
69
+ B, N, C = x.shape
70
+
71
+ qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None
72
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
73
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
74
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
75
+
76
+ q = q * self.scale
77
+ attn = (q @ k.transpose(-2, -1))
78
+
79
+ if self.relative_position_bias_table is not None:
80
+ window_size = tuple(np.array(resolution) // 16)
81
+ attn = attn + self._get_rel_pos_bias(window_size)
82
+ if shared_rel_pos_bias is not None:
83
+ attn = attn + shared_rel_pos_bias
84
+
85
+ attn = attn.softmax(dim=-1)
86
+ attn = self.attn_drop(attn)
87
+
88
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
89
+ x = self.proj(x)
90
+ x = self.proj_drop(x)
91
+ return x
92
+
93
+
94
+ def block_forward(self, x, resolution, shared_rel_pos_bias: Optional[torch.Tensor] = None):
95
+ """
96
+ Modification of timm.models.beit.py: Block.forward to support arbitrary window sizes.
97
+ """
98
+ if self.gamma_1 is None:
99
+ x = x + self.drop_path(self.attn(self.norm1(x), resolution, shared_rel_pos_bias=shared_rel_pos_bias))
100
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
101
+ else:
102
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), resolution,
103
+ shared_rel_pos_bias=shared_rel_pos_bias))
104
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
105
+ return x
106
+
107
+
108
+ def beit_forward_features(self, x):
109
+ """
110
+ Modification of timm.models.beit.py: Beit.forward_features to support arbitrary window sizes.
111
+ """
112
+ resolution = x.shape[2:]
113
+
114
+ x = self.patch_embed(x)
115
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
116
+ if self.pos_embed is not None:
117
+ x = x + self.pos_embed
118
+ x = self.pos_drop(x)
119
+
120
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
121
+ for blk in self.blocks:
122
+ if self.grad_checkpointing and not torch.jit.is_scripting():
123
+ x = checkpoint(blk, x, shared_rel_pos_bias=rel_pos_bias)
124
+ else:
125
+ x = blk(x, resolution, shared_rel_pos_bias=rel_pos_bias)
126
+ x = self.norm(x)
127
+ return x
128
+
129
+
130
+ def _make_beit_backbone(
131
+ model,
132
+ features=[96, 192, 384, 768],
133
+ size=[384, 384],
134
+ hooks=[0, 4, 8, 11],
135
+ vit_features=768,
136
+ use_readout="ignore",
137
+ start_index=1,
138
+ start_index_readout=1,
139
+ ):
140
+ backbone = make_backbone_default(model, features, size, hooks, vit_features, use_readout, start_index,
141
+ start_index_readout)
142
+
143
+ backbone.model.patch_embed.forward = types.MethodType(patch_embed_forward, backbone.model.patch_embed)
144
+ backbone.model.forward_features = types.MethodType(beit_forward_features, backbone.model)
145
+
146
+ for block in backbone.model.blocks:
147
+ attn = block.attn
148
+ attn._get_rel_pos_bias = types.MethodType(_get_rel_pos_bias, attn)
149
+ attn.forward = types.MethodType(attention_forward, attn)
150
+ attn.relative_position_indices = {}
151
+
152
+ block.forward = types.MethodType(block_forward, block)
153
+
154
+ return backbone
155
+
156
+
157
+ def _make_pretrained_beitl16_512(pretrained, use_readout="ignore", hooks=None):
158
+ model = timm.create_model("beit_large_patch16_512", pretrained=pretrained)
159
+
160
+ hooks = [5, 11, 17, 23] if hooks is None else hooks
161
+
162
+ features = [256, 512, 1024, 1024]
163
+
164
+ return _make_beit_backbone(
165
+ model,
166
+ features=features,
167
+ size=[512, 512],
168
+ hooks=hooks,
169
+ vit_features=1024,
170
+ use_readout=use_readout,
171
+ )
172
+
173
+
174
+ def _make_pretrained_beitl16_384(pretrained, use_readout="ignore", hooks=None):
175
+ model = timm.create_model("beit_large_patch16_384", pretrained=pretrained)
176
+
177
+ hooks = [5, 11, 17, 23] if hooks is None else hooks
178
+ return _make_beit_backbone(
179
+ model,
180
+ features=[256, 512, 1024, 1024],
181
+ hooks=hooks,
182
+ vit_features=1024,
183
+ use_readout=use_readout,
184
+ )
185
+
186
+
187
+ def _make_pretrained_beitb16_384(pretrained, use_readout="ignore", hooks=None):
188
+ model = timm.create_model("beit_base_patch16_384", pretrained=pretrained)
189
+
190
+ hooks = [2, 5, 8, 11] if hooks is None else hooks
191
+ return _make_beit_backbone(
192
+ model,
193
+ features=[96, 192, 384, 768],
194
+ hooks=hooks,
195
+ use_readout=use_readout,
196
+ )
midas/backbones/levit.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timm
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+
6
+ from .utils import activations, get_activation, Transpose
7
+
8
+
9
+ def forward_levit(pretrained, x):
10
+ pretrained.model.forward_features(x)
11
+
12
+ layer_1 = pretrained.activations["1"]
13
+ layer_2 = pretrained.activations["2"]
14
+ layer_3 = pretrained.activations["3"]
15
+
16
+ layer_1 = pretrained.act_postprocess1(layer_1)
17
+ layer_2 = pretrained.act_postprocess2(layer_2)
18
+ layer_3 = pretrained.act_postprocess3(layer_3)
19
+
20
+ return layer_1, layer_2, layer_3
21
+
22
+
23
+ def _make_levit_backbone(
24
+ model,
25
+ hooks=[3, 11, 21],
26
+ patch_grid=[14, 14]
27
+ ):
28
+ pretrained = nn.Module()
29
+
30
+ pretrained.model = model
31
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
32
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
33
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
34
+
35
+ pretrained.activations = activations
36
+
37
+ patch_grid_size = np.array(patch_grid, dtype=int)
38
+
39
+ pretrained.act_postprocess1 = nn.Sequential(
40
+ Transpose(1, 2),
41
+ nn.Unflatten(2, torch.Size(patch_grid_size.tolist()))
42
+ )
43
+ pretrained.act_postprocess2 = nn.Sequential(
44
+ Transpose(1, 2),
45
+ nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 2).astype(int)).tolist()))
46
+ )
47
+ pretrained.act_postprocess3 = nn.Sequential(
48
+ Transpose(1, 2),
49
+ nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 4).astype(int)).tolist()))
50
+ )
51
+
52
+ return pretrained
53
+
54
+
55
+ class ConvTransposeNorm(nn.Sequential):
56
+ """
57
+ Modification of
58
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: ConvNorm
59
+ such that ConvTranspose2d is used instead of Conv2d.
60
+ """
61
+
62
+ def __init__(
63
+ self, in_chs, out_chs, kernel_size=1, stride=1, pad=0, dilation=1,
64
+ groups=1, bn_weight_init=1):
65
+ super().__init__()
66
+ self.add_module('c',
67
+ nn.ConvTranspose2d(in_chs, out_chs, kernel_size, stride, pad, dilation, groups, bias=False))
68
+ self.add_module('bn', nn.BatchNorm2d(out_chs))
69
+
70
+ nn.init.constant_(self.bn.weight, bn_weight_init)
71
+
72
+ @torch.no_grad()
73
+ def fuse(self):
74
+ c, bn = self._modules.values()
75
+ w = bn.weight / (bn.running_var + bn.eps) ** 0.5
76
+ w = c.weight * w[:, None, None, None]
77
+ b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
78
+ m = nn.ConvTranspose2d(
79
+ w.size(1), w.size(0), w.shape[2:], stride=self.c.stride,
80
+ padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
81
+ m.weight.data.copy_(w)
82
+ m.bias.data.copy_(b)
83
+ return m
84
+
85
+
86
+ def stem_b4_transpose(in_chs, out_chs, activation):
87
+ """
88
+ Modification of
89
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: stem_b16
90
+ such that ConvTranspose2d is used instead of Conv2d and stem is also reduced to the half.
91
+ """
92
+ return nn.Sequential(
93
+ ConvTransposeNorm(in_chs, out_chs, 3, 2, 1),
94
+ activation(),
95
+ ConvTransposeNorm(out_chs, out_chs // 2, 3, 2, 1),
96
+ activation())
97
+
98
+
99
+ def _make_pretrained_levit_384(pretrained, hooks=None):
100
+ model = timm.create_model("levit_384", pretrained=pretrained)
101
+
102
+ hooks = [3, 11, 21] if hooks == None else hooks
103
+ return _make_levit_backbone(
104
+ model,
105
+ hooks=hooks
106
+ )
midas/backbones/next_vit.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timm
2
+
3
+ import torch.nn as nn
4
+
5
+ from pathlib import Path
6
+ from .utils import activations, forward_default, get_activation
7
+
8
+ from ..external.next_vit.classification.nextvit import *
9
+
10
+
11
+ def forward_next_vit(pretrained, x):
12
+ return forward_default(pretrained, x, "forward")
13
+
14
+
15
+ def _make_next_vit_backbone(
16
+ model,
17
+ hooks=[2, 6, 36, 39],
18
+ ):
19
+ pretrained = nn.Module()
20
+
21
+ pretrained.model = model
22
+ pretrained.model.features[hooks[0]].register_forward_hook(get_activation("1"))
23
+ pretrained.model.features[hooks[1]].register_forward_hook(get_activation("2"))
24
+ pretrained.model.features[hooks[2]].register_forward_hook(get_activation("3"))
25
+ pretrained.model.features[hooks[3]].register_forward_hook(get_activation("4"))
26
+
27
+ pretrained.activations = activations
28
+
29
+ return pretrained
30
+
31
+
32
+ def _make_pretrained_next_vit_large_6m(hooks=None):
33
+ model = timm.create_model("nextvit_large")
34
+
35
+ hooks = [2, 6, 36, 39] if hooks == None else hooks
36
+ return _make_next_vit_backbone(
37
+ model,
38
+ hooks=hooks,
39
+ )
midas/backbones/swin.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timm
2
+
3
+ from .swin_common import _make_swin_backbone
4
+
5
+
6
+ def _make_pretrained_swinl12_384(pretrained, hooks=None):
7
+ model = timm.create_model("swin_large_patch4_window12_384", pretrained=pretrained)
8
+
9
+ hooks = [1, 1, 17, 1] if hooks == None else hooks
10
+ return _make_swin_backbone(
11
+ model,
12
+ hooks=hooks
13
+ )
midas/backbones/swin2.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timm
2
+
3
+ from .swin_common import _make_swin_backbone
4
+
5
+
6
+ def _make_pretrained_swin2l24_384(pretrained, hooks=None):
7
+ model = timm.create_model("swinv2_large_window12to24_192to384_22kft1k", pretrained=pretrained)
8
+
9
+ hooks = [1, 1, 17, 1] if hooks == None else hooks
10
+ return _make_swin_backbone(
11
+ model,
12
+ hooks=hooks
13
+ )
14
+
15
+
16
+ def _make_pretrained_swin2b24_384(pretrained, hooks=None):
17
+ model = timm.create_model("swinv2_base_window12to24_192to384_22kft1k", pretrained=pretrained)
18
+
19
+ hooks = [1, 1, 17, 1] if hooks == None else hooks
20
+ return _make_swin_backbone(
21
+ model,
22
+ hooks=hooks
23
+ )
24
+
25
+
26
+ def _make_pretrained_swin2t16_256(pretrained, hooks=None):
27
+ model = timm.create_model("swinv2_tiny_window16_256", pretrained=pretrained)
28
+
29
+ hooks = [1, 1, 5, 1] if hooks == None else hooks
30
+ return _make_swin_backbone(
31
+ model,
32
+ hooks=hooks,
33
+ patch_grid=[64, 64]
34
+ )
midas/backbones/swin_common.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import torch.nn as nn
4
+ import numpy as np
5
+
6
+ from .utils import activations, forward_default, get_activation, Transpose
7
+
8
+
9
+ def forward_swin(pretrained, x):
10
+ return forward_default(pretrained, x)
11
+
12
+
13
+ def _make_swin_backbone(
14
+ model,
15
+ hooks=[1, 1, 17, 1],
16
+ patch_grid=[96, 96]
17
+ ):
18
+ pretrained = nn.Module()
19
+
20
+ pretrained.model = model
21
+ pretrained.model.layers[0].blocks[hooks[0]].register_forward_hook(get_activation("1"))
22
+ pretrained.model.layers[1].blocks[hooks[1]].register_forward_hook(get_activation("2"))
23
+ pretrained.model.layers[2].blocks[hooks[2]].register_forward_hook(get_activation("3"))
24
+ pretrained.model.layers[3].blocks[hooks[3]].register_forward_hook(get_activation("4"))
25
+
26
+ pretrained.activations = activations
27
+
28
+ if hasattr(model, "patch_grid"):
29
+ used_patch_grid = model.patch_grid
30
+ else:
31
+ used_patch_grid = patch_grid
32
+
33
+ patch_grid_size = np.array(used_patch_grid, dtype=int)
34
+
35
+ pretrained.act_postprocess1 = nn.Sequential(
36
+ Transpose(1, 2),
37
+ nn.Unflatten(2, torch.Size(patch_grid_size.tolist()))
38
+ )
39
+ pretrained.act_postprocess2 = nn.Sequential(
40
+ Transpose(1, 2),
41
+ nn.Unflatten(2, torch.Size((patch_grid_size // 2).tolist()))
42
+ )
43
+ pretrained.act_postprocess3 = nn.Sequential(
44
+ Transpose(1, 2),
45
+ nn.Unflatten(2, torch.Size((patch_grid_size // 4).tolist()))
46
+ )
47
+ pretrained.act_postprocess4 = nn.Sequential(
48
+ Transpose(1, 2),
49
+ nn.Unflatten(2, torch.Size((patch_grid_size // 8).tolist()))
50
+ )
51
+
52
+ return pretrained
midas/backbones/utils.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import torch.nn as nn
4
+
5
+
6
+ class Slice(nn.Module):
7
+ def __init__(self, start_index=1):
8
+ super(Slice, self).__init__()
9
+ self.start_index = start_index
10
+
11
+ def forward(self, x):
12
+ return x[:, self.start_index:]
13
+
14
+
15
+ class AddReadout(nn.Module):
16
+ def __init__(self, start_index=1):
17
+ super(AddReadout, self).__init__()
18
+ self.start_index = start_index
19
+
20
+ def forward(self, x):
21
+ if self.start_index == 2:
22
+ readout = (x[:, 0] + x[:, 1]) / 2
23
+ else:
24
+ readout = x[:, 0]
25
+ return x[:, self.start_index:] + readout.unsqueeze(1)
26
+
27
+
28
+ class ProjectReadout(nn.Module):
29
+ def __init__(self, in_features, start_index=1):
30
+ super(ProjectReadout, self).__init__()
31
+ self.start_index = start_index
32
+
33
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
34
+
35
+ def forward(self, x):
36
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index:])
37
+ features = torch.cat((x[:, self.start_index:], readout), -1)
38
+
39
+ return self.project(features)
40
+
41
+
42
+ class Transpose(nn.Module):
43
+ def __init__(self, dim0, dim1):
44
+ super(Transpose, self).__init__()
45
+ self.dim0 = dim0
46
+ self.dim1 = dim1
47
+
48
+ def forward(self, x):
49
+ x = x.transpose(self.dim0, self.dim1)
50
+ return x
51
+
52
+
53
+ activations = {}
54
+
55
+
56
+ def get_activation(name):
57
+ def hook(model, input, output):
58
+ activations[name] = output
59
+
60
+ return hook
61
+
62
+
63
+ def forward_default(pretrained, x, function_name="forward_features"):
64
+ exec(f"pretrained.model.{function_name}(x)")
65
+
66
+ layer_1 = pretrained.activations["1"]
67
+ layer_2 = pretrained.activations["2"]
68
+ layer_3 = pretrained.activations["3"]
69
+ layer_4 = pretrained.activations["4"]
70
+
71
+ if hasattr(pretrained, "act_postprocess1"):
72
+ layer_1 = pretrained.act_postprocess1(layer_1)
73
+ if hasattr(pretrained, "act_postprocess2"):
74
+ layer_2 = pretrained.act_postprocess2(layer_2)
75
+ if hasattr(pretrained, "act_postprocess3"):
76
+ layer_3 = pretrained.act_postprocess3(layer_3)
77
+ if hasattr(pretrained, "act_postprocess4"):
78
+ layer_4 = pretrained.act_postprocess4(layer_4)
79
+
80
+ return layer_1, layer_2, layer_3, layer_4
81
+
82
+
83
+ def forward_adapted_unflatten(pretrained, x, function_name="forward_features"):
84
+ b, c, h, w = x.shape
85
+
86
+ exec(f"glob = pretrained.model.{function_name}(x)")
87
+
88
+ layer_1 = pretrained.activations["1"]
89
+ layer_2 = pretrained.activations["2"]
90
+ layer_3 = pretrained.activations["3"]
91
+ layer_4 = pretrained.activations["4"]
92
+
93
+ layer_1 = pretrained.act_postprocess1[0:2](layer_1)
94
+ layer_2 = pretrained.act_postprocess2[0:2](layer_2)
95
+ layer_3 = pretrained.act_postprocess3[0:2](layer_3)
96
+ layer_4 = pretrained.act_postprocess4[0:2](layer_4)
97
+
98
+ unflatten = nn.Sequential(
99
+ nn.Unflatten(
100
+ 2,
101
+ torch.Size(
102
+ [
103
+ h // pretrained.model.patch_size[1],
104
+ w // pretrained.model.patch_size[0],
105
+ ]
106
+ ),
107
+ )
108
+ )
109
+
110
+ if layer_1.ndim == 3:
111
+ layer_1 = unflatten(layer_1)
112
+ if layer_2.ndim == 3:
113
+ layer_2 = unflatten(layer_2)
114
+ if layer_3.ndim == 3:
115
+ layer_3 = unflatten(layer_3)
116
+ if layer_4.ndim == 3:
117
+ layer_4 = unflatten(layer_4)
118
+
119
+ layer_1 = pretrained.act_postprocess1[3: len(pretrained.act_postprocess1)](layer_1)
120
+ layer_2 = pretrained.act_postprocess2[3: len(pretrained.act_postprocess2)](layer_2)
121
+ layer_3 = pretrained.act_postprocess3[3: len(pretrained.act_postprocess3)](layer_3)
122
+ layer_4 = pretrained.act_postprocess4[3: len(pretrained.act_postprocess4)](layer_4)
123
+
124
+ return layer_1, layer_2, layer_3, layer_4
125
+
126
+
127
+ def get_readout_oper(vit_features, features, use_readout, start_index=1):
128
+ if use_readout == "ignore":
129
+ readout_oper = [Slice(start_index)] * len(features)
130
+ elif use_readout == "add":
131
+ readout_oper = [AddReadout(start_index)] * len(features)
132
+ elif use_readout == "project":
133
+ readout_oper = [
134
+ ProjectReadout(vit_features, start_index) for out_feat in features
135
+ ]
136
+ else:
137
+ assert (
138
+ False
139
+ ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
140
+
141
+ return readout_oper
142
+
143
+
144
+ def make_backbone_default(
145
+ model,
146
+ features=[96, 192, 384, 768],
147
+ size=[384, 384],
148
+ hooks=[2, 5, 8, 11],
149
+ vit_features=768,
150
+ use_readout="ignore",
151
+ start_index=1,
152
+ start_index_readout=1,
153
+ ):
154
+ pretrained = nn.Module()
155
+
156
+ pretrained.model = model
157
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
158
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
159
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
160
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
161
+
162
+ pretrained.activations = activations
163
+
164
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index_readout)
165
+
166
+ # 32, 48, 136, 384
167
+ pretrained.act_postprocess1 = nn.Sequential(
168
+ readout_oper[0],
169
+ Transpose(1, 2),
170
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
171
+ nn.Conv2d(
172
+ in_channels=vit_features,
173
+ out_channels=features[0],
174
+ kernel_size=1,
175
+ stride=1,
176
+ padding=0,
177
+ ),
178
+ nn.ConvTranspose2d(
179
+ in_channels=features[0],
180
+ out_channels=features[0],
181
+ kernel_size=4,
182
+ stride=4,
183
+ padding=0,
184
+ bias=True,
185
+ dilation=1,
186
+ groups=1,
187
+ ),
188
+ )
189
+
190
+ pretrained.act_postprocess2 = nn.Sequential(
191
+ readout_oper[1],
192
+ Transpose(1, 2),
193
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
194
+ nn.Conv2d(
195
+ in_channels=vit_features,
196
+ out_channels=features[1],
197
+ kernel_size=1,
198
+ stride=1,
199
+ padding=0,
200
+ ),
201
+ nn.ConvTranspose2d(
202
+ in_channels=features[1],
203
+ out_channels=features[1],
204
+ kernel_size=2,
205
+ stride=2,
206
+ padding=0,
207
+ bias=True,
208
+ dilation=1,
209
+ groups=1,
210
+ ),
211
+ )
212
+
213
+ pretrained.act_postprocess3 = nn.Sequential(
214
+ readout_oper[2],
215
+ Transpose(1, 2),
216
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
217
+ nn.Conv2d(
218
+ in_channels=vit_features,
219
+ out_channels=features[2],
220
+ kernel_size=1,
221
+ stride=1,
222
+ padding=0,
223
+ ),
224
+ )
225
+
226
+ pretrained.act_postprocess4 = nn.Sequential(
227
+ readout_oper[3],
228
+ Transpose(1, 2),
229
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
230
+ nn.Conv2d(
231
+ in_channels=vit_features,
232
+ out_channels=features[3],
233
+ kernel_size=1,
234
+ stride=1,
235
+ padding=0,
236
+ ),
237
+ nn.Conv2d(
238
+ in_channels=features[3],
239
+ out_channels=features[3],
240
+ kernel_size=3,
241
+ stride=2,
242
+ padding=1,
243
+ ),
244
+ )
245
+
246
+ pretrained.model.start_index = start_index
247
+ pretrained.model.patch_size = [16, 16]
248
+
249
+ return pretrained
midas/backbones/vit.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import timm
4
+ import types
5
+ import math
6
+ import torch.nn.functional as F
7
+
8
+ from .utils import (activations, forward_adapted_unflatten, get_activation, get_readout_oper,
9
+ make_backbone_default, Transpose)
10
+
11
+
12
+ def forward_vit(pretrained, x):
13
+ return forward_adapted_unflatten(pretrained, x, "forward_flex")
14
+
15
+
16
+ def _resize_pos_embed(self, posemb, gs_h, gs_w):
17
+ posemb_tok, posemb_grid = (
18
+ posemb[:, : self.start_index],
19
+ posemb[0, self.start_index:],
20
+ )
21
+
22
+ gs_old = int(math.sqrt(len(posemb_grid)))
23
+
24
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
25
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
26
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
27
+
28
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
29
+
30
+ return posemb
31
+
32
+
33
+ def forward_flex(self, x):
34
+ b, c, h, w = x.shape
35
+
36
+ pos_embed = self._resize_pos_embed(
37
+ self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
38
+ )
39
+
40
+ B = x.shape[0]
41
+
42
+ if hasattr(self.patch_embed, "backbone"):
43
+ x = self.patch_embed.backbone(x)
44
+ if isinstance(x, (list, tuple)):
45
+ x = x[-1] # last feature if backbone outputs list/tuple of features
46
+
47
+ x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
48
+
49
+ if getattr(self, "dist_token", None) is not None:
50
+ cls_tokens = self.cls_token.expand(
51
+ B, -1, -1
52
+ ) # stole cls_tokens impl from Phil Wang, thanks
53
+ dist_token = self.dist_token.expand(B, -1, -1)
54
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
55
+ else:
56
+ if self.no_embed_class:
57
+ x = x + pos_embed
58
+ cls_tokens = self.cls_token.expand(
59
+ B, -1, -1
60
+ ) # stole cls_tokens impl from Phil Wang, thanks
61
+ x = torch.cat((cls_tokens, x), dim=1)
62
+
63
+ if not self.no_embed_class:
64
+ x = x + pos_embed
65
+ x = self.pos_drop(x)
66
+
67
+ for blk in self.blocks:
68
+ x = blk(x)
69
+
70
+ x = self.norm(x)
71
+
72
+ return x
73
+
74
+
75
+ def _make_vit_b16_backbone(
76
+ model,
77
+ features=[96, 192, 384, 768],
78
+ size=[384, 384],
79
+ hooks=[2, 5, 8, 11],
80
+ vit_features=768,
81
+ use_readout="ignore",
82
+ start_index=1,
83
+ start_index_readout=1,
84
+ ):
85
+ pretrained = make_backbone_default(model, features, size, hooks, vit_features, use_readout, start_index,
86
+ start_index_readout)
87
+
88
+ # We inject this function into the VisionTransformer instances so that
89
+ # we can use it with interpolated position embeddings without modifying the library source.
90
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
91
+ pretrained.model._resize_pos_embed = types.MethodType(
92
+ _resize_pos_embed, pretrained.model
93
+ )
94
+
95
+ return pretrained
96
+
97
+
98
+ def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
99
+ model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
100
+
101
+ hooks = [5, 11, 17, 23] if hooks == None else hooks
102
+ return _make_vit_b16_backbone(
103
+ model,
104
+ features=[256, 512, 1024, 1024],
105
+ hooks=hooks,
106
+ vit_features=1024,
107
+ use_readout=use_readout,
108
+ )
109
+
110
+
111
+ def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
112
+ model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
113
+
114
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
115
+ return _make_vit_b16_backbone(
116
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
117
+ )
118
+
119
+
120
+ def _make_vit_b_rn50_backbone(
121
+ model,
122
+ features=[256, 512, 768, 768],
123
+ size=[384, 384],
124
+ hooks=[0, 1, 8, 11],
125
+ vit_features=768,
126
+ patch_size=[16, 16],
127
+ number_stages=2,
128
+ use_vit_only=False,
129
+ use_readout="ignore",
130
+ start_index=1,
131
+ ):
132
+ pretrained = nn.Module()
133
+
134
+ pretrained.model = model
135
+
136
+ used_number_stages = 0 if use_vit_only else number_stages
137
+ for s in range(used_number_stages):
138
+ pretrained.model.patch_embed.backbone.stages[s].register_forward_hook(
139
+ get_activation(str(s + 1))
140
+ )
141
+ for s in range(used_number_stages, 4):
142
+ pretrained.model.blocks[hooks[s]].register_forward_hook(get_activation(str(s + 1)))
143
+
144
+ pretrained.activations = activations
145
+
146
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
147
+
148
+ for s in range(used_number_stages):
149
+ value = nn.Sequential(nn.Identity(), nn.Identity(), nn.Identity())
150
+ exec(f"pretrained.act_postprocess{s + 1}=value")
151
+ for s in range(used_number_stages, 4):
152
+ if s < number_stages:
153
+ final_layer = nn.ConvTranspose2d(
154
+ in_channels=features[s],
155
+ out_channels=features[s],
156
+ kernel_size=4 // (2 ** s),
157
+ stride=4 // (2 ** s),
158
+ padding=0,
159
+ bias=True,
160
+ dilation=1,
161
+ groups=1,
162
+ )
163
+ elif s > number_stages:
164
+ final_layer = nn.Conv2d(
165
+ in_channels=features[3],
166
+ out_channels=features[3],
167
+ kernel_size=3,
168
+ stride=2,
169
+ padding=1,
170
+ )
171
+ else:
172
+ final_layer = None
173
+
174
+ layers = [
175
+ readout_oper[s],
176
+ Transpose(1, 2),
177
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
178
+ nn.Conv2d(
179
+ in_channels=vit_features,
180
+ out_channels=features[s],
181
+ kernel_size=1,
182
+ stride=1,
183
+ padding=0,
184
+ ),
185
+ ]
186
+ if final_layer is not None:
187
+ layers.append(final_layer)
188
+
189
+ value = nn.Sequential(*layers)
190
+ exec(f"pretrained.act_postprocess{s + 1}=value")
191
+
192
+ pretrained.model.start_index = start_index
193
+ pretrained.model.patch_size = patch_size
194
+
195
+ # We inject this function into the VisionTransformer instances so that
196
+ # we can use it with interpolated position embeddings without modifying the library source.
197
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
198
+
199
+ # We inject this function into the VisionTransformer instances so that
200
+ # we can use it with interpolated position embeddings without modifying the library source.
201
+ pretrained.model._resize_pos_embed = types.MethodType(
202
+ _resize_pos_embed, pretrained.model
203
+ )
204
+
205
+ return pretrained
206
+
207
+
208
+ def _make_pretrained_vitb_rn50_384(
209
+ pretrained, use_readout="ignore", hooks=None, use_vit_only=False
210
+ ):
211
+ model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
212
+
213
+ hooks = [0, 1, 8, 11] if hooks == None else hooks
214
+ return _make_vit_b_rn50_backbone(
215
+ model,
216
+ features=[256, 512, 768, 768],
217
+ size=[384, 384],
218
+ hooks=hooks,
219
+ use_vit_only=use_vit_only,
220
+ use_readout=use_readout,
221
+ )
midas/base_model.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class BaseModel(torch.nn.Module):
5
+ def load(self, path):
6
+ """Load model from file.
7
+
8
+ Args:
9
+ path (str): file path
10
+ """
11
+ parameters = torch.load(path, map_location=torch.device('cpu'))
12
+
13
+ if "optimizer" in parameters:
14
+ parameters = parameters["model"]
15
+
16
+ self.load_state_dict(parameters)
midas/blocks.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .backbones.beit import (
5
+ _make_pretrained_beitl16_512,
6
+ _make_pretrained_beitl16_384,
7
+ _make_pretrained_beitb16_384,
8
+ forward_beit,
9
+ )
10
+ from .backbones.swin_common import (
11
+ forward_swin,
12
+ )
13
+ from .backbones.swin2 import (
14
+ _make_pretrained_swin2l24_384,
15
+ _make_pretrained_swin2b24_384,
16
+ _make_pretrained_swin2t16_256,
17
+ )
18
+ from .backbones.swin import (
19
+ _make_pretrained_swinl12_384,
20
+ )
21
+ from .backbones.levit import (
22
+ _make_pretrained_levit_384,
23
+ forward_levit,
24
+ )
25
+ from .backbones.vit import (
26
+ _make_pretrained_vitb_rn50_384,
27
+ _make_pretrained_vitl16_384,
28
+ _make_pretrained_vitb16_384,
29
+ forward_vit,
30
+ )
31
+
32
+ def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None,
33
+ use_vit_only=False, use_readout="ignore", in_features=[96, 256, 512, 1024]):
34
+ if backbone == "beitl16_512":
35
+ pretrained = _make_pretrained_beitl16_512(
36
+ use_pretrained, hooks=hooks, use_readout=use_readout
37
+ )
38
+ scratch = _make_scratch(
39
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
40
+ ) # BEiT_512-L (backbone)
41
+ elif backbone == "beitl16_384":
42
+ pretrained = _make_pretrained_beitl16_384(
43
+ use_pretrained, hooks=hooks, use_readout=use_readout
44
+ )
45
+ scratch = _make_scratch(
46
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
47
+ ) # BEiT_384-L (backbone)
48
+ elif backbone == "beitb16_384":
49
+ pretrained = _make_pretrained_beitb16_384(
50
+ use_pretrained, hooks=hooks, use_readout=use_readout
51
+ )
52
+ scratch = _make_scratch(
53
+ [96, 192, 384, 768], features, groups=groups, expand=expand
54
+ ) # BEiT_384-B (backbone)
55
+ elif backbone == "swin2l24_384":
56
+ pretrained = _make_pretrained_swin2l24_384(
57
+ use_pretrained, hooks=hooks
58
+ )
59
+ scratch = _make_scratch(
60
+ [192, 384, 768, 1536], features, groups=groups, expand=expand
61
+ ) # Swin2-L/12to24 (backbone)
62
+ elif backbone == "swin2b24_384":
63
+ pretrained = _make_pretrained_swin2b24_384(
64
+ use_pretrained, hooks=hooks
65
+ )
66
+ scratch = _make_scratch(
67
+ [128, 256, 512, 1024], features, groups=groups, expand=expand
68
+ ) # Swin2-B/12to24 (backbone)
69
+ elif backbone == "swin2t16_256":
70
+ pretrained = _make_pretrained_swin2t16_256(
71
+ use_pretrained, hooks=hooks
72
+ )
73
+ scratch = _make_scratch(
74
+ [96, 192, 384, 768], features, groups=groups, expand=expand
75
+ ) # Swin2-T/16 (backbone)
76
+ elif backbone == "swinl12_384":
77
+ pretrained = _make_pretrained_swinl12_384(
78
+ use_pretrained, hooks=hooks
79
+ )
80
+ scratch = _make_scratch(
81
+ [192, 384, 768, 1536], features, groups=groups, expand=expand
82
+ ) # Swin-L/12 (backbone)
83
+ elif backbone == "next_vit_large_6m":
84
+ from .backbones.next_vit import _make_pretrained_next_vit_large_6m
85
+ pretrained = _make_pretrained_next_vit_large_6m(hooks=hooks)
86
+ scratch = _make_scratch(
87
+ in_features, features, groups=groups, expand=expand
88
+ ) # Next-ViT-L on ImageNet-1K-6M (backbone)
89
+ elif backbone == "levit_384":
90
+ pretrained = _make_pretrained_levit_384(
91
+ use_pretrained, hooks=hooks
92
+ )
93
+ scratch = _make_scratch(
94
+ [384, 512, 768], features, groups=groups, expand=expand
95
+ ) # LeViT 384 (backbone)
96
+ elif backbone == "vitl16_384":
97
+ pretrained = _make_pretrained_vitl16_384(
98
+ use_pretrained, hooks=hooks, use_readout=use_readout
99
+ )
100
+ scratch = _make_scratch(
101
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
102
+ ) # ViT-L/16 - 85.0% Top1 (backbone)
103
+ elif backbone == "vitb_rn50_384":
104
+ pretrained = _make_pretrained_vitb_rn50_384(
105
+ use_pretrained,
106
+ hooks=hooks,
107
+ use_vit_only=use_vit_only,
108
+ use_readout=use_readout,
109
+ )
110
+ scratch = _make_scratch(
111
+ [256, 512, 768, 768], features, groups=groups, expand=expand
112
+ ) # ViT-H/16 - 85.0% Top1 (backbone)
113
+ elif backbone == "vitb16_384":
114
+ pretrained = _make_pretrained_vitb16_384(
115
+ use_pretrained, hooks=hooks, use_readout=use_readout
116
+ )
117
+ scratch = _make_scratch(
118
+ [96, 192, 384, 768], features, groups=groups, expand=expand
119
+ ) # ViT-B/16 - 84.6% Top1 (backbone)
120
+ elif backbone == "resnext101_wsl":
121
+ pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
122
+ scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
123
+ elif backbone == "efficientnet_lite3":
124
+ pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
125
+ scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
126
+ else:
127
+ print(f"Backbone '{backbone}' not implemented")
128
+ assert False
129
+
130
+ return pretrained, scratch
131
+
132
+
133
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
134
+ scratch = nn.Module()
135
+
136
+ out_shape1 = out_shape
137
+ out_shape2 = out_shape
138
+ out_shape3 = out_shape
139
+ if len(in_shape) >= 4:
140
+ out_shape4 = out_shape
141
+
142
+ if expand:
143
+ out_shape1 = out_shape
144
+ out_shape2 = out_shape*2
145
+ out_shape3 = out_shape*4
146
+ if len(in_shape) >= 4:
147
+ out_shape4 = out_shape*8
148
+
149
+ scratch.layer1_rn = nn.Conv2d(
150
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
151
+ )
152
+ scratch.layer2_rn = nn.Conv2d(
153
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
154
+ )
155
+ scratch.layer3_rn = nn.Conv2d(
156
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
157
+ )
158
+ if len(in_shape) >= 4:
159
+ scratch.layer4_rn = nn.Conv2d(
160
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
161
+ )
162
+
163
+ return scratch
164
+
165
+
166
+ def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
167
+ efficientnet = torch.hub.load(
168
+ "rwightman/gen-efficientnet-pytorch",
169
+ "tf_efficientnet_lite3",
170
+ pretrained=use_pretrained,
171
+ exportable=exportable
172
+ )
173
+ return _make_efficientnet_backbone(efficientnet)
174
+
175
+
176
+ def _make_efficientnet_backbone(effnet):
177
+ pretrained = nn.Module()
178
+
179
+ pretrained.layer1 = nn.Sequential(
180
+ effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
181
+ )
182
+ pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
183
+ pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
184
+ pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
185
+
186
+ return pretrained
187
+
188
+
189
+ def _make_resnet_backbone(resnet):
190
+ pretrained = nn.Module()
191
+ pretrained.layer1 = nn.Sequential(
192
+ resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
193
+ )
194
+
195
+ pretrained.layer2 = resnet.layer2
196
+ pretrained.layer3 = resnet.layer3
197
+ pretrained.layer4 = resnet.layer4
198
+
199
+ return pretrained
200
+
201
+
202
+ def _make_pretrained_resnext101_wsl(use_pretrained):
203
+ resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
204
+ return _make_resnet_backbone(resnet)
205
+
206
+
207
+
208
+ class Interpolate(nn.Module):
209
+ """Interpolation module.
210
+ """
211
+
212
+ def __init__(self, scale_factor, mode, align_corners=False):
213
+ """Init.
214
+
215
+ Args:
216
+ scale_factor (float): scaling
217
+ mode (str): interpolation mode
218
+ """
219
+ super(Interpolate, self).__init__()
220
+
221
+ self.interp = nn.functional.interpolate
222
+ self.scale_factor = scale_factor
223
+ self.mode = mode
224
+ self.align_corners = align_corners
225
+
226
+ def forward(self, x):
227
+ """Forward pass.
228
+
229
+ Args:
230
+ x (tensor): input
231
+
232
+ Returns:
233
+ tensor: interpolated data
234
+ """
235
+
236
+ x = self.interp(
237
+ x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
238
+ )
239
+
240
+ return x
241
+
242
+
243
+ class ResidualConvUnit(nn.Module):
244
+ """Residual convolution module.
245
+ """
246
+
247
+ def __init__(self, features):
248
+ """Init.
249
+
250
+ Args:
251
+ features (int): number of features
252
+ """
253
+ super().__init__()
254
+
255
+ self.conv1 = nn.Conv2d(
256
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
257
+ )
258
+
259
+ self.conv2 = nn.Conv2d(
260
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
261
+ )
262
+
263
+ self.relu = nn.ReLU(inplace=True)
264
+
265
+ def forward(self, x):
266
+ """Forward pass.
267
+
268
+ Args:
269
+ x (tensor): input
270
+
271
+ Returns:
272
+ tensor: output
273
+ """
274
+ out = self.relu(x)
275
+ out = self.conv1(out)
276
+ out = self.relu(out)
277
+ out = self.conv2(out)
278
+
279
+ return out + x
280
+
281
+
282
+ class FeatureFusionBlock(nn.Module):
283
+ """Feature fusion block.
284
+ """
285
+
286
+ def __init__(self, features):
287
+ """Init.
288
+
289
+ Args:
290
+ features (int): number of features
291
+ """
292
+ super(FeatureFusionBlock, self).__init__()
293
+
294
+ self.resConfUnit1 = ResidualConvUnit(features)
295
+ self.resConfUnit2 = ResidualConvUnit(features)
296
+
297
+ def forward(self, *xs):
298
+ """Forward pass.
299
+
300
+ Returns:
301
+ tensor: output
302
+ """
303
+ output = xs[0]
304
+
305
+ if len(xs) == 2:
306
+ output += self.resConfUnit1(xs[1])
307
+
308
+ output = self.resConfUnit2(output)
309
+
310
+ output = nn.functional.interpolate(
311
+ output, scale_factor=2, mode="bilinear", align_corners=True
312
+ )
313
+
314
+ return output
315
+
316
+
317
+
318
+
319
+ class ResidualConvUnit_custom(nn.Module):
320
+ """Residual convolution module.
321
+ """
322
+
323
+ def __init__(self, features, activation, bn):
324
+ """Init.
325
+
326
+ Args:
327
+ features (int): number of features
328
+ """
329
+ super().__init__()
330
+
331
+ self.bn = bn
332
+
333
+ self.groups=1
334
+
335
+ self.conv1 = nn.Conv2d(
336
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
337
+ )
338
+
339
+ self.conv2 = nn.Conv2d(
340
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
341
+ )
342
+
343
+ if self.bn==True:
344
+ self.bn1 = nn.BatchNorm2d(features)
345
+ self.bn2 = nn.BatchNorm2d(features)
346
+
347
+ self.activation = activation
348
+
349
+ self.skip_add = nn.quantized.FloatFunctional()
350
+
351
+ def forward(self, x):
352
+ """Forward pass.
353
+
354
+ Args:
355
+ x (tensor): input
356
+
357
+ Returns:
358
+ tensor: output
359
+ """
360
+
361
+ out = self.activation(x)
362
+ out = self.conv1(out)
363
+ if self.bn==True:
364
+ out = self.bn1(out)
365
+
366
+ out = self.activation(out)
367
+ out = self.conv2(out)
368
+ if self.bn==True:
369
+ out = self.bn2(out)
370
+
371
+ if self.groups > 1:
372
+ out = self.conv_merge(out)
373
+
374
+ return self.skip_add.add(out, x)
375
+
376
+ # return out + x
377
+
378
+
379
+ class FeatureFusionBlock_custom(nn.Module):
380
+ """Feature fusion block.
381
+ """
382
+
383
+ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None):
384
+ """Init.
385
+
386
+ Args:
387
+ features (int): number of features
388
+ """
389
+ super(FeatureFusionBlock_custom, self).__init__()
390
+
391
+ self.deconv = deconv
392
+ self.align_corners = align_corners
393
+
394
+ self.groups=1
395
+
396
+ self.expand = expand
397
+ out_features = features
398
+ if self.expand==True:
399
+ out_features = features//2
400
+
401
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
402
+
403
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
404
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
405
+
406
+ self.skip_add = nn.quantized.FloatFunctional()
407
+
408
+ self.size=size
409
+
410
+ def forward(self, *xs, size=None):
411
+ """Forward pass.
412
+
413
+ Returns:
414
+ tensor: output
415
+ """
416
+ output = xs[0]
417
+
418
+ if len(xs) == 2:
419
+ res = self.resConfUnit1(xs[1])
420
+ output = self.skip_add.add(output, res)
421
+ # output += res
422
+
423
+ output = self.resConfUnit2(output)
424
+
425
+ if (size is None) and (self.size is None):
426
+ modifier = {"scale_factor": 2}
427
+ elif size is None:
428
+ modifier = {"size": self.size}
429
+ else:
430
+ modifier = {"size": size}
431
+
432
+ output = nn.functional.interpolate(
433
+ output, **modifier, mode="bilinear", align_corners=self.align_corners
434
+ )
435
+
436
+ output = self.out_conv(output)
437
+
438
+ return output
439
+
midas/dpt_depth.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .base_model import BaseModel
5
+ from .blocks import (
6
+ FeatureFusionBlock_custom,
7
+ Interpolate,
8
+ _make_encoder,
9
+ forward_beit,
10
+ forward_swin,
11
+ forward_levit,
12
+ forward_vit,
13
+ )
14
+ from .backbones.levit import stem_b4_transpose
15
+ from timm.models.layers import get_act_layer
16
+
17
+
18
+ def _make_fusion_block(features, use_bn, size = None):
19
+ return FeatureFusionBlock_custom(
20
+ features,
21
+ nn.ReLU(False),
22
+ deconv=False,
23
+ bn=use_bn,
24
+ expand=False,
25
+ align_corners=True,
26
+ size=size,
27
+ )
28
+
29
+
30
+ class DPT(BaseModel):
31
+ def __init__(
32
+ self,
33
+ head,
34
+ features=256,
35
+ backbone="vitb_rn50_384",
36
+ readout="project",
37
+ channels_last=False,
38
+ use_bn=False,
39
+ **kwargs
40
+ ):
41
+
42
+ super(DPT, self).__init__()
43
+
44
+ self.channels_last = channels_last
45
+
46
+ # For the Swin, Swin 2, LeViT and Next-ViT Transformers, the hierarchical architectures prevent setting the
47
+ # hooks freely. Instead, the hooks have to be chosen according to the ranges specified in the comments.
48
+ hooks = {
49
+ "beitl16_512": [5, 11, 17, 23],
50
+ "beitl16_384": [5, 11, 17, 23],
51
+ "beitb16_384": [2, 5, 8, 11],
52
+ "swin2l24_384": [1, 1, 17, 1], # Allowed ranges: [0, 1], [0, 1], [ 0, 17], [ 0, 1]
53
+ "swin2b24_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1]
54
+ "swin2t16_256": [1, 1, 5, 1], # [0, 1], [0, 1], [ 0, 5], [ 0, 1]
55
+ "swinl12_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1]
56
+ "next_vit_large_6m": [2, 6, 36, 39], # [0, 2], [3, 6], [ 7, 36], [37, 39]
57
+ "levit_384": [3, 11, 21], # [0, 3], [6, 11], [14, 21]
58
+ "vitb_rn50_384": [0, 1, 8, 11],
59
+ "vitb16_384": [2, 5, 8, 11],
60
+ "vitl16_384": [5, 11, 17, 23],
61
+ }[backbone]
62
+
63
+ if "next_vit" in backbone:
64
+ in_features = {
65
+ "next_vit_large_6m": [96, 256, 512, 1024],
66
+ }[backbone]
67
+ else:
68
+ in_features = None
69
+
70
+ # Instantiate backbone and reassemble blocks
71
+ self.pretrained, self.scratch = _make_encoder(
72
+ backbone,
73
+ features,
74
+ False, # Set to true of you want to train from scratch, uses ImageNet weights
75
+ groups=1,
76
+ expand=False,
77
+ exportable=False,
78
+ hooks=hooks,
79
+ use_readout=readout,
80
+ in_features=in_features,
81
+ )
82
+
83
+ self.number_layers = len(hooks) if hooks is not None else 4
84
+ size_refinenet3 = None
85
+ self.scratch.stem_transpose = None
86
+
87
+ if "beit" in backbone:
88
+ self.forward_transformer = forward_beit
89
+ elif "swin" in backbone:
90
+ self.forward_transformer = forward_swin
91
+ elif "next_vit" in backbone:
92
+ from .backbones.next_vit import forward_next_vit
93
+ self.forward_transformer = forward_next_vit
94
+ elif "levit" in backbone:
95
+ self.forward_transformer = forward_levit
96
+ size_refinenet3 = 7
97
+ self.scratch.stem_transpose = stem_b4_transpose(256, 128, get_act_layer("hard_swish"))
98
+ else:
99
+ self.forward_transformer = forward_vit
100
+
101
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
102
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
103
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn, size_refinenet3)
104
+ if self.number_layers >= 4:
105
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
106
+
107
+ self.scratch.output_conv = head
108
+
109
+
110
+ def forward(self, x):
111
+ if self.channels_last == True:
112
+ x.contiguous(memory_format=torch.channels_last)
113
+
114
+ layers = self.forward_transformer(self.pretrained, x)
115
+ if self.number_layers == 3:
116
+ layer_1, layer_2, layer_3 = layers
117
+ else:
118
+ layer_1, layer_2, layer_3, layer_4 = layers
119
+
120
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
121
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
122
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
123
+ if self.number_layers >= 4:
124
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
125
+
126
+ if self.number_layers == 3:
127
+ path_3 = self.scratch.refinenet3(layer_3_rn, size=layer_2_rn.shape[2:])
128
+ else:
129
+ path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
130
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
131
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
132
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
133
+
134
+ if self.scratch.stem_transpose is not None:
135
+ path_1 = self.scratch.stem_transpose(path_1)
136
+
137
+ out = self.scratch.output_conv(path_1)
138
+
139
+ return out
140
+
141
+
142
+ class DPTDepthModel(DPT):
143
+ def __init__(self, path=None, non_negative=True, **kwargs):
144
+ features = kwargs["features"] if "features" in kwargs else 256
145
+ head_features_1 = kwargs["head_features_1"] if "head_features_1" in kwargs else features
146
+ head_features_2 = kwargs["head_features_2"] if "head_features_2" in kwargs else 32
147
+ kwargs.pop("head_features_1", None)
148
+ kwargs.pop("head_features_2", None)
149
+
150
+ head = nn.Sequential(
151
+ nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1),
152
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
153
+ nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
154
+ nn.ReLU(True),
155
+ nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
156
+ nn.ReLU(True) if non_negative else nn.Identity(),
157
+ nn.Identity(),
158
+ )
159
+
160
+ super().__init__(head, **kwargs)
161
+
162
+ if path is not None:
163
+ self.load(path)
164
+
165
+ def forward(self, x):
166
+ return super().forward(x).squeeze(dim=1)
midas/midas_net.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2
+ This file contains code that is adapted from
3
+ https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .base_model import BaseModel
9
+ from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
10
+
11
+
12
+ class MidasNet(BaseModel):
13
+ """Network for monocular depth estimation.
14
+ """
15
+
16
+ def __init__(self, path=None, features=256, non_negative=True):
17
+ """Init.
18
+
19
+ Args:
20
+ path (str, optional): Path to saved model. Defaults to None.
21
+ features (int, optional): Number of features. Defaults to 256.
22
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
23
+ """
24
+ print("Loading weights: ", path)
25
+
26
+ super(MidasNet, self).__init__()
27
+
28
+ use_pretrained = False if path is None else True
29
+
30
+ self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
31
+
32
+ self.scratch.refinenet4 = FeatureFusionBlock(features)
33
+ self.scratch.refinenet3 = FeatureFusionBlock(features)
34
+ self.scratch.refinenet2 = FeatureFusionBlock(features)
35
+ self.scratch.refinenet1 = FeatureFusionBlock(features)
36
+
37
+ self.scratch.output_conv = nn.Sequential(
38
+ nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
39
+ Interpolate(scale_factor=2, mode="bilinear"),
40
+ nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
41
+ nn.ReLU(True),
42
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
43
+ nn.ReLU(True) if non_negative else nn.Identity(),
44
+ )
45
+
46
+ if path:
47
+ self.load(path)
48
+
49
+ def forward(self, x):
50
+ """Forward pass.
51
+
52
+ Args:
53
+ x (tensor): input data (image)
54
+
55
+ Returns:
56
+ tensor: depth
57
+ """
58
+
59
+ layer_1 = self.pretrained.layer1(x)
60
+ layer_2 = self.pretrained.layer2(layer_1)
61
+ layer_3 = self.pretrained.layer3(layer_2)
62
+ layer_4 = self.pretrained.layer4(layer_3)
63
+
64
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
65
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
66
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
67
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
68
+
69
+ path_4 = self.scratch.refinenet4(layer_4_rn)
70
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
71
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
72
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
73
+
74
+ out = self.scratch.output_conv(path_1)
75
+
76
+ return torch.squeeze(out, dim=1)
midas/midas_net_custom.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2
+ This file contains code that is adapted from
3
+ https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .base_model import BaseModel
9
+ from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
10
+
11
+
12
+ class MidasNet_small(BaseModel):
13
+ """Network for monocular depth estimation.
14
+ """
15
+
16
+ def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
17
+ blocks={'expand': True}):
18
+ """Init.
19
+
20
+ Args:
21
+ path (str, optional): Path to saved model. Defaults to None.
22
+ features (int, optional): Number of features. Defaults to 256.
23
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
24
+ """
25
+ print("Loading weights: ", path)
26
+
27
+ super(MidasNet_small, self).__init__()
28
+
29
+ use_pretrained = False if path else True
30
+
31
+ self.channels_last = channels_last
32
+ self.blocks = blocks
33
+ self.backbone = backbone
34
+
35
+ self.groups = 1
36
+
37
+ features1=features
38
+ features2=features
39
+ features3=features
40
+ features4=features
41
+ self.expand = False
42
+ if "expand" in self.blocks and self.blocks['expand'] == True:
43
+ self.expand = True
44
+ features1=features
45
+ features2=features*2
46
+ features3=features*4
47
+ features4=features*8
48
+
49
+ self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
50
+
51
+ self.scratch.activation = nn.ReLU(False)
52
+
53
+ self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
54
+ self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
55
+ self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
56
+ self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
57
+
58
+
59
+ self.scratch.output_conv = nn.Sequential(
60
+ nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
61
+ Interpolate(scale_factor=2, mode="bilinear"),
62
+ nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
63
+ self.scratch.activation,
64
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
65
+ nn.ReLU(True) if non_negative else nn.Identity(),
66
+ nn.Identity(),
67
+ )
68
+
69
+ if path:
70
+ self.load(path)
71
+
72
+
73
+ def forward(self, x):
74
+ """Forward pass.
75
+
76
+ Args:
77
+ x (tensor): input data (image)
78
+
79
+ Returns:
80
+ tensor: depth
81
+ """
82
+ if self.channels_last==True:
83
+ print("self.channels_last = ", self.channels_last)
84
+ x.contiguous(memory_format=torch.channels_last)
85
+
86
+
87
+ layer_1 = self.pretrained.layer1(x)
88
+ layer_2 = self.pretrained.layer2(layer_1)
89
+ layer_3 = self.pretrained.layer3(layer_2)
90
+ layer_4 = self.pretrained.layer4(layer_3)
91
+
92
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
93
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
94
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
95
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
96
+
97
+
98
+ path_4 = self.scratch.refinenet4(layer_4_rn)
99
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
100
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
101
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
102
+
103
+ out = self.scratch.output_conv(path_1)
104
+
105
+ return torch.squeeze(out, dim=1)
106
+
107
+
108
+
109
+ def fuse_model(m):
110
+ prev_previous_type = nn.Identity()
111
+ prev_previous_name = ''
112
+ previous_type = nn.Identity()
113
+ previous_name = ''
114
+ for name, module in m.named_modules():
115
+ if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
116
+ # print("FUSED ", prev_previous_name, previous_name, name)
117
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
118
+ elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
119
+ # print("FUSED ", prev_previous_name, previous_name)
120
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
121
+ # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
122
+ # print("FUSED ", previous_name, name)
123
+ # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
124
+
125
+ prev_previous_type = previous_type
126
+ prev_previous_name = previous_name
127
+ previous_type = type(module)
128
+ previous_name = name
midas/model_loader.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+
4
+ from midas.dpt_depth import DPTDepthModel
5
+ from midas.midas_net import MidasNet
6
+ from midas.midas_net_custom import MidasNet_small
7
+ from midas.transforms import Resize, NormalizeImage, PrepareForNet
8
+
9
+ from torchvision.transforms import Compose
10
+
11
+ default_models = {
12
+ "dpt_beit_large_512": "weights/dpt_beit_large_512.pt",
13
+ "dpt_beit_large_384": "weights/dpt_beit_large_384.pt",
14
+ "dpt_beit_base_384": "weights/dpt_beit_base_384.pt",
15
+ "dpt_swin2_large_384": "weights/dpt_swin2_large_384.pt",
16
+ "dpt_swin2_base_384": "weights/dpt_swin2_base_384.pt",
17
+ "dpt_swin2_tiny_256": "weights/dpt_swin2_tiny_256.pt",
18
+ "dpt_swin_large_384": "weights/dpt_swin_large_384.pt",
19
+ "dpt_next_vit_large_384": "weights/dpt_next_vit_large_384.pt",
20
+ "dpt_levit_224": "weights/dpt_levit_224.pt",
21
+ "dpt_large_384": "weights/dpt_large_384.pt",
22
+ "dpt_hybrid_384": "weights/dpt_hybrid_384.pt",
23
+ "midas_v21_384": "weights/midas_v21_384.pt",
24
+ "midas_v21_small_256": "weights/midas_v21_small_256.pt",
25
+ "openvino_midas_v21_small_256": "weights/openvino_midas_v21_small_256.xml",
26
+ }
27
+
28
+
29
+ def load_model(device, model_path, model_type="dpt_large_384", optimize=True, height=None, square=False):
30
+ """Load the specified network.
31
+
32
+ Args:
33
+ device (device): the torch device used
34
+ model_path (str): path to saved model
35
+ model_type (str): the type of the model to be loaded
36
+ optimize (bool): optimize the model to half-integer on CUDA?
37
+ height (int): inference encoder image height
38
+ square (bool): resize to a square resolution?
39
+
40
+ Returns:
41
+ The loaded network, the transform which prepares images as input to the network and the dimensions of the
42
+ network input
43
+ """
44
+ if "openvino" in model_type:
45
+ from openvino.runtime import Core
46
+
47
+ keep_aspect_ratio = not square
48
+
49
+ if model_type == "dpt_beit_large_512":
50
+ model = DPTDepthModel(
51
+ path=model_path,
52
+ backbone="beitl16_512",
53
+ non_negative=True,
54
+ )
55
+ net_w, net_h = 512, 512
56
+ resize_mode = "minimal"
57
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
58
+
59
+ elif model_type == "dpt_beit_large_384":
60
+ model = DPTDepthModel(
61
+ path=model_path,
62
+ backbone="beitl16_384",
63
+ non_negative=True,
64
+ )
65
+ net_w, net_h = 384, 384
66
+ resize_mode = "minimal"
67
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
68
+
69
+ elif model_type == "dpt_beit_base_384":
70
+ model = DPTDepthModel(
71
+ path=model_path,
72
+ backbone="beitb16_384",
73
+ non_negative=True,
74
+ )
75
+ net_w, net_h = 384, 384
76
+ resize_mode = "minimal"
77
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
78
+
79
+ elif model_type == "dpt_swin2_large_384":
80
+ model = DPTDepthModel(
81
+ path=model_path,
82
+ backbone="swin2l24_384",
83
+ non_negative=True,
84
+ )
85
+ net_w, net_h = 384, 384
86
+ keep_aspect_ratio = False
87
+ resize_mode = "minimal"
88
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
89
+
90
+ elif model_type == "dpt_swin2_base_384":
91
+ model = DPTDepthModel(
92
+ path=model_path,
93
+ backbone="swin2b24_384",
94
+ non_negative=True,
95
+ )
96
+ net_w, net_h = 384, 384
97
+ keep_aspect_ratio = False
98
+ resize_mode = "minimal"
99
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
100
+
101
+ elif model_type == "dpt_swin2_tiny_256":
102
+ model = DPTDepthModel(
103
+ path=model_path,
104
+ backbone="swin2t16_256",
105
+ non_negative=True,
106
+ )
107
+ net_w, net_h = 256, 256
108
+ keep_aspect_ratio = False
109
+ resize_mode = "minimal"
110
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
111
+
112
+ elif model_type == "dpt_swin_large_384":
113
+ model = DPTDepthModel(
114
+ path=model_path,
115
+ backbone="swinl12_384",
116
+ non_negative=True,
117
+ )
118
+ net_w, net_h = 384, 384
119
+ keep_aspect_ratio = False
120
+ resize_mode = "minimal"
121
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
122
+
123
+ elif model_type == "dpt_next_vit_large_384":
124
+ model = DPTDepthModel(
125
+ path=model_path,
126
+ backbone="next_vit_large_6m",
127
+ non_negative=True,
128
+ )
129
+ net_w, net_h = 384, 384
130
+ resize_mode = "minimal"
131
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
132
+
133
+ # We change the notation from dpt_levit_224 (MiDaS notation) to levit_384 (timm notation) here, where the 224 refers
134
+ # to the resolution 224x224 used by LeViT and 384 is the first entry of the embed_dim, see _cfg and model_cfgs of
135
+ # https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/levit.py
136
+ # (commit id: 927f031293a30afb940fff0bee34b85d9c059b0e)
137
+ elif model_type == "dpt_levit_224":
138
+ model = DPTDepthModel(
139
+ path=model_path,
140
+ backbone="levit_384",
141
+ non_negative=True,
142
+ head_features_1=64,
143
+ head_features_2=8,
144
+ )
145
+ net_w, net_h = 224, 224
146
+ keep_aspect_ratio = False
147
+ resize_mode = "minimal"
148
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
149
+
150
+ elif model_type == "dpt_large_384":
151
+ model = DPTDepthModel(
152
+ path=model_path,
153
+ backbone="vitl16_384",
154
+ non_negative=True,
155
+ )
156
+ net_w, net_h = 384, 384
157
+ resize_mode = "minimal"
158
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
159
+
160
+ elif model_type == "dpt_hybrid_384":
161
+ model = DPTDepthModel(
162
+ path=model_path,
163
+ backbone="vitb_rn50_384",
164
+ non_negative=True,
165
+ )
166
+ net_w, net_h = 384, 384
167
+ resize_mode = "minimal"
168
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
169
+
170
+ elif model_type == "midas_v21_384":
171
+ model = MidasNet(model_path, non_negative=True)
172
+ net_w, net_h = 384, 384
173
+ resize_mode = "upper_bound"
174
+ normalization = NormalizeImage(
175
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
176
+ )
177
+
178
+ elif model_type == "midas_v21_small_256":
179
+ model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
180
+ non_negative=True, blocks={'expand': True})
181
+ net_w, net_h = 256, 256
182
+ resize_mode = "upper_bound"
183
+ normalization = NormalizeImage(
184
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
185
+ )
186
+
187
+ elif model_type == "openvino_midas_v21_small_256":
188
+ ie = Core()
189
+ uncompiled_model = ie.read_model(model=model_path)
190
+ model = ie.compile_model(uncompiled_model, "CPU")
191
+ net_w, net_h = 256, 256
192
+ resize_mode = "upper_bound"
193
+ normalization = NormalizeImage(
194
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
195
+ )
196
+
197
+ else:
198
+ print(f"model_type '{model_type}' not implemented, use: --model_type large")
199
+ assert False
200
+
201
+ if not "openvino" in model_type:
202
+ print("Model loaded, number of parameters = {:.0f}M".format(sum(p.numel() for p in model.parameters()) / 1e6))
203
+ else:
204
+ print("Model loaded, optimized with OpenVINO")
205
+
206
+ if "openvino" in model_type:
207
+ keep_aspect_ratio = False
208
+
209
+ if height is not None:
210
+ net_w, net_h = height, height
211
+
212
+ transform = Compose(
213
+ [
214
+ Resize(
215
+ net_w,
216
+ net_h,
217
+ resize_target=None,
218
+ keep_aspect_ratio=keep_aspect_ratio,
219
+ ensure_multiple_of=32,
220
+ resize_method=resize_mode,
221
+ image_interpolation_method=cv2.INTER_CUBIC,
222
+ ),
223
+ normalization,
224
+ PrepareForNet(),
225
+ ]
226
+ )
227
+
228
+ if not "openvino" in model_type:
229
+ model.eval()
230
+
231
+ if optimize and (device == torch.device("cuda")):
232
+ if not "openvino" in model_type:
233
+ model = model.to(memory_format=torch.channels_last)
234
+ model = model.half()
235
+ else:
236
+ print("Error: OpenVINO models are already optimized. No optimization to half-float possible.")
237
+ exit()
238
+
239
+ if not "openvino" in model_type:
240
+ model.to(device)
241
+
242
+ return model, transform, net_w, net_h
midas/transforms.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import math
4
+
5
+
6
+ def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
7
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
8
+
9
+ Args:
10
+ sample (dict): sample
11
+ size (tuple): image size
12
+
13
+ Returns:
14
+ tuple: new size
15
+ """
16
+ shape = list(sample["disparity"].shape)
17
+
18
+ if shape[0] >= size[0] and shape[1] >= size[1]:
19
+ return sample
20
+
21
+ scale = [0, 0]
22
+ scale[0] = size[0] / shape[0]
23
+ scale[1] = size[1] / shape[1]
24
+
25
+ scale = max(scale)
26
+
27
+ shape[0] = math.ceil(scale * shape[0])
28
+ shape[1] = math.ceil(scale * shape[1])
29
+
30
+ # resize
31
+ sample["image"] = cv2.resize(
32
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
33
+ )
34
+
35
+ sample["disparity"] = cv2.resize(
36
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
37
+ )
38
+ sample["mask"] = cv2.resize(
39
+ sample["mask"].astype(np.float32),
40
+ tuple(shape[::-1]),
41
+ interpolation=cv2.INTER_NEAREST,
42
+ )
43
+ sample["mask"] = sample["mask"].astype(bool)
44
+
45
+ return tuple(shape)
46
+
47
+
48
+ class Resize(object):
49
+ """Resize sample to given size (width, height).
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ width,
55
+ height,
56
+ resize_target=True,
57
+ keep_aspect_ratio=False,
58
+ ensure_multiple_of=1,
59
+ resize_method="lower_bound",
60
+ image_interpolation_method=cv2.INTER_AREA,
61
+ ):
62
+ """Init.
63
+
64
+ Args:
65
+ width (int): desired output width
66
+ height (int): desired output height
67
+ resize_target (bool, optional):
68
+ True: Resize the full sample (image, mask, target).
69
+ False: Resize image only.
70
+ Defaults to True.
71
+ keep_aspect_ratio (bool, optional):
72
+ True: Keep the aspect ratio of the input sample.
73
+ Output sample might not have the given width and height, and
74
+ resize behaviour depends on the parameter 'resize_method'.
75
+ Defaults to False.
76
+ ensure_multiple_of (int, optional):
77
+ Output width and height is constrained to be multiple of this parameter.
78
+ Defaults to 1.
79
+ resize_method (str, optional):
80
+ "lower_bound": Output will be at least as large as the given size.
81
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
82
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
83
+ Defaults to "lower_bound".
84
+ """
85
+ self.__width = width
86
+ self.__height = height
87
+
88
+ self.__resize_target = resize_target
89
+ self.__keep_aspect_ratio = keep_aspect_ratio
90
+ self.__multiple_of = ensure_multiple_of
91
+ self.__resize_method = resize_method
92
+ self.__image_interpolation_method = image_interpolation_method
93
+
94
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
95
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
96
+
97
+ if max_val is not None and y > max_val:
98
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
99
+
100
+ if y < min_val:
101
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
102
+
103
+ return y
104
+
105
+ def get_size(self, width, height):
106
+ # determine new height and width
107
+ scale_height = self.__height / height
108
+ scale_width = self.__width / width
109
+
110
+ if self.__keep_aspect_ratio:
111
+ if self.__resize_method == "lower_bound":
112
+ # scale such that output size is lower bound
113
+ if scale_width > scale_height:
114
+ # fit width
115
+ scale_height = scale_width
116
+ else:
117
+ # fit height
118
+ scale_width = scale_height
119
+ elif self.__resize_method == "upper_bound":
120
+ # scale such that output size is upper bound
121
+ if scale_width < scale_height:
122
+ # fit width
123
+ scale_height = scale_width
124
+ else:
125
+ # fit height
126
+ scale_width = scale_height
127
+ elif self.__resize_method == "minimal":
128
+ # scale as least as possbile
129
+ if abs(1 - scale_width) < abs(1 - scale_height):
130
+ # fit width
131
+ scale_height = scale_width
132
+ else:
133
+ # fit height
134
+ scale_width = scale_height
135
+ else:
136
+ raise ValueError(
137
+ f"resize_method {self.__resize_method} not implemented"
138
+ )
139
+
140
+ if self.__resize_method == "lower_bound":
141
+ new_height = self.constrain_to_multiple_of(
142
+ scale_height * height, min_val=self.__height
143
+ )
144
+ new_width = self.constrain_to_multiple_of(
145
+ scale_width * width, min_val=self.__width
146
+ )
147
+ elif self.__resize_method == "upper_bound":
148
+ new_height = self.constrain_to_multiple_of(
149
+ scale_height * height, max_val=self.__height
150
+ )
151
+ new_width = self.constrain_to_multiple_of(
152
+ scale_width * width, max_val=self.__width
153
+ )
154
+ elif self.__resize_method == "minimal":
155
+ new_height = self.constrain_to_multiple_of(scale_height * height)
156
+ new_width = self.constrain_to_multiple_of(scale_width * width)
157
+ else:
158
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
159
+
160
+ return (new_width, new_height)
161
+
162
+ def __call__(self, sample):
163
+ width, height = self.get_size(
164
+ sample["image"].shape[1], sample["image"].shape[0]
165
+ )
166
+
167
+ # resize sample
168
+ sample["image"] = cv2.resize(
169
+ sample["image"],
170
+ (width, height),
171
+ interpolation=self.__image_interpolation_method,
172
+ )
173
+
174
+ if self.__resize_target:
175
+ if "disparity" in sample:
176
+ sample["disparity"] = cv2.resize(
177
+ sample["disparity"],
178
+ (width, height),
179
+ interpolation=cv2.INTER_NEAREST,
180
+ )
181
+
182
+ if "depth" in sample:
183
+ sample["depth"] = cv2.resize(
184
+ sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
185
+ )
186
+
187
+ sample["mask"] = cv2.resize(
188
+ sample["mask"].astype(np.float32),
189
+ (width, height),
190
+ interpolation=cv2.INTER_NEAREST,
191
+ )
192
+ sample["mask"] = sample["mask"].astype(bool)
193
+
194
+ return sample
195
+
196
+
197
+ class NormalizeImage(object):
198
+ """Normlize image by given mean and std.
199
+ """
200
+
201
+ def __init__(self, mean, std):
202
+ self.__mean = mean
203
+ self.__std = std
204
+
205
+ def __call__(self, sample):
206
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
207
+
208
+ return sample
209
+
210
+
211
+ class PrepareForNet(object):
212
+ """Prepare sample for usage as network input.
213
+ """
214
+
215
+ def __init__(self):
216
+ pass
217
+
218
+ def __call__(self, sample):
219
+ image = np.transpose(sample["image"], (2, 0, 1))
220
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
221
+
222
+ if "mask" in sample:
223
+ sample["mask"] = sample["mask"].astype(np.float32)
224
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
225
+
226
+ if "disparity" in sample:
227
+ disparity = sample["disparity"].astype(np.float32)
228
+ sample["disparity"] = np.ascontiguousarray(disparity)
229
+
230
+ if "depth" in sample:
231
+ depth = sample["depth"].astype(np.float32)
232
+ sample["depth"] = np.ascontiguousarray(depth)
233
+
234
+ return sample
monocular_depth_estimator.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import numpy as np
4
+ import time
5
+ from midas.model_loader import default_models, load_model
6
+ import os
7
+ import urllib.request
8
+
9
+ class MonocularDepthEstimator:
10
+ def __init__(self,
11
+ model_type="midas_v21_small_256",
12
+ model_weights_path="models/midas_v21_small_256.pt",
13
+ optimize=False,
14
+ side_by_side=True,
15
+ height=None,
16
+ square=False,
17
+ grayscale=False):
18
+
19
+ # model type
20
+ # MiDaS 3.1:
21
+ # For highest quality: dpt_beit_large_512
22
+ # For moderately less quality, but better speed-performance trade-off: dpt_swin2_large_384
23
+ # For embedded devices: dpt_swin2_tiny_256, dpt_levit_224
24
+ # For inference on Intel CPUs, OpenVINO may be used for the small legacy model: openvino_midas_v21_small .xml, .bin
25
+
26
+ # MiDaS 3.0:
27
+ # Legacy transformer models dpt_large_384 and dpt_hybrid_384
28
+
29
+ # MiDaS 2.1:
30
+ # Legacy convolutional models midas_v21_384 and midas_v21_small_256
31
+
32
+ # params
33
+ print("Initializing parameters and model...")
34
+ self.is_optimize = optimize
35
+ self.is_square = square
36
+ self.is_grayscale = grayscale
37
+ self.height = height
38
+ self.side_by_side = side_by_side
39
+
40
+ # select device
41
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
+ print("Running inference on : %s" % self.device)
43
+ model_file_url = "https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_small_256.pt"
44
+
45
+ # loading model
46
+ if not os.path.exists(model_weights_path):
47
+ print("Model file not found. Downloading...")
48
+ # Download the model file
49
+ urllib.request.urlretrieve(model_file_url, model_weights_path)
50
+ print("Model file downloaded successfully.")
51
+
52
+ self.model, self.transform, self.net_w, self.net_h = load_model(self.device, model_weights_path,
53
+ model_type, optimize, height, square)
54
+ print("Net width and height: ", (self.net_w, self.net_h))
55
+
56
+
57
+ def predict(self, image, model, target_size):
58
+
59
+
60
+ # convert img to tensor and load to gpu
61
+ img_tensor = torch.from_numpy(image).to(self.device).unsqueeze(0)
62
+
63
+ if self.is_optimize and self.device == torch.device("cuda"):
64
+ img_tensor = img_tensor.to(memory_format=torch.channels_last)
65
+ img_tensor = img_tensor.half()
66
+
67
+ prediction = model.forward(img_tensor)
68
+ prediction = (
69
+ torch.nn.functional.interpolate(
70
+ prediction.unsqueeze(1),
71
+ size=target_size[::-1],
72
+ mode="bicubic",
73
+ align_corners=False,
74
+ )
75
+ .squeeze()
76
+ .cpu()
77
+ .numpy()
78
+ )
79
+
80
+ return prediction
81
+
82
+ def process_prediction(self, original_img, depth_img, is_grayscale=False, side_by_side=False):
83
+ """
84
+ Take an RGB image and depth map and place them side by side. This includes a proper normalization of the depth map
85
+ for better visibility.
86
+ Args:
87
+ original_img: the RGB image
88
+ depth_img: the depth map
89
+ is_grayscale: use a grayscale colormap?
90
+ Returns:
91
+ the image and depth map place side by side
92
+ """
93
+
94
+ # normalizing depth image
95
+ depth_min = depth_img.min()
96
+ depth_max = depth_img.max()
97
+ normalized_depth = 255 * (depth_img - depth_min) / (depth_max - depth_min)
98
+ normalized_depth *= 3
99
+
100
+ depth_side = np.repeat(np.expand_dims(normalized_depth, 2), 3, axis=2) / 3
101
+ if not is_grayscale:
102
+ depth_side = cv2.applyColorMap(np.uint8(depth_side), cv2.COLORMAP_INFERNO)
103
+
104
+ if side_by_side:
105
+ return np.concatenate((original_img, depth_side), axis=1)/255
106
+
107
+ return depth_side/255
108
+
109
+ def make_prediction(self, image):
110
+ with torch.no_grad():
111
+ original_image_rgb = np.flip(image, 2) # in [0, 255] (flip required to get RGB)
112
+ # resizing the image to feed to the model
113
+ image_tranformed = self.transform({"image": original_image_rgb/255})["image"]
114
+
115
+ # monocular depth prediction
116
+ prediction = self.predict(image_tranformed, self.model, target_size=original_image_rgb.shape[1::-1])
117
+ original_image_bgr = np.flip(original_image_rgb, 2) if self.side_by_side else None
118
+
119
+ # process the model predictions
120
+ output = self.process_prediction(original_image_bgr, prediction, is_grayscale=self.is_grayscale, side_by_side=self.side_by_side)
121
+ return output
122
+
123
+ def run(self, input_path):
124
+
125
+ # input video
126
+ cap = cv2.VideoCapture(input_path)
127
+
128
+ # Check if camera opened successfully
129
+ if not cap.isOpened():
130
+ print("Error opening video file")
131
+
132
+ with torch.no_grad():
133
+ while cap.isOpened():
134
+
135
+ # Capture frame-by-frame
136
+ inference_start_time = time.time()
137
+ ret, frame = cap.read()
138
+
139
+ if ret == True:
140
+ output = self.make_prediction(frame)
141
+ inference_end_time = time.time()
142
+ fps = round(1/(inference_end_time - inference_start_time))
143
+ cv2.putText(output, f'FPS: {fps}', (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (10, 255, 100), 2)
144
+ cv2.imshow('MiDaS Depth Estimation - Press Escape to close window ', output)
145
+
146
+ # Press ESC on keyboard to exit
147
+ if cv2.waitKey(1) == 27: # Escape key
148
+ break
149
+
150
+ else:
151
+ break
152
+
153
+
154
+ # When everything done, release
155
+ # the video capture object
156
+ cap.release()
157
+
158
+ # Closes all the frames
159
+ cv2.destroyAllWindows()
160
+
161
+
162
+
163
+ if __name__ == "__main__":
164
+ # params
165
+ INPUT_PATH = "assets/videos/testvideo2.mp4"
166
+
167
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
168
+
169
+ # set torch options
170
+ torch.backends.cudnn.enabled = True
171
+ torch.backends.cudnn.benchmark = True
172
+
173
+ depth_estimator = MonocularDepthEstimator(side_by_side=False)
174
+ depth_estimator.run(INPUT_PATH)
175
+