harshm121 commited on
Commit
d4ebf73
1 Parent(s): 2bdbad1

Working demo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. .gitattributes +67 -0
  3. README.md +2 -2
  4. app.py +161 -0
  5. arial.ttf +0 -0
  6. checkpoints/sid_1-500_m3lteacher.pth +3 -0
  7. checkpoints/sid_1-500_mtteacher.pth +3 -0
  8. classcolors.png +3 -0
  9. colors.pkl +3 -0
  10. datasets/__init__.py +0 -0
  11. datasets/__pycache__/__init__.cpython-36.pyc +0 -0
  12. datasets/__pycache__/__init__.cpython-38.pyc +0 -0
  13. datasets/__pycache__/__init__.cpython-39.pyc +0 -0
  14. datasets/__pycache__/base_dataset.cpython-36.pyc +0 -0
  15. datasets/__pycache__/base_dataset.cpython-38.pyc +0 -0
  16. datasets/__pycache__/base_dataset.cpython-39.pyc +0 -0
  17. datasets/__pycache__/citysundepth.cpython-36.pyc +0 -0
  18. datasets/__pycache__/citysundepth.cpython-39.pyc +0 -0
  19. datasets/__pycache__/citysunrgb.cpython-36.pyc +0 -0
  20. datasets/__pycache__/citysunrgb.cpython-38.pyc +0 -0
  21. datasets/__pycache__/citysunrgb.cpython-39.pyc +0 -0
  22. datasets/__pycache__/citysunrgbd.cpython-36.pyc +0 -0
  23. datasets/__pycache__/citysunrgbd.cpython-38.pyc +0 -0
  24. datasets/__pycache__/get_dataset.cpython-36.pyc +0 -0
  25. datasets/__pycache__/get_dataset.cpython-39.pyc +0 -0
  26. datasets/__pycache__/preprocessors.cpython-36.pyc +0 -0
  27. datasets/__pycache__/preprocessors.cpython-38.pyc +0 -0
  28. datasets/__pycache__/tfnyu.cpython-36.pyc +0 -0
  29. datasets/base_dataset.py +128 -0
  30. datasets/citysunrgbd.py +67 -0
  31. datasets/get_dataset.py +146 -0
  32. datasets/preprocessors.py +144 -0
  33. examples/.DS_Store +0 -0
  34. examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_45_domain.png +3 -0
  35. examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_46_domain.png +3 -0
  36. examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_47_domain.png +3 -0
  37. examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_48_domain.png +3 -0
  38. examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_49_domain.png +3 -0
  39. examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_50_domain.png +3 -0
  40. examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_51_domain.png +3 -0
  41. examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_52_domain.png +3 -0
  42. examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_53_domain.png +3 -0
  43. examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_0_domain.png +3 -0
  44. examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_10_domain.png +3 -0
  45. examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_1_domain.png +3 -0
  46. examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_2_domain.png +3 -0
  47. examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_3_domain.png +3 -0
  48. examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_4_domain.png +3 -0
  49. examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_5_domain.png +3 -0
  50. examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_6_domain.png +3 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitattributes CHANGED
@@ -32,3 +32,70 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ checkpoints/sid_1-500_m3lteacher.pth filter=lfs diff=lfs merge=lfs -text
36
+ checkpoints/sid_1-500_mtteacher.pth filter=lfs diff=lfs merge=lfs -text
37
+ examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_10_domain.png filter=lfs diff=lfs merge=lfs -text
38
+ examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_9_domain.png filter=lfs diff=lfs merge=lfs -text
39
+ examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_46_domain.png filter=lfs diff=lfs merge=lfs -text
40
+ examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_4_domain.png filter=lfs diff=lfs merge=lfs -text
41
+ examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_8_domain.png filter=lfs diff=lfs merge=lfs -text
42
+ examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_53_domain.png filter=lfs diff=lfs merge=lfs -text
43
+ examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_49_domain.png filter=lfs diff=lfs merge=lfs -text
44
+ examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_8_domain.png filter=lfs diff=lfs merge=lfs -text
45
+ examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_5_domain.png filter=lfs diff=lfs merge=lfs -text
46
+ examples/rgb/camera_4e7d4edf7c2a40ae93437d1fbde22043_office_6_frame_47_domain.png filter=lfs diff=lfs merge=lfs -text
47
+ examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_47_domain.png filter=lfs diff=lfs merge=lfs -text
48
+ examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_1_domain.png filter=lfs diff=lfs merge=lfs -text
49
+ examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_2_domain.png filter=lfs diff=lfs merge=lfs -text
50
+ examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_3_domain.png filter=lfs diff=lfs merge=lfs -text
51
+ examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_9_domain.png filter=lfs diff=lfs merge=lfs -text
52
+ examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_48_domain.png filter=lfs diff=lfs merge=lfs -text
53
+ examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_7_domain.png filter=lfs diff=lfs merge=lfs -text
54
+ examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_4_domain.png filter=lfs diff=lfs merge=lfs -text
55
+ examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_6_domain.png filter=lfs diff=lfs merge=lfs -text
56
+ examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_5_domain.png filter=lfs diff=lfs merge=lfs -text
57
+ examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_0_domain.png filter=lfs diff=lfs merge=lfs -text
58
+ examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_3_domain.png filter=lfs diff=lfs merge=lfs -text
59
+ examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_6_domain.png filter=lfs diff=lfs merge=lfs -text
60
+ examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_0_domain.png filter=lfs diff=lfs merge=lfs -text
61
+ examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_10_domain.png filter=lfs diff=lfs merge=lfs -text
62
+ examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_7_domain.png filter=lfs diff=lfs merge=lfs -text
63
+ examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_52_domain.png filter=lfs diff=lfs merge=lfs -text
64
+ examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_2_domain.png filter=lfs diff=lfs merge=lfs -text
65
+ examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_1_domain.png filter=lfs diff=lfs merge=lfs -text
66
+ examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_50_domain.png filter=lfs diff=lfs merge=lfs -text
67
+ examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_51_domain.png filter=lfs diff=lfs merge=lfs -text
68
+ examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_45_domain.png filter=lfs diff=lfs merge=lfs -text
69
+ examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_49_domain.png filter=lfs diff=lfs merge=lfs -text
70
+ examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_53_domain.png filter=lfs diff=lfs merge=lfs -text
71
+ examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_3_domain.png filter=lfs diff=lfs merge=lfs -text
72
+ examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_6_domain.png filter=lfs diff=lfs merge=lfs -text
73
+ examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_0_domain.png filter=lfs diff=lfs merge=lfs -text
74
+ examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_8_domain.png filter=lfs diff=lfs merge=lfs -text
75
+ examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_3_domain.png filter=lfs diff=lfs merge=lfs -text
76
+ examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_8_domain.png filter=lfs diff=lfs merge=lfs -text
77
+ examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_46_domain.png filter=lfs diff=lfs merge=lfs -text
78
+ examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_52_domain.png filter=lfs diff=lfs merge=lfs -text
79
+ examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_5_domain.png filter=lfs diff=lfs merge=lfs -text
80
+ examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_4_domain.png filter=lfs diff=lfs merge=lfs -text
81
+ examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_0_domain.png filter=lfs diff=lfs merge=lfs -text
82
+ examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_10_domain.png filter=lfs diff=lfs merge=lfs -text
83
+ examples/depth/camera_4e7d4edf7c2a40ae93437d1fbde22043_office_6_frame_47_domain.png filter=lfs diff=lfs merge=lfs -text
84
+ examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_47_domain.png filter=lfs diff=lfs merge=lfs -text
85
+ examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_4_domain.png filter=lfs diff=lfs merge=lfs -text
86
+ examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_5_domain.png filter=lfs diff=lfs merge=lfs -text
87
+ examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_9_domain.png filter=lfs diff=lfs merge=lfs -text
88
+ examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_48_domain.png filter=lfs diff=lfs merge=lfs -text
89
+ examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_50_domain.png filter=lfs diff=lfs merge=lfs -text
90
+ examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_51_domain.png filter=lfs diff=lfs merge=lfs -text
91
+ examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_1_domain.png filter=lfs diff=lfs merge=lfs -text
92
+ examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_9_domain.png filter=lfs diff=lfs merge=lfs -text
93
+ examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_10_domain.png filter=lfs diff=lfs merge=lfs -text
94
+ examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_6_domain.png filter=lfs diff=lfs merge=lfs -text
95
+ examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_2_domain.png filter=lfs diff=lfs merge=lfs -text
96
+ examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_45_domain.png filter=lfs diff=lfs merge=lfs -text
97
+ examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_2_domain.png filter=lfs diff=lfs merge=lfs -text
98
+ examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_7_domain.png filter=lfs diff=lfs merge=lfs -text
99
+ examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_1_domain.png filter=lfs diff=lfs merge=lfs -text
100
+ examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_7_domain.png filter=lfs diff=lfs merge=lfs -text
101
+ classcolors.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,7 +1,7 @@
1
  ---
2
  title: M3L
3
- emoji: 🐢
4
- colorFrom: pink
5
  colorTo: gray
6
  sdk: gradio
7
  sdk_version: 3.23.0
 
1
  ---
2
  title: M3L
3
+ emoji: 📚
4
+ colorFrom: purple
5
  colorTo: gray
6
  sdk: gradio
7
  sdk_version: 3.23.0
app.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import gradio as gr
4
+ import numpy as np
5
+ import os
6
+ import random
7
+ import pickle as pkl
8
+ from models.segmentation_models.linearfusemaskedconsmixbatch.segformer import LinearFusionMaskedConsistencyMixBatch
9
+ from models.segmentation_models.linearfuse.segformer import WeTrLinearFusion
10
+ from datasets.preprocessors import RGBDValPre
11
+ from utils.constants import Constants as C
12
+
13
+ class Arguments:
14
+ def __init__(self, ratio):
15
+ self.ratio = ratio
16
+ self.masking_ratio = 1.0
17
+
18
+ colors = pkl.load(open('./colors.pkl', 'rb'))
19
+ args = Arguments(ratio = 0.8)
20
+
21
+ mtmodel = WeTrLinearFusion("mit_b2", args, num_classes=13, pretrained=False)
22
+ mtmodelpath = './checkpoints/sid_1-500_mtteacher.pth'
23
+ mtmodel.load_state_dict(torch.load(mtmodelpath, map_location=torch.device('cpu')))
24
+ mtmodel.eval()
25
+
26
+ m3lmodel = LinearFusionMaskedConsistencyMixBatch("mit_b2", args, num_classes=13, pretrained=False)
27
+ m3lmodelpath = './checkpoints/sid_1-500_m3lteacher.pth'
28
+ m3lmodel.load_state_dict(torch.load(m3lmodelpath, map_location=torch.device('cpu')))
29
+ m3lmodel.eval()
30
+
31
+
32
+
33
+ class MaskStudentTeacher(nn.Module):
34
+
35
+ def __init__(self, student, teacher, ema_alpha, mode = 'train'):
36
+ super(MaskStudentTeacher, self).__init__()
37
+ self.student = student
38
+ self.teacher = teacher
39
+ self.teacher = self._detach_teacher(self.teacher)
40
+ self.ema_alpha = ema_alpha
41
+ self.mode = mode
42
+ def forward(self, data, student = True, teacher = True, mask = False, range_batches_to_mask = None, **kwargs):
43
+ ret = []
44
+ if student:
45
+ if self.mode == 'train':
46
+ ret.append(self.student(data, mask = mask, range_batches_to_mask = range_batches_to_mask, **kwargs))
47
+ elif self.mode == 'val':
48
+ ret.append(self.student(data, mask = False, **kwargs))
49
+ else:
50
+ raise Exception('Mode not supported')
51
+ if teacher:
52
+ ret.append(self.teacher(data, mask = False, **kwargs)) #Not computing loss for teacher ever but passing the results as if loss was also returned
53
+ return ret
54
+ def _detach_teacher(self, model):
55
+ for param in model.parameters():
56
+ param.detach_()
57
+ return model
58
+ def update_teacher_models(self, global_step):
59
+ alpha = min(1 - 1 / (global_step + 1), self.ema_alpha)
60
+ for ema_param, param in zip(self.teacher.parameters(), self.student.parameters()):
61
+ ema_param.data.mul_(alpha).add_(1 - alpha, param.data)
62
+ return
63
+ def copy_student_to_teacher(self):
64
+ for ema_param, param in zip(self.teacher.parameters(), self.student.parameters()):
65
+ ema_param.data.mul_(0).add_(param.data)
66
+ return
67
+ def get_params(self):
68
+ student_params = self.student.get_params()
69
+ teacher_params = self.teacher.get_params()
70
+ return student_params
71
+
72
+
73
+ def preprocess_data(rgb, depth, dataset_settings):
74
+ #RGB: np.array, RGB
75
+ #Depth: np.array, minmax normalized, *255
76
+ preprocess = RGBDValPre(C.pytorch_mean, C.pytorch_std, dataset_settings)
77
+ rgb, depth = preprocess(rgb, depth)
78
+ if rgb is not None:
79
+ rgb = torch.from_numpy(np.ascontiguousarray(rgb)).float()
80
+ if depth is not None:
81
+ depth = torch.from_numpy(np.ascontiguousarray(depth)).float()
82
+ return rgb, depth
83
+
84
+
85
+ def visualize(colors, pred, num_classes, dataset_settings):
86
+ pred = pred.transpose(1, 2, 0)
87
+ predvis = np.zeros((dataset_settings['orig_height'], dataset_settings['orig_width'], 3))
88
+ for i in range(num_classes):
89
+ color = colors[i]
90
+ predvis = np.where(pred == i, color, predvis)
91
+ predvis /= 255.0
92
+ predvis = predvis[:,:,::-1]
93
+ return predvis
94
+
95
+ def predict(rgb, depth, check):
96
+ dataset_settings = {}
97
+ dataset_settings['image_height'], dataset_settings['image_width'] = 540, 540
98
+ dataset_settings['orig_height'], dataset_settings['orig_width'] = 540,540
99
+
100
+ rgb, depth = preprocess_data(rgb, depth, dataset_settings)
101
+ if rgb is not None:
102
+ rgb = rgb.unsqueeze(dim = 0)
103
+ if depth is not None:
104
+ depth = depth.unsqueeze(dim = 0)
105
+ ret = [None, None, './classcolors.png']
106
+ if "Mean Teacher" in check:
107
+ if rgb is None:
108
+ rgb = torch.zeros_like(depth)
109
+ if depth is None:
110
+ depth = torch.zeros_like(rgb)
111
+ scores = mtmodel([rgb, depth])[2]
112
+ scores = torch.nn.functional.interpolate(scores, size = (dataset_settings["orig_height"], dataset_settings["orig_width"]), mode = 'bilinear', align_corners = True)
113
+ prob = scores.detach()
114
+ _, pred = torch.max(prob, dim=1)
115
+ pred = pred.numpy()
116
+ predvis = visualize(colors, pred, num_classes=13, dataset_settings=dataset_settings)
117
+ ret[0] = predvis
118
+ if "M3L" in check:
119
+ mask = False
120
+ masking_branch = None
121
+ if rgb is None:
122
+ mask = True
123
+ masking_branch = 0
124
+ if depth is None:
125
+ mask = True
126
+ masking_branch = 1
127
+ scores = m3lmodel([rgb, depth], mask = mask, masking_branch = masking_branch)[2]
128
+ scores = torch.nn.functional.interpolate(scores, size = (dataset_settings["orig_height"], dataset_settings["orig_width"]), mode = 'bilinear', align_corners = True)
129
+ prob = scores.detach()
130
+ _, pred = torch.max(prob, dim=1)
131
+ pred = pred.numpy()
132
+ predvis = visualize(colors, pred, num_classes=13, dataset_settings=dataset_settings)
133
+ ret[1] = predvis
134
+
135
+ return ret
136
+
137
+ imgs = os.listdir('./examples/rgb')
138
+ random.shuffle(imgs)
139
+ examples = []
140
+ for img in imgs:
141
+ examples.append([
142
+ './examples/rgb/'+img, './examples/depth/'+img, ["M3L", "Mean Teacher"]
143
+ ])
144
+
145
+ with gr.Blocks() as demo:
146
+ with gr.Row():
147
+ rgbinput = gr.Image(label="RGB Input").style(height=256, width=256)
148
+ depthinput = gr.Image(label="Depth Input").style(height=256, width=256)
149
+ with gr.Row():
150
+ modelcheck = gr.CheckboxGroup(["Mean Teacher", "M3L"], label="Predictions from", info="Predict using model trained with:")
151
+ with gr.Row():
152
+ submit_btn = gr.Button("Submit")
153
+ with gr.Row():
154
+ mtoutput = gr.Image(label="Mean Teacher Output").style(height=384, width=384)
155
+ m3loutput = gr.Image(label="M3L Output").style(height=384, width=384)
156
+ classnameouptut = gr.Image(label="Classes").style(height=384, width=384)
157
+ with gr.Row():
158
+ examplesRow = gr.Examples(examples=examples, examples_per_page=10, inputs=[rgbinput, depthinput, modelcheck])
159
+ submit_btn.click(fn = predict, inputs = [rgbinput, depthinput, modelcheck], outputs = [mtoutput, m3loutput, classnameouptut])
160
+
161
+ demo.launch()
arial.ttf ADDED
Binary file (289 kB). View file
 
checkpoints/sid_1-500_m3lteacher.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d5a23d7e2697b44b18e368e01353c328b13055a05a1cb0946ffb95b692d6facd
3
+ size 99192724
checkpoints/sid_1-500_mtteacher.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7eb24d6275e15376c40ec526281d550e4842ffa71aaa7af58fea54cbf56c2eeb
3
+ size 99186911
classcolors.png ADDED

Git LFS Details

  • SHA256: 54ab7bebebd4982336252535488ad47d42751c06652eb4c47fb0af47c7880aba
  • Pointer size: 130 Bytes
  • Size of remote file: 38.4 kB
colors.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c03050cb753b0802781f0ce92893ac22129c15724dd4ece6f4b9b4a352db591
3
+ size 2342
datasets/__init__.py ADDED
File without changes
datasets/__pycache__/__init__.cpython-36.pyc ADDED
Binary file (192 Bytes). View file
 
datasets/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (165 Bytes). View file
 
datasets/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (200 Bytes). View file
 
datasets/__pycache__/base_dataset.cpython-36.pyc ADDED
Binary file (4.77 kB). View file
 
datasets/__pycache__/base_dataset.cpython-38.pyc ADDED
Binary file (4.75 kB). View file
 
datasets/__pycache__/base_dataset.cpython-39.pyc ADDED
Binary file (4.78 kB). View file
 
datasets/__pycache__/citysundepth.cpython-36.pyc ADDED
Binary file (1.65 kB). View file
 
datasets/__pycache__/citysundepth.cpython-39.pyc ADDED
Binary file (1.65 kB). View file
 
datasets/__pycache__/citysunrgb.cpython-36.pyc ADDED
Binary file (2.09 kB). View file
 
datasets/__pycache__/citysunrgb.cpython-38.pyc ADDED
Binary file (2.04 kB). View file
 
datasets/__pycache__/citysunrgb.cpython-39.pyc ADDED
Binary file (2.15 kB). View file
 
datasets/__pycache__/citysunrgbd.cpython-36.pyc ADDED
Binary file (2.01 kB). View file
 
datasets/__pycache__/citysunrgbd.cpython-38.pyc ADDED
Binary file (1.96 kB). View file
 
datasets/__pycache__/get_dataset.cpython-36.pyc ADDED
Binary file (5.37 kB). View file
 
datasets/__pycache__/get_dataset.cpython-39.pyc ADDED
Binary file (5.32 kB). View file
 
datasets/__pycache__/preprocessors.cpython-36.pyc ADDED
Binary file (5.7 kB). View file
 
datasets/__pycache__/preprocessors.cpython-38.pyc ADDED
Binary file (5.3 kB). View file
 
datasets/__pycache__/tfnyu.cpython-36.pyc ADDED
Binary file (2.04 kB). View file
 
datasets/base_dataset.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.utils.data as data
3
+ import numpy as np
4
+ import cv2
5
+ from PIL import Image
6
+ from utils.img_utils import pad_image_to_shape
7
+
8
+ class BaseDataset(data.Dataset):
9
+
10
+ def __init__(self, dataset_settings, mode, unsupervised):
11
+ self._mode = mode
12
+ self.unsupervised = unsupervised
13
+ self._rgb_path = dataset_settings['rgb_root']
14
+ self._depth_path = dataset_settings['depth_root']
15
+ self._gt_path = dataset_settings['gt_root']
16
+ self._train_source = dataset_settings['train_source']
17
+ self._eval_source = dataset_settings['eval_source']
18
+ self.modalities = dataset_settings['modalities']
19
+ # self._file_length = dataset_settings['max_samples']
20
+ self._required_length = dataset_settings['required_length']
21
+ self._file_names = self._get_file_names(mode)
22
+ self.model_input_shape = (dataset_settings['image_height'], dataset_settings['image_width'])
23
+
24
+ def __len__(self):
25
+ if self._required_length is not None:
26
+ return self._required_length
27
+ return len(self._file_names) # when model == "val"
28
+
29
+ def _get_file_names(self, mode):
30
+ assert mode in ['train', 'val']
31
+ source = self._train_source
32
+ if mode == "val":
33
+ source = self._eval_source
34
+
35
+ file_names = []
36
+ with open(source) as f:
37
+ files = f.readlines()
38
+
39
+ for item in files:
40
+ names = self._process_item_names(item)
41
+ file_names.append(names)
42
+
43
+ if mode == "val":
44
+ return file_names
45
+ elif self._required_length <= len(file_names):
46
+ return file_names[:self._required_length]
47
+ else:
48
+ return self._construct_new_file_names(file_names, self._required_length)
49
+
50
+ def _construct_new_file_names(self, file_names, length):
51
+ assert isinstance(length, int)
52
+ files_len = len(file_names)
53
+
54
+ new_file_names = file_names * (length // files_len) #length % files_len items remaining
55
+
56
+ rand_indices = torch.randperm(files_len).tolist()
57
+ new_indices = rand_indices[:length % files_len]
58
+
59
+ new_file_names += [file_names[i] for i in new_indices]
60
+
61
+ return new_file_names
62
+
63
+ def _process_item_names(self, item):
64
+ item = item.strip()
65
+ item = item.split('\t')
66
+ num_modalities = len(self.modalities)
67
+ num_items = len(item)
68
+ names = {}
69
+ if not self.unsupervised:
70
+ assert num_modalities + 1 == num_items, f"Number of modalities and number of items in file name don't match, len(modalities) = {num_modalities} and len(item) = {num_items}" + item[0]
71
+ for i, modality in enumerate(self.modalities):
72
+ names[modality] = item[i]
73
+ names['gt'] = item[-1]
74
+ else:
75
+ assert num_modalities == num_items, f"Number of modalities and number of items in file name don't match, len(modalities) = {num_modalities} and len(item) = {num_items}"
76
+ for i, modality in enumerate(self.modalities):
77
+ names[modality] = item[i]
78
+ names['gt'] = None
79
+
80
+ return names
81
+
82
+ def _open_rgb(self, rgb_path, dtype = None):
83
+ bgr = cv2.imread(rgb_path, cv2.IMREAD_COLOR) #cv2 reads in BGR format, HxWxC
84
+ rgb = np.array(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB), dtype=dtype) #Pretrained PyTorch model accepts image in RGB
85
+ return rgb
86
+
87
+ def _open_depth(self, depth_path, dtype = None): #returns in HxWx3 with the same image in all channels
88
+ img_arr = np.array(Image.open(depth_path))
89
+ if len(img_arr.shape) == 2: # grayscale
90
+ img_arr = np.array(np.tile(img_arr, [3, 1, 1]).transpose(1, 2, 0), dtype = dtype)
91
+ img_arr = (img_arr - img_arr.min()) * 255.0 / (img_arr.max() - img_arr.min())
92
+ return img_arr
93
+
94
+ def _open_depth_tf_nyu(self, depth_path, dtype = None): #returns in HxWx3 with the same image in all channels
95
+ img_arr = np.array(Image.open(depth_path))
96
+ if len(img_arr.shape) == 2: # grayscale
97
+ img_arr = np.tile(img_arr, [3, 1, 1]).transpose(1, 2, 0)
98
+ return img_arr
99
+
100
+ def _open_gt(self, gt_path, dtype = None):
101
+ return np.array(cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE), dtype=dtype)
102
+
103
+ def slide_over_image(self, img, crop_size, stride_rate):
104
+ H, W, C = img.shape
105
+ long_size = H if H > W else W
106
+ output = []
107
+ if long_size <= min(crop_size[0], crop_size[1]):
108
+ raise Exception("Crop size is greater than the image size itself. Not handeled right now")
109
+
110
+ else:
111
+ stride_0 = int(np.ceil(crop_size[0] * stride_rate))
112
+ stride_1 = int(np.ceil(crop_size[1] * stride_rate))
113
+ r_grid = int(np.ceil((H - crop_size[0]) / stride_0)) + 1
114
+ c_grid = int(np.ceil((W - crop_size[1]) / stride_1)) + 1
115
+
116
+ for grid_yidx in range(r_grid):
117
+ for grid_xidx in range(c_grid):
118
+ s_x = grid_xidx * stride_1
119
+ s_y = grid_yidx * stride_0
120
+ e_x = min(s_x + crop_size[1], W)
121
+ e_y = min(s_y + crop_size[0], H)
122
+ s_x = e_x - crop_size[1]
123
+ s_y = e_y - crop_size[0]
124
+ img_sub = img[s_y:e_y, s_x: e_x, :]
125
+ img_sub, margin = pad_image_to_shape(img_sub, crop_size, cv2.BORDER_CONSTANT, value=0)
126
+ output.append((img_sub, np.array([s_y, e_y, s_x, e_x]), margin))
127
+
128
+ return output
datasets/citysunrgbd.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ from datasets.base_dataset import BaseDataset
5
+
6
+
7
+ class CityScapesSunRGBD(BaseDataset):
8
+
9
+ def __init__(self, dataset_settings, mode, unsupervised, preprocess, sliding = False, stride_rate = None):
10
+ super(CityScapesSunRGBD, self).__init__(dataset_settings, mode, unsupervised)
11
+ self.preprocess = preprocess
12
+ self.sliding = sliding
13
+ self.stride_rate = stride_rate
14
+ if self.sliding and self._mode == 'train':
15
+ print("Ensure correct preprocessing is being done!")
16
+
17
+ def __getitem__(self, index):
18
+ # if self._file_length is not None:
19
+ # names = self._construct_new_file_names(self._file_length)[index]
20
+ # else:
21
+ # names = self._file_names[index]
22
+ names = self._file_names[index]
23
+ rgb_path = self._rgb_path+names['rgb']
24
+ depth_path = self._rgb_path+names['depth']
25
+ if not self.unsupervised:
26
+ gt_path = self._gt_path+names['gt']
27
+ item_name = names['rgb'].split("/")[-1].split(".")[0]
28
+
29
+ rgb = self._open_rgb(rgb_path)
30
+ depth = self._open_depth(depth_path)
31
+ gt = None
32
+ if not self.unsupervised:
33
+ gt = self._open_gt(gt_path)
34
+
35
+ if not self.sliding:
36
+ if self.preprocess is not None:
37
+ rgb, depth, gt = self.preprocess(rgb, depth, gt)
38
+
39
+ if self._mode in ['train', 'val']:
40
+ rgb = torch.from_numpy(np.ascontiguousarray(rgb)).float()
41
+ depth = torch.from_numpy(np.ascontiguousarray(depth)).float()
42
+ if gt is not None:
43
+ gt = torch.from_numpy(np.ascontiguousarray(gt)).long()
44
+ else:
45
+ raise Exception(f"{self._mode} not supported in CityScapesSunRGB")
46
+
47
+ # output_dict = dict(rgb=rgb, fn=str(item_name),
48
+ # n=len(self._file_names))
49
+ output_dict = dict(data=[rgb, depth], name = item_name)
50
+ if gt is not None:
51
+ output_dict['gt'] = gt
52
+ return output_dict
53
+
54
+ else:
55
+ sliding_ouptut = self.slide_over_image(rgb, self.model_input_shape, self.stride_rate)
56
+ output_dict = {}
57
+ if self._mode in ['train', 'val']:
58
+ if gt is not None:
59
+ gt = torch.from_numpy(np.ascontiguousarray(gt)).long()
60
+ output_dict['gt'] = gt
61
+ output_dict['sliding_output'] = []
62
+ for img_sub, pos, margin in sliding_ouptut:
63
+ if self.preprocess is not None:
64
+ img_sub, _ = self.preprocess(img_sub, None)
65
+ img_sub = torch.from_numpy(np.ascontiguousarray(img_sub)).float()
66
+ output_dict['sliding_output'].append(([img_sub], pos, margin))
67
+ return output_dict
datasets/get_dataset.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader
3
+ from datasets.citysundepth import CityScapesSunDepth
4
+ from datasets.citysunrgb import CityScapesSunRGB
5
+ from datasets.citysunrgbd import CityScapesSunRGBD
6
+ from datasets.preprocessors import DepthTrainPre, DepthValPre, NYURGBDTrainPre, NYURGBDValPre, RGBDTrainPre, RGBDValPre, RGBTrainPre, RGBValPre
7
+ from datasets.tfnyu import TFNYU
8
+ from utils.constants import Constants as C
9
+
10
+ def get_dataset(args):
11
+ datasetClass = None
12
+ if args.data == "nyudv2":
13
+ return TFNYU
14
+ if args.data == "city" or args.data == "sunrgbd" or args.data == 'stanford_indoor':
15
+ if len(args.modalities) == 1 and args.modalities[0] == 'rgb':
16
+ datasetClass = CityScapesSunRGB
17
+ elif len(args.modalities) == 1 and args.modalities[0] == 'depth':
18
+ datasetClass = CityScapesSunDepth
19
+ elif len(args.modalities) == 2 and args.modalities[0] == 'rgb' and args.modalities[1] == 'depth':
20
+ datasetClass = CityScapesSunRGBD
21
+ else:
22
+ raise Exception(f"{args.modalities} not configured in get_dataset function.")
23
+ else:
24
+ raise Exception(f"{args.data} not configured in get_dataset function.")
25
+ return datasetClass
26
+
27
+ def get_preprocessors(args, dataset_settings, mode):
28
+ if args.data == "nyudv2" and len(args.modalities) == 2 and args.modalities[0] == 'rgb' and args.modalities[1] == 'depth':
29
+ if mode == 'train':
30
+ return NYURGBDTrainPre(C.pytorch_mean, C.pytorch_std, dataset_settings)
31
+ elif mode == 'val':
32
+ return NYURGBDValPre(C.pytorch_mean, C.pytorch_std, dataset_settings)
33
+
34
+ if len(args.modalities) == 1 and args.modalities[0] == 'rgb':
35
+ if mode == 'train':
36
+ return RGBTrainPre(C.pytorch_mean, C.pytorch_std, dataset_settings)
37
+ elif mode == 'val':
38
+ return RGBValPre(C.pytorch_mean, C.pytorch_std, dataset_settings)
39
+ else:
40
+ return Exception("%s mode not defined" % mode)
41
+ elif len(args.modalities) == 1 and args.modalities[0] == 'depth':
42
+ if mode == 'train':
43
+ return DepthTrainPre(dataset_settings)
44
+ elif mode == 'val':
45
+ return DepthValPre(dataset_settings)
46
+ else:
47
+ return Exception("%s mode not defined" % mode)
48
+ elif len(args.modalities) == 2 and args.modalities[0] == 'rgb' and args.modalities[1] == 'depth':
49
+ if mode == 'train':
50
+ return RGBDTrainPre(C.pytorch_mean, C.pytorch_std, dataset_settings)
51
+ elif mode == 'val':
52
+ return RGBDValPre(C.pytorch_mean, C.pytorch_std, dataset_settings)
53
+ else:
54
+ return Exception("%s mode not defined" % mode)
55
+ else:
56
+ raise Exception("%s not configured for preprocessing" % args.modalities)
57
+
58
+ def get_train_loader(datasetClass, args, train_source, unsupervised = False):
59
+ dataset_settings = {'rgb_root': args.rgb_root,
60
+ 'gt_root': args.gt_root,
61
+ 'depth_root': args.depth_root,
62
+ 'train_source': train_source,
63
+ 'eval_source': args.eval_source,
64
+ 'required_length': args.total_train_imgs, #Every dataloader will have Total Train Images / batch size iterations to be consistent
65
+ # 'max_samples': args.max_samples, #Every dataloader will have Total Train Images / batch size iterations to be consistent
66
+ 'train_scale_array': args.train_scale_array,
67
+ 'image_height': args.image_height,
68
+ 'image_width': args.image_width,
69
+ 'modalities': args.modalities}
70
+
71
+ preprocessing = get_preprocessors(args, dataset_settings, "train")
72
+ train_dataset = datasetClass(dataset_settings, "train", unsupervised, preprocessing)
73
+ train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas = args.world_size, rank = args.rank)
74
+ if unsupervised and "unsup_batch_size" in args:
75
+ batch_size = args.unsup_batch_size
76
+ else:
77
+ batch_size = args.batch_size
78
+ train_loader = DataLoader(train_dataset,
79
+ batch_size = args.batch_size // args.world_size,
80
+ num_workers = args.num_workers,
81
+ drop_last = True,
82
+ shuffle = False,
83
+ sampler = train_sampler)
84
+ return train_loader
85
+
86
+ def get_val_loader(datasetClass, args):
87
+ dataset_settings = {'rgb_root': args.rgb_root,
88
+ 'gt_root': args.gt_root,
89
+ 'depth_root': args.depth_root,
90
+ 'train_source': None,
91
+ 'eval_source': args.eval_source,
92
+ 'required_length': None,
93
+ 'max_samples': None,
94
+ 'train_scale_array': args.train_scale_array,
95
+ 'image_height': args.image_height,
96
+ 'image_width': args.image_width,
97
+ 'modalities': args.modalities}
98
+ if args.data == 'sunrgbd':
99
+ eval_sources = []
100
+ for shape in ['427_561', '441_591', '530_730', '531_681']:
101
+ eval_sources.append(dataset_settings['eval_source'].split('.')[0] + '_' + shape + '.txt')
102
+ else:
103
+ eval_sources = [args.eval_source]
104
+
105
+ preprocessing = get_preprocessors(args, dataset_settings, "val")
106
+ if args.sliding_eval:
107
+ collate_fn = _sliding_collate_fn
108
+ else:
109
+ collate_fn = None
110
+
111
+ val_loaders = []
112
+ for eval_source in eval_sources:
113
+ dataset_settings['eval_source'] = eval_source
114
+ val_dataset = datasetClass(dataset_settings, "val", False, preprocessing, args.sliding_eval, args.stride_rate)
115
+ if args.rank is not None: #DDP Evaluation
116
+ val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, num_replicas = args.world_size, rank = args.rank)
117
+ batch_size = args.val_batch_size // args.world_size
118
+ else: #DP Evaluation
119
+ val_sampler = None
120
+ batch_size = args.val_batch_size
121
+
122
+ val_loader = DataLoader(val_dataset,
123
+ batch_size = batch_size,
124
+ num_workers = 4,
125
+ drop_last = False,
126
+ shuffle = False,
127
+ collate_fn = collate_fn,
128
+ sampler = val_sampler)
129
+ val_loaders.append(val_loader)
130
+ return val_loaders
131
+
132
+
133
+ def _sliding_collate_fn(batch):
134
+ gt = torch.stack([b['gt'] for b in batch])
135
+ sliding_output = []
136
+ num_modalities = len(batch[0]['sliding_output'][0][0])
137
+ for i in range(len(batch[0]['sliding_output'])): #i iterates over positions
138
+ imgs = [torch.stack([b['sliding_output'][i][0][m] for b in batch]) for m in range(num_modalities)]
139
+ pos = batch[0]['sliding_output'][i][1]
140
+ pos_compare = [(b['sliding_output'][i][1] == pos).all() for b in batch]
141
+ assert all(pos_compare), f"Position not same for all points in the batch: {pos_compare}, {[b['sliding_output'][i][1] for b in batch]}"
142
+ margin = batch[0]['sliding_output'][i][2]
143
+ margin_compare = [(b['sliding_output'][i][2] == margin).all() for b in batch]
144
+ assert all(margin_compare), f"Margin not same for all points in the batch: {margin_compare}, {[b['sliding_output'][i][2] for b in batch]}"
145
+ sliding_output.append((imgs, pos, margin))
146
+ return {"gt": gt, "sliding_output": sliding_output}
datasets/preprocessors.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.img_utils import normalizedepth, random_crop_pad_to_shape, random_mirror, random_scale, normalize, resizedepth, resizergb, tfnyu_normalizedepth
2
+
3
+ class RGBTrainPre(object):
4
+ def __init__(self, pytorch_mean, pytorch_std, dataset_settings):
5
+ self.pytorch_mean = pytorch_mean
6
+ self.pytorch_std = pytorch_std
7
+ self.train_scale_array = dataset_settings['train_scale_array']
8
+ self.crop_size = (dataset_settings['image_height'], dataset_settings['image_width'])
9
+
10
+ def __call__(self, rgb, gt):
11
+ transformed_dict = random_mirror({"rgb":rgb, "gt":gt})
12
+ if self.train_scale_array is not None:
13
+ transformed_dict, _ = random_scale(transformed_dict, self.train_scale_array, (rgb.shape[0], rgb.shape[1]))
14
+
15
+ transformed_dict, _ = random_crop_pad_to_shape(transformed_dict, transformed_dict['rgb'].shape[:2], self.crop_size) #Makes gt HxWx1
16
+ rgb = transformed_dict['rgb']
17
+ gt = transformed_dict['gt']
18
+ rgb = normalize(rgb, self.pytorch_mean, self.pytorch_std)
19
+
20
+ rgb = rgb.transpose(2, 0, 1) #Brings the channel dimension in the top. Final output = CxHxW
21
+ return rgb, gt
22
+
23
+
24
+ class RGBValPre(object):
25
+ def __init__(self, pytorch_mean, pytorch_std, dataset_settings):
26
+ self.pytorch_mean = pytorch_mean
27
+ self.pytorch_std = pytorch_std
28
+ self.model_input_shape = (dataset_settings['image_height'], dataset_settings['image_width'])
29
+
30
+ def __call__(self, rgb, gt):
31
+ rgb = resizergb(rgb, self.model_input_shape)
32
+ rgb = normalize(rgb, self.pytorch_mean, self.pytorch_std)
33
+ rgb = rgb.transpose(2, 0, 1) #Brings the channel dimension in the top. Final output = CxHxW
34
+ return rgb, gt
35
+
36
+
37
+ class RGBDTrainPre(object):
38
+ def __init__(self, pytorch_mean, pytorch_std, dataset_settings):
39
+ self.pytorch_mean = pytorch_mean
40
+ self.pytorch_std = pytorch_std
41
+ self.train_scale_array = dataset_settings['train_scale_array']
42
+ self.crop_size = (dataset_settings['image_height'], dataset_settings['image_width'])
43
+
44
+ def __call__(self, rgb, depth, gt):
45
+ transformed_dict = random_mirror({"rgb":rgb, "depth": depth, "gt":gt})
46
+ if self.train_scale_array is not None:
47
+ transformed_dict, _ = random_scale(transformed_dict, self.train_scale_array, (rgb.shape[0], rgb.shape[1]))
48
+
49
+ transformed_dict, _ = random_crop_pad_to_shape(transformed_dict, transformed_dict['rgb'].shape[:2], self.crop_size) #Makes gt HxWx1
50
+ rgb = transformed_dict['rgb']
51
+ depth = transformed_dict['depth']
52
+ gt = transformed_dict['gt']
53
+ rgb = normalize(rgb, self.pytorch_mean, self.pytorch_std)
54
+ depth = normalizedepth(depth)
55
+ rgb = rgb.transpose(2, 0, 1) #Brings the channel dimension in the top. Final output = CxHxW
56
+ depth = depth.transpose(2, 0, 1) #Brings the channel dimension in the top. Final output = CxHxW
57
+ return rgb, depth, gt
58
+
59
+
60
+ class RGBDValPre(object):
61
+ def __init__(self, pytorch_mean, pytorch_std, dataset_settings):
62
+ self.pytorch_mean = pytorch_mean
63
+ self.pytorch_std = pytorch_std
64
+ self.model_input_shape = (dataset_settings['image_height'], dataset_settings['image_width'])
65
+
66
+ def __call__(self, rgb, depth):
67
+ if rgb is not None:
68
+ rgb = resizergb(rgb, self.model_input_shape)
69
+ rgb = normalize(rgb, self.pytorch_mean, self.pytorch_std)
70
+ rgb = rgb.transpose(2, 0, 1) #Brings the channel dimension in the top. Final output = CxHxW
71
+ if depth is not None:
72
+ depth = resizedepth(depth, self.model_input_shape)
73
+ depth = normalizedepth(depth)
74
+ depth = depth.transpose(2, 0, 1) #Brings the channel dimension in the top. Final output = CxHxW
75
+
76
+ return rgb, depth
77
+
78
+
79
+ class NYURGBDTrainPre(object):
80
+ def __init__(self, pytorch_mean, pytorch_std, dataset_settings):
81
+ self.pytorch_mean = pytorch_mean
82
+ self.pytorch_std = pytorch_std
83
+ self.train_scale_array = dataset_settings['train_scale_array']
84
+ self.crop_size = (dataset_settings['image_height'], dataset_settings['image_width'])
85
+
86
+ def __call__(self, rgb, depth, gt):
87
+ transformed_dict = random_mirror({"rgb":rgb, "depth": depth, "gt":gt})
88
+ if self.train_scale_array is not None:
89
+ transformed_dict, _ = random_scale(transformed_dict, self.train_scale_array, (rgb.shape[0], rgb.shape[1]))
90
+
91
+ transformed_dict, _ = random_crop_pad_to_shape(transformed_dict, transformed_dict['rgb'].shape[:2], self.crop_size) #Makes gt HxWx1
92
+ rgb = transformed_dict['rgb']
93
+ depth = transformed_dict['depth']
94
+ gt = transformed_dict['gt']
95
+ rgb = normalize(rgb, self.pytorch_mean, self.pytorch_std)
96
+ depth = tfnyu_normalizedepth(depth)
97
+ rgb = rgb.transpose(2, 0, 1) #Brings the channel dimension in the top. Final output = CxHxW
98
+ depth = depth.transpose(2, 0, 1) #Brings the channel dimension in the top. Final output = CxHxW
99
+ return rgb, depth, gt
100
+
101
+
102
+ class NYURGBDValPre(object):
103
+ def __init__(self, pytorch_mean, pytorch_std, dataset_settings):
104
+ self.pytorch_mean = pytorch_mean
105
+ self.pytorch_std = pytorch_std
106
+ self.model_input_shape = (dataset_settings['image_height'], dataset_settings['image_width'])
107
+
108
+ def __call__(self, rgb, depth, gt):
109
+ rgb = resizergb(rgb, self.model_input_shape)
110
+ depth = resizedepth(depth, self.model_input_shape)
111
+ rgb = normalize(rgb, self.pytorch_mean, self.pytorch_std)
112
+ depth = tfnyu_normalizedepth(depth)
113
+ rgb = rgb.transpose(2, 0, 1) #Brings the channel dimension in the top. Final output = CxHxW
114
+ depth = depth.transpose(2, 0, 1) #Brings the channel dimension in the top. Final output = CxHxW
115
+ return rgb, depth, gt
116
+
117
+
118
+ class DepthTrainPre(object):
119
+ def __init__(self, dataset_settings):
120
+ self.train_scale_array = dataset_settings['train_scale_array']
121
+ self.crop_size = (dataset_settings['image_height'], dataset_settings['image_width'])
122
+
123
+ def __call__(self, depth, gt):
124
+ transformed_dict = random_mirror({"depth": depth, "gt":gt})
125
+ if self.train_scale_array is not None:
126
+ transformed_dict, _ = random_scale(transformed_dict, self.train_scale_array, (depth.shape[0], depth.shape[1]))
127
+
128
+ transformed_dict, _ = random_crop_pad_to_shape(transformed_dict, transformed_dict['depth'].shape[:2], self.crop_size) #Makes gt HxWx1
129
+ depth = transformed_dict['depth']
130
+ gt = transformed_dict['gt']
131
+ depth = normalizedepth(depth)
132
+ depth = depth.transpose(2, 0, 1) #Brings the channel dimension in the top. Final output = CxHxW
133
+ return depth, gt
134
+
135
+
136
+ class DepthValPre(object):
137
+ def __init__(self, dataset_settings):
138
+ self.model_input_shape = (dataset_settings['image_height'], dataset_settings['image_width'])
139
+
140
+ def __call__(self, depth, gt):
141
+ depth = resizedepth(depth, self.model_input_shape)
142
+ depth = normalizedepth(depth)
143
+ depth = depth.transpose(2, 0, 1) #Brings the channel dimension in the top. Final output = CxHxW
144
+ return depth, gt
examples/.DS_Store ADDED
Binary file (6.15 kB). View file
 
examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_45_domain.png ADDED

Git LFS Details

  • SHA256: 55885dc912a9accfa7bc492e065b323b1a52f82a5117b41fc8efd4a2b12adaf9
  • Pointer size: 130 Bytes
  • Size of remote file: 83.6 kB
examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_46_domain.png ADDED

Git LFS Details

  • SHA256: 2bf577c404957a98fedd04fb458c4361e181b735c1eeef60939c643b0c3a60a3
  • Pointer size: 130 Bytes
  • Size of remote file: 64.6 kB
examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_47_domain.png ADDED

Git LFS Details

  • SHA256: 96cf8016c66212f327fbda0163b8b03a2daf3c7e6c28ae02faf83bd2e3f11cdb
  • Pointer size: 131 Bytes
  • Size of remote file: 233 kB
examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_48_domain.png ADDED

Git LFS Details

  • SHA256: d11dbca1fa89c97e1eaf8302dd37caf49cbe94f9fd6d8331c61c29d2f786c4b3
  • Pointer size: 130 Bytes
  • Size of remote file: 77.4 kB
examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_49_domain.png ADDED

Git LFS Details

  • SHA256: a56e0ea6823606d15bf04c2a642be8df5485883f5be2823eb020d7ba6f2802f1
  • Pointer size: 131 Bytes
  • Size of remote file: 276 kB
examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_50_domain.png ADDED

Git LFS Details

  • SHA256: 98b9da600b500edd19c89cd0ca0b454a49aa964928d88a7836b3abcf04425cb6
  • Pointer size: 131 Bytes
  • Size of remote file: 194 kB
examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_51_domain.png ADDED

Git LFS Details

  • SHA256: 36db03a470d9826b6dfdd49a804c3d23434ab88d231841c2afb74ef33e69c4a2
  • Pointer size: 131 Bytes
  • Size of remote file: 169 kB
examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_52_domain.png ADDED

Git LFS Details

  • SHA256: 411bc3b1f22ddc361cf8c89cec596f12818e930a9cb3d9e39bfb509c1b9e46f4
  • Pointer size: 131 Bytes
  • Size of remote file: 193 kB
examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_53_domain.png ADDED

Git LFS Details

  • SHA256: 0004251bc5fc96e0df7eeb677444aa50b8d8d569175c408ce48b71792c32b64d
  • Pointer size: 131 Bytes
  • Size of remote file: 256 kB
examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_0_domain.png ADDED

Git LFS Details

  • SHA256: 9aabf2840f638a9c52ae8151cb3677c8786e2350574ecf21c59f3fd44fe9bf91
  • Pointer size: 131 Bytes
  • Size of remote file: 171 kB
examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_10_domain.png ADDED

Git LFS Details

  • SHA256: 3fda21d17a465af9cb332f07fd8fc9bb255a03a48bd4b661a5ca789784dbcad7
  • Pointer size: 131 Bytes
  • Size of remote file: 273 kB
examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_1_domain.png ADDED

Git LFS Details

  • SHA256: 2ec6050af0b91ab3fd7784deb0e005c89eec40a40e51a9936ca499bbb9775fa3
  • Pointer size: 131 Bytes
  • Size of remote file: 251 kB
examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_2_domain.png ADDED

Git LFS Details

  • SHA256: b3800cdc565340cba8097902bf374506b898ef2f59864697a89ea2817e714be7
  • Pointer size: 131 Bytes
  • Size of remote file: 106 kB
examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_3_domain.png ADDED

Git LFS Details

  • SHA256: cc28a5d5e9003fe651e5a4da7683512b8d4f9da0b447fa024dfd9ffd9f8ff6c2
  • Pointer size: 130 Bytes
  • Size of remote file: 96.3 kB
examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_4_domain.png ADDED

Git LFS Details

  • SHA256: 8b25e48bd4afdb9e9d60e1637065caa4ec6a3f56bbc5bd3be6255d32d63d724b
  • Pointer size: 131 Bytes
  • Size of remote file: 150 kB
examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_5_domain.png ADDED

Git LFS Details

  • SHA256: abef7219c2a2aca39183172aff749593a6b10a7cc6f7d0c6f47b3a0eb2186832
  • Pointer size: 131 Bytes
  • Size of remote file: 123 kB
examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_6_domain.png ADDED

Git LFS Details

  • SHA256: 7346e13826370bb01b2aee6ee4396d922dad43af5b714f9d0993faf6ed3a340a
  • Pointer size: 130 Bytes
  • Size of remote file: 82 kB