zxhuang1698 commited on
Commit
414b431
1 Parent(s): 1a618eb

initial commit

Browse files
README.md CHANGED
@@ -1,13 +1,12 @@
1
  ---
2
  title: ZeroShape
3
- emoji: 📚
4
  colorFrom: green
5
  colorTo: blue
6
  sdk: gradio
7
- sdk_version: 4.14.0
 
8
  app_file: app.py
9
- pinned: false
10
  license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: ZeroShape
3
+ emoji: 🔥
4
  colorFrom: green
5
  colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 4.5.0
8
+ python_version: 3.10
9
  app_file: app.py
10
+ pinned: true
11
  license: mit
12
+ ---
 
 
app.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchvision.transforms.functional as torchvision_F
4
+ import numpy as np
5
+ import os
6
+ import shutil
7
+ import importlib
8
+ import trimesh
9
+ import tempfile
10
+ import subprocess
11
+ import utils.options as options
12
+ import shlex
13
+ import time
14
+ import rembg
15
+
16
+ from utils.util import EasyDict as edict
17
+ from PIL import Image
18
+ from utils.eval_3D import get_dense_3D_grid, compute_level_grid, convert_to_explicit
19
+
20
+ def get_1d_bounds(arr):
21
+ nz = np.flatnonzero(arr)
22
+ return nz[0], nz[-1]
23
+
24
+ def get_bbox_from_mask(mask, thr):
25
+ masks_for_box = (mask > thr).astype(np.float32)
26
+ assert masks_for_box.sum() > 0, "Empty mask!"
27
+ x0, x1 = get_1d_bounds(masks_for_box.sum(axis=-2))
28
+ y0, y1 = get_1d_bounds(masks_for_box.sum(axis=-1))
29
+
30
+ return x0, y0, x1, y1
31
+
32
+ def square_crop(image, bbox, crop_ratio=1.):
33
+ x1, y1, x2, y2 = bbox
34
+ h, w = y2-y1, x2-x1
35
+ yc, xc = (y1+y2)/2, (x1+x2)/2
36
+ S = max(h, w)*1.2
37
+ scale = S*crop_ratio
38
+ image = torchvision_F.crop(image, top=int(yc-scale/2), left=int(xc-scale/2), height=int(scale), width=int(scale))
39
+ return image
40
+
41
+ def preprocess_image(opt, image, bbox):
42
+ image = square_crop(image, bbox=bbox)
43
+ if image.size[0] != opt.W or image.size[1] != opt.H:
44
+ image = image.resize((opt.W, opt.H))
45
+ image = torchvision_F.to_tensor(image)
46
+ rgb, mask = image[:3], image[3:]
47
+ if opt.data.bgcolor is not None:
48
+ # replace background color using mask
49
+ rgb = rgb * mask + opt.data.bgcolor * (1 - mask)
50
+ mask = (mask > 0.5).float()
51
+ return rgb, mask
52
+
53
+ def get_image(opt, image_fname, mask_fname):
54
+ image = Image.open(image_fname).convert("RGB")
55
+ mask = Image.open(mask_fname).convert("L")
56
+ mask_np = np.array(mask)
57
+
58
+ #binarize
59
+ mask_np[mask_np <= 127] = 0
60
+ mask_np[mask_np >= 127] = 1.0
61
+
62
+ image = Image.merge("RGBA", (*image.split(), mask))
63
+ bbox = get_bbox_from_mask(mask_np, 0.5)
64
+ rgb_input_map, mask_input_map = preprocess_image(opt, image, bbox=bbox)
65
+ return rgb_input_map, mask_input_map
66
+
67
+ def get_intr(opt):
68
+ # load camera
69
+ f = 1.3875
70
+ K = torch.tensor([[f*opt.W, 0, opt.W/2],
71
+ [0, f*opt.H, opt.H/2],
72
+ [0, 0, 1]]).float()
73
+ return K
74
+
75
+ def get_pixel_grid(H, W, device='cuda'):
76
+ y_range = torch.arange(H, dtype=torch.float32).to(device)
77
+ x_range = torch.arange(W, dtype=torch.float32).to(device)
78
+ Y, X = torch.meshgrid(y_range, x_range, indexing='ij')
79
+ Z = torch.ones_like(Y).to(device)
80
+ xyz_grid = torch.stack([X, Y, Z],dim=-1).view(-1,3)
81
+ return xyz_grid
82
+
83
+ def unproj_depth(depth, intr):
84
+ '''
85
+ depth: [B, H, W]
86
+ intr: [B, 3, 3]
87
+ '''
88
+ batch_size, H, W = depth.shape
89
+ intr = intr.to(depth.device)
90
+
91
+ # [B, 3, 3]
92
+ K_inv = torch.linalg.inv(intr).float()
93
+ # [1, H*W,3]
94
+ pixel_grid = get_pixel_grid(H, W, depth.device).unsqueeze(0)
95
+ # [B, H*W,3]
96
+ pixel_grid = pixel_grid.repeat(batch_size, 1, 1)
97
+ # [B, 3, H*W]
98
+ ray_dirs = K_inv @ pixel_grid.permute(0, 2, 1).contiguous()
99
+ # [B, H*W, 3], in camera coordinates
100
+ seen_points = ray_dirs.permute(0, 2, 1).contiguous() * depth.view(batch_size, H*W, 1)
101
+ # [B, H, W, 3]
102
+ seen_points = seen_points.view(batch_size, H, W, 3)
103
+ return seen_points
104
+
105
+ def prepare_data(opt, image_path, mask_path):
106
+ var = edict()
107
+ rgb_input_map, mask_input_map = get_image(opt, image_path, mask_path)
108
+ intr = get_intr(opt)
109
+ var.rgb_input_map = rgb_input_map.unsqueeze(0).to(opt.device)
110
+ var.mask_input_map = mask_input_map.unsqueeze(0).to(opt.device)
111
+ var.intr = intr.unsqueeze(0).to(opt.device)
112
+ var.idx = torch.tensor([0]).to(opt.device).long()
113
+ var.pose_gt = False
114
+ return var
115
+
116
+ @torch.no_grad()
117
+ def marching_cubes(opt, var, impl_network, visualize_attn=False):
118
+ points_3D = get_dense_3D_grid(opt, var) # [B, N, N, N, 3]
119
+ level_vox, attn_vis = compute_level_grid(opt, impl_network, var.latent_depth, var.latent_semantic,
120
+ points_3D, var.rgb_input_map, visualize_attn)
121
+ if attn_vis: var.attn_vis = attn_vis
122
+ # occ_grids: a list of length B, each is [N, N, N]
123
+ *level_grids, = level_vox.cpu().numpy()
124
+ meshes = convert_to_explicit(opt, level_grids, isoval=0.5, to_pointcloud=False)
125
+ var.mesh_pred = meshes
126
+ return var
127
+
128
+ @torch.no_grad()
129
+ def infer_sample(opt, var, graph):
130
+ var = graph.forward(opt, var, training=False, get_loss=False)
131
+ var = marching_cubes(opt, var, graph.impl_network, visualize_attn=True)
132
+ return var.mesh_pred[0]
133
+
134
+ def infer(input_image_path, input_mask_path):
135
+ opt_cmd = options.parse_arguments(["--yaml=options/shape.yaml", "--datadir=examples", "--eval.vox_res=128", "--ckpt=weights/shape.ckpt"])
136
+ opt = options.set(opt_cmd=opt_cmd, safe_check=False)
137
+ opt.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
138
+
139
+ # build model
140
+ print("Building model...")
141
+ opt.pretrain.depth = None
142
+ opt.arch.depth.pretrained = None
143
+ module = importlib.import_module("model.compute_graph.graph_shape")
144
+ graph = module.Graph(opt).to(opt.device)
145
+
146
+ # download checkpoint
147
+ if not os.path.isfile(opt.ckpt):
148
+ print("Downloading checkpoint...")
149
+ subprocess.run(
150
+ shlex.split(
151
+ "wget -q -O weights/shape.ckpt https://www.dropbox.com/scl/fi/hv3w9z59dqytievwviko4/shape.ckpt?rlkey=a2gut89kavrldmnt8b3df92oi&dl=0"
152
+ )
153
+ )
154
+
155
+ # wait if the checkpoint is still downloading
156
+ while not os.path.isfile(opt.ckpt):
157
+ time.sleep(1)
158
+
159
+ # load checkpoint
160
+ print("Loading checkpoint...")
161
+ checkpoint = torch.load(opt.ckpt, map_location=torch.device(opt.device))
162
+ graph.load_state_dict(checkpoint["graph"], strict=True)
163
+ graph.eval()
164
+
165
+ # load the data
166
+ print("Loading data...")
167
+ var = prepare_data(opt, input_image_path, input_mask_path)
168
+
169
+ # create the save dir
170
+ save_folder = os.path.join(opt.datadir, 'preds')
171
+ if os.path.isdir(save_folder):
172
+ shutil.rmtree(save_folder)
173
+ os.makedirs(save_folder)
174
+ opt.output_path = opt.datadir
175
+
176
+ # inference the model and save the results
177
+ print("Inferencing...")
178
+ mesh_pred = infer_sample(opt, var, graph)
179
+ # rotate the mesh upside down
180
+ mesh_pred.apply_transform(trimesh.transformations.rotation_matrix(np.pi, [1, 0, 0]))
181
+ mesh_path = tempfile.NamedTemporaryFile(suffix=".glb", delete=False)
182
+ mesh_pred.export(mesh_path.name, file_type="glb")
183
+ return mesh_path.name
184
+
185
+ def infer_wrapper_mask(input_image_path, input_mask_path):
186
+ return infer(input_image_path, input_mask_path)
187
+
188
+ def infer_wrapper_nomask(input_image_path):
189
+ input = Image.open(input_image_path)
190
+ segmented = rembg.remove(input)
191
+ mask = segmented.split()[-1]
192
+ mask_path = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
193
+ mask.save(mask_path.name)
194
+ return infer(input_image_path, mask_path.name), mask_path.name
195
+
196
+
197
+ def assert_input_image(input_image):
198
+ if input_image is None:
199
+ raise gr.Error("No image selected or uploaded!")
200
+
201
+ def assert_mask_image(input_mask):
202
+ if input_mask is None:
203
+ raise gr.Error("No mask selected or uploaded! Please check the box if you do not have the mask.")
204
+
205
+ def demo_gradio():
206
+ with gr.Blocks(analytics_enabled=False) as demo_ui:
207
+
208
+ # HEADERS
209
+ with gr.Row():
210
+ with gr.Column(scale=1):
211
+ gr.Markdown('# ZeroShape: Regression-based Zero-shot Shape Reconstruction')
212
+ gr.Markdown("[\[Arxiv\]](https://arxiv.org/pdf/2312.14198.pdf) | [\[Project\]](https://zixuanh.com/projects/zeroshape.html) | [\[GitHub\]](https://github.com/zxhuang1698/ZeroShape)")
213
+ gr.Markdown("Please switch to the \"Estimated Mask\" tab if you do not have the foreground mask. The demo will try to estimate the mask for you.")
214
+
215
+ # with mask
216
+ with gr.Tab("Groundtruth Mask"):
217
+ with gr.Row():
218
+ input_image_tab1 = gr.Image(label="Input Image", image_mode="RGB", sources="upload", type="filepath", elem_id="content_image", width=300)
219
+ mask_tab1 = gr.Image(label="Foreground Mask", image_mode="RGB", sources="upload", type="filepath", elem_id="content_image", width=300)
220
+ output_mesh_tab1 = gr.Model3D(label="Output Mesh")
221
+ with gr.Row():
222
+ submit_tab1 = gr.Button('Reconstruct', elem_id="recon_button_tab1", variant='primary')
223
+ # examples
224
+ with gr.Row():
225
+ examples_tab1 = [
226
+ ['examples/images/armchair.png', 'examples/masks/armchair.png'],
227
+ ['examples/images/bolt.png', 'examples/masks/bolt.png'],
228
+ ['examples/images/bucket.png', 'examples/masks/bucket.png'],
229
+ ['examples/images/case.png', 'examples/masks/case.png'],
230
+ ['examples/images/dispenser.png', 'examples/masks/dispenser.png'],
231
+ ['examples/images/hat.png', 'examples/masks/hat.png'],
232
+ ['examples/images/teddy_bear.png', 'examples/masks/teddy_bear.png'],
233
+ ['examples/images/tiger.png', 'examples/masks/tiger.png'],
234
+ ['examples/images/toy.png', 'examples/masks/toy.png'],
235
+ ['examples/images/wedding_cake.png', 'examples/masks/wedding_cake.png'],
236
+ ]
237
+ gr.Examples(
238
+ examples=examples_tab1,
239
+ inputs=[input_image_tab1, mask_tab1],
240
+ outputs=[output_mesh_tab1],
241
+ fn=infer_wrapper_mask,
242
+ cache_examples=False#os.getenv('SYSTEM') == 'spaces',
243
+ )
244
+ # without mask
245
+ with gr.Tab("Estimated Mask"):
246
+ with gr.Row():
247
+ input_image_tab2 = gr.Image(label="Input Image", image_mode="RGB", sources="upload", type="filepath", elem_id="content_image", width=300)
248
+ mask_tab2 = gr.Image(label="Foreground Mask", image_mode="RGB", sources="upload", type="filepath", elem_id="content_image", width=300)
249
+ output_mesh_tab2 = gr.Model3D(label="Output Mesh")
250
+ with gr.Row():
251
+ submit_tab2 = gr.Button('Reconstruct', elem_id="recon_button_tab2", variant='primary')
252
+ # examples
253
+ with gr.Row():
254
+ examples_tab2 = [
255
+ ['examples/images/armchair.png'],
256
+ ['examples/images/bolt.png'],
257
+ ['examples/images/bucket.png'],
258
+ ['examples/images/case.png'],
259
+ ['examples/images/dispenser.png'],
260
+ ['examples/images/hat.png'],
261
+ ['examples/images/teddy_bear.png'],
262
+ ['examples/images/tiger.png'],
263
+ ['examples/images/toy.png'],
264
+ ['examples/images/wedding_cake.png'],
265
+ ]
266
+ gr.Examples(
267
+ examples=examples_tab2,
268
+ inputs=[input_image_tab2],
269
+ outputs=[output_mesh_tab2, mask_tab2],
270
+ fn=infer_wrapper_nomask,
271
+ cache_examples=False#os.getenv('SYSTEM') == 'spaces',
272
+ )
273
+
274
+ submit_tab1.click(
275
+ fn=assert_input_image,
276
+ inputs=[input_image_tab1],
277
+ queue=False
278
+ ).success(
279
+ fn=assert_mask_image,
280
+ inputs=[mask_tab1],
281
+ queue=False
282
+ ).success(
283
+ fn=infer_wrapper_mask,
284
+ inputs=[input_image_tab1, mask_tab1],
285
+ outputs=[output_mesh_tab1],
286
+ )
287
+
288
+ submit_tab2.click(
289
+ fn=assert_input_image,
290
+ inputs=[input_image_tab2],
291
+ queue=False
292
+ ).success(
293
+ fn=infer_wrapper_nomask,
294
+ inputs=[input_image_tab2],
295
+ outputs=[output_mesh_tab2, mask_tab2],
296
+ )
297
+
298
+ return demo_ui
299
+
300
+ if __name__ == "__main__":
301
+ demo_ui = demo_gradio()
302
+ demo_ui.queue(max_size=10)
303
+ demo_ui.launch()
examples/images/armchair.png ADDED
examples/images/bolt.png ADDED
examples/images/bucket.png ADDED
examples/images/case.png ADDED
examples/images/dispenser.png ADDED
examples/images/hat.png ADDED
examples/images/teddy_bear.png ADDED
examples/images/tiger.png ADDED
examples/images/toy.png ADDED
examples/images/wedding_cake.png ADDED
examples/masks/armchair.png ADDED
examples/masks/bolt.png ADDED
examples/masks/bucket.png ADDED
examples/masks/case.png ADDED
examples/masks/dispenser.png ADDED
examples/masks/hat.png ADDED
examples/masks/teddy_bear.png ADDED
examples/masks/tiger.png ADDED
examples/masks/toy.png ADDED
examples/masks/wedding_cake.png ADDED
model/compute_graph/graph_depth.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from utils.util import EasyDict as edict
5
+ from utils.loss import Loss
6
+ from model.depth.dpt_depth import DPTDepthModel
7
+ from utils.layers import Bottleneck_Conv
8
+ from utils.camera import unproj_depth, valid_norm_fac
9
+
10
+ class Graph(nn.Module):
11
+
12
+ def __init__(self, opt):
13
+ super().__init__()
14
+ # define the depth pred model based on omnidata
15
+ self.dpt_depth = DPTDepthModel(backbone='vitb_rn50_384')
16
+ if opt.arch.depth.pretrained is not None:
17
+ checkpoint = torch.load(opt.arch.depth.pretrained, map_location="cuda:{}".format(opt.device))
18
+ state_dict = checkpoint['model_state_dict']
19
+ self.dpt_depth.load_state_dict(state_dict)
20
+
21
+ if opt.loss_weight.intr is not None:
22
+ self.intr_feat_channels = 768
23
+ self.intr_head = nn.Sequential(
24
+ Bottleneck_Conv(self.intr_feat_channels, kernel_size=3),
25
+ Bottleneck_Conv(self.intr_feat_channels, kernel_size=3),
26
+ )
27
+ self.intr_pool = nn.AdaptiveAvgPool2d((1, 1))
28
+ self.intr_proj = nn.Linear(self.intr_feat_channels, 3)
29
+ # init the last linear layer so it outputs zeros
30
+ nn.init.zeros_(self.intr_proj.weight)
31
+ nn.init.zeros_(self.intr_proj.bias)
32
+
33
+ self.loss_fns = Loss(opt)
34
+
35
+ def intr_param2mtx(self, opt, intr_params):
36
+ '''
37
+ Parameters:
38
+ opt: config
39
+ intr_params: [B, 3], [scale_f, delta_cx, delta_cy]
40
+ Return:
41
+ intr: [B, 3, 3]
42
+ '''
43
+ batch_size = len(intr_params)
44
+ f = 1.3875
45
+ intr = torch.zeros(3, 3).float().to(intr_params.device).unsqueeze(0).repeat(batch_size, 1, 1)
46
+ intr[:, 2, 2] += 1
47
+ # scale the focal length
48
+ # range: [-1, 1], symmetric
49
+ scale_f = torch.tanh(intr_params[:, 0])
50
+ # range: [1/4, 4], symmetric
51
+ scale_f = torch.pow(4. , scale_f)
52
+ intr[:, 0, 0] += f * opt.W * scale_f
53
+ intr[:, 1, 1] += f * opt.H * scale_f
54
+ # shift the optic center, (at most to the image border)
55
+ shift_cx = torch.tanh(intr_params[:, 1]) * opt.W / 2
56
+ shift_cy = torch.tanh(intr_params[:, 2]) * opt.H / 2
57
+ intr[:, 0, 2] += opt.W / 2 + shift_cx
58
+ intr[:, 1, 2] += opt.H / 2 + shift_cy
59
+ return intr
60
+
61
+ def forward(self, opt, var, training=False, get_loss=True):
62
+ batch_size = len(var.idx)
63
+
64
+ # predict the depth map and feature maps if needed
65
+ if opt.loss_weight.intr is None:
66
+ var.depth_pred = self.dpt_depth(var.rgb_input_map)
67
+ else:
68
+ var.depth_pred, intr_feat = self.dpt_depth(var.rgb_input_map, get_feat=True)
69
+ # predict the intrinsics
70
+ intr_feat = self.intr_head(intr_feat)
71
+ intr_feat = self.intr_pool(intr_feat).squeeze(-1).squeeze(-1)
72
+ intr_params = self.intr_proj(intr_feat)
73
+ # [B, 3, 3]
74
+ var.intr_pred = self.intr_param2mtx(opt, intr_params)
75
+
76
+ # project the predicted depth map to 3D points and normalize, [B, H*W, 3]
77
+ seen_points_3D_pred = unproj_depth(opt, var.depth_pred, var.intr_pred)
78
+ seen_points_mean_pred, seen_points_scale_pred = valid_norm_fac(seen_points_3D_pred, var.mask_input_map > 0.5)
79
+ var.seen_points_pred = (seen_points_3D_pred - seen_points_mean_pred.unsqueeze(1)) / seen_points_scale_pred.unsqueeze(-1).unsqueeze(-1)
80
+ var.seen_points_pred[(var.mask_input_map<=0.5).view(batch_size, -1)] = 0
81
+
82
+ if 'depth_input_map' in var or training:
83
+ # project the ground truth depth map to 3D points and normalize, [B, H*W, 3]
84
+ seen_points_3D_gt = unproj_depth(opt, var.depth_input_map, var.intr)
85
+ seen_points_mean_gt, seen_points_scale_gt = valid_norm_fac(seen_points_3D_gt, var.mask_input_map > 0.5)
86
+ var.seen_points_gt = (seen_points_3D_gt - seen_points_mean_gt.unsqueeze(1)) / seen_points_scale_gt.unsqueeze(-1).unsqueeze(-1)
87
+ var.seen_points_gt[(var.mask_input_map<=0.5).view(batch_size, -1)] = 0
88
+
89
+ # record the validity mask, [B, H*W]
90
+ var.validity_mask = (var.mask_input_map>0.5).float().view(batch_size, -1)
91
+
92
+ # calculate the loss if needed
93
+ if get_loss:
94
+ loss = self.compute_loss(opt, var, training)
95
+ return var, loss
96
+
97
+ return var
98
+
99
+ def compute_loss(self, opt, var, training=False):
100
+ loss = edict()
101
+ if opt.loss_weight.depth is not None:
102
+ loss.depth = self.loss_fns.depth_loss(var.depth_pred, var.depth_input_map, var.mask_input_map)
103
+ if opt.loss_weight.intr is not None:
104
+ loss.intr = self.loss_fns.intr_loss(var.seen_points_pred, var.seen_points_gt, var.validity_mask)
105
+ return loss
106
+
model/compute_graph/graph_shape.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from utils.util import EasyDict as edict
5
+ from utils.loss import Loss
6
+ from model.shape.implicit import Implicit
7
+ from model.shape.seen_coord_enc import CoordEncAtt, CoordEncRes
8
+ from model.shape.rgb_enc import RGBEncAtt, RGBEncRes
9
+ from model.depth.dpt_depth import DPTDepthModel
10
+ from utils.util import toggle_grad, interpolate_coordmap, get_child_state_dict
11
+ from utils.camera import unproj_depth, valid_norm_fac
12
+ from utils.layers import Bottleneck_Conv
13
+
14
+ class Graph(nn.Module):
15
+
16
+ def __init__(self, opt):
17
+ super().__init__()
18
+ # define the intrinsics head
19
+ self.intr_feat_channels = 768
20
+ self.intr_head = nn.Sequential(
21
+ Bottleneck_Conv(self.intr_feat_channels, kernel_size=3),
22
+ Bottleneck_Conv(self.intr_feat_channels, kernel_size=3),
23
+ )
24
+ self.intr_pool = nn.AdaptiveAvgPool2d((1, 1))
25
+ self.intr_proj = nn.Linear(self.intr_feat_channels, 3)
26
+ # init the last linear layer so it outputs zeros
27
+ nn.init.zeros_(self.intr_proj.weight)
28
+ nn.init.zeros_(self.intr_proj.bias)
29
+
30
+ # define the depth pred model based on omnidata
31
+ self.dpt_depth = DPTDepthModel(backbone='vitb_rn50_384')
32
+ # load the pretrained depth model
33
+ # when intrinsics need to be predicted we need to load that part as well
34
+ self.load_pretrained_depth(opt)
35
+ if opt.optim.fix_dpt:
36
+ toggle_grad(self.dpt_depth, False)
37
+ toggle_grad(self.intr_head, False)
38
+ toggle_grad(self.intr_proj, False)
39
+
40
+ # encoder that encode seen surface to impl conditioning vec
41
+ if opt.arch.depth.encoder == 'resnet':
42
+ opt.arch.depth.dsp = 1
43
+ self.coord_encoder = CoordEncRes(opt)
44
+ else:
45
+ self.coord_encoder = CoordEncAtt(embed_dim=opt.arch.latent_dim, n_blocks=opt.arch.depth.n_blocks,
46
+ num_heads=opt.arch.num_heads, win_size=opt.arch.win_size//opt.arch.depth.dsp)
47
+
48
+ # rgb branch (not used in final model, keep here for extension)
49
+ if opt.arch.rgb.encoder == 'resnet':
50
+ self.rgb_encoder = RGBEncRes(opt)
51
+ elif opt.arch.rgb.encoder == 'transformer':
52
+ self.rgb_encoder = RGBEncAtt(img_size=opt.H, embed_dim=opt.arch.latent_dim, n_blocks=opt.arch.rgb.n_blocks,
53
+ num_heads=opt.arch.num_heads, win_size=opt.arch.win_size)
54
+ else:
55
+ self.rgb_encoder = None
56
+
57
+ # implicit function
58
+ feat_res = opt.H // opt.arch.win_size
59
+ self.impl_network = Implicit(feat_res**2, latent_dim=opt.arch.latent_dim*2 if self.rgb_encoder else opt.arch.latent_dim,
60
+ semantic=self.rgb_encoder is not None, n_channels=opt.arch.impl.n_channels,
61
+ n_blocks_attn=opt.arch.impl.att_blocks, n_layers_mlp=opt.arch.impl.mlp_layers,
62
+ num_heads=opt.arch.num_heads, posenc_3D=opt.arch.impl.posenc_3D,
63
+ mlp_ratio=opt.arch.impl.mlp_ratio, skip_in=opt.arch.impl.skip_in,
64
+ pos_perlayer=opt.arch.impl.posenc_perlayer)
65
+
66
+ # loss functions
67
+ self.loss_fns = Loss(opt)
68
+
69
+ def load_pretrained_depth(self, opt):
70
+ if opt.pretrain.depth:
71
+ # loading from our pretrained depth and intr model
72
+ if opt.device == 0:
73
+ print("loading dpt depth from {}...".format(opt.pretrain.depth))
74
+ checkpoint = torch.load(opt.pretrain.depth, map_location="cuda:{}".format(opt.device))
75
+ self.dpt_depth.load_state_dict(get_child_state_dict(checkpoint["graph"], "dpt_depth"))
76
+ # load the intr head
77
+ if opt.device == 0:
78
+ print("loading pretrained intr from {}...".format(opt.pretrain.depth))
79
+ self.intr_head.load_state_dict(get_child_state_dict(checkpoint["graph"], "intr_head"))
80
+ self.intr_proj.load_state_dict(get_child_state_dict(checkpoint["graph"], "intr_proj"))
81
+ elif opt.arch.depth.pretrained:
82
+ # loading from omnidata weights
83
+ if opt.device == 0:
84
+ print("loading dpt depth from {}...".format(opt.arch.depth.pretrained))
85
+ checkpoint = torch.load(opt.arch.depth.pretrained, map_location="cuda:{}".format(opt.device))
86
+ state_dict = checkpoint['model_state_dict']
87
+ self.dpt_depth.load_state_dict(state_dict)
88
+
89
+ def intr_param2mtx(self, opt, intr_params):
90
+ '''
91
+ Parameters:
92
+ opt: config
93
+ intr_params: [B, 3], [scale_f, delta_cx, delta_cy]
94
+ Return:
95
+ intr: [B, 3, 3]
96
+ '''
97
+ batch_size = len(intr_params)
98
+ f = 1.3875
99
+ intr = torch.zeros(3, 3).float().to(intr_params.device).unsqueeze(0).repeat(batch_size, 1, 1)
100
+ intr[:, 2, 2] += 1
101
+ # scale the focal length
102
+ # range: [-1, 1], symmetric
103
+ scale_f = torch.tanh(intr_params[:, 0])
104
+ # range: [1/4, 4], symmetric
105
+ scale_f = torch.pow(4. , scale_f)
106
+ intr[:, 0, 0] += f * opt.W * scale_f
107
+ intr[:, 1, 1] += f * opt.H * scale_f
108
+ # shift the optic center, (at most to the image border)
109
+ shift_cx = torch.tanh(intr_params[:, 1]) * opt.W / 2
110
+ shift_cy = torch.tanh(intr_params[:, 2]) * opt.H / 2
111
+ intr[:, 0, 2] += opt.W / 2 + shift_cx
112
+ intr[:, 1, 2] += opt.H / 2 + shift_cy
113
+ return intr
114
+
115
+ def forward(self, opt, var, training=False, get_loss=True):
116
+ batch_size = len(var.idx)
117
+
118
+ # encode the rgb, [B, 3, H, W] -> [B, 1+H/(ws)*W/(ws), C], not used in our final model
119
+ var.latent_semantic = self.rgb_encoder(var.rgb_input_map) if self.rgb_encoder else None
120
+
121
+ # predict the depth map and intrinsics
122
+ var.depth_pred, intr_feat = self.dpt_depth(var.rgb_input_map, get_feat=True)
123
+ depth_map = var.depth_pred
124
+ # predict the intrinsics
125
+ intr_feat = self.intr_head(intr_feat)
126
+ intr_feat = self.intr_pool(intr_feat).squeeze(-1).squeeze(-1)
127
+ intr_params = self.intr_proj(intr_feat)
128
+ # [B, 3, 3]
129
+ var.intr_pred = self.intr_param2mtx(opt, intr_params)
130
+ intr_forward = var.intr_pred
131
+ # record the validity mask, [B, H*W]
132
+ var.validity_mask = (var.mask_input_map>0.5).float().view(batch_size, -1)
133
+
134
+ # project the depth to 3D points in view-centric frame
135
+ # [B, H*W, 3], in camera coordinates
136
+ seen_points_3D_pred = unproj_depth(opt, depth_map, intr_forward)
137
+ # [B, H*W, 3], [B, 1, H, W] (boolean) -> [B, 3], [B]
138
+ seen_points_mean_pred, seen_points_scale_pred = valid_norm_fac(seen_points_3D_pred, var.mask_input_map > 0.5)
139
+ # normalize the seen surface, [B, H*W, 3]
140
+ var.seen_points = (seen_points_3D_pred - seen_points_mean_pred.unsqueeze(1)) / seen_points_scale_pred.unsqueeze(-1).unsqueeze(-1)
141
+ var.seen_points[(var.mask_input_map<=0.5).view(batch_size, -1)] = 0
142
+ # [B, 3, H, W]
143
+ seen_3D_map = var.seen_points.view(batch_size, opt.H, opt.W, 3).permute(0, 3, 1, 2).contiguous()
144
+ seen_3D_dsp, mask_dsp = interpolate_coordmap(seen_3D_map, var.mask_input_map, (opt.H//opt.arch.depth.dsp, opt.W//opt.arch.depth.dsp))
145
+
146
+ # encode the depth, [B, 1, H/k, W/k] -> [B, 1+H/(ws)*W/(ws), C]
147
+ if opt.arch.depth.encoder == 'resnet':
148
+ var.latent_depth = self.coord_encoder(seen_3D_dsp, mask_dsp)
149
+ else:
150
+ var.latent_depth = self.coord_encoder(seen_3D_dsp.permute(0, 2, 3, 1).contiguous(), mask_dsp.squeeze(1)>0.5)
151
+
152
+
153
+ var.pose = var.pose_gt
154
+ # forward for loss calculation (only during training)
155
+ if 'gt_sample_points' in var and 'gt_sample_sdf' in var:
156
+ with torch.no_grad():
157
+ # get the normalizing fac based on the GT seen surface
158
+ # project the GT depth to 3D points in view-centric frame
159
+ # [B, H*W, 3], in camera coordinates
160
+ seen_points_3D_gt = unproj_depth(opt, var.depth_input_map, var.intr)
161
+ # [B, H*W, 3], [B, 1, H, W] (boolean) -> [B, 3], [B]
162
+ seen_points_mean_gt, seen_points_scale_gt = valid_norm_fac(seen_points_3D_gt, var.mask_input_map > 0.5)
163
+ var.seen_points_gt = (seen_points_3D_gt - seen_points_mean_gt.unsqueeze(1)) / seen_points_scale_gt.unsqueeze(-1).unsqueeze(-1)
164
+ var.seen_points_gt[(var.mask_input_map<=0.5).view(batch_size, -1)] = 0
165
+
166
+ # transform the GT points accordingly
167
+ # [B, 3, 3]
168
+ R_gt = var.pose_gt[:, :, :3]
169
+ # [B, 3, 1]
170
+ T_gt = var.pose_gt[:, :, 3:]
171
+ # [B, 3, N]
172
+ gt_sample_points_transposed = var.gt_sample_points.permute(0, 2, 1).contiguous()
173
+ # camera coordinates, [B, N, 3]
174
+ gt_sample_points_cam = (R_gt @ gt_sample_points_transposed + T_gt).permute(0, 2, 1).contiguous()
175
+ # normalize with seen std and mean, [B, N, 3]
176
+ var.gt_points_cam = (gt_sample_points_cam - seen_points_mean_gt.unsqueeze(1)) / seen_points_scale_gt.unsqueeze(-1).unsqueeze(-1)
177
+
178
+ # get near-surface points for visualization
179
+ # [B, 100, 3]
180
+ close_surf_idx = torch.topk(var.gt_sample_sdf.abs(), k=100, dim=1, largest=False)[1].unsqueeze(-1).repeat(1, 1, 3)
181
+ # [B, 100, 3]
182
+ var.gt_surf_points = torch.gather(var.gt_points_cam, dim=1, index=close_surf_idx)
183
+
184
+ # [B, N], [B, N, 1+feat_res**2], inference the impl_network for 3D loss
185
+ var.pred_sample_occ, attn = self.impl_network(var.latent_depth, var.latent_semantic, var.gt_points_cam)
186
+
187
+ # calculate the loss if needed
188
+ if get_loss:
189
+ loss = self.compute_loss(opt, var, training)
190
+ return var, loss
191
+
192
+ return var
193
+
194
+ def compute_loss(self, opt, var, training=False):
195
+ loss = edict()
196
+ if opt.loss_weight.depth is not None:
197
+ loss.depth = self.loss_fns.depth_loss(var.depth_pred, var.depth_input_map, var.mask_input_map)
198
+ if opt.loss_weight.intr is not None and training:
199
+ loss.intr = self.loss_fns.intr_loss(var.seen_points, var.seen_points_gt, var.validity_mask)
200
+ if opt.loss_weight.shape is not None and training:
201
+ loss.shape = self.loss_fns.shape_loss(var.pred_sample_occ, var.gt_sample_sdf)
202
+ return loss
model/depth/__init__.py ADDED
File without changes
model/depth/base_model.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/isl-org/DPT
2
+ import torch
3
+
4
+
5
+ class BaseModel(torch.nn.Module):
6
+ def load(self, path):
7
+ """Load model from file.
8
+
9
+ Args:
10
+ path (str): file path
11
+ """
12
+ parameters = torch.load(path, map_location=torch.device('cpu'))
13
+
14
+ if "optimizer" in parameters:
15
+ parameters = parameters["model"]
16
+
17
+ self.load_state_dict(parameters)
model/depth/blocks.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/isl-org/DPT
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from .vit import (
6
+ _make_pretrained_vitb_rn50_384,
7
+ _make_pretrained_vitl16_384,
8
+ _make_pretrained_vitb16_384,
9
+ forward_vit,
10
+ )
11
+
12
+ def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
13
+ if backbone == "vitl16_384":
14
+ pretrained = _make_pretrained_vitl16_384(
15
+ use_pretrained, hooks=hooks, use_readout=use_readout
16
+ )
17
+ scratch = _make_scratch(
18
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
19
+ ) # ViT-L/16 - 85.0% Top1 (backbone)
20
+ elif backbone == "vitb_rn50_384":
21
+ pretrained = _make_pretrained_vitb_rn50_384(
22
+ use_pretrained,
23
+ hooks=hooks,
24
+ use_vit_only=use_vit_only,
25
+ use_readout=use_readout,
26
+ )
27
+ scratch = _make_scratch(
28
+ [256, 512, 768, 768], features, groups=groups, expand=expand
29
+ ) # ViT-H/16 - 85.0% Top1 (backbone)
30
+ elif backbone == "vitb16_384":
31
+ pretrained = _make_pretrained_vitb16_384(
32
+ use_pretrained, hooks=hooks, use_readout=use_readout
33
+ )
34
+ scratch = _make_scratch(
35
+ [96, 192, 384, 768], features, groups=groups, expand=expand
36
+ ) # ViT-B/16 - 84.6% Top1 (backbone)
37
+ elif backbone == "resnext101_wsl":
38
+ pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
39
+ scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
40
+ elif backbone == "efficientnet_lite3":
41
+ pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
42
+ scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
43
+ else:
44
+ print(f"Backbone '{backbone}' not implemented")
45
+ assert False
46
+
47
+ return pretrained, scratch
48
+
49
+
50
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
51
+ scratch = nn.Module()
52
+
53
+ out_shape1 = out_shape
54
+ out_shape2 = out_shape
55
+ out_shape3 = out_shape
56
+ out_shape4 = out_shape
57
+ if expand==True:
58
+ out_shape1 = out_shape
59
+ out_shape2 = out_shape*2
60
+ out_shape3 = out_shape*4
61
+ out_shape4 = out_shape*8
62
+
63
+ scratch.layer1_rn = nn.Conv2d(
64
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
65
+ )
66
+ scratch.layer2_rn = nn.Conv2d(
67
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
68
+ )
69
+ scratch.layer3_rn = nn.Conv2d(
70
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
71
+ )
72
+ scratch.layer4_rn = nn.Conv2d(
73
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
74
+ )
75
+
76
+ return scratch
77
+
78
+
79
+ def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
80
+ efficientnet = torch.hub.load(
81
+ "rwightman/gen-efficientnet-pytorch",
82
+ "tf_efficientnet_lite3",
83
+ pretrained=use_pretrained,
84
+ exportable=exportable
85
+ )
86
+ return _make_efficientnet_backbone(efficientnet)
87
+
88
+
89
+ def _make_efficientnet_backbone(effnet):
90
+ pretrained = nn.Module()
91
+
92
+ pretrained.layer1 = nn.Sequential(
93
+ effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
94
+ )
95
+ pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
96
+ pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
97
+ pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
98
+
99
+ return pretrained
100
+
101
+
102
+ def _make_resnet_backbone(resnet):
103
+ pretrained = nn.Module()
104
+ pretrained.layer1 = nn.Sequential(
105
+ resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
106
+ )
107
+
108
+ pretrained.layer2 = resnet.layer2
109
+ pretrained.layer3 = resnet.layer3
110
+ pretrained.layer4 = resnet.layer4
111
+
112
+ return pretrained
113
+
114
+
115
+ def _make_pretrained_resnext101_wsl(use_pretrained):
116
+ resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
117
+ return _make_resnet_backbone(resnet)
118
+
119
+
120
+
121
+ class Interpolate(nn.Module):
122
+ """Interpolation module.
123
+ """
124
+
125
+ def __init__(self, scale_factor, mode, align_corners=False):
126
+ """Init.
127
+
128
+ Args:
129
+ scale_factor (float): scaling
130
+ mode (str): interpolation mode
131
+ """
132
+ super(Interpolate, self).__init__()
133
+
134
+ self.interp = nn.functional.interpolate
135
+ self.scale_factor = scale_factor
136
+ self.mode = mode
137
+ self.align_corners = align_corners
138
+
139
+ def forward(self, x):
140
+ """Forward pass.
141
+
142
+ Args:
143
+ x (tensor): input
144
+
145
+ Returns:
146
+ tensor: interpolated data
147
+ """
148
+
149
+ x = self.interp(
150
+ x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
151
+ )
152
+
153
+ return x
154
+
155
+
156
+ class ResidualConvUnit(nn.Module):
157
+ """Residual convolution module.
158
+ """
159
+
160
+ def __init__(self, features):
161
+ """Init.
162
+
163
+ Args:
164
+ features (int): number of features
165
+ """
166
+ super().__init__()
167
+
168
+ self.conv1 = nn.Conv2d(
169
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
170
+ )
171
+
172
+ self.conv2 = nn.Conv2d(
173
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
174
+ )
175
+
176
+ self.relu = nn.ReLU(inplace=True)
177
+
178
+ def forward(self, x):
179
+ """Forward pass.
180
+
181
+ Args:
182
+ x (tensor): input
183
+
184
+ Returns:
185
+ tensor: output
186
+ """
187
+ out = self.relu(x)
188
+ out = self.conv1(out)
189
+ out = self.relu(out)
190
+ out = self.conv2(out)
191
+
192
+ return out + x
193
+
194
+
195
+ class FeatureFusionBlock(nn.Module):
196
+ """Feature fusion block.
197
+ """
198
+
199
+ def __init__(self, features):
200
+ """Init.
201
+
202
+ Args:
203
+ features (int): number of features
204
+ """
205
+ super(FeatureFusionBlock, self).__init__()
206
+
207
+ self.resConfUnit1 = ResidualConvUnit(features)
208
+ self.resConfUnit2 = ResidualConvUnit(features)
209
+
210
+ def forward(self, *xs):
211
+ """Forward pass.
212
+
213
+ Returns:
214
+ tensor: output
215
+ """
216
+ output = xs[0]
217
+
218
+ if len(xs) == 2:
219
+ output += self.resConfUnit1(xs[1])
220
+
221
+ output = self.resConfUnit2(output)
222
+
223
+ output = nn.functional.interpolate(
224
+ output, scale_factor=2, mode="bilinear", align_corners=True
225
+ )
226
+
227
+ return output
228
+
229
+
230
+
231
+
232
+ class ResidualConvUnit_custom(nn.Module):
233
+ """Residual convolution module.
234
+ """
235
+
236
+ def __init__(self, features, activation, bn):
237
+ """Init.
238
+
239
+ Args:
240
+ features (int): number of features
241
+ """
242
+ super().__init__()
243
+
244
+ self.bn = bn
245
+
246
+ self.groups=1
247
+
248
+ self.conv1 = nn.Conv2d(
249
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
250
+ )
251
+
252
+ self.conv2 = nn.Conv2d(
253
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
254
+ )
255
+
256
+ if self.bn==True:
257
+ self.bn1 = nn.BatchNorm2d(features)
258
+ self.bn2 = nn.BatchNorm2d(features)
259
+
260
+ self.activation = activation
261
+
262
+ self.skip_add = nn.quantized.FloatFunctional()
263
+
264
+ def forward(self, x):
265
+ """Forward pass.
266
+
267
+ Args:
268
+ x (tensor): input
269
+
270
+ Returns:
271
+ tensor: output
272
+ """
273
+
274
+ out = self.activation(x)
275
+ out = self.conv1(out)
276
+ if self.bn==True:
277
+ out = self.bn1(out)
278
+
279
+ out = self.activation(out)
280
+ out = self.conv2(out)
281
+ if self.bn==True:
282
+ out = self.bn2(out)
283
+
284
+ if self.groups > 1:
285
+ out = self.conv_merge(out)
286
+
287
+ return self.skip_add.add(out, x)
288
+
289
+ # return out + x
290
+
291
+
292
+ class FeatureFusionBlock_custom(nn.Module):
293
+ """Feature fusion block.
294
+ """
295
+
296
+ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
297
+ """Init.
298
+
299
+ Args:
300
+ features (int): number of features
301
+ """
302
+ super(FeatureFusionBlock_custom, self).__init__()
303
+
304
+ self.deconv = deconv
305
+ self.align_corners = align_corners
306
+
307
+ self.groups=1
308
+
309
+ self.expand = expand
310
+ out_features = features
311
+ if self.expand==True:
312
+ out_features = features//2
313
+
314
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
315
+
316
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
317
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
318
+
319
+ self.skip_add = nn.quantized.FloatFunctional()
320
+
321
+ def forward(self, *xs):
322
+ """Forward pass.
323
+
324
+ Returns:
325
+ tensor: output
326
+ """
327
+ output = xs[0]
328
+
329
+ if len(xs) == 2:
330
+ res = self.resConfUnit1(xs[1])
331
+ output = self.skip_add.add(output, res)
332
+ # output += res
333
+
334
+ output = self.resConfUnit2(output)
335
+
336
+ output = nn.functional.interpolate(
337
+ output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
338
+ )
339
+
340
+ output = self.out_conv(output)
341
+
342
+ return output
343
+
model/depth/dpt_depth.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/isl-org/DPT
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from .base_model import BaseModel
7
+ from .blocks import (
8
+ FeatureFusionBlock,
9
+ FeatureFusionBlock_custom,
10
+ Interpolate,
11
+ _make_encoder,
12
+ forward_vit,
13
+ )
14
+
15
+
16
+ def _make_fusion_block(features, use_bn):
17
+ return FeatureFusionBlock_custom(
18
+ features,
19
+ nn.ReLU(False),
20
+ deconv=False,
21
+ bn=use_bn,
22
+ expand=False,
23
+ align_corners=True,
24
+ )
25
+
26
+
27
+ class DPT(BaseModel):
28
+ def __init__(
29
+ self,
30
+ head,
31
+ features=256,
32
+ backbone="vitb_rn50_384",
33
+ readout="project",
34
+ channels_last=False,
35
+ use_bn=False,
36
+ ):
37
+
38
+ super(DPT, self).__init__()
39
+
40
+ self.channels_last = channels_last
41
+
42
+ hooks = {
43
+ "vitb_rn50_384": [0, 1, 8, 11],
44
+ "vitb16_384": [2, 5, 8, 11],
45
+ "vitl16_384": [5, 11, 17, 23],
46
+ }
47
+
48
+ # Instantiate backbone and reassemble blocks
49
+ self.pretrained, self.scratch = _make_encoder(
50
+ backbone,
51
+ features,
52
+ True, # Set to true of you want to train from scratch, uses ImageNet weights
53
+ groups=1,
54
+ expand=False,
55
+ exportable=False,
56
+ hooks=hooks[backbone],
57
+ use_readout=readout,
58
+ )
59
+
60
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
61
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
62
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
63
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
64
+
65
+ self.scratch.output_conv = head
66
+
67
+
68
+ def forward(self, x, get_feat=False):
69
+ if self.channels_last == True:
70
+ x.contiguous(memory_format=torch.channels_last)
71
+
72
+ layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
73
+
74
+ # res 8x -> 4x -> 2x -> 1x base_size
75
+ # base_size = H / 32
76
+ # all n_channels same (256 by default) after these
77
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
78
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
79
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
80
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
81
+
82
+ # upsample by two with out changing n_channels each time, conv-sum for fusing
83
+ path_4 = self.scratch.refinenet4(layer_4_rn)
84
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
85
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
86
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
87
+
88
+ out = self.scratch.output_conv(path_1)
89
+
90
+ # save the feat if required
91
+ if get_feat:
92
+ return out, layer_4
93
+
94
+ return out
95
+
96
+ class DPTDepthModel(DPT):
97
+ def __init__(self, path=None, non_negative=True, num_channels=1, **kwargs):
98
+ features = kwargs["features"] if "features" in kwargs else 256
99
+
100
+ head = nn.Sequential(
101
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
102
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
103
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
104
+ nn.ReLU(True),
105
+ nn.Conv2d(32, num_channels, kernel_size=1, stride=1, padding=0),
106
+ nn.ReLU(True) if non_negative else nn.Identity(),
107
+ nn.Identity(),
108
+ )
109
+ nn.init.constant_(head[-3].bias, 0.05)
110
+ super().__init__(head, **kwargs)
111
+
112
+ if path is not None:
113
+ self.load(path)
114
+
115
+ def forward(self, image, get_feat=False):
116
+ x = image * 2 - 1
117
+ if get_feat:
118
+ output, feat = super().forward(x, get_feat=get_feat)
119
+ output = output.clamp(min=0, max=1)
120
+ return output, feat
121
+ else:
122
+ output = super().forward(x, get_feat=get_feat).clamp(min=0, max=1)
123
+ return output
model/depth/midas_loss.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/EPFL-VILAB/omnidata
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+
6
+ def masked_l1_loss(preds, target, mask_valid):
7
+ element_wise_loss = abs(preds - target)
8
+ element_wise_loss[~mask_valid] = 0
9
+ return element_wise_loss.sum() / (mask_valid.sum() + 1.e-6)
10
+
11
+ def compute_scale_and_shift(prediction, target, mask):
12
+ # system matrix: A = [[a_00, a_01], [a_10, a_11]]
13
+ a_00 = torch.sum(mask * prediction * prediction, (1, 2))
14
+ a_01 = torch.sum(mask * prediction, (1, 2))
15
+ a_11 = torch.sum(mask, (1, 2))
16
+
17
+ # right hand side: b = [b_0, b_1]
18
+ b_0 = torch.sum(mask * prediction * target, (1, 2))
19
+ b_1 = torch.sum(mask * target, (1, 2))
20
+
21
+ # solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b
22
+ x_0 = torch.zeros_like(b_0)
23
+ x_1 = torch.zeros_like(b_1)
24
+
25
+ det = a_00 * a_11 - a_01 * a_01
26
+ valid = det.nonzero()
27
+
28
+ x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / (det[valid] + 1e-6)
29
+ x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / (det[valid] + 1e-6)
30
+
31
+ return x_0, x_1
32
+
33
+
34
+ def masked_shift_and_scale(depth_preds, depth_gt, mask_valid):
35
+ depth_preds_nan = depth_preds.clone()
36
+ depth_gt_nan = depth_gt.clone()
37
+ depth_preds_nan[~mask_valid] = np.nan
38
+ depth_gt_nan[~mask_valid] = np.nan
39
+
40
+ mask_diff = mask_valid.view(mask_valid.size()[:2] + (-1,)).sum(-1, keepdims=True) + 1
41
+
42
+ # flatten spatial dimension and take valid median [B, 1, 1, 1]
43
+ t_gt = depth_gt_nan.view(depth_gt_nan.size()[:2] + (-1,)).nanmedian(-1, keepdims=True)[0].unsqueeze(-1)
44
+ t_gt[torch.isnan(t_gt)] = 0
45
+ # subtract median and set invalid position to 0
46
+ diff_gt = torch.abs(depth_gt - t_gt)
47
+ diff_gt[~mask_valid] = 0
48
+ # get the avg abs diff value over valid regions [B, 1, 1, 1]
49
+ s_gt = (diff_gt.view(diff_gt.size()[:2] + (-1,)).sum(-1, keepdims=True) / mask_diff).unsqueeze(-1)
50
+ # normalize
51
+ depth_gt_aligned = (depth_gt - t_gt) / (s_gt + 1e-6)
52
+
53
+ # same as gt normalization
54
+ t_pred = depth_preds_nan.view(depth_preds_nan.size()[:2] + (-1,)).nanmedian(-1, keepdims=True)[0].unsqueeze(-1)
55
+ t_pred[torch.isnan(t_pred)] = 0
56
+ diff_pred = torch.abs(depth_preds - t_pred)
57
+ diff_pred[~mask_valid] = 0
58
+ s_pred = (diff_pred.view(diff_pred.size()[:2] + (-1,)).sum(-1, keepdims=True) / mask_diff).unsqueeze(-1)
59
+ depth_pred_aligned = (depth_preds - t_pred) / (s_pred + 1e-6)
60
+
61
+ return depth_pred_aligned, depth_gt_aligned
62
+
63
+
64
+ def reduction_batch_based(image_loss, M):
65
+ # average of all valid pixels of the batch
66
+
67
+ # avoid division by 0 (if sum(M) = sum(sum(mask)) = 0: sum(image_loss) = 0)
68
+ divisor = torch.sum(M)
69
+
70
+ if divisor == 0:
71
+ return 0
72
+ else:
73
+ return torch.sum(image_loss) / divisor
74
+
75
+
76
+ def reduction_image_based(image_loss, M):
77
+ # mean of average of valid pixels of an image
78
+
79
+ # avoid division by 0 (if M = sum(mask) = 0: image_loss = 0)
80
+ valid = M.nonzero()
81
+
82
+ image_loss[valid] = image_loss[valid] / M[valid]
83
+
84
+ return torch.mean(image_loss)
85
+
86
+
87
+
88
+ def gradient_loss(prediction, target, mask, reduction=reduction_batch_based):
89
+
90
+ M = torch.sum(mask, (1, 2))
91
+
92
+ diff = prediction - target
93
+ diff = torch.mul(mask, diff)
94
+
95
+ grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1])
96
+ mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1])
97
+ grad_x = torch.mul(mask_x, grad_x)
98
+
99
+ grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :])
100
+ mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :])
101
+ grad_y = torch.mul(mask_y, grad_y)
102
+
103
+ image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2))
104
+
105
+ return reduction(image_loss, M)
106
+
107
+
108
+
109
+ class SSIMAE(nn.Module):
110
+ def __init__(self):
111
+ super().__init__()
112
+
113
+ def forward(self, depth_preds, depth_gt, mask_valid):
114
+ depth_pred_aligned, depth_gt_aligned = masked_shift_and_scale(depth_preds, depth_gt, mask_valid)
115
+ ssi_mae_loss = masked_l1_loss(depth_pred_aligned, depth_gt_aligned, mask_valid)
116
+ return ssi_mae_loss
117
+
118
+
119
+ class GradientMatchingTerm(nn.Module):
120
+ def __init__(self, scales=4, reduction='batch-based'):
121
+ super().__init__()
122
+
123
+ if reduction == 'batch-based':
124
+ self.__reduction = reduction_batch_based
125
+ else:
126
+ self.__reduction = reduction_image_based
127
+
128
+ self.__scales = scales
129
+
130
+ def forward(self, prediction, target, mask):
131
+ total = 0
132
+
133
+ for scale in range(self.__scales):
134
+ step = pow(2, scale)
135
+
136
+ total += gradient_loss(prediction[:, ::step, ::step], target[:, ::step, ::step],
137
+ mask[:, ::step, ::step], reduction=self.__reduction)
138
+
139
+ return total
140
+
141
+
142
+ class MidasLoss(nn.Module):
143
+ def __init__(self, alpha=0.1, scales=4, reduction='image-based', inverse_depth=True, shrink_mask=False):
144
+ super().__init__()
145
+
146
+ self.__ssi_mae_loss = SSIMAE()
147
+ self.__gradient_matching_term = GradientMatchingTerm(scales=scales, reduction=reduction)
148
+ self.__alpha = alpha
149
+ self.inverse_depth = inverse_depth
150
+ self.shrink_mask = shrink_mask
151
+
152
+ # decrease valid region via min-pooling
153
+ @torch.no_grad()
154
+ def erode_mask(self, mask, max_pool_size=4):
155
+ mask_float = mask.float()
156
+ h, w = mask_float.shape[2], mask_float.shape[3]
157
+ mask_float = 1 - mask_float
158
+ mask_float = torch.nn.functional.max_pool2d(mask_float, kernel_size=max_pool_size)
159
+ mask_float = torch.nn.functional.interpolate(mask_float, (h, w), mode='nearest')
160
+ # only if a 4x4 region is all valid then we make that valid
161
+ mask_valid = mask_float == 0
162
+ return mask_valid
163
+
164
+ def forward(self, prediction_raw, target_raw, mask_raw):
165
+ if self.shrink_mask:
166
+ mask = self.erode_mask(mask_raw)
167
+ else:
168
+ mask = mask_raw > 0.5
169
+ ssi_loss = self.__ssi_mae_loss(prediction_raw, target_raw, mask)
170
+ if self.__alpha <= 0:
171
+ return ssi_loss
172
+
173
+ if self.inverse_depth:
174
+ prediction = 1 / (prediction_raw.squeeze(1) + 1e-6)
175
+ target = 1 / (target_raw.squeeze(1) + 1e-6)
176
+ else:
177
+ prediction = prediction_raw.squeeze(1)
178
+ target = target_raw.squeeze(1)
179
+ # gradient loss
180
+ scale, shift = compute_scale_and_shift(prediction, target, mask.squeeze(1))
181
+ prediction_ssi = scale.view(-1, 1, 1) * prediction + shift.view(-1, 1, 1)
182
+ reg_loss = self.__gradient_matching_term(prediction_ssi, target, mask.squeeze(1))
183
+ total = ssi_loss + self.__alpha * reg_loss
184
+
185
+ return total
model/depth/vit.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/isl-org/DPT
2
+ import torch
3
+ import torch.nn as nn
4
+ import timm
5
+ import types
6
+ import math
7
+ import torch.nn.functional as F
8
+
9
+
10
+ class Slice(nn.Module):
11
+ def __init__(self, start_index=1):
12
+ super(Slice, self).__init__()
13
+ self.start_index = start_index
14
+
15
+ def forward(self, x):
16
+ return x[:, self.start_index :]
17
+
18
+
19
+ class AddReadout(nn.Module):
20
+ def __init__(self, start_index=1):
21
+ super(AddReadout, self).__init__()
22
+ self.start_index = start_index
23
+
24
+ def forward(self, x):
25
+ if self.start_index == 2:
26
+ readout = (x[:, 0] + x[:, 1]) / 2
27
+ else:
28
+ readout = x[:, 0]
29
+ return x[:, self.start_index :] + readout.unsqueeze(1)
30
+
31
+
32
+ class ProjectReadout(nn.Module):
33
+ def __init__(self, in_features, start_index=1):
34
+ super(ProjectReadout, self).__init__()
35
+ self.start_index = start_index
36
+
37
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
38
+
39
+ def forward(self, x):
40
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
41
+ features = torch.cat((x[:, self.start_index :], readout), -1)
42
+
43
+ return self.project(features)
44
+
45
+
46
+ class Transpose(nn.Module):
47
+ def __init__(self, dim0, dim1):
48
+ super(Transpose, self).__init__()
49
+ self.dim0 = dim0
50
+ self.dim1 = dim1
51
+
52
+ def forward(self, x):
53
+ x = x.transpose(self.dim0, self.dim1).contiguous()
54
+ return x
55
+
56
+
57
+ def forward_vit(pretrained, x):
58
+ b, c, h, w = x.shape
59
+
60
+ glob = pretrained.model.forward_flex(x)
61
+
62
+ layer_1 = pretrained.activations["1"]
63
+ layer_2 = pretrained.activations["2"]
64
+ layer_3 = pretrained.activations["3"]
65
+ layer_4 = pretrained.activations["4"]
66
+
67
+ layer_1 = pretrained.act_postprocess1[0:2](layer_1)
68
+ layer_2 = pretrained.act_postprocess2[0:2](layer_2)
69
+ layer_3 = pretrained.act_postprocess3[0:2](layer_3)
70
+ layer_4 = pretrained.act_postprocess4[0:2](layer_4)
71
+
72
+ unflatten = nn.Sequential(
73
+ nn.Unflatten(
74
+ 2,
75
+ torch.Size(
76
+ [
77
+ h // pretrained.model.patch_size[1],
78
+ w // pretrained.model.patch_size[0],
79
+ ]
80
+ ),
81
+ )
82
+ )
83
+
84
+ if layer_1.ndim == 3:
85
+ layer_1 = unflatten(layer_1)
86
+ if layer_2.ndim == 3:
87
+ layer_2 = unflatten(layer_2)
88
+ if layer_3.ndim == 3:
89
+ layer_3 = unflatten(layer_3)
90
+ if layer_4.ndim == 3:
91
+ layer_4 = unflatten(layer_4)
92
+
93
+ layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
94
+ layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
95
+ layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
96
+ layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
97
+
98
+ return layer_1, layer_2, layer_3, layer_4
99
+
100
+
101
+ def _resize_pos_embed(self, posemb, gs_h, gs_w):
102
+ posemb_tok, posemb_grid = (
103
+ posemb[:, : self.start_index],
104
+ posemb[0, self.start_index :],
105
+ )
106
+
107
+ gs_old = int(math.sqrt(len(posemb_grid)))
108
+
109
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2).contiguous()
110
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear", align_corners=False)
111
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1).contiguous()
112
+
113
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
114
+
115
+ return posemb
116
+
117
+
118
+ def forward_flex(self, x):
119
+ b, c, h, w = x.shape
120
+
121
+ pos_embed = self._resize_pos_embed(
122
+ self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
123
+ )
124
+
125
+ B = x.shape[0]
126
+
127
+ if hasattr(self.patch_embed, "backbone"):
128
+ x = self.patch_embed.backbone(x)
129
+ if isinstance(x, (list, tuple)):
130
+ x = x[-1] # last feature if backbone outputs list/tuple of features
131
+
132
+ x = self.patch_embed.proj(x).flatten(2).transpose(1, 2).contiguous()
133
+
134
+ if getattr(self, "dist_token", None) is not None:
135
+ cls_tokens = self.cls_token.expand(
136
+ B, -1, -1
137
+ ) # stole cls_tokens impl from Phil Wang, thanks
138
+ dist_token = self.dist_token.expand(B, -1, -1)
139
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
140
+ else:
141
+ cls_tokens = self.cls_token.expand(
142
+ B, -1, -1
143
+ ) # stole cls_tokens impl from Phil Wang, thanks
144
+ x = torch.cat((cls_tokens, x), dim=1)
145
+
146
+ x = x + pos_embed
147
+ x = self.pos_drop(x)
148
+
149
+ for blk in self.blocks:
150
+ x = blk(x)
151
+
152
+ x = self.norm(x)
153
+
154
+ return x
155
+
156
+
157
+ activations = {}
158
+
159
+
160
+ def get_activation(name):
161
+ def hook(model, input, output):
162
+ activations[name] = output
163
+
164
+ return hook
165
+
166
+
167
+ def get_readout_oper(vit_features, features, use_readout, start_index=1):
168
+ if use_readout == "ignore":
169
+ readout_oper = [Slice(start_index)] * len(features)
170
+ elif use_readout == "add":
171
+ readout_oper = [AddReadout(start_index)] * len(features)
172
+ elif use_readout == "project":
173
+ readout_oper = [
174
+ ProjectReadout(vit_features, start_index) for out_feat in features
175
+ ]
176
+ else:
177
+ assert (
178
+ False
179
+ ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
180
+
181
+ return readout_oper
182
+
183
+
184
+ def _make_vit_b16_backbone(
185
+ model,
186
+ features=[96, 192, 384, 768],
187
+ size=[384, 384],
188
+ hooks=[2, 5, 8, 11],
189
+ vit_features=768,
190
+ use_readout="ignore",
191
+ start_index=1,
192
+ ):
193
+ pretrained = nn.Module()
194
+
195
+ pretrained.model = model
196
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
197
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
198
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
199
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
200
+
201
+ pretrained.activations = activations
202
+
203
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
204
+
205
+ # 32, 48, 136, 384
206
+ pretrained.act_postprocess1 = nn.Sequential(
207
+ readout_oper[0],
208
+ Transpose(1, 2),
209
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
210
+ nn.Conv2d(
211
+ in_channels=vit_features,
212
+ out_channels=features[0],
213
+ kernel_size=1,
214
+ stride=1,
215
+ padding=0,
216
+ ),
217
+ nn.ConvTranspose2d(
218
+ in_channels=features[0],
219
+ out_channels=features[0],
220
+ kernel_size=4,
221
+ stride=4,
222
+ padding=0,
223
+ bias=True,
224
+ dilation=1,
225
+ groups=1,
226
+ ),
227
+ )
228
+
229
+ pretrained.act_postprocess2 = nn.Sequential(
230
+ readout_oper[1],
231
+ Transpose(1, 2),
232
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
233
+ nn.Conv2d(
234
+ in_channels=vit_features,
235
+ out_channels=features[1],
236
+ kernel_size=1,
237
+ stride=1,
238
+ padding=0,
239
+ ),
240
+ nn.ConvTranspose2d(
241
+ in_channels=features[1],
242
+ out_channels=features[1],
243
+ kernel_size=2,
244
+ stride=2,
245
+ padding=0,
246
+ bias=True,
247
+ dilation=1,
248
+ groups=1,
249
+ ),
250
+ )
251
+
252
+ pretrained.act_postprocess3 = nn.Sequential(
253
+ readout_oper[2],
254
+ Transpose(1, 2),
255
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
256
+ nn.Conv2d(
257
+ in_channels=vit_features,
258
+ out_channels=features[2],
259
+ kernel_size=1,
260
+ stride=1,
261
+ padding=0,
262
+ ),
263
+ )
264
+
265
+ pretrained.act_postprocess4 = nn.Sequential(
266
+ readout_oper[3],
267
+ Transpose(1, 2),
268
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
269
+ nn.Conv2d(
270
+ in_channels=vit_features,
271
+ out_channels=features[3],
272
+ kernel_size=1,
273
+ stride=1,
274
+ padding=0,
275
+ ),
276
+ nn.Conv2d(
277
+ in_channels=features[3],
278
+ out_channels=features[3],
279
+ kernel_size=3,
280
+ stride=2,
281
+ padding=1,
282
+ ),
283
+ )
284
+
285
+ pretrained.model.start_index = start_index
286
+ pretrained.model.patch_size = [16, 16]
287
+
288
+ # We inject this function into the VisionTransformer instances so that
289
+ # we can use it with interpolated position embeddings without modifying the library source.
290
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
291
+ pretrained.model._resize_pos_embed = types.MethodType(
292
+ _resize_pos_embed, pretrained.model
293
+ )
294
+
295
+ return pretrained
296
+
297
+
298
+ def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
299
+ model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
300
+
301
+ hooks = [5, 11, 17, 23] if hooks == None else hooks
302
+ return _make_vit_b16_backbone(
303
+ model,
304
+ features=[256, 512, 1024, 1024],
305
+ hooks=hooks,
306
+ vit_features=1024,
307
+ use_readout=use_readout,
308
+ )
309
+
310
+
311
+ def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
312
+ model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
313
+
314
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
315
+ return _make_vit_b16_backbone(
316
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
317
+ )
318
+
319
+
320
+ def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
321
+ model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
322
+
323
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
324
+ return _make_vit_b16_backbone(
325
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
326
+ )
327
+
328
+
329
+ def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
330
+ model = timm.create_model(
331
+ "vit_deit_base_distilled_patch16_384", pretrained=pretrained
332
+ )
333
+
334
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
335
+ return _make_vit_b16_backbone(
336
+ model,
337
+ features=[96, 192, 384, 768],
338
+ hooks=hooks,
339
+ use_readout=use_readout,
340
+ start_index=2,
341
+ )
342
+
343
+
344
+ def _make_vit_b_rn50_backbone(
345
+ model,
346
+ features=[256, 512, 768, 768],
347
+ size=[384, 384],
348
+ hooks=[0, 1, 8, 11],
349
+ vit_features=768,
350
+ use_vit_only=False,
351
+ use_readout="ignore",
352
+ start_index=1,
353
+ ):
354
+ pretrained = nn.Module()
355
+
356
+ pretrained.model = model
357
+
358
+ if use_vit_only == True:
359
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
360
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
361
+ else:
362
+ pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
363
+ get_activation("1")
364
+ )
365
+ pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
366
+ get_activation("2")
367
+ )
368
+
369
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
370
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
371
+
372
+ pretrained.activations = activations
373
+
374
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
375
+
376
+ if use_vit_only == True:
377
+ pretrained.act_postprocess1 = nn.Sequential(
378
+ readout_oper[0],
379
+ Transpose(1, 2),
380
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
381
+ nn.Conv2d(
382
+ in_channels=vit_features,
383
+ out_channels=features[0],
384
+ kernel_size=1,
385
+ stride=1,
386
+ padding=0,
387
+ ),
388
+ nn.ConvTranspose2d(
389
+ in_channels=features[0],
390
+ out_channels=features[0],
391
+ kernel_size=4,
392
+ stride=4,
393
+ padding=0,
394
+ bias=True,
395
+ dilation=1,
396
+ groups=1,
397
+ ),
398
+ )
399
+
400
+ pretrained.act_postprocess2 = nn.Sequential(
401
+ readout_oper[1],
402
+ Transpose(1, 2),
403
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
404
+ nn.Conv2d(
405
+ in_channels=vit_features,
406
+ out_channels=features[1],
407
+ kernel_size=1,
408
+ stride=1,
409
+ padding=0,
410
+ ),
411
+ nn.ConvTranspose2d(
412
+ in_channels=features[1],
413
+ out_channels=features[1],
414
+ kernel_size=2,
415
+ stride=2,
416
+ padding=0,
417
+ bias=True,
418
+ dilation=1,
419
+ groups=1,
420
+ ),
421
+ )
422
+ else:
423
+ pretrained.act_postprocess1 = nn.Sequential(
424
+ nn.Identity(), nn.Identity(), nn.Identity()
425
+ )
426
+ pretrained.act_postprocess2 = nn.Sequential(
427
+ nn.Identity(), nn.Identity(), nn.Identity()
428
+ )
429
+
430
+ pretrained.act_postprocess3 = nn.Sequential(
431
+ readout_oper[2],
432
+ Transpose(1, 2),
433
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
434
+ nn.Conv2d(
435
+ in_channels=vit_features,
436
+ out_channels=features[2],
437
+ kernel_size=1,
438
+ stride=1,
439
+ padding=0,
440
+ ),
441
+ )
442
+
443
+ pretrained.act_postprocess4 = nn.Sequential(
444
+ readout_oper[3],
445
+ Transpose(1, 2),
446
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
447
+ nn.Conv2d(
448
+ in_channels=vit_features,
449
+ out_channels=features[3],
450
+ kernel_size=1,
451
+ stride=1,
452
+ padding=0,
453
+ ),
454
+ nn.Conv2d(
455
+ in_channels=features[3],
456
+ out_channels=features[3],
457
+ kernel_size=3,
458
+ stride=2,
459
+ padding=1,
460
+ ),
461
+ )
462
+
463
+ pretrained.model.start_index = start_index
464
+ pretrained.model.patch_size = [16, 16]
465
+
466
+ # We inject this function into the VisionTransformer instances so that
467
+ # we can use it with interpolated position embeddings without modifying the library source.
468
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
469
+
470
+ # We inject this function into the VisionTransformer instances so that
471
+ # we can use it with interpolated position embeddings without modifying the library source.
472
+ pretrained.model._resize_pos_embed = types.MethodType(
473
+ _resize_pos_embed, pretrained.model
474
+ )
475
+
476
+ return pretrained
477
+
478
+
479
+ def _make_pretrained_vitb_rn50_384(
480
+ pretrained, use_readout="ignore", hooks=None, use_vit_only=False
481
+ ):
482
+ model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
483
+
484
+ hooks = [0, 1, 8, 11] if hooks == None else hooks
485
+ return _make_vit_b_rn50_backbone(
486
+ model,
487
+ features=[256, 512, 768, 768],
488
+ size=[384, 384],
489
+ hooks=hooks,
490
+ use_vit_only=use_vit_only,
491
+ use_readout=use_readout,
492
+ )
model/depth_engine.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os, time, datetime
3
+ import torch
4
+ import torch.utils.tensorboard
5
+ import importlib
6
+ import shutil
7
+ import utils.util as util
8
+ import utils.util_vis as util_vis
9
+
10
+ from torch.nn.parallel import DistributedDataParallel as DDP
11
+ from utils.util import print_eval, setup, cleanup
12
+ from utils.util import EasyDict as edict
13
+ from utils.eval_depth import DepthMetric
14
+ from copy import deepcopy
15
+ from model.compute_graph import graph_depth
16
+
17
+ # ============================ main engine for training and evaluation ============================
18
+
19
+ class Runner():
20
+
21
+ def __init__(self, opt):
22
+ super().__init__()
23
+ if os.path.isdir(opt.output_path) and opt.resume == False and opt.device == 0:
24
+ for filename in os.listdir(opt.output_path):
25
+ if "tfevents" in filename: os.remove(os.path.join(opt.output_path, filename))
26
+ if "html" in filename: os.remove(os.path.join(opt.output_path, filename))
27
+ if "vis" in filename: shutil.rmtree(os.path.join(opt.output_path, filename))
28
+ if "dump" in filename: shutil.rmtree(os.path.join(opt.output_path, filename))
29
+ if "embedding" in filename: shutil.rmtree(os.path.join(opt.output_path, filename))
30
+ if opt.device == 0:
31
+ os.makedirs(opt.output_path,exist_ok=True)
32
+ setup(opt.device, opt.world_size, opt.port)
33
+ opt.batch_size = opt.batch_size // opt.world_size
34
+
35
+ def get_viz_data(self, opt):
36
+ # get data for visualization
37
+ viz_data_list = []
38
+ sample_range = len(self.viz_loader)
39
+ viz_interval = sample_range // opt.eval.n_vis
40
+ for i in range(sample_range):
41
+ current_batch = next(self.viz_loader_iter)
42
+ if i % viz_interval != 0: continue
43
+ viz_data_list.append(current_batch)
44
+ return viz_data_list
45
+
46
+ def load_dataset(self, opt, eval_split="test"):
47
+ data_train = importlib.import_module('data.{}'.format(opt.data.dataset_train))
48
+ data_test = importlib.import_module('data.{}'.format(opt.data.dataset_test))
49
+ if opt.device == 0: print("loading training data...")
50
+ self.batch_order = []
51
+ self.train_data = data_train.Dataset(opt, split="train", load_3D=False)
52
+ self.train_loader = self.train_data.setup_loader(opt, shuffle=True, use_ddp=True, drop_last=True)
53
+ self.num_batches = len(self.train_loader)
54
+ if opt.device == 0: print("loading test data...")
55
+ self.test_data = data_test.Dataset(opt, split=eval_split, load_3D=False)
56
+ self.test_loader = self.test_data.setup_loader(opt, shuffle=False, use_ddp=True, drop_last=True, batch_size=opt.eval.batch_size)
57
+ self.num_batches_test = len(self.test_loader)
58
+ if len(self.test_loader.sampler) * opt.world_size < len(self.test_data):
59
+ self.aux_test_dataset = torch.utils.data.Subset(self.test_data,
60
+ range(len(self.test_loader.sampler) * opt.world_size, len(self.test_data)))
61
+ self.aux_test_loader = torch.utils.data.DataLoader(
62
+ self.aux_test_dataset, batch_size=opt.eval.batch_size, shuffle=False, drop_last=False,
63
+ num_workers=opt.data.num_workers)
64
+ if opt.device == 0:
65
+ print("creating data for visualization...")
66
+ self.viz_loader = self.test_data.setup_loader(opt, shuffle=False, use_ddp=False, drop_last=False, batch_size=1)
67
+ self.viz_loader_iter = iter(self.viz_loader)
68
+ self.viz_data = self.get_viz_data(opt)
69
+
70
+ def build_networks(self, opt):
71
+ if opt.device == 0: print("building networks...")
72
+ self.graph = DDP(graph_depth.Graph(opt).to(opt.device), device_ids=[opt.device], find_unused_parameters=True)
73
+ self.depth_metric = DepthMetric(thresholds=opt.eval.d_thresholds, depth_cap=opt.eval.depth_cap)
74
+
75
+ # =================================================== set up training =========================================================
76
+
77
+ def setup_optimizer(self, opt):
78
+ if opt.device == 0: print("setting up optimizers...")
79
+ param_nodecay = []
80
+ param_decay = []
81
+ for name, param in self.graph.named_parameters():
82
+ # skip and fixed params
83
+ if not param.requires_grad:
84
+ continue
85
+ if param.ndim <= 1 or name.endswith(".bias"):
86
+ # print("{} -> finetune_param_nodecay".format(name))
87
+ param_nodecay.append(param)
88
+ else:
89
+ param_decay.append(param)
90
+ # print("{} -> finetune_param_decay".format(name))
91
+ # create the optim dictionary
92
+ optim_dict = [
93
+ {'params': param_nodecay, 'lr': opt.optim.lr, 'weight_decay': 0.},
94
+ {'params': param_decay, 'lr': opt.optim.lr, 'weight_decay': opt.optim.weight_decay}
95
+ ]
96
+
97
+ self.optim = torch.optim.AdamW(optim_dict, betas=(0.9, 0.95))
98
+ if opt.optim.sched:
99
+ self.sched = torch.optim.lr_scheduler.CosineAnnealingLR(self.optim, opt.max_epoch)
100
+ if opt.optim.amp:
101
+ self.scaler = torch.cuda.amp.GradScaler()
102
+
103
+ def restore_checkpoint(self, opt, best=False, evaluate=False):
104
+ epoch_start, iter_start = None, None
105
+ if opt.resume:
106
+ if opt.device == 0: print("resuming from previous checkpoint...")
107
+ epoch_start, iter_start, best_val, best_ep = util.restore_checkpoint(opt, self, resume=opt.resume, best=best, evaluate=evaluate)
108
+ self.best_val = best_val
109
+ self.best_ep = best_ep
110
+ elif opt.load is not None:
111
+ if opt.device == 0: print("loading weights from checkpoint {}...".format(opt.load))
112
+ epoch_start, iter_start, best_val, best_ep = util.restore_checkpoint(opt, self, load_name=opt.load)
113
+ else:
114
+ if opt.device == 0: print("initializing weights from scratch...")
115
+ self.epoch_start = epoch_start or 0
116
+ self.iter_start = iter_start or 0
117
+
118
+ def setup_visualizer(self, opt, test=False):
119
+ if opt.device == 0:
120
+ print("setting up visualizers...")
121
+ if opt.tb:
122
+ self.tb = torch.utils.tensorboard.SummaryWriter(log_dir=opt.output_path, flush_secs=10)
123
+
124
+ def train(self, opt):
125
+ # before training
126
+ torch.cuda.set_device(opt.device)
127
+ torch.cuda.empty_cache()
128
+ if opt.device == 0: print("TRAINING START")
129
+ self.train_metric_logger = util.MetricLogger(delimiter=" ")
130
+ self.train_metric_logger.add_meter('lr', util.SmoothedValue(window_size=1, fmt='{value:.6f}'))
131
+ self.iter_skip = self.iter_start % len(self.train_loader)
132
+ self.it = self.iter_start
133
+ self.skip_dis = False
134
+ if not opt.resume:
135
+ self.best_val = np.inf
136
+ self.best_ep = 1
137
+ # training
138
+ if self.iter_start == 0 and not opt.debug: self.evaluate(opt, ep=0, training=True)
139
+ # if opt.device == 0: self.save_checkpoint(opt, ep=0, it=0, best_val=self.best_val, best_ep=self.best_ep)
140
+ self.ep = self.epoch_start
141
+ for self.ep in range(self.epoch_start, opt.max_epoch):
142
+ self.train_epoch(opt)
143
+ # after training
144
+ if opt.device == 0: self.save_checkpoint(opt, ep=self.ep, it=self.it, best_val=self.best_val, best_ep=self.best_ep)
145
+ if opt.tb and opt.device == 0:
146
+ self.tb.flush()
147
+ self.tb.close()
148
+ if opt.device == 0:
149
+ print("TRAINING DONE")
150
+ print("Best val: %.4f @ epoch %d" % (self.best_val, self.best_ep))
151
+ cleanup()
152
+
153
+ def train_epoch(self, opt):
154
+ # before train epoch
155
+ self.train_loader.sampler.set_epoch(self.ep)
156
+ if opt.device == 0:
157
+ print("training epoch {}".format(self.ep+1))
158
+ batch_progress = range(self.num_batches)
159
+ self.graph.train()
160
+ # train epoch
161
+ loader = iter(self.train_loader)
162
+
163
+ for batch_id in batch_progress:
164
+ # if resuming from previous checkpoint, skip until the last iteration number is reached
165
+ if self.iter_skip>0:
166
+ self.iter_skip -= 1
167
+ continue
168
+ batch = next(loader)
169
+ # train iteration
170
+ var = edict(batch)
171
+ opt.H, opt.W = opt.image_size
172
+ var = util.move_to_device(var, opt.device)
173
+ loss = self.train_iteration(opt, var, batch_progress)
174
+
175
+ # after train epoch
176
+ lr = self.sched.get_last_lr()[0] if opt.optim.sched else opt.optim.lr
177
+ if opt.optim.sched: self.sched.step()
178
+ if (self.ep + 1) % opt.freq.eval == 0:
179
+ if opt.device == 0: print("validating epoch {}".format(self.ep+1))
180
+ current_val = self.evaluate(opt, ep=self.ep+1, training=True)
181
+ if current_val < self.best_val and opt.device == 0:
182
+ self.best_val = current_val
183
+ self.best_ep = self.ep + 1
184
+ self.save_checkpoint(opt, ep=self.ep, it=self.it, best_val=self.best_val, best_ep=self.best_ep, best=True, latest=True)
185
+
186
+ def train_iteration(self, opt, var, loader):
187
+ # before train iteration
188
+ torch.distributed.barrier()
189
+
190
+ # train iteration
191
+ with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=opt.optim.amp):
192
+ var, loss = self.graph.forward(opt, var, training=True, get_loss=True)
193
+ loss = self.summarize_loss(opt, var, loss)
194
+ loss_scaled = loss.all / opt.optim.accum
195
+
196
+ # backward
197
+ if opt.optim.amp:
198
+ self.scaler.scale(loss_scaled).backward()
199
+ # skip update if accumulating gradient
200
+ if (self.it + 1) % opt.optim.accum == 0:
201
+ self.scaler.unscale_(self.optim)
202
+ # gradient clipping
203
+ if opt.optim.clip_norm:
204
+ norm = torch.nn.utils.clip_grad_norm_(self.graph.parameters(), opt.optim.clip_norm)
205
+ if opt.debug: print("Grad norm: {}".format(norm))
206
+ self.scaler.step(self.optim)
207
+ self.scaler.update()
208
+ self.optim.zero_grad()
209
+ else:
210
+ loss_scaled.backward()
211
+ if (self.it + 1) % opt.optim.accum == 0:
212
+ if opt.optim.clip_norm:
213
+ norm = torch.nn.utils.clip_grad_norm_(self.graph.parameters(), opt.optim.clip_norm)
214
+ if opt.debug: print("Grad norm: {}".format(norm))
215
+ self.optim.step()
216
+ self.optim.zero_grad()
217
+
218
+ # after train iteration
219
+ lr = self.sched.get_last_lr()[0] if opt.optim.sched else opt.optim.lr
220
+ self.train_metric_logger.update(lr=lr)
221
+ self.train_metric_logger.update(loss=loss.all)
222
+ if opt.device == 0:
223
+ if (self.it) % opt.freq.vis == 0 and not opt.debug:
224
+ self.visualize(opt, var, step=self.it, split="train")
225
+ if (self.it+1) % opt.freq.ckpt_latest == 0 and not opt.debug:
226
+ self.save_checkpoint(opt, ep=self.ep, it=self.it+1, best_val=self.best_val, best_ep=self.best_ep, latest=True)
227
+ if (self.it) % opt.freq.scalar == 0 and not opt.debug:
228
+ self.log_scalars(opt, var, loss, step=self.it, split="train")
229
+ if (self.it) % (opt.freq.save_vis * (self.it//10000*10+1)) == 0 and not opt.debug:
230
+ self.vis_train_iter(opt)
231
+ if (self.it) % opt.freq.print == 0:
232
+ print('[{}] '.format(datetime.datetime.now().time()), end='')
233
+ print(f'Train Iter {self.it}/{self.num_batches*opt.max_epoch}: {self.train_metric_logger}')
234
+ self.it += 1
235
+ return loss
236
+
237
+ @torch.no_grad()
238
+ def vis_train_iter(self, opt):
239
+ self.graph.eval()
240
+ for i in range(len(self.viz_data)):
241
+ var_viz = edict(deepcopy(self.viz_data[i]))
242
+ var_viz = util.move_to_device(var_viz, opt.device)
243
+ var_viz = self.graph.module(opt, var_viz, training=False, get_loss=False)
244
+ vis_folder = "vis_log/iter_{}".format(self.it)
245
+ os.makedirs("{}/{}".format(opt.output_path, vis_folder), exist_ok=True)
246
+ util_vis.dump_images(opt, var_viz.idx, "image_input", var_viz.rgb_input_map, masks=None, from_range=(0, 1), folder=vis_folder)
247
+ util_vis.dump_images(opt, var_viz.idx, "mask_input", var_viz.mask_input_map, folder=vis_folder)
248
+ util_vis.dump_depths(opt, var_viz.idx, "depth_est", var_viz.depth_pred, var_viz.mask_input_map, rescale=True, folder=vis_folder)
249
+ util_vis.dump_depths(opt, var_viz.idx, "depth_input", var_viz.depth_input_map, var_viz.mask_input_map, rescale=True, folder=vis_folder)
250
+ if 'seen_points_pred' in var_viz and 'seen_points_gt' in var_viz:
251
+ util_vis.dump_pointclouds_compare(opt, var_viz.idx, "seen_surface", var_viz.seen_points_pred, var_viz.seen_points_gt, folder=vis_folder)
252
+ self.graph.train()
253
+
254
+ def summarize_loss(self, opt, var, loss, non_act_loss_key=[]):
255
+ loss_all = 0.
256
+ assert("all" not in loss)
257
+ # weigh losses
258
+ for key in loss:
259
+ assert(key in opt.loss_weight)
260
+ if opt.loss_weight[key] is not None:
261
+ assert not torch.isinf(loss[key].mean()), "loss {} is Inf".format(key)
262
+ assert not torch.isnan(loss[key].mean()), "loss {} is NaN".format(key)
263
+ loss_all += float(opt.loss_weight[key])*loss[key].mean() if key not in non_act_loss_key else 0.0*loss[key].mean()
264
+ loss.update(all=loss_all)
265
+ return loss
266
+
267
+ # =================================================== set up evaluation =========================================================
268
+
269
+ @torch.no_grad()
270
+ def evaluate(self, opt, ep, training=False):
271
+ self.graph.eval()
272
+ loss_eval = edict()
273
+
274
+ # metric dictionary
275
+ metric_eval = {}
276
+ for metric_key in self.depth_metric.metric_keys:
277
+ metric_eval[metric_key] = []
278
+ metric_avg = {}
279
+ eval_metric_logger = util.MetricLogger(delimiter=" ")
280
+
281
+ # dataloader on the test set
282
+ with torch.cuda.device(opt.device):
283
+ for it, batch in enumerate(self.test_loader):
284
+
285
+ # inference the model
286
+ var = edict(batch)
287
+ var = self.evaluate_batch(opt, var, ep, it, single_gpu=False)
288
+
289
+ # record foreground mae for evaluation
290
+ sample_metrics, var.depth_pred_aligned = self.depth_metric.compute_metrics(
291
+ var.depth_pred, var.depth_input_map, var.mask_eroded if 'mask_eroded' in var else var.mask_input_map)
292
+ var.rmse = sample_metrics['rmse']
293
+ curr_metrics = {}
294
+ for metric_key in metric_eval:
295
+ metric_eval[metric_key].append(sample_metrics[metric_key])
296
+ curr_metrics[metric_key] = sample_metrics[metric_key].mean()
297
+ eval_metric_logger.update(**curr_metrics)
298
+ # eval_metric_logger.update(metric_key=sample_metrics[metric_key].mean())
299
+
300
+ # accumulate the scores
301
+ if opt.device == 0 and it % opt.freq.print_eval == 0:
302
+ print('[{}] '.format(datetime.datetime.now().time()), end='')
303
+ print(f'Eval Iter {it}/{len(self.test_loader)} @ EP {ep}: {eval_metric_logger}')
304
+
305
+ # dump the result if in eval mode
306
+ if not training:
307
+ self.dump_results(opt, var, ep, write_new=(it == 0))
308
+
309
+ # save the visualization
310
+ if it == 0 and training and opt.device == 0:
311
+ print("visualizing and saving results...")
312
+ for i in range(len(self.viz_data)):
313
+ var_viz = edict(deepcopy(self.viz_data[i]))
314
+ var_viz = self.evaluate_batch(opt, var_viz, ep, it, single_gpu=True)
315
+ self.visualize(opt, var_viz, step=ep, split="eval")
316
+ self.dump_results(opt, var_viz, ep, train=True)
317
+
318
+ # collect the eval results into tensors
319
+ for metric_key in metric_eval:
320
+ metric_eval[metric_key] = torch.cat(metric_eval[metric_key], dim=0)
321
+
322
+ if opt.world_size > 1:
323
+ metric_gather_dict = {}
324
+ # empty tensors for gathering
325
+ for metric_key in metric_eval:
326
+ metric_gather_dict[metric_key] = [torch.zeros_like(metric_eval[metric_key]).to(opt.device) for _ in range(opt.world_size)]
327
+
328
+ # gather the metrics
329
+ torch.distributed.barrier()
330
+ for metric_key in metric_eval:
331
+ torch.distributed.all_gather(metric_gather_dict[metric_key], metric_eval[metric_key])
332
+ metric_gather_dict[metric_key] = torch.cat(metric_gather_dict[metric_key], dim=0)
333
+ else:
334
+ metric_gather_dict = metric_eval
335
+
336
+ # handle last batch, if any
337
+ if len(self.test_loader.sampler) * opt.world_size < len(self.test_data):
338
+ for metric_key in metric_eval:
339
+ metric_gather_dict[metric_key] = [metric_gather_dict[metric_key]]
340
+ for batch in self.aux_test_loader:
341
+ # inference the model
342
+ var = edict(batch)
343
+ var = self.evaluate_batch(opt, var, ep, it, single_gpu=False)
344
+
345
+ # record MAE for evaluation
346
+ sample_metrics, var.depth_pred_aligned = self.depth_metric.compute_metrics(
347
+ var.depth_pred, var.depth_input_map, var.mask_eroded if 'mask_eroded' in var else var.mask_input_map)
348
+ var.rmse = sample_metrics['rmse']
349
+ for metric_key in metric_eval:
350
+ metric_gather_dict[metric_key].append(sample_metrics[metric_key])
351
+
352
+ # dump the result if in eval mode
353
+ if not training and opt.device == 0:
354
+ self.dump_results(opt, var, ep, write_new=(it == 0))
355
+
356
+ for metric_key in metric_eval:
357
+ metric_gather_dict[metric_key] = torch.cat(metric_gather_dict[metric_key], dim=0)
358
+
359
+ assert metric_gather_dict['l1_err'].shape[0] == len(self.test_data)
360
+ # compute the mean of the metrics
361
+ for metric_key in metric_eval:
362
+ metric_avg[metric_key] = metric_gather_dict[metric_key].mean()
363
+
364
+ # printout and save the metrics
365
+ if opt.device == 0:
366
+ # print eval info
367
+ print_eval(opt, depth_metrics=metric_avg)
368
+ val_metric = metric_avg['l1_err']
369
+
370
+ if training:
371
+ # log/visualize results to tb/vis
372
+ self.log_scalars(opt, var, loss_eval, metric=metric_avg, step=ep, split="eval")
373
+
374
+ if not training:
375
+ # write to file
376
+ metrics_file = os.path.join(opt.output_path, 'best_val.txt')
377
+ with open(metrics_file, "w") as outfile:
378
+ for metric_key in metric_avg:
379
+ outfile.write('{}: {:.6f}\n'.format(metric_key, metric_avg[metric_key].item()))
380
+
381
+ return val_metric.item()
382
+ return float('inf')
383
+
384
+ def evaluate_batch(self, opt, var, ep=None, it=None, single_gpu=False):
385
+ var = util.move_to_device(var, opt.device)
386
+ if single_gpu:
387
+ var = self.graph.module(opt, var, training=False, get_loss=False)
388
+ else:
389
+ var = self.graph(opt, var, training=False, get_loss=False)
390
+ return var
391
+
392
+ @torch.no_grad()
393
+ def log_scalars(self, opt, var, loss, metric=None, step=0, split="train"):
394
+ if split=="train":
395
+ sample_metrics, _ = self.depth_metric.compute_metrics(
396
+ var.depth_pred, var.depth_input_map, var.mask_eroded if 'mask_eroded' in var else var.mask_input_map)
397
+ metric = dict(L1_ERR=sample_metrics['l1_err'].mean().item())
398
+ for key, value in loss.items():
399
+ if key=="all": continue
400
+ self.tb.add_scalar("{0}/loss_{1}".format(split, key), value.mean(), step)
401
+ if metric is not None:
402
+ for key, value in metric.items():
403
+ self.tb.add_scalar("{0}/{1}".format(split, key), value, step)
404
+
405
+ @torch.no_grad()
406
+ def visualize(self, opt, var, step=0, split="train"):
407
+ pass
408
+
409
+ @torch.no_grad()
410
+ def dump_results(self, opt, var, ep, write_new=False, train=False):
411
+ # create the dir
412
+ current_folder = "dump" if train == False else "vis_{}".format(ep)
413
+ os.makedirs("{}/{}/".format(opt.output_path, current_folder), exist_ok=True)
414
+
415
+ # save the results
416
+ util_vis.dump_images(opt, var.idx, "image_input", var.rgb_input_map, masks=None, from_range=(0, 1), folder=current_folder)
417
+ util_vis.dump_images(opt, var.idx, "mask_input", var.mask_input_map, folder=current_folder)
418
+ util_vis.dump_depths(opt, var.idx, "depth_pred", var.depth_pred, var.mask_input_map, rescale=True, folder=current_folder)
419
+ util_vis.dump_depths(opt, var.idx, "depth_input", var.depth_input_map, var.mask_input_map, rescale=True, folder=current_folder)
420
+ if 'seen_points_pred' in var and 'seen_points_gt' in var:
421
+ util_vis.dump_pointclouds_compare(opt, var.idx, "seen_surface", var.seen_points_pred, var.seen_points_gt, folder=current_folder)
422
+
423
+ if "depth_pred_aligned" in var:
424
+ # get the max and min for the depth map
425
+ batch_size = var.depth_input_map.shape[0]
426
+ mask = var.mask_eroded if 'mask_eroded' in var else var.mask_input_map
427
+ masked_depth_far_bg = var.depth_input_map * mask + (1 - mask) * 1000
428
+ depth_min_gt = masked_depth_far_bg.view(batch_size, -1).min(dim=1)[0]
429
+ masked_depth_invalid_bg = var.depth_input_map * mask + (1 - mask) * 0
430
+ depth_max_gt = masked_depth_invalid_bg.view(batch_size, -1).max(dim=1)[0]
431
+ depth_vis_pred = (var.depth_pred_aligned - depth_min_gt.view(batch_size, 1, 1, 1)) / (depth_max_gt - depth_min_gt).view(batch_size, 1, 1, 1)
432
+ depth_vis_pred = depth_vis_pred * mask + (1 - mask)
433
+ depth_vis_gt = (var.depth_input_map - depth_min_gt.view(batch_size, 1, 1, 1)) / (depth_max_gt - depth_min_gt).view(batch_size, 1, 1, 1)
434
+ depth_vis_gt = depth_vis_gt * mask + (1 - mask)
435
+ util_vis.dump_depths(opt, var.idx, "depth_gt_aligned", depth_vis_gt.clamp(max=1, min=0), None, rescale=False, folder=current_folder)
436
+ util_vis.dump_depths(opt, var.idx, "depth_pred_aligned", depth_vis_pred.clamp(max=1, min=0), None, rescale=False, folder=current_folder)
437
+ if "mask_eroded" in var and "rmse" in var:
438
+ util_vis.dump_images(opt, var.idx, "image_eroded", var.rgb_input_map, masks=var.mask_eroded, metrics=var.rmse, from_range=(0, 1), folder=current_folder)
439
+
440
+ def save_checkpoint(self, opt, ep=0, it=0, best_val=np.inf, best_ep=1, latest=False, best=False):
441
+ util.save_checkpoint(opt, self, ep=ep, it=it, best_val=best_val, best_ep=best_ep, latest=latest, best=best)
442
+ if not latest:
443
+ print("checkpoint saved: ({0}) {1}, epoch {2} (iteration {3})".format(opt.group, opt.name, ep, it))
444
+ if best:
445
+ print("Saving the current model as the best...")
model/shape/implicit.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from functools import partial
6
+ from utils.layers import get_embedder
7
+ from utils.layers import LayerScale
8
+ from timm.models.vision_transformer import Mlp, DropPath
9
+ from utils.pos_embed import get_2d_sincos_pos_embed
10
+
11
+ class ImplFuncAttention(nn.Module):
12
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., last_layer=False):
13
+ super().__init__()
14
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
15
+ self.num_heads = num_heads
16
+ head_dim = dim // num_heads
17
+ self.scale = head_dim ** -0.5
18
+
19
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
20
+ self.attn_drop = nn.Dropout(attn_drop)
21
+ self.proj = nn.Linear(dim, dim)
22
+ self.proj_drop = nn.Dropout(proj_drop)
23
+ self.last_layer = last_layer
24
+
25
+ def forward(self, x, N_points):
26
+
27
+ B, N, C = x.shape
28
+ N_latent = N - N_points
29
+ # [3, B, num_heads, N, C/num_heads]
30
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
31
+ # [B, num_heads, N, C/num_heads]
32
+ q, k, v = qkv.unbind(0)
33
+ # [B, num_heads, N_latent, C/num_heads]
34
+ q_latent, k_latent, v_latent = q[:, :, :-N_points], k[:, :, :-N_points], v[:, :, :-N_points]
35
+ # [B, num_heads, N_points, C/num_heads]
36
+ q_points, k_points, v_points = q[:, :, -N_points:], k[:, :, -N_points:], v[:, :, -N_points:]
37
+
38
+ # attention weight for each point, it's only connected to the latent and itself
39
+ # [B, num_heads, N_points, N_latent+1]
40
+ # get the cross attention, [B, num_heads, N_points, N_latent]
41
+ attn_cross = (q_points @ k_latent.transpose(-2, -1)) * self.scale
42
+ # get the attention to self feature, [B, num_heads, N_points, 1]
43
+ attn_self = torch.sum(q_points * k_points, dim=-1, keepdim=True) * self.scale
44
+ # get the normalized attention, [B, num_heads, N_points, N_latent+1]
45
+ attn_joint = torch.cat([attn_cross, attn_self], dim=-1)
46
+ attn_joint = attn_joint.softmax(dim=-1)
47
+ attn_joint = self.attn_drop(attn_joint)
48
+
49
+ # break it down to weigh and sum the values
50
+ # [B, num_heads, N_points, N_latent] @ [B, num_heads, N_latent, C/num_heads]
51
+ # -> [B, num_heads, N_points, C/num_heads] -> [B, N_points, C]
52
+ sum_cross = (attn_joint[:, :, :, :N_latent] @ v_latent).transpose(1, 2).reshape(B, N_points, C)
53
+ # [B, num_heads, N_points, 1] * [B, num_heads, N_points, C/num_heads]
54
+ # -> [B, num_heads, N_points, C/num_heads] -> [B, N_points, C]
55
+ sum_self = (attn_joint[:, :, :, N_latent:] * v_points).transpose(1, 2).reshape(B, N_points, C)
56
+ # [B, N_points, C]
57
+ output_points = sum_cross + sum_self
58
+
59
+ if self.last_layer:
60
+ output = self.proj(output_points)
61
+ output = self.proj_drop(output)
62
+ # [B, N_points, C], [B, N_points, N_latent]
63
+ return output, attn_joint[..., :-1].mean(dim=1)
64
+
65
+ # attention weight for the latent vec, it's not connected to the points
66
+ # [B, num_heads, N_latent, N_latent]
67
+ attn_latent = (q_latent @ k_latent.transpose(-2, -1)) * self.scale
68
+ attn_latent = attn_latent.softmax(dim=-1)
69
+ attn_latent = self.attn_drop(attn_latent)
70
+ # get the output latent, [B, N_latent, C]
71
+ output_latent = (attn_latent @ v_latent).transpose(1, 2).reshape(B, N_latent, C)
72
+
73
+ # concatenate the output and return
74
+ output = torch.cat([output_latent, output_points], dim=1)
75
+ output = self.proj(output)
76
+ output = self.proj_drop(output)
77
+
78
+ # [B, N, C], [B, N_points, N_latent+1]
79
+ return output, attn_joint[..., :-1].mean(dim=1)
80
+
81
+ class ImplFuncBlock(nn.Module):
82
+
83
+ def __init__(
84
+ self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
85
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, last_layer=False):
86
+ super().__init__()
87
+ self.last_layer = last_layer
88
+ self.norm1 = norm_layer(dim)
89
+ self.attn = ImplFuncAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, last_layer=last_layer)
90
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
91
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
92
+
93
+ self.norm2 = norm_layer(dim)
94
+ mlp_hidden_dim = int(dim * mlp_ratio)
95
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
96
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
97
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
98
+
99
+ def forward(self, x, unseen_size):
100
+ if self.last_layer:
101
+ attn_out, attn_vis = self.attn(self.norm1(x), unseen_size)
102
+ output = x[:, -unseen_size:] + self.drop_path1(self.ls1(attn_out))
103
+ output = output + self.drop_path2(self.ls2(self.mlp(self.norm2(output))))
104
+ return output, attn_vis
105
+ else:
106
+ attn_out, attn_vis = self.attn(self.norm1(x), unseen_size)
107
+ x = x + self.drop_path1(self.ls1(attn_out))
108
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
109
+ return x, attn_vis
110
+
111
+ class LinearProj3D(nn.Module):
112
+ """
113
+ Linear projection of 3D point into embedding space
114
+ """
115
+ def __init__(self, embed_dim, posenc_res=0):
116
+ super().__init__()
117
+ self.embed_dim = embed_dim
118
+
119
+ # define positional embedder
120
+ self.embed_fn = None
121
+ input_ch = 3
122
+ if posenc_res > 0:
123
+ self.embed_fn, input_ch = get_embedder(posenc_res, input_dims=3)
124
+
125
+ # linear proj layer
126
+ self.proj = nn.Linear(input_ch, embed_dim)
127
+
128
+ def forward(self, points_3D):
129
+ if self.embed_fn is not None:
130
+ points_3D = self.embed_fn(points_3D)
131
+ return self.proj(points_3D)
132
+
133
+ class MLPBlocks(nn.Module):
134
+ def __init__(self, num_hidden_layers, n_channels, latent_dim,
135
+ skip_in=[], posenc_res=0):
136
+ super().__init__()
137
+
138
+ # projection to the same number of channels
139
+ self.dims = [3 + latent_dim] + [n_channels] * num_hidden_layers + [1]
140
+ self.num_layers = len(self.dims)
141
+ self.skip_in = skip_in
142
+
143
+ # define positional embedder
144
+ self.embed_fn = None
145
+ if posenc_res > 0:
146
+ embed_fn, input_ch = get_embedder(posenc_res, input_dims=3)
147
+ self.embed_fn = embed_fn
148
+ self.dims[0] += (input_ch - 3)
149
+
150
+ self.layers = nn.ModuleList([])
151
+
152
+ for l in range(0, self.num_layers - 1):
153
+ out_dim = self.dims[l + 1]
154
+ if l in self.skip_in:
155
+ in_dim = self.dims[l] + self.dims[0]
156
+ else:
157
+ in_dim = self.dims[l]
158
+
159
+ lin = nn.Linear(in_dim, out_dim)
160
+ self.layers.append(lin)
161
+
162
+ # register for param init
163
+ self.posenc_res = posenc_res
164
+
165
+ # activation
166
+ self.softplus = nn.Softplus(beta=100)
167
+
168
+ def forward(self, points, proj_latent):
169
+
170
+ # positional encoding
171
+ if self.embed_fn is not None:
172
+ points = self.embed_fn(points)
173
+
174
+ # forward by layer
175
+ # [B, N, posenc+C]
176
+ inputs = torch.cat([points, proj_latent], dim=-1)
177
+ x = inputs
178
+ for l in range(0, self.num_layers - 1):
179
+ if l in self.skip_in:
180
+ x = torch.cat([x, inputs], -1) / np.sqrt(2)
181
+ x = self.layers[l](x)
182
+ if l < self.num_layers - 2:
183
+ x = self.softplus(x)
184
+ return x
185
+
186
+ class Implicit(nn.Module):
187
+ """
188
+ Implicit function conditioned on depth encodings
189
+ """
190
+ def __init__(self,
191
+ num_patches, latent_dim=768, semantic=False, n_channels=512,
192
+ n_blocks_attn=2, n_layers_mlp=6, num_heads=16, posenc_3D=0,
193
+ mlp_ratio=4., norm_layer=partial(nn.LayerNorm, eps=1e-6), drop_path=0.1,
194
+ skip_in=[], pos_perlayer=True):
195
+ super().__init__()
196
+ self.num_patches = num_patches
197
+ self.pos_perlayer = pos_perlayer
198
+ self.semantic = semantic
199
+
200
+ # projection to the same number of channels, no posenc
201
+ self.point_proj = LinearProj3D(n_channels)
202
+ self.latent_proj = nn.Linear(latent_dim, n_channels, bias=True)
203
+
204
+ # positional embedding for the depth latent codes
205
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, n_channels), requires_grad=False) # fixed sin-cos embedding
206
+
207
+ # multi-head attention blocks
208
+ self.blocks_attn = nn.ModuleList([
209
+ ImplFuncBlock(
210
+ n_channels, num_heads, mlp_ratio,
211
+ qkv_bias=True, norm_layer=norm_layer, drop_path=drop_path
212
+ ) for _ in range(n_blocks_attn-1)])
213
+ self.blocks_attn.append(
214
+ ImplFuncBlock(
215
+ n_channels, num_heads, mlp_ratio,
216
+ qkv_bias=True, norm_layer=norm_layer, drop_path=drop_path, last_layer=True
217
+ )
218
+ )
219
+ self.norm = norm_layer(n_channels)
220
+
221
+ self.impl_mlp = None
222
+ # define the impl MLP
223
+ if n_layers_mlp > 0:
224
+ self.impl_mlp = MLPBlocks(n_layers_mlp, n_channels, n_channels,
225
+ skip_in=skip_in, posenc_res=posenc_3D)
226
+ else:
227
+ # occ and color prediction
228
+ self.pred_head = nn.Linear(n_channels, 1, bias=True)
229
+
230
+ self.initialize_weights()
231
+
232
+ def initialize_weights(self):
233
+
234
+ # initialize the positional embedding for the depth latent codes
235
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.num_patches**.5), cls_token=True)
236
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
237
+
238
+ # initialize nn.Linear and nn.LayerNorm
239
+ self.apply(self._init_weights)
240
+
241
+ def _init_weights(self, m):
242
+ if isinstance(m, nn.Linear):
243
+ # we use xavier_uniform following official JAX ViT:
244
+ torch.nn.init.xavier_uniform_(m.weight)
245
+ if isinstance(m, nn.Linear) and m.bias is not None:
246
+ nn.init.constant_(m.bias, 0)
247
+ elif isinstance(m, nn.LayerNorm):
248
+ nn.init.constant_(m.bias, 0)
249
+ nn.init.constant_(m.weight, 1.0)
250
+
251
+ def forward(self, latent_depth, latent_semantic, points_3D):
252
+ # concatenate latent codes if semantic is used
253
+ latent = torch.cat([latent_depth, latent_semantic], dim=-1) if self.semantic else latent_depth
254
+
255
+ # project latent code and add posenc
256
+ # [B, 1+n_patches, C]
257
+ latent = self.latent_proj(latent)
258
+ N_latent = latent.shape[1]
259
+
260
+ # project query points
261
+ # [B, n_points, C_dec]
262
+ points_feat = self.point_proj(points_3D)
263
+
264
+ # concat point feat with latent
265
+ # [B, 1+n_patches+n_points, C_dec]
266
+ output = torch.cat([latent, points_feat], dim=1)
267
+
268
+ # apply multi-head attention blocks
269
+ attn_vis = []
270
+ for l, blk in enumerate(self.blocks_attn):
271
+ if self.pos_perlayer or l == 0:
272
+ output[:, :N_latent] = output[:, :N_latent] + self.pos_embed
273
+ output, attn = blk(output, points_feat.shape[1])
274
+ attn_vis.append(attn)
275
+ output = self.norm(output)
276
+ # average of attention weights across layers, [B, N_points, N_latent+1]
277
+ attn_vis = torch.stack(attn_vis, dim=-1).mean(dim=-1)
278
+
279
+ if self.impl_mlp:
280
+ # apply mlp blocks
281
+ output = self.impl_mlp(points_3D, output)
282
+ else:
283
+ # predictor projection
284
+ # [B, n_points, 1]
285
+ output = self.pred_head(output)
286
+
287
+ # return the occ logit of shape [B, n_points] and the attention weights if needed
288
+ return output.squeeze(-1), attn_vis
model/shape/rgb_enc.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This source code is written based on https://github.com/facebookresearch/MCC
2
+ # The original code base is licensed under the license found in the LICENSE file in the root directory.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchvision
7
+
8
+ from functools import partial
9
+ from timm.models.vision_transformer import Block, PatchEmbed
10
+ from utils.pos_embed import get_2d_sincos_pos_embed
11
+ from utils.layers import Bottleneck_Conv
12
+
13
+ class RGBEncAtt(nn.Module):
14
+ """
15
+ Seen surface encoder based on transformer.
16
+ """
17
+ def __init__(self,
18
+ img_size=224, embed_dim=768, n_blocks=12, num_heads=12, win_size=16,
19
+ mlp_ratio=4., norm_layer=partial(nn.LayerNorm, eps=1e-6), drop_path=0.1):
20
+ super().__init__()
21
+
22
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
23
+ self.rgb_embed = PatchEmbed(img_size, win_size, 3, embed_dim)
24
+
25
+ num_patches = self.rgb_embed.num_patches
26
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False)
27
+
28
+ self.blocks = nn.ModuleList([
29
+ Block(
30
+ embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
31
+ drop_path=drop_path
32
+ ) for _ in range(n_blocks)])
33
+
34
+ self.norm = norm_layer(embed_dim)
35
+
36
+ self.initialize_weights()
37
+
38
+ def initialize_weights(self):
39
+ # initialize the pos enc with fixed cos-sin pattern
40
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.rgb_embed.num_patches**.5), cls_token=True)
41
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
42
+
43
+ # initialize rgb patch_embed like nn.Linear (instead of nn.Conv2d)
44
+ w = self.rgb_embed.proj.weight.data
45
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
46
+
47
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
48
+ torch.nn.init.normal_(self.cls_token, std=.02)
49
+
50
+ # initialize nn.Linear and nn.LayerNorm
51
+ self.apply(self._init_weights)
52
+
53
+ def _init_weights(self, m):
54
+ if isinstance(m, nn.Linear):
55
+ # we use xavier_uniform following official JAX ViT:
56
+ torch.nn.init.xavier_uniform_(m.weight)
57
+ if isinstance(m, nn.Linear) and m.bias is not None:
58
+ nn.init.constant_(m.bias, 0)
59
+ elif isinstance(m, nn.LayerNorm):
60
+ nn.init.constant_(m.bias, 0)
61
+ nn.init.constant_(m.weight, 1.0)
62
+
63
+ def forward(self, rgb_obj):
64
+
65
+ # [B, H/ws*W/ws, C]
66
+ rgb_embedding = self.rgb_embed(rgb_obj)
67
+ rgb_embedding = rgb_embedding + self.pos_embed[:, 1:, :]
68
+
69
+ # append cls token
70
+ # [1, 1, C]
71
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
72
+ # [B, 1, C]
73
+ cls_tokens = cls_token.expand(rgb_embedding.shape[0], -1, -1)
74
+
75
+ # [B, H/ws*W/ws+1, C]
76
+ rgb_embedding = torch.cat((cls_tokens, rgb_embedding), dim=1)
77
+
78
+ # apply Transformer blocks
79
+ for blk in self.blocks:
80
+ rgb_embedding = blk(rgb_embedding)
81
+ rgb_embedding = self.norm(rgb_embedding)
82
+
83
+ # [B, H/ws*W/ws+1, C]
84
+ return rgb_embedding
85
+
86
+ class RGBEncRes(nn.Module):
87
+ """
88
+ RGB encoder based on resnet.
89
+ """
90
+ def __init__(self, opt):
91
+ super().__init__()
92
+
93
+ self.encoder = torchvision.models.resnet50(pretrained=True)
94
+ self.encoder.fc = nn.Sequential(
95
+ Bottleneck_Conv(2048),
96
+ Bottleneck_Conv(2048),
97
+ nn.Linear(2048, opt.arch.latent_dim)
98
+ )
99
+
100
+ # define hooks
101
+ self.rgb_feature = None
102
+ def feature_hook(model, input, output):
103
+ self.rgb_feature = output
104
+
105
+ # attach hooks
106
+ if (opt.arch.win_size) == 16:
107
+ self.encoder.layer3.register_forward_hook(feature_hook)
108
+ self.rgb_feat_proj = nn.Sequential(
109
+ Bottleneck_Conv(1024),
110
+ Bottleneck_Conv(1024),
111
+ nn.Conv2d(1024, opt.arch.latent_dim, 1)
112
+ )
113
+ elif (opt.arch.win_size) == 32:
114
+ self.encoder.layer4.register_forward_hook(feature_hook)
115
+ self.rgb_feat_proj = nn.Sequential(
116
+ Bottleneck_Conv(2048),
117
+ Bottleneck_Conv(2048),
118
+ nn.Conv2d(2048, opt.arch.latent_dim, 1)
119
+ )
120
+ else:
121
+ print('Make sure win_size is 16 or 32 when using resnet backbone!')
122
+ raise NotImplementedError
123
+
124
+ def forward(self, rgb_obj):
125
+ batch_size = rgb_obj.shape[0]
126
+ assert len(rgb_obj.shape) == 4
127
+
128
+ # [B, 1, C]
129
+ global_feat = self.encoder(rgb_obj).unsqueeze(1)
130
+ # [B, C, H/ws*W/ws]
131
+ local_feat = self.rgb_feat_proj(self.rgb_feature).view(batch_size, global_feat.shape[-1], -1)
132
+ # [B, H/ws*W/ws, C]
133
+ local_feat = local_feat.permute(0, 2, 1).contiguous()
134
+ # [B, 1+H/ws*W/ws, C]
135
+ rgb_embedding = torch.cat([global_feat, local_feat], dim=1)
136
+
137
+ return rgb_embedding
model/shape/seen_coord_enc.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This source code is written based on https://github.com/facebookresearch/MCC
2
+ # The original code base is licensed under the license found in the LICENSE file in the root directory.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchvision
7
+
8
+ from functools import partial
9
+ from timm.models.vision_transformer import Block
10
+ from utils.pos_embed import get_2d_sincos_pos_embed
11
+ from utils.layers import Bottleneck_Conv
12
+
13
+ class CoordEmb(nn.Module):
14
+ """
15
+ Encode the seen coordinate map to a lower resolution feature map
16
+ Achieved with window-wise attention block by deviding coord map into windows
17
+ Each window is seperately encoded into a single CLS token with self-attention and posenc
18
+ """
19
+ def __init__(self, embed_dim, win_size=8, num_heads=8):
20
+ super().__init__()
21
+ self.embed_dim = embed_dim
22
+ self.win_size = win_size
23
+
24
+ self.two_d_pos_embed = nn.Parameter(
25
+ torch.zeros(1, self.win_size*self.win_size + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
26
+
27
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
28
+
29
+ self.pos_embed = nn.Linear(3, embed_dim)
30
+
31
+ self.blocks = nn.ModuleList([
32
+ # each block is a residual block with layernorm -> attention -> layernorm -> mlp
33
+ Block(embed_dim, num_heads=num_heads, mlp_ratio=2.0, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6))
34
+ for _ in range(1)
35
+ ])
36
+
37
+ self.invalid_coord_token = nn.Parameter(torch.zeros(embed_dim,))
38
+
39
+ self.initialize_weights()
40
+
41
+ def initialize_weights(self):
42
+ torch.nn.init.normal_(self.cls_token, std=.02)
43
+
44
+ two_d_pos_embed = get_2d_sincos_pos_embed(self.two_d_pos_embed.shape[-1], self.win_size, cls_token=True)
45
+ self.two_d_pos_embed.data.copy_(torch.from_numpy(two_d_pos_embed).float().unsqueeze(0))
46
+
47
+ torch.nn.init.normal_(self.invalid_coord_token, std=.02)
48
+
49
+ def forward(self, coord_obj, mask_obj):
50
+ # [B, H, W, C]
51
+ emb = self.pos_embed(coord_obj)
52
+
53
+ emb[~mask_obj] = 0.0
54
+ emb[~mask_obj] += self.invalid_coord_token
55
+
56
+ B, H, W, C = emb.shape
57
+ # [B, H/ws, 8, W/ws, W, C]
58
+ emb = emb.view(B, H // self.win_size, self.win_size, W // self.win_size, self.win_size, C)
59
+ # [B * H/ws * W/ws, 64, C]
60
+ emb = emb.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.win_size * self.win_size, C)
61
+
62
+ # [B * H/ws * W/ws, 64, C], add posenc that is local to each patch
63
+ emb = emb + self.two_d_pos_embed[:, 1:, :]
64
+ # [1, 1, C]
65
+ cls_token = self.cls_token + self.two_d_pos_embed[:, :1, :]
66
+
67
+ # [B * H/ws * W/ws, 1, C]
68
+ cls_tokens = cls_token.expand(emb.shape[0], -1, -1)
69
+ # [B * H/ws * W/ws, 65, C]
70
+ emb = torch.cat((cls_tokens, emb), dim=1)
71
+
72
+ # transformer (single block) that handle each of the patch seperately
73
+ # reasoning is done within each batch
74
+ for _, blk in enumerate(self.blocks):
75
+ emb = blk(emb)
76
+
77
+ # return the cls token of each window, [B, H/ws*W/ws, C]
78
+ return emb[:, 0].view(B, (H // self.win_size) * (W // self.win_size), -1)
79
+
80
+ class CoordEncAtt(nn.Module):
81
+ """
82
+ Seen surface encoder based on transformer.
83
+ """
84
+ def __init__(self,
85
+ embed_dim=768, n_blocks=12, num_heads=12, win_size=8,
86
+ mlp_ratio=4., norm_layer=partial(nn.LayerNorm, eps=1e-6), drop_path=0.1):
87
+ super().__init__()
88
+
89
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
90
+ self.coord_embed = CoordEmb(embed_dim, win_size, num_heads)
91
+
92
+ self.blocks = nn.ModuleList([
93
+ Block(
94
+ embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
95
+ drop_path=drop_path
96
+ ) for _ in range(n_blocks)])
97
+
98
+ self.norm = norm_layer(embed_dim)
99
+
100
+ self.initialize_weights()
101
+
102
+ def initialize_weights(self):
103
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
104
+ torch.nn.init.normal_(self.cls_token, std=.02)
105
+
106
+ # initialize nn.Linear and nn.LayerNorm
107
+ self.apply(self._init_weights)
108
+
109
+ def _init_weights(self, m):
110
+ if isinstance(m, nn.Linear):
111
+ # we use xavier_uniform following official JAX ViT:
112
+ torch.nn.init.xavier_uniform_(m.weight)
113
+ if isinstance(m, nn.Linear) and m.bias is not None:
114
+ nn.init.constant_(m.bias, 0)
115
+ elif isinstance(m, nn.LayerNorm):
116
+ nn.init.constant_(m.bias, 0)
117
+ nn.init.constant_(m.weight, 1.0)
118
+
119
+ def forward(self, coord_obj, mask_obj):
120
+
121
+ # [B, H/ws*W/ws, C]
122
+ coord_embedding = self.coord_embed(coord_obj, mask_obj)
123
+
124
+ # append cls token
125
+ # [1, 1, C]
126
+ cls_token = self.cls_token
127
+ # [B, 1, C]
128
+ cls_tokens = cls_token.expand(coord_embedding.shape[0], -1, -1)
129
+
130
+ # [B, H/ws*W/ws+1, C]
131
+ coord_embedding = torch.cat((cls_tokens, coord_embedding), dim=1)
132
+
133
+ # apply Transformer blocks
134
+ for blk in self.blocks:
135
+ coord_embedding = blk(coord_embedding)
136
+ coord_embedding = self.norm(coord_embedding)
137
+
138
+ # [B, H/ws*W/ws+1, C]
139
+ return coord_embedding
140
+
141
+ class CoordEncRes(nn.Module):
142
+ """
143
+ Seen surface encoder based on resnet.
144
+ """
145
+ def __init__(self, opt):
146
+ super().__init__()
147
+
148
+ self.encoder = torchvision.models.resnet50(pretrained=True)
149
+ self.encoder.fc = nn.Sequential(
150
+ Bottleneck_Conv(2048),
151
+ Bottleneck_Conv(2048),
152
+ nn.Linear(2048, opt.arch.latent_dim)
153
+ )
154
+
155
+ # define hooks
156
+ self.seen_feature = None
157
+ def feature_hook(model, input, output):
158
+ self.seen_feature = output
159
+
160
+ # attach hooks
161
+ assert opt.arch.depth.dsp == 1
162
+ if (opt.arch.win_size) == 16:
163
+ self.encoder.layer3.register_forward_hook(feature_hook)
164
+ self.depth_feat_proj = nn.Sequential(
165
+ Bottleneck_Conv(1024),
166
+ Bottleneck_Conv(1024),
167
+ nn.Conv2d(1024, opt.arch.latent_dim, 1)
168
+ )
169
+ elif (opt.arch.win_size) == 32:
170
+ self.encoder.layer4.register_forward_hook(feature_hook)
171
+ self.depth_feat_proj = nn.Sequential(
172
+ Bottleneck_Conv(2048),
173
+ Bottleneck_Conv(2048),
174
+ nn.Conv2d(2048, opt.arch.latent_dim, 1)
175
+ )
176
+ else:
177
+ print('Make sure win_size is 16 or 32 when using resnet backbone!')
178
+ raise NotImplementedError
179
+
180
+ def forward(self, coord_obj, mask_obj):
181
+ batch_size = coord_obj.shape[0]
182
+ assert len(coord_obj.shape) == len(mask_obj.shape) == 4
183
+ mask_obj = mask_obj.float()
184
+ coord_obj = coord_obj * mask_obj
185
+
186
+ # [B, 1, C]
187
+ global_feat = self.encoder(coord_obj).unsqueeze(1)
188
+ # [B, C, H/ws*W/ws]
189
+ local_feat = self.depth_feat_proj(self.seen_feature).view(batch_size, global_feat.shape[-1], -1)
190
+ # [B, H/ws*W/ws, C]
191
+ local_feat = local_feat.permute(0, 2, 1).contiguous()
192
+ # [B, 1+H/ws*W/ws, C]
193
+ seen_embedding = torch.cat([global_feat, local_feat], dim=1)
194
+
195
+ return seen_embedding
model/shape_engine.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os, time, datetime
3
+ import torch
4
+ import torch.utils.tensorboard
5
+ import torch.profiler
6
+ import importlib
7
+ import shutil
8
+ import utils.util as util
9
+ import utils.util_vis as util_vis
10
+ import utils.eval_3D as eval_3D
11
+
12
+ from torch.nn.parallel import DistributedDataParallel as DDP
13
+ from utils.util import print_eval, setup, cleanup
14
+ from utils.util import EasyDict as edict
15
+ from copy import deepcopy
16
+ from model.compute_graph import graph_shape
17
+
18
+ # ============================ main engine for training and evaluation ============================
19
+
20
+ class Runner():
21
+
22
+ def __init__(self, opt):
23
+ super().__init__()
24
+ if os.path.isdir(opt.output_path) and opt.resume == False and opt.device == 0:
25
+ for filename in os.listdir(opt.output_path):
26
+ if "tfevents" in filename: os.remove(os.path.join(opt.output_path, filename))
27
+ if "html" in filename: os.remove(os.path.join(opt.output_path, filename))
28
+ if "vis" in filename: shutil.rmtree(os.path.join(opt.output_path, filename))
29
+ if "embedding" in filename: shutil.rmtree(os.path.join(opt.output_path, filename))
30
+ if opt.device == 0:
31
+ os.makedirs(opt.output_path,exist_ok=True)
32
+ setup(opt.device, opt.world_size, opt.port)
33
+ opt.batch_size = opt.batch_size // opt.world_size
34
+
35
+ def get_viz_data(self, opt):
36
+ # get data for visualization
37
+ viz_data_list = []
38
+ sample_range = len(self.viz_loader)
39
+ viz_interval = sample_range // opt.eval.n_vis
40
+ for i in range(sample_range):
41
+ current_batch = next(self.viz_loader_iter)
42
+ if i % viz_interval != 0: continue
43
+ viz_data_list.append(current_batch)
44
+ return viz_data_list
45
+
46
+ def load_dataset(self, opt, eval_split="test"):
47
+ data_train = importlib.import_module('data.{}'.format(opt.data.dataset_train))
48
+ data_test = importlib.import_module('data.{}'.format(opt.data.dataset_test))
49
+ if opt.device == 0: print("loading training data...")
50
+ self.train_data = data_train.Dataset(opt, split="train")
51
+ self.train_loader = self.train_data.setup_loader(opt, shuffle=True, use_ddp=True, drop_last=True)
52
+ self.num_batches = len(self.train_loader)
53
+ if opt.device == 0: print("loading test data...")
54
+ self.test_data = data_test.Dataset(opt, split=eval_split)
55
+ self.test_loader = self.test_data.setup_loader(opt, shuffle=False, use_ddp=True, drop_last=True, batch_size=opt.eval.batch_size)
56
+ self.num_batches_test = len(self.test_loader)
57
+ if len(self.test_loader.sampler) * opt.world_size < len(self.test_data):
58
+ self.aux_test_dataset = torch.utils.data.Subset(self.test_data,
59
+ range(len(self.test_loader.sampler) * opt.world_size, len(self.test_data)))
60
+ self.aux_test_loader = torch.utils.data.DataLoader(
61
+ self.aux_test_dataset, batch_size=opt.eval.batch_size, shuffle=False, drop_last=False,
62
+ num_workers=opt.data.num_workers)
63
+ if opt.device == 0:
64
+ print("creating data for visualization...")
65
+ self.viz_loader = self.test_data.setup_loader(opt, shuffle=False, use_ddp=False, drop_last=False, batch_size=1)
66
+ self.viz_loader_iter = iter(self.viz_loader)
67
+ self.viz_data = self.get_viz_data(opt)
68
+
69
+ def build_networks(self, opt):
70
+ if opt.device == 0: print("building networks...")
71
+ self.graph = DDP(graph_shape.Graph(opt).to(opt.device), device_ids=[opt.device], find_unused_parameters=(not opt.optim.fix_dpt or not opt.optim.fix_clip))
72
+
73
+ # =================================================== set up training =========================================================
74
+
75
+ def setup_optimizer(self, opt):
76
+ if opt.device == 0: print("setting up optimizers...")
77
+ if opt.optim.fix_dpt:
78
+ # when we do not need to train the dpt depth, every param will start from scratch
79
+ scratch_param_decay = []
80
+ scratch_param_nodecay = []
81
+ # loop over all params
82
+ for name, param in self.graph.named_parameters():
83
+ # skip and fixed params
84
+ if not param.requires_grad or 'dpt_depth' in name or 'intr_' in name:
85
+ continue
86
+ # do not add wd on bias or low-dim params
87
+ if param.ndim <= 1 or name.endswith(".bias"):
88
+ scratch_param_nodecay.append(param)
89
+ # print("{} -> scratch_param_nodecay".format(name))
90
+ else:
91
+ scratch_param_decay.append(param)
92
+ # print("{} -> scratch_param_decay".format(name))
93
+ # create the optim dictionary
94
+ optim_dict = [
95
+ {'params': scratch_param_nodecay, 'lr': opt.optim.lr, 'weight_decay': 0.},
96
+ {'params': scratch_param_decay, 'lr': opt.optim.lr, 'weight_decay': opt.optim.weight_decay}
97
+ ]
98
+ else:
99
+ # when we need to train dpt as well, related params should go to finetune list
100
+ finetune_param_nodecay = []
101
+ scratch_param_nodecay = []
102
+ finetune_param_decay = []
103
+ scratch_param_decay = []
104
+ for name, param in self.graph.named_parameters():
105
+ # skip and fixed params
106
+ if not param.requires_grad:
107
+ continue
108
+ # put dpt params into finetune list
109
+ if 'dpt_depth' in name or 'intr_' in name:
110
+ if param.ndim <= 1 or name.endswith(".bias"):
111
+ # print("{} -> finetune_param_nodecay".format(name))
112
+ finetune_param_nodecay.append(param)
113
+ else:
114
+ finetune_param_decay.append(param)
115
+ # print("{} -> finetune_param_decay".format(name))
116
+ # all other params go to scratch list
117
+ else:
118
+ if param.ndim <= 1 or name.endswith(".bias"):
119
+ scratch_param_nodecay.append(param)
120
+ # print("{} -> scratch_param_nodecay".format(name))
121
+ else:
122
+ scratch_param_decay.append(param)
123
+ # print("{} -> scratch_param_decay".format(name))
124
+ # create the optim dictionary
125
+ optim_dict = [
126
+ {'params': finetune_param_nodecay, 'lr': opt.optim.lr_ft, 'weight_decay': 0.},
127
+ {'params': finetune_param_decay, 'lr': opt.optim.lr_ft, 'weight_decay': opt.optim.weight_decay},
128
+ {'params': scratch_param_nodecay, 'lr': opt.optim.lr, 'weight_decay': 0.},
129
+ {'params': scratch_param_decay, 'lr': opt.optim.lr, 'weight_decay': opt.optim.weight_decay}
130
+ ]
131
+
132
+ self.optim = torch.optim.AdamW(optim_dict, betas=(0.9, 0.95))
133
+ if opt.optim.sched:
134
+ self.sched = torch.optim.lr_scheduler.CosineAnnealingLR(self.optim, opt.max_epoch)
135
+ if opt.optim.amp:
136
+ self.scaler = torch.cuda.amp.GradScaler()
137
+
138
+ def restore_checkpoint(self, opt, best=False, evaluate=False):
139
+ epoch_start, iter_start = None, None
140
+ if opt.resume:
141
+ if opt.device == 0: print("resuming from previous checkpoint...")
142
+ epoch_start, iter_start, best_val, best_ep = util.restore_checkpoint(opt, self, resume=opt.resume, best=best, evaluate=evaluate)
143
+ self.best_val = best_val
144
+ self.best_ep = best_ep
145
+ elif opt.load is not None:
146
+ if opt.device == 0: print("loading weights from checkpoint {}...".format(opt.load))
147
+ epoch_start, iter_start, best_val, best_ep = util.restore_checkpoint(opt, self, load_name=opt.load)
148
+ else:
149
+ if opt.device == 0: print("initializing weights from scratch...")
150
+ self.epoch_start = epoch_start or 0
151
+ self.iter_start = iter_start or 0
152
+
153
+ def setup_visualizer(self, opt, test=False):
154
+ if opt.device == 0:
155
+ print("setting up visualizers...")
156
+ if opt.tb:
157
+ if test == False:
158
+ self.tb = torch.utils.tensorboard.SummaryWriter(log_dir=opt.output_path, flush_secs=10)
159
+ else:
160
+ embedding_folder = os.path.join(opt.output_path, 'embedding')
161
+ os.makedirs(embedding_folder, exist_ok=True)
162
+ self.tb = torch.utils.tensorboard.SummaryWriter(log_dir=embedding_folder, flush_secs=10)
163
+
164
+ def train(self, opt):
165
+ # before training
166
+ torch.cuda.set_device(opt.device)
167
+ torch.cuda.empty_cache()
168
+ if opt.device == 0: print("TRAINING START")
169
+ self.train_metric_logger = util.MetricLogger(delimiter=" ")
170
+ self.train_metric_logger.add_meter('lr', util.SmoothedValue(window_size=1, fmt='{value:.6f}'))
171
+ self.iter_skip = self.iter_start % len(self.train_loader)
172
+ self.it = self.iter_start
173
+ self.skip_dis = False
174
+ if not opt.resume:
175
+ self.best_val = np.inf
176
+ self.best_ep = 1
177
+ # training
178
+ if self.iter_start == 0 and not opt.debug: self.evaluate(opt, ep=0, training=True)
179
+ for self.ep in range(self.epoch_start, opt.max_epoch):
180
+ self.train_epoch(opt)
181
+ # after training
182
+ if opt.device == 0: self.save_checkpoint(opt, ep=self.ep, it=self.it, best_val=self.best_val, best_ep=self.best_ep)
183
+ if opt.tb and opt.device == 0:
184
+ self.tb.flush()
185
+ self.tb.close()
186
+ if opt.device == 0:
187
+ print("TRAINING DONE")
188
+ print("Best CD: %.4f @ epoch %d" % (self.best_val, self.best_ep))
189
+ cleanup()
190
+
191
+ def train_epoch(self, opt):
192
+ # before train epoch
193
+ self.train_loader.sampler.set_epoch(self.ep)
194
+ if opt.device == 0:
195
+ print("training epoch {}".format(self.ep+1))
196
+ batch_progress = range(self.num_batches)
197
+ self.graph.train()
198
+ # train epoch
199
+ loader = iter(self.train_loader)
200
+
201
+ if opt.debug and opt.profile:
202
+ with torch.profiler.profile(
203
+ schedule=torch.profiler.schedule(wait=3, warmup=3, active=5, repeat=2),
204
+ on_trace_ready=torch.profiler.tensorboard_trace_handler('debug/profiler_log'),
205
+ record_shapes=True,
206
+ profile_memory=True,
207
+ with_stack=False
208
+ ) as prof:
209
+ for batch_id in batch_progress:
210
+ if batch_id >= (1 + 1 + 3) * 2:
211
+ # exit the program after 2 iterations of the warmup, active, and repeat steps
212
+ exit()
213
+
214
+ # if resuming from previous checkpoint, skip until the last iteration number is reached
215
+ if self.iter_skip>0:
216
+ self.iter_skip -= 1
217
+ continue
218
+ batch = next(loader)
219
+ # train iteration
220
+ var = edict(batch)
221
+ opt.H, opt.W = opt.image_size
222
+ var = util.move_to_device(var, opt.device)
223
+ loss = self.train_iteration(opt, var, batch_progress)
224
+ prof.step()
225
+ else:
226
+ for batch_id in batch_progress:
227
+ # if resuming from previous checkpoint, skip until the last iteration number is reached
228
+ if self.iter_skip>0:
229
+ self.iter_skip -= 1
230
+ continue
231
+ batch = next(loader)
232
+ # train iteration
233
+ var = edict(batch)
234
+ opt.H, opt.W = opt.image_size
235
+ var = util.move_to_device(var, opt.device)
236
+ loss = self.train_iteration(opt, var, batch_progress)
237
+
238
+ # after train epoch
239
+ if opt.optim.sched: self.sched.step()
240
+ if (self.ep + 1) % opt.freq.eval == 0:
241
+ if opt.device == 0: print("validating epoch {}".format(self.ep+1))
242
+ current_val = self.evaluate(opt, ep=self.ep+1, training=True)
243
+ if current_val < self.best_val and opt.device == 0:
244
+ self.best_val = current_val
245
+ self.best_ep = self.ep + 1
246
+ self.save_checkpoint(opt, ep=self.ep, it=self.it, best_val=self.best_val, best_ep=self.best_ep, best=True, latest=True)
247
+
248
+ def train_iteration(self, opt, var, batch_progress):
249
+ # before train iteration
250
+ torch.distributed.barrier()
251
+ # train iteration
252
+ with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=opt.optim.amp):
253
+ var, loss = self.graph.forward(opt, var, training=True, get_loss=True)
254
+ loss = self.summarize_loss(opt, var, loss)
255
+ loss_scaled = loss.all / opt.optim.accum
256
+
257
+ # backward
258
+ if opt.optim.amp:
259
+ self.scaler.scale(loss_scaled).backward()
260
+ # skip update if accumulating gradient
261
+ if (self.it + 1) % opt.optim.accum == 0:
262
+ self.scaler.unscale_(self.optim)
263
+ # gradient clipping
264
+ if opt.optim.clip_norm:
265
+ norm = torch.nn.utils.clip_grad_norm_(self.graph.parameters(), opt.optim.clip_norm)
266
+ if opt.debug: print("Grad norm: {}".format(norm))
267
+ self.scaler.step(self.optim)
268
+ self.scaler.update()
269
+ self.optim.zero_grad()
270
+ else:
271
+ loss_scaled.backward()
272
+ if (self.it + 1) % opt.optim.accum == 0:
273
+ if opt.optim.clip_norm:
274
+ norm = torch.nn.utils.clip_grad_norm_(self.graph.parameters(), opt.optim.clip_norm)
275
+ if opt.debug: print("Grad norm: {}".format(norm))
276
+ self.optim.step()
277
+ self.optim.zero_grad()
278
+
279
+ # after train iteration
280
+ lr = self.sched.get_last_lr()[0] if opt.optim.sched else opt.optim.lr
281
+ self.train_metric_logger.update(lr=lr)
282
+ self.train_metric_logger.update(loss=loss.all)
283
+ if opt.device == 0:
284
+ self.graph.eval()
285
+ # if (self.it) % opt.freq.vis == 0: self.visualize(opt, var, step=self.it, split="train")
286
+ if (self.it) % opt.freq.ckpt_latest == 0 and not opt.debug:
287
+ self.save_checkpoint(opt, ep=self.ep, it=self.it, best_val=self.best_val, best_ep=self.best_ep, latest=True)
288
+ if (self.it) % opt.freq.scalar == 0 and not opt.debug:
289
+ self.log_scalars(opt, var, loss, step=self.it, split="train")
290
+ if (self.it) % (opt.freq.save_vis * (self.it//10000*10+1)) == 0 and not opt.debug:
291
+ self.vis_train_iter(opt)
292
+ if (self.it) % opt.freq.print == 0:
293
+ print('[{}] '.format(datetime.datetime.now().time()), end='')
294
+ print(f'Train Iter {self.it}/{self.num_batches*opt.max_epoch}: {self.train_metric_logger}')
295
+ self.graph.train()
296
+ self.it += 1
297
+ return loss
298
+
299
+ @torch.no_grad()
300
+ def vis_train_iter(self, opt):
301
+ for i in range(len(self.viz_data)):
302
+ var_viz = edict(deepcopy(self.viz_data[i]))
303
+ var_viz = util.move_to_device(var_viz, opt.device)
304
+ var_viz = self.graph.module(opt, var_viz, training=False, get_loss=False)
305
+ eval_3D.eval_metrics(opt, var_viz, self.graph.module.impl_network, vis_only=True)
306
+ vis_folder = "vis_log/iter_{}".format(self.it)
307
+ os.makedirs("{}/{}".format(opt.output_path, vis_folder), exist_ok=True)
308
+ util_vis.dump_images(opt, var_viz.idx, "image_input", var_viz.rgb_input_map, masks=None, from_range=(0, 1), folder=vis_folder)
309
+ util_vis.dump_images(opt, var_viz.idx, "mask_input", var_viz.mask_input_map, folder=vis_folder)
310
+ util_vis.dump_meshes_viz(opt, var_viz.idx, "mesh_viz", var_viz.mesh_pred, folder=vis_folder)
311
+ if 'depth_pred' in var_viz:
312
+ util_vis.dump_depths(opt, var_viz.idx, "depth_est", var_viz.depth_pred, var_viz.mask_input_map, rescale=True, folder=vis_folder)
313
+ if 'depth_input_map' in var_viz:
314
+ util_vis.dump_depths(opt, var_viz.idx, "depth_input", var_viz.depth_input_map, var_viz.mask_input_map, rescale=True, folder=vis_folder)
315
+ if 'attn_vis' in var_viz:
316
+ util_vis.dump_attentions(opt, var_viz.idx, "attn", var_viz.attn_vis, folder=vis_folder)
317
+ if 'gt_surf_points' in var_viz and 'seen_points' in var_viz:
318
+ util_vis.dump_pointclouds_compare(opt, var_viz.idx, "seen_surface", var_viz.seen_points, var_viz.gt_surf_points, folder=vis_folder)
319
+
320
+ def summarize_loss(self, opt, var, loss, non_act_loss_key=[]):
321
+ loss_all = 0.
322
+ assert("all" not in loss)
323
+ # weigh losses
324
+ for key in loss:
325
+ assert(key in opt.loss_weight)
326
+ if opt.loss_weight[key] is not None:
327
+ assert not torch.isinf(loss[key].mean()), "loss {} is Inf".format(key)
328
+ assert not torch.isnan(loss[key].mean()), "loss {} is NaN".format(key)
329
+ loss_all += float(opt.loss_weight[key])*loss[key].mean() if key not in non_act_loss_key else 0.0*loss[key].mean()
330
+ loss.update(all=loss_all)
331
+ return loss
332
+
333
+ # =================================================== set up evaluation =========================================================
334
+
335
+ @torch.no_grad()
336
+ def evaluate(self, opt, ep, training=False):
337
+ self.graph.eval()
338
+
339
+ # lists for metrics
340
+ cd_accs = []
341
+ cd_comps = []
342
+ f_scores = []
343
+ cat_indices = []
344
+ loss_eval = edict()
345
+ metric_eval = dict(dist_acc=0., dist_cov=0.)
346
+ eval_metric_logger = util.MetricLogger(delimiter=" ")
347
+
348
+ # result file on the fly
349
+ if not training:
350
+ assert opt.device == 0
351
+ full_results_file = open(os.path.join(opt.output_path, '{}_full_results.txt'.format(opt.data.dataset_test)), 'w')
352
+ full_results_file.write("IND, CD, ACC, COMP, ")
353
+ full_results_file.write(", ".join(["F-score@{:.2f}".format(opt.eval.f_thresholds[i]*100) for i in range(len(opt.eval.f_thresholds))]))
354
+
355
+ # dataloader on the test set
356
+ with torch.cuda.device(opt.device):
357
+ for it, batch in enumerate(self.test_loader):
358
+
359
+ # inference the model
360
+ var = edict(batch)
361
+ var = self.evaluate_batch(opt, var, ep, it, single_gpu=False)
362
+
363
+ # record CD for evaluation
364
+ dist_acc, dist_cov = eval_3D.eval_metrics(opt, var, self.graph.module.impl_network)
365
+
366
+ # accumulate the scores
367
+ cd_accs.append(var.cd_acc)
368
+ cd_comps.append(var.cd_comp)
369
+ f_scores.append(var.f_score)
370
+ cat_indices.append(var.category_label)
371
+ eval_metric_logger.update(ACC=dist_acc)
372
+ eval_metric_logger.update(COMP=dist_cov)
373
+ eval_metric_logger.update(CD=(dist_acc+dist_cov) / 2)
374
+
375
+ if opt.device == 0 and it % opt.freq.print_eval == 0:
376
+ print('[{}] '.format(datetime.datetime.now().time()), end='')
377
+ print(f'Eval Iter {it}/{len(self.test_loader)} @ EP {ep}: {eval_metric_logger}')
378
+
379
+ # write to file
380
+ if not training:
381
+ full_results_file.write("\n")
382
+ full_results_file.write("{:d}".format(var.idx.item()))
383
+ full_results_file.write("\t{:.4f}".format((var.cd_acc.item() + var.cd_comp.item()) / 2))
384
+ full_results_file.write("\t{:.4f}".format(var.cd_acc.item()))
385
+ full_results_file.write("\t{:.4f}".format(var.cd_comp.item()))
386
+ full_results_file.write("\t" + "\t".join(["{:.4f}".format(var.f_score[0][i].item()) for i in range(len(opt.eval.f_thresholds))]))
387
+ full_results_file.flush()
388
+
389
+ # dump the result if in eval mode
390
+ if not training:
391
+ self.dump_results(opt, var, ep, write_new=(it == 0))
392
+
393
+ # save the predicted mesh for vis data if in train mode
394
+ if it == 0 and training and opt.device == 0:
395
+ print("visualizing and saving results...")
396
+ for i in range(len(self.viz_data)):
397
+ var_viz = edict(deepcopy(self.viz_data[i]))
398
+ var_viz = self.evaluate_batch(opt, var_viz, ep, it, single_gpu=True)
399
+ eval_3D.eval_metrics(opt, var_viz, self.graph.module.impl_network, vis_only=True)
400
+ # self.visualize(opt, var_viz, step=ep, split="eval")
401
+ self.dump_results(opt, var_viz, ep, train=True)
402
+ # write html that organizes the results
403
+ util_vis.create_gif_html(os.path.join(opt.output_path, "vis_{}".format(ep)),
404
+ os.path.join(opt.output_path, "results_ep{}.html".format(ep)),
405
+ skip_every=1)
406
+
407
+ # collect the eval results into tensors
408
+ cd_accs = torch.cat(cd_accs, dim=0)
409
+ cd_comps = torch.cat(cd_comps, dim=0)
410
+ f_scores = torch.cat(f_scores, dim=0)
411
+ cat_indices = torch.cat(cat_indices, dim=0)
412
+
413
+ if opt.world_size > 1:
414
+ # empty tensors for gathering
415
+ cd_accs_all = [torch.zeros_like(cd_accs).to(opt.device) for _ in range(opt.world_size)]
416
+ cd_comps_all = [torch.zeros_like(cd_comps).to(opt.device) for _ in range(opt.world_size)]
417
+ f_scores_all = [torch.zeros_like(f_scores).to(opt.device) for _ in range(opt.world_size)]
418
+ cat_indices_all = [torch.zeros_like(cat_indices).long().to(opt.device) for _ in range(opt.world_size)]
419
+
420
+ # gather the metrics
421
+ torch.distributed.barrier()
422
+ torch.distributed.all_gather(cd_accs_all, cd_accs)
423
+ torch.distributed.all_gather(cd_comps_all, cd_comps)
424
+ torch.distributed.all_gather(f_scores_all, f_scores)
425
+ torch.distributed.all_gather(cat_indices_all, cat_indices)
426
+ cd_accs_all = torch.cat(cd_accs_all, dim=0)
427
+ cd_comps_all = torch.cat(cd_comps_all, dim=0)
428
+ f_scores_all = torch.cat(f_scores_all, dim=0)
429
+ cat_indices_all = torch.cat(cat_indices_all, dim=0)
430
+ else:
431
+ cd_accs_all = cd_accs
432
+ cd_comps_all = cd_comps
433
+ f_scores_all = f_scores
434
+ cat_indices_all = cat_indices
435
+ # handle last batch, if any
436
+ if len(self.test_loader.sampler) * opt.world_size < len(self.test_data):
437
+ cd_accs_all = [cd_accs_all]
438
+ cd_comps_all = [cd_comps_all]
439
+ f_scores_all = [f_scores_all]
440
+ cat_indices_all = [cat_indices_all]
441
+ for batch in self.aux_test_loader:
442
+ # inference the model
443
+ var = edict(batch)
444
+ var = self.evaluate_batch(opt, var, ep, it, single_gpu=False)
445
+
446
+ # record CD for evaluation
447
+ dist_acc, dist_cov = eval_3D.eval_metrics(opt, var, self.graph.module.impl_network)
448
+ # accumulate the scores
449
+ cd_accs_all.append(var.cd_acc)
450
+ cd_comps_all.append(var.cd_comp)
451
+ f_scores_all.append(var.f_score)
452
+ cat_indices_all.append(var.category_label)
453
+
454
+ # dump the result if in eval mode
455
+ if not training and opt.device == 0:
456
+ self.dump_results(opt, var, ep, write_new=(it == 0))
457
+
458
+ cd_accs_all = torch.cat(cd_accs_all, dim=0)
459
+ cd_comps_all = torch.cat(cd_comps_all, dim=0)
460
+ f_scores_all = torch.cat(f_scores_all, dim=0)
461
+ cat_indices_all = torch.cat(cat_indices_all, dim=0)
462
+
463
+ assert cd_accs_all.shape[0] == len(self.test_data)
464
+ if not training:
465
+ full_results_file.close()
466
+ # printout and save the metrics
467
+ if opt.device == 0:
468
+ metric_eval["dist_acc"] = cd_accs_all.mean()
469
+ metric_eval["dist_cov"] = cd_comps_all.mean()
470
+
471
+ # print eval info
472
+ print_eval(opt, loss=None, chamfer=(metric_eval["dist_acc"],
473
+ metric_eval["dist_cov"]))
474
+ val_metric = (metric_eval["dist_acc"] + metric_eval["dist_cov"]) / 2
475
+
476
+ if training:
477
+ # log/visualize results to tb/vis
478
+ self.log_scalars(opt, var, loss_eval, metric=metric_eval, step=ep, split="eval")
479
+
480
+ if not training:
481
+ # save the per-cat evaluation metrics if on shapenet
482
+ per_cat_cd_file = os.path.join(opt.output_path, 'cd_cat.txt')
483
+ with open(per_cat_cd_file, "w") as outfile:
484
+ outfile.write("CD Acc Comp Count Cat\n")
485
+ for i in range(opt.data.num_classes_test):
486
+ if (cat_indices_all==i).sum() == 0:
487
+ continue
488
+ acc_i = cd_accs_all[cat_indices_all==i].mean().item()
489
+ comp_i = cd_comps_all[cat_indices_all==i].mean().item()
490
+ counts_cat = torch.sum(cat_indices_all==i)
491
+ cd_i = (acc_i + comp_i) / 2
492
+ outfile.write("%.4f %.4f %.4f %5d %s\n" % (cd_i, acc_i, comp_i, counts_cat, self.test_data.label2cat[i]))
493
+
494
+ # print f_scores
495
+ f_scores_avg = f_scores_all.mean(dim=0)
496
+ print('##############################')
497
+ for i in range(len(opt.eval.f_thresholds)):
498
+ print('F-score @ %.2f: %.4f' % (opt.eval.f_thresholds[i]*100, f_scores_avg[i].item()))
499
+ print('##############################')
500
+
501
+ # write to file
502
+ result_file = os.path.join(opt.output_path, 'quantitative_{}.txt'.format(opt.data.dataset_test))
503
+ with open(result_file, "w") as outfile:
504
+ outfile.write('CD Acc Comp \n')
505
+ outfile.write('%.4f %.4f %.4f\n' % (val_metric, metric_eval["dist_acc"], metric_eval["dist_cov"]))
506
+ for i in range(len(opt.eval.f_thresholds)):
507
+ outfile.write('F-score @ %.2f: %.4f\n' % (opt.eval.f_thresholds[i]*100, f_scores_avg[i].item()))
508
+
509
+ # write html that organizes the results
510
+ util_vis.create_gif_html(os.path.join(opt.output_path, "dump_{}".format(opt.data.dataset_test)),
511
+ os.path.join(opt.output_path, "results_test.html"), skip_every=10)
512
+
513
+ # torch.cuda.empty_cache()
514
+ return val_metric.item()
515
+ return float('inf')
516
+
517
+ def evaluate_batch(self, opt, var, ep=None, it=None, single_gpu=False):
518
+ var = util.move_to_device(var, opt.device)
519
+ if single_gpu:
520
+ var = self.graph.module(opt, var, training=False, get_loss=False)
521
+ else:
522
+ var = self.graph(opt, var, training=False, get_loss=False)
523
+ return var
524
+
525
+ @torch.no_grad()
526
+ def log_scalars(self, opt, var, loss, metric=None, step=0, split="train"):
527
+ if split=="train":
528
+ dist_acc, dist_cov = eval_3D.eval_metrics(opt, var, self.graph.module.impl_network)
529
+ metric = dict(dist_acc=dist_acc, dist_cov=dist_cov)
530
+ for key, value in loss.items():
531
+ if key=="all": continue
532
+ self.tb.add_scalar("{0}/loss_{1}".format(split, key), value.mean(), step)
533
+ if metric is not None:
534
+ for key, value in metric.items():
535
+ self.tb.add_scalar("{0}/{1}".format(split, key), value, step)
536
+ # log the attention average values
537
+ if 'attn_geo_avg' in var:
538
+ self.tb.add_scalar("{0}/attn_geo_avg".format(split), var.attn_geo_avg, step)
539
+ if 'attn_geo_seen' in var:
540
+ self.tb.add_scalar("{0}/attn_geo_seen".format(split), var.attn_geo_seen, step)
541
+ if 'attn_geo_occl' in var:
542
+ self.tb.add_scalar("{0}/attn_geo_occl".format(split), var.attn_geo_occl, step)
543
+ if 'attn_geo_bg' in var:
544
+ self.tb.add_scalar("{0}/attn_geo_bg".format(split), var.attn_geo_bg, step)
545
+
546
+ @torch.no_grad()
547
+ def visualize(self, opt, var, step=0, split="train"):
548
+ if 'pose_input' in var:
549
+ pose_input = var.pose_input
550
+ elif 'pose_gt' in var:
551
+ pose_input = var.pose_gt
552
+ else:
553
+ pose_input = None
554
+ util_vis.tb_image(opt, self.tb, step, split, "image_input_map", var.rgb_input_map, masks=None, from_range=(0, 1), poses=pose_input)
555
+ util_vis.tb_image(opt, self.tb, step, split, "image_input_map_est", var.rgb_input_map, masks=None, from_range=(0, 1),
556
+ poses=var.pose_pred if 'pose_pred' in var else var.pose)
557
+ util_vis.tb_image(opt, self.tb, step, split, "mask_input_map", var.mask_input_map)
558
+ if 'depth_pred' in var:
559
+ util_vis.tb_image(opt, self.tb, step, split, "depth_est_map", var.depth_pred)
560
+ if 'depth_input_map' in var:
561
+ util_vis.tb_image(opt, self.tb, step, split, "depth_input_map", var.depth_input_map)
562
+
563
+ @torch.no_grad()
564
+ def dump_results(self, opt, var, ep, write_new=False, train=False):
565
+ # create the dir
566
+ current_folder = "dump_{}".format(opt.data.dataset_test) if train == False else "vis_{}".format(ep)
567
+ os.makedirs("{}/{}/".format(opt.output_path, current_folder), exist_ok=True)
568
+
569
+ # save the results
570
+ if 'pose_input' in var:
571
+ pose_input = var.pose_input
572
+ elif 'pose_gt' in var:
573
+ pose_input = var.pose_gt
574
+ else:
575
+ pose_input = None
576
+ util_vis.dump_images(opt, var.idx, "image_input", var.rgb_input_map, masks=None, from_range=(0, 1), poses=pose_input, folder=current_folder)
577
+ util_vis.dump_images(opt, var.idx, "mask_input", var.mask_input_map, folder=current_folder)
578
+ util_vis.dump_meshes(opt, var.idx, "mesh", var.mesh_pred, folder=current_folder)
579
+ util_vis.dump_meshes_viz(opt, var.idx, "mesh_viz", var.mesh_pred, folder=current_folder) # image frames + gifs
580
+ if 'depth_pred' in var:
581
+ util_vis.dump_depths(opt, var.idx, "depth_est", var.depth_pred, var.mask_input_map, rescale=True, folder=current_folder)
582
+ if 'depth_input_map' in var:
583
+ util_vis.dump_depths(opt, var.idx, "depth_input", var.depth_input_map, var.mask_input_map, rescale=True, folder=current_folder)
584
+ if 'gt_surf_points' in var and 'seen_points' in var:
585
+ util_vis.dump_pointclouds_compare(opt, var.idx, "seen_surface", var.seen_points, var.gt_surf_points, folder=current_folder)
586
+ if 'attn_vis' in var:
587
+ util_vis.dump_attentions(opt, var.idx, "attn", var.attn_vis, folder=current_folder)
588
+ if 'attn_pc' in var:
589
+ util_vis.dump_pointclouds(opt, var.idx, "attn_pc", var.attn_pc["points"], var.attn_pc["colors"], folder=current_folder)
590
+ if 'dpc' in var:
591
+ util_vis.dump_pointclouds_compare(opt, var.idx, "pointclouds_comp", var.dpc_pred, var.dpc.points, folder=current_folder)
592
+
593
+ def save_checkpoint(self, opt, ep=0, it=0, best_val=np.inf, best_ep=1, latest=False, best=False):
594
+ util.save_checkpoint(opt, self, ep=ep, it=it, best_val=best_val, best_ep=best_ep, latest=latest, best=best)
595
+ if not latest:
596
+ print("checkpoint saved: ({0}) {1}, epoch {2} (iteration {3})".format(opt.group, opt.name, ep, it))
597
+ if best:
598
+ print("Saving the current model as the best...")
options/depth.yaml ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ group: depth
2
+ name: depth_est
3
+ load:
4
+
5
+ batch_size: 44
6
+ debug: false
7
+ image_size: [224,224]
8
+ gpu: 0
9
+ max_epoch: 15
10
+ output_root: output
11
+ resume: false
12
+ seed: 0
13
+ yaml:
14
+
15
+ arch:
16
+ depth:
17
+ pretrained: model/depth/pretrained_weights/omnidata_dpt_depth_v2.ckpt
18
+
19
+ eval:
20
+ batch_size: 44
21
+ n_vis: 50
22
+ depth_cap:
23
+ d_thresholds: [1.02,1.05,1.1,1.2]
24
+
25
+ data:
26
+ num_classes_test: 15
27
+ max_img_cat:
28
+ dataset_train: synthetic
29
+ dataset_test: synthetic
30
+ num_workers: 6
31
+ bgcolor: 1
32
+ pix3d:
33
+ cat:
34
+ ocrtoc:
35
+ cat:
36
+ erode_mask: 10
37
+ synthetic:
38
+ subset: objaverse_LVIS,ShapeNet55
39
+ percentage: 1
40
+ train_sub:
41
+ val_sub:
42
+
43
+ training:
44
+ n_sdf_points: 4096
45
+ depth_loss:
46
+ grad_reg: 0.1
47
+ depth_inv: true
48
+ mask_shrink: false
49
+
50
+ loss_weight:
51
+ depth: 1
52
+ intr: 10
53
+
54
+ optim:
55
+ lr: 3.e-5
56
+ weight_decay: 0.05
57
+ clip_norm:
58
+ amp: false
59
+ accum: 1
60
+ sched: false
61
+
62
+ tb:
63
+ num_images: [4,8]
64
+
65
+ freq:
66
+ print: 200
67
+ print_eval: 100
68
+ scalar: 1000 # iterations
69
+ vis: 1000 # iterations
70
+ save_vis: 1000
71
+ ckpt_latest: 1000 # iterations
72
+ eval: 1
options/shape.yaml ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ group: shape
2
+ name: shape_recon
3
+ load:
4
+
5
+ batch_size: 28
6
+ debug: false
7
+ profile: false
8
+ image_size: [224,224]
9
+ gpu: 0
10
+ max_epoch: 15
11
+ output_root: output
12
+ resume: false
13
+ seed: 0
14
+ yaml:
15
+
16
+ pretrain:
17
+ depth: weights/depth.ckpt
18
+
19
+ arch:
20
+ # general
21
+ num_heads: 8
22
+ latent_dim: 256
23
+ win_size: 16
24
+ # depth
25
+ depth:
26
+ encoder: resnet
27
+ n_blocks: 12
28
+ dsp: 2
29
+ pretrained: model/depth/pretrained_weights/omnidata_dpt_depth_v2.ckpt
30
+ # rgb
31
+ rgb:
32
+ encoder:
33
+ n_blocks: 12
34
+ # implicit
35
+ impl:
36
+ n_channels: 256
37
+ # attention-related
38
+ att_blocks: 2
39
+ mlp_ratio: 4.
40
+ posenc_perlayer: false
41
+ # mlp-related
42
+ mlp_layers: 8
43
+ posenc_3D: 0
44
+ skip_in: [2,4,6]
45
+
46
+ eval:
47
+ batch_size: 2
48
+ brute_force: false
49
+ n_vis: 50
50
+ vox_res: 64
51
+ num_points: 10000
52
+ range: [-1.5,1.5]
53
+ icp: false
54
+ f_thresholds: [0.005, 0.01, 0.02, 0.05, 0.1, 0.2]
55
+
56
+ data:
57
+ num_classes_test: 15
58
+ max_img_cat:
59
+ dataset_train: synthetic
60
+ dataset_test: synthetic
61
+ num_workers: 6
62
+ bgcolor: 1
63
+ pix3d:
64
+ cat:
65
+ ocrtoc:
66
+ cat:
67
+ erode_mask:
68
+ synthetic:
69
+ subset: objaverse_LVIS,ShapeNet55
70
+ percentage: 1
71
+ train_sub:
72
+ val_sub:
73
+
74
+ training:
75
+ n_sdf_points: 4096
76
+ shape_loss:
77
+ impt_weight: 1
78
+ impt_thres: 0.01
79
+ depth_loss:
80
+ grad_reg: 0.1
81
+ depth_inv: true
82
+ mask_shrink: false
83
+
84
+ loss_weight:
85
+ shape: 1
86
+ depth:
87
+ intr:
88
+
89
+ optim:
90
+ lr: 3.e-5
91
+ lr_ft: 1.e-5
92
+ weight_decay: 0.05
93
+ fix_dpt: false
94
+ fix_clip: true
95
+ clip_norm:
96
+ amp: false
97
+ accum: 1
98
+ sched: false
99
+
100
+ tb:
101
+ num_images: [4,8]
102
+
103
+ freq:
104
+ print: 200
105
+ print_eval: 100
106
+ scalar: 1000 # iterations
107
+ vis: 1000 # iterations
108
+ save_vis: 1000
109
+ ckpt_latest: 1000 # iterations
110
+ eval: 1
requirements.txt ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.0.0
2
+ anyio==4.2.0
3
+ attrs==23.2.0
4
+ cachetools==5.3.2
5
+ certifi==2023.11.17
6
+ chardet==5.2.0
7
+ charset-normalizer==3.3.2
8
+ colorlog==6.8.0
9
+ contourpy==1.2.0
10
+ cycler==0.12.1
11
+ docopt==0.6.2
12
+ embreex==2.17.7.post4
13
+ exceptiongroup==1.2.0
14
+ filelock==3.13.1
15
+ fonttools==4.47.2
16
+ freetype-py==2.4.0
17
+ fsspec==2023.12.2
18
+ google-auth==2.26.2
19
+ google-auth-oauthlib==1.2.0
20
+ grpcio==1.60.0
21
+ h11==0.14.0
22
+ httpcore==1.0.2
23
+ httpx==0.26.0
24
+ huggingface-hub==0.20.2
25
+ idna==3.6
26
+ imageio==2.33.1
27
+ Jinja2==3.1.3
28
+ jsonschema==4.20.0
29
+ jsonschema-specifications==2023.12.1
30
+ kiwisolver==1.4.5
31
+ lxml==5.1.0
32
+ mapbox-earcut==1.0.1
33
+ Markdown==3.5.2
34
+ MarkupSafe==2.1.3
35
+ matplotlib==3.8.2
36
+ mpmath==1.3.0
37
+ networkx==3.2.1
38
+ ninja==1.11.1.1
39
+ numpy==1.26.3
40
+ nvidia-cublas-cu12==12.1.3.1
41
+ nvidia-cuda-cupti-cu12==12.1.105
42
+ nvidia-cuda-nvrtc-cu12==12.1.105
43
+ nvidia-cuda-runtime-cu12==12.1.105
44
+ nvidia-cudnn-cu12==8.9.2.26
45
+ nvidia-cufft-cu12==11.0.2.54
46
+ nvidia-curand-cu12==10.3.2.106
47
+ nvidia-cusolver-cu12==11.4.5.107
48
+ nvidia-cusparse-cu12==12.1.0.106
49
+ nvidia-nccl-cu12==2.18.1
50
+ nvidia-nvjitlink-cu12==12.3.101
51
+ nvidia-nvtx-cu12==12.1.105
52
+ oauthlib==3.2.2
53
+ opencv-python==4.9.0.80
54
+ packaging==23.2
55
+ pillow==10.2.0
56
+ pipreqs==0.4.13
57
+ protobuf==4.23.4
58
+ pyasn1==0.5.1
59
+ pyasn1-modules==0.3.0
60
+ pycollada==0.8
61
+ pyglet==2.0.10
62
+ PyMCubes==0.1.4
63
+ PyOpenGL==3.1.0
64
+ pyparsing==3.1.1
65
+ pyrender==0.1.45
66
+ python-dateutil==2.8.2
67
+ PyYAML==6.0.1
68
+ referencing==0.32.1
69
+ requests==2.31.0
70
+ requests-oauthlib==1.3.1
71
+ rpds-py==0.16.2
72
+ rsa==4.9
73
+ Rtree==1.1.0
74
+ safetensors==0.4.1
75
+ scipy==1.11.4
76
+ shapely==2.0.2
77
+ six==1.16.0
78
+ sniffio==1.3.0
79
+ svg.path==6.3
80
+ sympy==1.12
81
+ tensorboard==2.15.1
82
+ tensorboard-data-server==0.7.2
83
+ timm==0.9.12
84
+ torch==2.1.2
85
+ torchvision==0.16.2
86
+ tqdm==4.66.1
87
+ trimesh==4.0.9
88
+ triton==2.1.0
89
+ typing_extensions==4.9.0
90
+ urllib3==2.1.0
91
+ vhacdx==0.0.5
92
+ Werkzeug==3.0.1
93
+ xxhash==3.4.1
94
+ yarg==0.1.9
95
+ rembg
utils/camera.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # partially from https://github.com/chenhsuanlin/signed-distance-SRN
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ class Pose():
7
+ # a pose class with util methods
8
+ def __call__(self, R=None, t=None):
9
+ assert(R is not None or t is not None)
10
+ if R is None:
11
+ if not isinstance(t, torch.Tensor): t = torch.tensor(t)
12
+ R = torch.eye(3, device=t.device).repeat(*t.shape[:-1], 1, 1)
13
+ elif t is None:
14
+ if not isinstance(R, torch.Tensor): R = torch.tensor(R)
15
+ t = torch.zeros(R.shape[:-1], device=R.device)
16
+ else:
17
+ if not isinstance(R, torch.Tensor): R = torch.tensor(R)
18
+ if not isinstance(t, torch.Tensor): t = torch.tensor(t)
19
+ assert(R.shape[:-1]==t.shape and R.shape[-2:]==(3, 3))
20
+ R = R.float()
21
+ t = t.float()
22
+ pose = torch.cat([R, t[..., None]], dim=-1) # [..., 3, 4]
23
+ assert(pose.shape[-2:]==(3, 4))
24
+ return pose
25
+
26
+ def invert(self, pose, use_inverse=False):
27
+ R, t = pose[..., :3], pose[..., 3:]
28
+ R_inv = R.inverse() if use_inverse else R.transpose(-1, -2)
29
+ t_inv = (-R_inv@t)[..., 0]
30
+ pose_inv = self(R=R_inv, t=t_inv)
31
+ return pose_inv
32
+
33
+ def compose(self, pose_list):
34
+ # pose_new(x) = poseN(...(pose2(pose1(x)))...)
35
+ pose_new = pose_list[0]
36
+ for pose in pose_list[1:]:
37
+ pose_new = self.compose_pair(pose_new, pose)
38
+ return pose_new
39
+
40
+ def compose_pair(self, pose_a, pose_b):
41
+ # pose_new(x) = pose_b(pose_a(x))
42
+ R_a, t_a = pose_a[..., :3], pose_a[..., 3:]
43
+ R_b, t_b = pose_b[..., :3], pose_b[..., 3:]
44
+ R_new = R_b@R_a
45
+ t_new = (R_b@t_a+t_b)[..., 0]
46
+ pose_new = self(R=R_new, t=t_new)
47
+ return pose_new
48
+
49
+ pose = Pose()
50
+
51
+ # unit sphere normalization
52
+ def valid_norm_fac(seen_points, mask):
53
+ '''
54
+ seen_points: [B, H*W, 3]
55
+ mask: [B, 1, H, W], boolean
56
+ '''
57
+ # get valid points
58
+ batch_size = seen_points.shape[0]
59
+ # [B, H*W]
60
+ mask = mask.view(batch_size, seen_points.shape[1])
61
+
62
+ # get mean and variance by sample
63
+ means, max_dists = [], []
64
+ for b in range(batch_size):
65
+ # [N_valid, 3]
66
+ seen_points_valid = seen_points[b][mask[b]]
67
+ # [3]
68
+ xyz_mean = torch.mean(seen_points_valid, dim=0)
69
+ seen_points_valid_zmean = seen_points_valid - xyz_mean
70
+ # scalar
71
+ max_dist = torch.max(seen_points_valid_zmean.norm(dim=1))
72
+ means.append(xyz_mean)
73
+ max_dists.append(max_dist)
74
+ # [B, 3]
75
+ means = torch.stack(means, dim=0)
76
+ # [B]
77
+ max_dists = torch.stack(max_dists, dim=0)
78
+ return means, max_dists
79
+
80
+ def get_pixel_grid(opt, H, W):
81
+ y_range = torch.arange(H, dtype=torch.float32).to(opt.device)
82
+ x_range = torch.arange(W, dtype=torch.float32).to(opt.device)
83
+ Y, X = torch.meshgrid(y_range, x_range, indexing='ij')
84
+ Z = torch.ones_like(Y)
85
+ xyz_grid = torch.stack([X, Y, Z],dim=-1).view(-1,3)
86
+ return xyz_grid
87
+
88
+ def unproj_depth(opt, depth, intr):
89
+ '''
90
+ depth: [B, 1, H, W]
91
+ intr: [B, 3, 3]
92
+ '''
93
+ batch_size, _, H, W = depth.shape
94
+ assert opt.H == H == W
95
+ depth = depth.squeeze(1)
96
+
97
+ # [B, 3, 3]
98
+ K_inv = torch.linalg.inv(intr).float()
99
+ # [1, H*W,3]
100
+ pixel_grid = get_pixel_grid(opt, H, W).unsqueeze(0)
101
+ # [B, H*W,3]
102
+ pixel_grid = pixel_grid.repeat(batch_size, 1, 1)
103
+ # [B, 3, H*W]
104
+ ray_dirs = K_inv @ pixel_grid.permute(0, 2, 1).contiguous()
105
+ # [B, H*W, 3], in camera coordinates
106
+ seen_points = ray_dirs.permute(0, 2, 1).contiguous() * depth.view(batch_size, H*W, 1)
107
+
108
+ return seen_points
109
+
110
+ def to_hom(X):
111
+ '''
112
+ X: [B, N, 3]
113
+ Returns:
114
+ X_hom: [B, N, 4]
115
+ '''
116
+ X_hom = torch.cat([X, torch.ones_like(X[..., :1])], dim=-1)
117
+ return X_hom
118
+
119
+ def world2cam(X_world, pose):
120
+ '''
121
+ X_world: [B, N, 3]
122
+ pose: [B, 3, 4]
123
+ Returns:
124
+ X_cam: [B, N, 3]
125
+ '''
126
+ X_hom = to_hom(X_world)
127
+ X_cam = X_hom @ pose.transpose(-1, -2)
128
+ return X_cam
129
+
130
+ def cam2img(X_cam, cam_intr):
131
+ '''
132
+ X_cam: [B, N, 3]
133
+ cam_intr: [B, 3, 3]
134
+ Returns:
135
+ X_img: [B, N, 3]
136
+ '''
137
+ X_img = X_cam @ cam_intr.transpose(-1, -2)
138
+ return X_img
139
+
140
+ def proj_points(opt, points, intr, pose):
141
+ '''
142
+ points: [B, N, 3]
143
+ intr: [B, 3, 3]
144
+ pose: [B, 3, 4]
145
+ '''
146
+ # [B, N, 3]
147
+ points_cam = world2cam(points, pose)
148
+ # [B, N]
149
+ depth = points_cam[..., 2]
150
+ # [B, N, 3]
151
+ points_img = cam2img(points_cam, intr)
152
+ # [B, N, 2]
153
+ points_2D = points_img[..., :2] / points_img[..., 2:]
154
+ return points_2D, depth
155
+
156
+ def azim_to_rotation_matrix(azim, representation='angle'):
157
+ """Azim is angle with vector +X, rotated in XZ plane"""
158
+ if representation == 'rad':
159
+ # [B, ]
160
+ cos, sin = torch.cos(azim), torch.sin(azim)
161
+ elif representation == 'angle':
162
+ # [B, ]
163
+ azim = azim * np.pi / 180
164
+ cos, sin = torch.cos(azim), torch.sin(azim)
165
+ elif representation == 'trig':
166
+ # [B, 2]
167
+ cos, sin = azim[:, 0], azim[:, 1]
168
+ R = torch.eye(3, device=azim.device)[None].repeat(len(azim), 1, 1)
169
+ zeros = torch.zeros(len(azim), device=azim.device)
170
+ R[:, 0, :] = torch.stack([cos, zeros, sin], dim=-1)
171
+ R[:, 2, :] = torch.stack([-sin, zeros, cos], dim=-1)
172
+ return R
173
+
174
+ def elev_to_rotation_matrix(elev, representation='angle'):
175
+ """Angle with vector +Z in YZ plane"""
176
+ if representation == 'rad':
177
+ # [B, ]
178
+ cos, sin = torch.cos(elev), torch.sin(elev)
179
+ elif representation == 'angle':
180
+ # [B, ]
181
+ elev = elev * np.pi / 180
182
+ cos, sin = torch.cos(elev), torch.sin(elev)
183
+ elif representation == 'trig':
184
+ # [B, 2]
185
+ cos, sin = elev[:, 0], elev[:, 1]
186
+ R = torch.eye(3, device=elev.device)[None].repeat(len(elev), 1, 1)
187
+ R[:, 1, 1:] = torch.stack([cos, -sin], dim=-1)
188
+ R[:, 2, 1:] = torch.stack([sin, cos], dim=-1)
189
+ return R
190
+
191
+ def roll_to_rotation_matrix(roll, representation='angle'):
192
+ """Angle with vector +X in XY plane"""
193
+ if representation == 'rad':
194
+ # [B, ]
195
+ cos, sin = torch.cos(roll), torch.sin(roll)
196
+ elif representation == 'angle':
197
+ # [B, ]
198
+ roll = roll * np.pi / 180
199
+ cos, sin = torch.cos(roll), torch.sin(roll)
200
+ elif representation == 'trig':
201
+ # [B, 2]
202
+ cos, sin = roll[:, 0], roll[:, 1]
203
+ R = torch.eye(3, device=roll.device)[None].repeat(len(roll), 1, 1)
204
+ R[:, 0, :2] = torch.stack([cos, sin], dim=-1)
205
+ R[:, 1, :2] = torch.stack([-sin, cos], dim=-1)
206
+ return R
207
+
208
+ def get_rotation_sphere(azim_sample=4, elev_sample=4, roll_sample=4, scales=[1.0], device='cuda'):
209
+ rotations = []
210
+ azim_range = [0, 360]
211
+ elev_range = [0, 360]
212
+ roll_range = [0, 360]
213
+ azims = np.linspace(azim_range[0], azim_range[1], num=azim_sample, endpoint=False)
214
+ elevs = np.linspace(elev_range[0], elev_range[1], num=elev_sample, endpoint=False)
215
+ rolls = np.linspace(roll_range[0], roll_range[1], num=roll_sample, endpoint=False)
216
+ for scale in scales:
217
+ for azim in azims:
218
+ for elev in elevs:
219
+ for roll in rolls:
220
+ Ry = azim_to_rotation_matrix(torch.tensor([azim]))
221
+ Rx = elev_to_rotation_matrix(torch.tensor([elev]))
222
+ Rz = roll_to_rotation_matrix(torch.tensor([roll]))
223
+ R_permute = torch.tensor([
224
+ [-1, 0, 0],
225
+ [0, 0, -1],
226
+ [0, -1, 0]
227
+ ]).float().to(Ry.device).unsqueeze(0).expand_as(Ry)
228
+ R = scale * Rz@Rx@Ry@R_permute
229
+ rotations.append(R.to(device).float())
230
+ return torch.cat(rotations, dim=0)
utils/eval_3D.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import threading
4
+ import mcubes
5
+ import trimesh
6
+ from utils.util_vis import show_att_on_image
7
+ from utils.camera import get_rotation_sphere
8
+
9
+ @torch.no_grad()
10
+ def get_dense_3D_grid(opt, var, N=None):
11
+ batch_size = len(var.idx)
12
+ N = N or opt.eval.vox_res
13
+ # -0.6, 0.6
14
+ range_min, range_max = opt.eval.range
15
+ grid = torch.linspace(range_min, range_max, N+1, device=opt.device)
16
+ points_3D = torch.stack(torch.meshgrid(grid, grid, grid, indexing='ij'), dim=-1) # [N, N, N, 3]
17
+ # actually N+1 instead of N
18
+ points_3D = points_3D.repeat(batch_size, 1, 1, 1, 1) # [B, N, N, N, 3]
19
+ return points_3D
20
+
21
+ @torch.no_grad()
22
+ def compute_level_grid(opt, impl_network, latent_depth, latent_semantic, points_3D, images, vis_attn=False):
23
+ # needed for amp
24
+ latent_depth = latent_depth.to(torch.float32) if latent_depth is not None else None
25
+ latent_semantic = latent_semantic.to(torch.float32) if latent_semantic is not None else None
26
+
27
+ # process points in sliced way
28
+ batch_size = points_3D.shape[0]
29
+ N = points_3D.shape[1]
30
+ assert N == points_3D.shape[2] == points_3D.shape[3]
31
+ assert points_3D.shape[4] == 3
32
+
33
+ points_3D = points_3D.view(batch_size, N, N*N, 3)
34
+ occ = []
35
+ attn = []
36
+ for i in range(N):
37
+ # [B, N*N, 3]
38
+ points_slice = points_3D[:, i]
39
+ # [B, N*N, 3] -> [B, N*N], [B, N*N, 1+feat_res**2]
40
+ occ_slice, attn_slice = impl_network(latent_depth, latent_semantic, points_slice)
41
+ occ.append(occ_slice)
42
+ attn.append(attn_slice.detach())
43
+ # [B, N, N*N] -> [B, N, N, N]
44
+ occ = torch.stack(occ, dim=1).view(batch_size, N, N, N)
45
+ occ = torch.sigmoid(occ)
46
+ if vis_attn:
47
+ N_global = 1
48
+ feat_res = opt.H // opt.arch.win_size
49
+ attn = torch.stack(attn, dim=1).view(batch_size, N, N, N, N_global+feat_res**2)
50
+ # average along Z, [B, N, N, N_global+feat_res**2]
51
+ attn = torch.mean(attn, dim=3)
52
+ # [B, N, N, N_global] -> [B, N, N, 1]
53
+ attn_global = attn[:, :, :, :N_global].sum(dim=-1, keepdim=True)
54
+ # [B, N, N, feat_res, feat_res]
55
+ attn_local = attn[:, :, :, N_global:].view(batch_size, N, N, feat_res, feat_res)
56
+ # [B, N, N, feat_res, feat_res]
57
+ attn_vis = attn_global.unsqueeze(-1) + attn_local
58
+ # list of frame lists
59
+ images_vis = []
60
+ for b in range(batch_size):
61
+ images_vis_sample = []
62
+ for row in range(0, N, 8):
63
+ if row % 16 == 0:
64
+ col_range = range(0, N//8*8+1, 8)
65
+ else:
66
+ col_range = range(N//8*8, -1, -8)
67
+ for col in col_range:
68
+ # [feat_res, feat_res], x is col
69
+ attn_curr = attn_vis[b, col, row]
70
+ attn_curr = torch.nn.functional.interpolate(
71
+ attn_curr.unsqueeze(0).unsqueeze(0), size=(opt.H, opt.W),
72
+ mode='bilinear', align_corners=False
73
+ ).squeeze(0).squeeze(0).cpu().numpy()
74
+ attn_curr /= attn_curr.max()
75
+ # [feat_res, feat_res, 3]
76
+ image_curr = images[b].permute(1, 2, 0).cpu().numpy()
77
+ # merge the image and the attention
78
+ images_vis_sample.append(show_att_on_image(image_curr, attn_curr))
79
+ images_vis.append(images_vis_sample)
80
+ return occ, images_vis if vis_attn else None
81
+
82
+ @torch.no_grad()
83
+ def standardize_pc(pc):
84
+ assert len(pc.shape) == 3
85
+ pc_mean = pc.mean(dim=1, keepdim=True)
86
+ pc_zmean = pc - pc_mean
87
+ origin_distance = (pc_zmean**2).sum(dim=2, keepdim=True).sqrt()
88
+ scale = torch.sqrt(torch.sum(origin_distance**2, dim=1, keepdim=True) / pc.shape[1])
89
+ pc_standardized = pc_zmean / (scale * 2)
90
+ return pc_standardized
91
+
92
+ @torch.no_grad()
93
+ def normalize_pc(pc):
94
+ assert len(pc.shape) == 3
95
+ pc_mean = pc.mean(dim=1, keepdim=True)
96
+ pc_zmean = pc - pc_mean
97
+ length_x = pc_zmean[:, :, 0].max(dim=-1)[0] - pc_zmean[:, :, 0].min(dim=-1)[0]
98
+ length_y = pc_zmean[:, :, 1].max(dim=-1)[0] - pc_zmean[:, :, 1].min(dim=-1)[0]
99
+ length_max = torch.stack([length_x, length_y], dim=-1).max(dim=-1)[0].unsqueeze(-1).unsqueeze(-1)
100
+ pc_normalized = pc_zmean / (length_max + 1.e-7)
101
+ return pc_normalized
102
+
103
+ def convert_to_explicit(opt, level_grids, isoval=0., to_pointcloud=False):
104
+ N = len(level_grids)
105
+ meshes = [None]*N
106
+ pointclouds = [None]*N if to_pointcloud else None
107
+ threads = [threading.Thread(target=convert_to_explicit_worker,
108
+ args=(opt, i, level_grids[i], isoval, meshes),
109
+ kwargs=dict(pointclouds=pointclouds),
110
+ daemon=False) for i in range(N)]
111
+ for t in threads: t.start()
112
+ for t in threads: t.join()
113
+ if to_pointcloud:
114
+ pointclouds = np.stack(pointclouds, axis=0)
115
+ return meshes, pointclouds
116
+ else: return meshes
117
+
118
+ def convert_to_explicit_worker(opt, i, level_vox_i, isoval, meshes, pointclouds=None):
119
+ # use marching cubes to convert implicit surface to mesh
120
+ vertices, faces = mcubes.marching_cubes(level_vox_i, isovalue=isoval)
121
+ assert(level_vox_i.shape[0]==level_vox_i.shape[1]==level_vox_i.shape[2])
122
+ S = level_vox_i.shape[0]
123
+ range_min, range_max = opt.eval.range
124
+ # marching cubes treat every cube as unit length
125
+ vertices = vertices/S*(range_max-range_min)+range_min
126
+ mesh = trimesh.Trimesh(vertices, faces)
127
+ meshes[i] = mesh
128
+ if pointclouds is not None:
129
+ # randomly sample on mesh to get uniform dense point cloud
130
+ if len(mesh.triangles)!=0:
131
+ points = mesh.sample(opt.eval.num_points)
132
+ else: points = np.zeros([opt.eval.num_points, 3])
133
+ pointclouds[i] = points
utils/eval_depth.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # based on https://gist.github.com/ranftlr/45f4c7ddeb1bbb88d606bc600cab6c8d
2
+
3
+ import torch
4
+
5
+ class DepthMetric:
6
+ def __init__(self, thresholds=[1.25, 1.25**2, 1.25**3], depth_cap=None, prediction_type='depth'):
7
+ self.thresholds = thresholds
8
+ self.depth_cap = depth_cap
9
+ self.metric_keys = self.get_metric_keys()
10
+ self.prediction_type = prediction_type
11
+
12
+ def compute_scale_and_shift(self, prediction, target, mask):
13
+ # system matrix: A = [[a_00, a_01], [a_10, a_11]]
14
+ a_00 = torch.sum(mask * prediction * prediction, (1, 2))
15
+ a_01 = torch.sum(mask * prediction, (1, 2))
16
+ a_11 = torch.sum(mask, (1, 2))
17
+
18
+ # right hand side: b = [b_0, b_1]
19
+ b_0 = torch.sum(mask * prediction * target, (1, 2))
20
+ b_1 = torch.sum(mask * target, (1, 2))
21
+
22
+ # solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b
23
+ x_0 = torch.zeros_like(b_0)
24
+ x_1 = torch.zeros_like(b_1)
25
+
26
+ det = a_00 * a_11 - a_01 * a_01
27
+ # A needs to be a positive definite matrix.
28
+ valid = det > 0
29
+
30
+ x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid]
31
+ x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid]
32
+
33
+ return x_0, x_1
34
+
35
+ def get_metric_keys(self):
36
+ metric_keys = []
37
+ for threshold in self.thresholds:
38
+ metric_keys.append('d>{}'.format(threshold))
39
+ metric_keys.append('rmse')
40
+ metric_keys.append('l1_err')
41
+ metric_keys.append('abs_rel')
42
+ return metric_keys
43
+
44
+ def compute_metrics(self, prediction, target, mask):
45
+ # check inputs
46
+ prediction = prediction.float()
47
+ target = target.float()
48
+ mask = mask.float()
49
+ assert prediction.shape == target.shape == mask.shape
50
+ assert len(prediction.shape) == 4
51
+ assert prediction.shape[1] == 1
52
+ assert prediction.dtype == target.dtype == mask.dtype == torch.float32
53
+
54
+ # process inputs
55
+ prediction = prediction.squeeze(1)
56
+ target = target.squeeze(1)
57
+ mask = (mask.squeeze(1) > 0.5).long()
58
+
59
+ # output dict
60
+ metrics = {}
61
+
62
+ # get the predicted disparity
63
+ prediction_disparity = torch.zeros_like(prediction)
64
+ if self.prediction_type == 'depth':
65
+ prediction_disparity[mask == 1] = 1.0 / (prediction[mask == 1] + 1.e-6)
66
+ elif self.prediction_type == 'disparity':
67
+ prediction_disparity[mask == 1] = prediction[mask == 1]
68
+ else:
69
+ raise ValueError('Unknown prediction type: {}'.format(self.prediction_type))
70
+
71
+ # transform predicted disparity to align with depth
72
+ target_disparity = torch.zeros_like(target)
73
+ target_disparity[mask == 1] = 1.0 / target[mask == 1]
74
+ scale, shift = self.compute_scale_and_shift(prediction_disparity, target_disparity, mask)
75
+ prediction_aligned = scale.view(-1, 1, 1) * prediction_disparity + shift.view(-1, 1, 1)
76
+
77
+ if self.depth_cap is not None:
78
+ disparity_cap = 1.0 / self.depth_cap
79
+ prediction_aligned[prediction_aligned < disparity_cap] = disparity_cap
80
+
81
+ prediciton_depth = 1.0 / prediction_aligned
82
+
83
+ # delta > threshold, [batch_size, ]
84
+ for threshold in self.thresholds:
85
+ err = torch.zeros_like(prediciton_depth, dtype=torch.float)
86
+ err[mask == 1] = torch.max(
87
+ prediciton_depth[mask == 1] / target[mask == 1],
88
+ target[mask == 1] / prediciton_depth[mask == 1],
89
+ )
90
+ err[mask == 1] = (err[mask == 1] > threshold).float()
91
+ metrics['d>{}'.format(threshold)] = torch.sum(err, (1, 2)) / torch.sum(mask, (1, 2))
92
+
93
+ # rmse, [batch_size, ]
94
+ rmse = torch.zeros_like(prediciton_depth, dtype=torch.float)
95
+ rmse[mask == 1] = (prediciton_depth[mask == 1] - target[mask == 1]) ** 2
96
+ rmse = torch.sum(rmse, (1, 2)) / torch.sum(mask, (1, 2))
97
+ metrics['rmse'] = torch.sqrt(rmse)
98
+
99
+ # l1 error, [batch_size, ]
100
+ l1_err = torch.zeros_like(prediciton_depth, dtype=torch.float)
101
+ l1_err[mask == 1] = torch.abs(prediciton_depth[mask == 1] - target[mask == 1])
102
+ metrics['l1_err'] = torch.sum(l1_err, (1, 2)) / torch.sum(mask, (1, 2))
103
+
104
+ # abs_rel, [batch_size, ]
105
+ abs_rel = torch.zeros_like(prediciton_depth, dtype=torch.float)
106
+ abs_rel[mask == 1] = torch.abs(prediciton_depth[mask == 1] - target[mask == 1]) / target[mask == 1]
107
+ metrics['abs_rel'] = torch.sum(abs_rel, (1, 2)) / torch.sum(mask, (1, 2))
108
+
109
+ return metrics, prediciton_depth.unsqueeze(1)
110
+
utils/layers.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from functools import partial
5
+ from timm.models.vision_transformer import Block
6
+
7
+ # 3D positional encoding, from https://github.com/bmild/nerf.
8
+ class Embedder:
9
+ def __init__(self, **kwargs):
10
+ self.kwargs = kwargs
11
+ self.create_embedding_fn()
12
+
13
+ def create_embedding_fn(self):
14
+ embed_fns = []
15
+ d = self.kwargs['input_dims']
16
+ out_dim = 0
17
+ if self.kwargs['include_input']:
18
+ embed_fns.append(lambda x: x)
19
+ out_dim += d
20
+
21
+ max_freq = self.kwargs['max_freq_log2']
22
+ N_freqs = self.kwargs['num_freqs']
23
+
24
+ if self.kwargs['log_sampling']:
25
+ freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)
26
+ else:
27
+ freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs)
28
+
29
+ for freq in freq_bands:
30
+ for p_fn in self.kwargs['periodic_fns']:
31
+ embed_fns.append(lambda x, p_fn=p_fn,
32
+ freq=freq: p_fn(x * freq))
33
+ out_dim += d
34
+
35
+ self.embed_fns = embed_fns
36
+ self.out_dim = out_dim
37
+
38
+ def embed(self, inputs):
39
+ return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
40
+
41
+ def get_embedder(posenc_res, input_dims=3):
42
+ embed_kwargs = {
43
+ 'include_input': True,
44
+ 'input_dims': input_dims,
45
+ 'max_freq_log2': posenc_res-1,
46
+ 'num_freqs': posenc_res,
47
+ 'log_sampling': True,
48
+ 'periodic_fns': [torch.sin, torch.cos],
49
+ }
50
+
51
+ embedder_obj = Embedder(**embed_kwargs)
52
+ def embed(x, eo=embedder_obj): return eo.embed(x)
53
+ return embed, embedder_obj.out_dim
54
+
55
+ class LayerScale(nn.Module):
56
+ def __init__(self, dim, init_values=1e-5, inplace=False):
57
+ super().__init__()
58
+ self.inplace = inplace
59
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
60
+
61
+ def forward(self, x):
62
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
63
+
64
+ class Bottleneck_Linear(nn.Module):
65
+ def __init__(self, n_channels):
66
+ super().__init__()
67
+ self.linear1 = nn.Linear(n_channels, n_channels)
68
+ self.norm = nn.LayerNorm(n_channels)
69
+ self.linear2 = nn.Linear(n_channels, n_channels)
70
+ self.gelu = nn.GELU()
71
+
72
+ def forward(self, x):
73
+ x = x + self.linear2(self.gelu(self.linear1(self.norm(x))))
74
+ return x
75
+
76
+ class Bottleneck_Conv(nn.Module):
77
+ def __init__(self, n_channels, kernel_size=1):
78
+ super().__init__()
79
+ self.linear1 = nn.Conv2d(n_channels, n_channels, kernel_size=kernel_size, padding=kernel_size//2, bias=False)
80
+ self.bn1 = nn.BatchNorm2d(n_channels)
81
+ self.linear2 = nn.Conv2d(n_channels, n_channels, kernel_size=kernel_size, padding=kernel_size//2, bias=False)
82
+ self.bn2 = nn.BatchNorm2d(n_channels)
83
+ self.relu = nn.ReLU(inplace=True)
84
+
85
+ def forward(self, x):
86
+ assert len(x.shape) in [2, 4]
87
+ input_dims = len(x.shape)
88
+ if input_dims == 2:
89
+ x = x.unsqueeze(-1).unsqueeze(-1)
90
+ residual = x
91
+ out = self.linear1(x)
92
+ out = self.bn1(out)
93
+ out = self.relu(out)
94
+ out = self.linear2(out)
95
+ out = self.bn2(out)
96
+ out += residual
97
+ out = self.relu(out)
98
+ if input_dims == 2:
99
+ out = out.squeeze(-1).squeeze(-1)
100
+ return out
101
+
102
+ class CLIPFusionBlock_Concat(nn.Module):
103
+ """
104
+ Fuse clip and rgb embeddings via concat-proj
105
+ """
106
+ def __init__(self, n_channels=512, n_layers=1, act=True):
107
+ super().__init__()
108
+ proj = [Bottleneck_Linear(2 * n_channels) for _ in range(n_layers)]
109
+ proj.append(nn.Linear(2 * n_channels, n_channels))
110
+ if act: proj.append(nn.GELU())
111
+ self.proj = nn.Sequential(*proj)
112
+
113
+ def forward(self, sem_latent, clip_latent):
114
+ """
115
+ sem_latent: [B, N, C]
116
+ clip_latent: [B, C]
117
+ """
118
+ # [B, N, 2C]
119
+ latent_concat = torch.cat([sem_latent, clip_latent.unsqueeze(1).expand_as(sem_latent)], dim=-1)
120
+ # [B, N, C]
121
+ latent = self.proj(latent_concat)
122
+ return latent
123
+
124
+ class CLIPFusionBlock_Attn(nn.Module):
125
+ """
126
+ Fuse geometric and semantic embeddings via multi-layer MHA blocks
127
+ """
128
+ def __init__(self, n_channels=512, n_layers=1, act=True):
129
+ super().__init__()
130
+ self.attn_blocks = nn.ModuleList(
131
+ [Block(
132
+ n_channels, 8, 4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), drop_path=0.1
133
+ ) for _ in range(n_layers)]
134
+ )
135
+ if act: self.attn_blocks.append(nn.GELU())
136
+
137
+ def forward(self, sem_latent, clip_latent):
138
+ """
139
+ sem_latent: [B, N, C]
140
+ clip_latent: [B, C]
141
+ """
142
+ # [B, 1+N, C], clip first
143
+ latent = torch.cat([clip_latent.unsqueeze(1), sem_latent], dim=1)
144
+ for attn_block in self.attn_blocks:
145
+ latent = attn_block(latent)
146
+ # [B, N, C]
147
+ return latent[:, 1:, :]
utils/loss.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as torch_F
4
+
5
+ from copy import deepcopy
6
+ from model.depth.midas_loss import MidasLoss
7
+
8
+ class Loss(nn.Module):
9
+
10
+ def __init__(self, opt):
11
+ super().__init__()
12
+ self.opt = deepcopy(opt)
13
+ self.occ_loss = nn.BCEWithLogitsLoss(reduction='none')
14
+ self.midas_loss = MidasLoss(alpha=opt.training.depth_loss.grad_reg,
15
+ inverse_depth=opt.training.depth_loss.depth_inv,
16
+ shrink_mask=opt.training.depth_loss.mask_shrink)
17
+
18
+ def shape_loss(self, pred_occ_raw, gt_sdf):
19
+ assert len(pred_occ_raw.shape) == 2
20
+ assert len(gt_sdf.shape) == 2
21
+ # [B, N]
22
+ gt_occ = (gt_sdf < 0).float()
23
+ loss = self.occ_loss(pred_occ_raw, gt_occ)
24
+ weight_mask = torch.ones_like(loss)
25
+ thres = self.opt.training.shape_loss.impt_thres
26
+ weight_mask[torch.abs(gt_sdf) < thres] = weight_mask[torch.abs(gt_sdf) < thres] * self.opt.training.shape_loss.impt_weight
27
+ loss = loss * weight_mask
28
+ return loss.mean()
29
+
30
+ def depth_loss(self, pred_depth, gt_depth, mask):
31
+ assert len(pred_depth.shape) == len(gt_depth.shape) == len(mask.shape) == 4
32
+ assert pred_depth.shape[1] == gt_depth.shape[1] == mask.shape[1] == 1
33
+ loss = self.midas_loss(pred_depth, gt_depth, mask)
34
+ return loss
35
+
36
+ def intr_loss(self, seen_pred, seen_gt, mask):
37
+ assert len(seen_pred.shape) == len(seen_gt.shape) == 3
38
+ assert len(mask.shape) == 2
39
+ # [B, HW]
40
+ distance = torch.sum((seen_pred - seen_gt)**2, dim=-1)
41
+ loss = (distance * mask).sum() / (mask.sum() + 1.e-8)
42
+ return loss
utils/options.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os, sys, time
3
+ import torch
4
+ import random
5
+ import string
6
+ import yaml
7
+ import utils.util as util
8
+ import time
9
+
10
+ from utils.util import EasyDict as edict
11
+
12
+ # torch.backends.cudnn.enabled = False
13
+ # torch.backends.cudnn.benchmark = False
14
+ # torch.backends.cudnn.deterministic = True
15
+
16
+ def parse_arguments(args):
17
+ # parse from command line (syntax: --key1.key2.key3=value)
18
+ opt_cmd = {}
19
+ for arg in args:
20
+ assert(arg.startswith("--"))
21
+ if "=" not in arg[2:]: # --key means key=True, --key! means key=False
22
+ key_str, value = (arg[2:-1], "false") if arg[-1]=="!" else (arg[2:], "true")
23
+ else:
24
+ key_str, value = arg[2:].split("=")
25
+ keys_sub = key_str.split(".")
26
+ opt_sub = opt_cmd
27
+ for k in keys_sub[:-1]:
28
+ if k not in opt_sub: opt_sub[k] = {}
29
+ opt_sub = opt_sub[k]
30
+ # if opt_cmd['key1']['key2']['key3'] already exist for key1.key2.key3, print key3 as error msg
31
+ assert keys_sub[-1] not in opt_sub, keys_sub[-1]
32
+ opt_sub[keys_sub[-1]] = yaml.safe_load(value)
33
+ opt_cmd = edict(opt_cmd)
34
+ return opt_cmd
35
+
36
+ def set(opt_cmd={}, verbose=True, safe_check=True):
37
+ print("setting configurations...")
38
+ fname = opt_cmd.yaml # load from yaml file
39
+ opt_base = load_options(fname)
40
+ # override with command line arguments
41
+ opt = override_options(opt_base, opt_cmd, key_stack=[], safe_check=safe_check)
42
+ process_options(opt)
43
+ if verbose:
44
+ def print_options(opt, level=0):
45
+ for key, value in sorted(opt.items()):
46
+ if isinstance(value, (dict, edict)):
47
+ print(" "*level+"* "+key+":")
48
+ print_options(value, level+1)
49
+ else:
50
+ print(" "*level+"* "+key+":", value)
51
+ print_options(opt)
52
+ return opt
53
+
54
+ def load_options(fname):
55
+ with open(fname) as file:
56
+ opt = edict(yaml.safe_load(file))
57
+ if "_parent_" in opt:
58
+ # load parent yaml file(s) as base options
59
+ parent_fnames = opt.pop("_parent_")
60
+ if type(parent_fnames) is str:
61
+ parent_fnames = [parent_fnames]
62
+ for parent_fname in parent_fnames:
63
+ opt_parent = load_options(parent_fname)
64
+ opt_parent = override_options(opt_parent, opt, key_stack=[])
65
+ opt = opt_parent
66
+ print("loading {}...".format(fname))
67
+ return opt
68
+
69
+ def override_options(opt, opt_over, key_stack=None, safe_check=False):
70
+ for key, value in opt_over.items():
71
+ if isinstance(value, dict):
72
+ # parse child options (until leaf nodes are reached)
73
+ opt[key] = override_options(opt.get(key, dict()), value, key_stack=key_stack+[key], safe_check=safe_check)
74
+ else:
75
+ # ensure command line argument to override is also in yaml file
76
+ if safe_check and key not in opt:
77
+ add_new = None
78
+ while add_new not in ["y", "n"]:
79
+ key_str = ".".join(key_stack+[key])
80
+ add_new = input("\"{}\" not found in original opt, add? (y/n) ".format(key_str))
81
+ if add_new=="n":
82
+ print("safe exiting...")
83
+ exit()
84
+ opt[key] = value
85
+ return opt
86
+
87
+ def process_options(opt):
88
+ # set seed
89
+ if opt.seed is not None:
90
+ random.seed(opt.seed)
91
+ np.random.seed(opt.seed)
92
+ torch.manual_seed(opt.seed)
93
+ torch.cuda.manual_seed_all(opt.seed)
94
+ else:
95
+ # create random string as run ID
96
+ randkey = "".join(random.choice(string.ascii_uppercase) for _ in range(4))
97
+ opt.name += "_{}".format(randkey)
98
+ # other default options
99
+ opt.output_path = "{0}/{1}/{2}".format(opt.output_root, opt.group, opt.name)
100
+ os.makedirs(opt.output_path, exist_ok=True)
101
+ opt.H, opt.W = opt.image_size
102
+ if opt.freq.eval is None:
103
+ opt.freq.eval = max(opt.max_epoch // 20, 1)
104
+ if 'loss_weight' in opt:
105
+ opt.get_depth = False
106
+ opt.get_normal = False
107
+
108
+ def save_options_file(opt):
109
+ opt_fname = "{}/options.yaml".format(opt.output_path)
110
+ if os.path.isfile(opt_fname):
111
+ with open(opt_fname) as file:
112
+ opt_old = yaml.safe_load(file)
113
+ if opt!=opt_old:
114
+ # prompt if options are not identical
115
+ opt_new_fname = "{}/options_temp.yaml".format(opt.output_path)
116
+ with open(opt_new_fname, "w") as file:
117
+ yaml.safe_dump(util.to_dict(opt), file, default_flow_style=False, indent=4)
118
+ print("existing options file found (different from current one)...")
119
+ os.system("diff {} {}".format(opt_fname, opt_new_fname))
120
+ os.system("rm {}".format(opt_new_fname))
121
+ if not opt.debug:
122
+ print("please cancel within 10 seconds if you do not want to override...")
123
+ time.sleep(10)
124
+ else: print("existing options file found (identical)")
125
+ else: print("(creating new options file...)")
126
+ with open(opt_fname, "w") as file:
127
+ yaml.safe_dump(util.to_dict(opt), file, default_flow_style=False, indent=4)
utils/pos_embed.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is copied from https://github.com/facebookresearch/MCC
5
+ # The original code base is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ # --------------------------------------------------------
8
+ # Position embedding utils
9
+ # --------------------------------------------------------
10
+
11
+ import numpy as np
12
+
13
+ import torch
14
+
15
+ # --------------------------------------------------------
16
+ # 2D sine-cosine position embedding
17
+ # References:
18
+ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
19
+ # MoCo v3: https://github.com/facebookresearch/moco-v3
20
+ # --------------------------------------------------------
21
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
22
+ """
23
+ grid_size: int of the grid height and width
24
+ return:
25
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
26
+ """
27
+ grid_h = np.arange(grid_size, dtype=np.float32)
28
+ grid_w = np.arange(grid_size, dtype=np.float32)
29
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
30
+ grid = np.stack(grid, axis=0)
31
+
32
+ grid = grid.reshape([2, 1, grid_size, grid_size])
33
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
34
+ if cls_token:
35
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
36
+ return pos_embed
37
+
38
+
39
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
40
+ assert embed_dim % 2 == 0
41
+
42
+ # use half of dimensions to encode grid_h
43
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
44
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
45
+
46
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
47
+ return emb
48
+
49
+
50
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
51
+ """
52
+ embed_dim: output dimension for each position
53
+ pos: a list of positions to be encoded: size (M,)
54
+ out: (M, D)
55
+ """
56
+ assert embed_dim % 2 == 0
57
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
58
+ omega /= embed_dim / 2.
59
+ omega = 1. / 10000**omega # (D/2,)
60
+
61
+ pos = pos.reshape(-1) # (M,)
62
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
63
+
64
+ emb_sin = np.sin(out) # (M, D/2)
65
+ emb_cos = np.cos(out) # (M, D/2)
66
+
67
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
68
+ return emb
69
+
70
+ def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos):
71
+ """
72
+ embed_dim: output dimension for each position
73
+ pos: a list of positions to be encoded: size (M,)
74
+ out: (M, D)
75
+ """
76
+ assert embed_dim % 2 == 0
77
+ omega = torch.arange(embed_dim // 2, device=pos.device).float()
78
+ omega /= embed_dim / 2.
79
+ omega = 1. / 10000**omega # (D/2,)
80
+
81
+ pos = pos.reshape(-1) # (M,)
82
+ out = torch.einsum('m,d->md', pos, omega) # (M, D/2), outer product
83
+
84
+ emb_sin = torch.sin(out) # (M, D/2)
85
+ emb_cos = torch.cos(out) # (M, D/2)
86
+
87
+ emb = torch.cat([emb_sin, emb_cos], axis=1) # (M, D)
88
+ return emb
89
+
90
+
91
+
92
+ # --------------------------------------------------------
93
+ # Interpolate position embeddings for high-resolution
94
+ # References:
95
+ # DeiT: https://github.com/facebookresearch/deit
96
+ # --------------------------------------------------------
97
+ def interpolate_pos_embed(model, checkpoint_model):
98
+ if 'pos_embed' in checkpoint_model:
99
+ pos_embed_checkpoint = checkpoint_model['pos_embed']
100
+ embedding_size = pos_embed_checkpoint.shape[-1]
101
+ num_patches = model.patch_embed.num_patches
102
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
103
+ # height (== width) for the checkpoint position embedding
104
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
105
+ # height (== width) for the new position embedding
106
+ new_size = int(num_patches ** 0.5)
107
+ # class_token and dist_token are kept unchanged
108
+ if orig_size != new_size:
109
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
110
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
111
+ # only the position tokens are interpolated
112
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
113
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
114
+ pos_tokens = torch.nn.functional.interpolate(
115
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
116
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
117
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
118
+ checkpoint_model['pos_embed'] = new_pos_embed
utils/util.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, time
2
+ import shutil
3
+ import datetime
4
+ import torch
5
+ import torch.nn.functional as torch_F
6
+ import socket
7
+ import contextlib
8
+ import socket
9
+ import torch.distributed as dist
10
+ from collections import defaultdict, deque
11
+
12
+ class SmoothedValue(object):
13
+ """Track a series of values and provide access to smoothed values over a
14
+ window or the global series average.
15
+ """
16
+
17
+ def __init__(self, window_size=20, fmt=None):
18
+ if fmt is None:
19
+ fmt = "{median:.4f} ({global_avg:.4f})"
20
+ self.deque = deque(maxlen=window_size)
21
+ self.total = 0.0
22
+ self.count = 0
23
+ self.fmt = fmt
24
+
25
+ def update(self, value, n=1):
26
+ self.deque.append(value)
27
+ self.count += n
28
+ self.total += value * n
29
+
30
+ @property
31
+ def median(self):
32
+ d = torch.tensor(list(self.deque))
33
+ return d.median().item()
34
+
35
+ @property
36
+ def avg(self):
37
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
38
+ return d.mean().item()
39
+
40
+ @property
41
+ def global_avg(self):
42
+ return self.total / self.count
43
+
44
+ @property
45
+ def max(self):
46
+ return max(self.deque)
47
+
48
+ @property
49
+ def value(self):
50
+ return self.deque[-1]
51
+
52
+ def __str__(self):
53
+ return self.fmt.format(
54
+ median=self.median,
55
+ avg=self.avg,
56
+ global_avg=self.global_avg,
57
+ max=self.max,
58
+ value=self.value)
59
+
60
+
61
+ class MetricLogger(object):
62
+ def __init__(self, delimiter="\t"):
63
+ self.meters = defaultdict(SmoothedValue)
64
+ self.delimiter = delimiter
65
+
66
+ def update(self, **kwargs):
67
+ for k, v in kwargs.items():
68
+ if v is None:
69
+ continue
70
+ if isinstance(v, torch.Tensor):
71
+ v = v.item()
72
+ assert isinstance(v, (float, int))
73
+ self.meters[k].update(v)
74
+
75
+ def __getattr__(self, attr):
76
+ if attr in self.meters:
77
+ return self.meters[attr]
78
+ if attr in self.__dict__:
79
+ return self.__dict__[attr]
80
+ raise AttributeError("'{}' object has no attribute '{}'".format(
81
+ type(self).__name__, attr))
82
+
83
+ def __str__(self):
84
+ loss_str = []
85
+ for name, meter in self.meters.items():
86
+ loss_str.append(
87
+ "{}: {}".format(name, str(meter))
88
+ )
89
+ return self.delimiter.join(loss_str)
90
+
91
+ def add_meter(self, name, meter):
92
+ self.meters[name] = meter
93
+
94
+ def log_every(self, iterable, print_freq, header=None):
95
+ i = 0
96
+ if not header:
97
+ header = ''
98
+ start_time = time.time()
99
+ end = time.time()
100
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
101
+ data_time = SmoothedValue(fmt='{avg:.4f}')
102
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
103
+ log_msg = [
104
+ header,
105
+ '[{0' + space_fmt + '}/{1}]',
106
+ 'eta: {eta}',
107
+ '{meters}',
108
+ 'time: {time}',
109
+ 'data: {data}'
110
+ ]
111
+ if torch.cuda.is_available():
112
+ log_msg.append('max mem: {memory:.0f}')
113
+ log_msg = self.delimiter.join(log_msg)
114
+ MB = 1024.0 * 1024.0
115
+ for obj in iterable:
116
+ data_time.update(time.time() - end)
117
+ yield obj
118
+ iter_time.update(time.time() - end)
119
+ if i % print_freq == 0 or i == len(iterable) - 1:
120
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
121
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
122
+ if torch.cuda.is_available():
123
+ print(log_msg.format(
124
+ i, len(iterable), eta=eta_string,
125
+ meters=str(self),
126
+ time=str(iter_time), data=str(data_time),
127
+ memory=torch.cuda.max_memory_allocated() / MB))
128
+ else:
129
+ print(log_msg.format(
130
+ i, len(iterable), eta=eta_string,
131
+ meters=str(self),
132
+ time=str(iter_time), data=str(data_time)))
133
+ i += 1
134
+ end = time.time()
135
+ total_time = time.time() - start_time
136
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
137
+ print('{} Total time: {} ({:.4f} s / it)'.format(
138
+ header, total_time_str, total_time / len(iterable)))
139
+
140
+ def print_eval(opt, loss=None, chamfer=None, depth_metrics=None):
141
+ message = "[eval] "
142
+ if loss is not None: message += "loss:{}".format("{:.3e}".format(loss.all))
143
+ if chamfer is not None:
144
+ message += " chamfer:{}|{}|{}".format("{:.4f}".format(chamfer[0]),
145
+ "{:.4f}".format(chamfer[1]),
146
+ "{:.4f}".format((chamfer[0]+chamfer[1])/2))
147
+ if depth_metrics is not None:
148
+ for k, v in depth_metrics.items():
149
+ message += "{}:{}, ".format(k, "{:.4f}".format(v))
150
+ message = message[:-2]
151
+ print(message)
152
+
153
+ def update_timer(opt, timer, ep, it_per_ep):
154
+ momentum = 0.99
155
+ timer.elapsed = time.time()-timer.start
156
+ timer.it = timer.it_end-timer.it_start
157
+ # compute speed with moving average
158
+ timer.it_mean = timer.it_mean*momentum+timer.it*(1-momentum) if timer.it_mean is not None else timer.it
159
+ timer.arrival = timer.it_mean*it_per_ep*(opt.max_epoch-ep)
160
+
161
+ # move tensors to device in-place
162
+ def move_to_device(X, device):
163
+ if isinstance(X, dict):
164
+ for k, v in X.items():
165
+ X[k] = move_to_device(v, device)
166
+ elif isinstance(X, list):
167
+ for i, e in enumerate(X):
168
+ X[i] = move_to_device(e, device)
169
+ elif isinstance(X, tuple) and hasattr(X, "_fields"): # collections.namedtuple
170
+ dd = X._asdict()
171
+ dd = move_to_device(dd, device)
172
+ return type(X)(**dd)
173
+ elif isinstance(X, torch.Tensor):
174
+ return X.to(device=device, non_blocking=True)
175
+ return X
176
+
177
+ # detach tensors
178
+ def detach_tensors(X):
179
+ if isinstance(X, dict):
180
+ for k, v in X.items():
181
+ X[k] = detach_tensors(v)
182
+ elif isinstance(X, list):
183
+ for i, e in enumerate(X):
184
+ X[i] = detach_tensors(e)
185
+ elif isinstance(X, tuple) and hasattr(X, "_fields"): # collections.namedtuple
186
+ dd = X._asdict()
187
+ dd = detach_tensors(dd)
188
+ return type(X)(**dd)
189
+ elif isinstance(X, torch.Tensor):
190
+ return X.detach()
191
+ return X
192
+
193
+ # this recursion seems to only work for the outer loop when dict_type is not dict
194
+ def to_dict(D, dict_type=dict):
195
+ D = dict_type(D)
196
+ for k, v in D.items():
197
+ if isinstance(v, dict):
198
+ D[k] = to_dict(v, dict_type)
199
+ return D
200
+
201
+ def get_child_state_dict(state_dict, key):
202
+ out_dict = {}
203
+ for k, v in state_dict.items():
204
+ if k.startswith("module."):
205
+ param_name = k[7:]
206
+ else:
207
+ param_name = k
208
+ if param_name.startswith("{}.".format(key)):
209
+ out_dict[".".join(param_name.split(".")[1:])] = v
210
+ return out_dict
211
+
212
+ def resume_checkpoint(opt, model, best):
213
+ load_name = "{0}/best.ckpt".format(opt.output_path) if best else "{0}/latest.ckpt".format(opt.output_path)
214
+ checkpoint = torch.load(load_name, map_location=torch.device(opt.device))
215
+ model.graph.module.load_state_dict(checkpoint["graph"], strict=True)
216
+ # load the training stats
217
+ for key in model.__dict__:
218
+ if key.split("_")[0] in ["optim", "sched", "scaler"] and key in checkpoint:
219
+ if opt.device == 0: print("restoring {}...".format(key))
220
+ getattr(model, key).load_state_dict(checkpoint[key])
221
+ # also need to record ep, it, best_val if we are returning
222
+ ep, it = checkpoint["epoch"], checkpoint["iter"]
223
+ best_val, best_ep = checkpoint["best_val"], checkpoint["best_ep"] if "best_ep" in checkpoint else 0
224
+ print("resuming from epoch {0} (iteration {1})".format(ep, it))
225
+
226
+ return ep, it, best_val, best_ep
227
+
228
+ def load_checkpoint(opt, model, load_name):
229
+ # load_name as to be given
230
+ checkpoint = torch.load(load_name, map_location=torch.device(opt.device))
231
+ # load individual (possibly partial) children modules
232
+ for name, child in model.graph.module.named_children():
233
+ child_state_dict = get_child_state_dict(checkpoint["graph"], name)
234
+ if child_state_dict:
235
+ if opt.device == 0: print("restoring {}...".format(name))
236
+ child.load_state_dict(child_state_dict, strict=True)
237
+ else:
238
+ if opt.device == 0: print("skipping {}...".format(name))
239
+ return None, None, None, None
240
+
241
+ def restore_checkpoint(opt, model, load_name=None, resume=False, best=False, evaluate=False):
242
+ # we cannot load and resume at the same time
243
+ assert not (load_name is not None and resume)
244
+ # when resuming we want everything to be the same
245
+ if resume:
246
+ ep, it, best_val, best_ep = resume_checkpoint(opt, model, best)
247
+ # loading is more flexible, as we can only load parts of the model
248
+ else:
249
+ ep, it, best_val, best_ep = load_checkpoint(opt, model, load_name)
250
+ return ep, it, best_val, best_ep
251
+
252
+ def save_checkpoint(opt, model, ep, it, best_val, best_ep, latest=False, best=False, children=None):
253
+ os.makedirs("{0}/checkpoint".format(opt.output_path), exist_ok=True)
254
+ if isinstance(model.graph, torch.nn.DataParallel) or isinstance(model.graph, torch.nn.parallel.DistributedDataParallel):
255
+ graph = model.graph.module
256
+ else:
257
+ graph = model.graph
258
+ if children is not None:
259
+ graph_state_dict = { k: v for k, v in graph.state_dict().items() if k.startswith(children) }
260
+ else: graph_state_dict = graph.state_dict()
261
+ checkpoint = dict(
262
+ epoch=ep,
263
+ iter=it,
264
+ best_val=best_val,
265
+ best_ep=best_ep,
266
+ graph=graph_state_dict,
267
+ )
268
+ for key in model.__dict__:
269
+ if key.split("_")[0] in ["optim", "sched", "scaler"]:
270
+ checkpoint.update({key: getattr(model, key).state_dict()})
271
+ torch.save(checkpoint, "{0}/latest.ckpt".format(opt.output_path))
272
+ if best:
273
+ shutil.copy("{0}/latest.ckpt".format(opt.output_path),
274
+ "{0}/best.ckpt".format(opt.output_path))
275
+ if not latest:
276
+ shutil.copy("{0}/latest.ckpt".format(opt.output_path),
277
+ "{0}/checkpoint/ep{1}.ckpt".format(opt.output_path, ep))
278
+
279
+ def check_socket_open(hostname, port):
280
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
281
+ is_open = False
282
+ try:
283
+ s.bind((hostname, port))
284
+ except socket.error:
285
+ is_open = True
286
+ finally:
287
+ s.close()
288
+ return is_open
289
+
290
+ def get_layer_dims(layers):
291
+ # return a list of tuples (k_in, k_out)
292
+ return list(zip(layers[:-1], layers[1:]))
293
+
294
+ @contextlib.contextmanager
295
+ def suppress(stdout=False, stderr=False):
296
+ with open(os.devnull, "w") as devnull:
297
+ if stdout: old_stdout, sys.stdout = sys.stdout, devnull
298
+ if stderr: old_stderr, sys.stderr = sys.stderr, devnull
299
+ try: yield
300
+ finally:
301
+ if stdout: sys.stdout = old_stdout
302
+ if stderr: sys.stderr = old_stderr
303
+
304
+ def toggle_grad(model, requires_grad):
305
+ for p in model.parameters():
306
+ p.requires_grad_(requires_grad)
307
+
308
+ def compute_grad2(d_outs, x_in):
309
+ d_outs = [d_outs] if not isinstance(d_outs, list) else d_outs
310
+ reg = 0
311
+ for d_out in d_outs:
312
+ batch_size = x_in.size(0)
313
+ grad_dout = torch.autograd.grad(
314
+ outputs=d_out.sum(), inputs=x_in,
315
+ create_graph=True, retain_graph=True, only_inputs=True
316
+ )[0]
317
+ grad_dout2 = grad_dout.pow(2)
318
+ assert(grad_dout2.size() == x_in.size())
319
+ reg += grad_dout2.view(batch_size, -1).sum(1)
320
+ return reg / len(d_outs)
321
+
322
+ # import matplotlib.pyplot as plt
323
+ def interpolate_depth(depth_input, mask_input, size, bg_depth=20):
324
+ assert len(depth_input.shape) == len(mask_input.shape) == 4
325
+ mask = (mask_input > 0.5).float()
326
+ depth_valid = depth_input * mask
327
+ depth_valid = torch_F.interpolate(depth_valid, size, mode='bilinear', align_corners=False)
328
+ mask = torch_F.interpolate(mask, size, mode='bilinear', align_corners=False)
329
+ depth_out = depth_valid / (mask + 1.e-6)
330
+ mask_binary = (mask > 0.5).float()
331
+ depth_out = depth_out * mask_binary + bg_depth * (1 - mask_binary)
332
+ return depth_out, mask_binary
333
+
334
+ # import matplotlib.pyplot as plt
335
+ # import torchvision
336
+ def interpolate_coordmap(coord_map, mask_input, size, bg_coord=0):
337
+ assert len(coord_map.shape) == len(mask_input.shape) == 4
338
+ mask = (mask_input > 0.5).float()
339
+ coord_valid = coord_map * mask
340
+ coord_valid = torch_F.interpolate(coord_valid, size, mode='bilinear', align_corners=False)
341
+ mask = torch_F.interpolate(mask, size, mode='bilinear', align_corners=False)
342
+ coord_out = coord_valid / (mask + 1.e-6)
343
+ mask_binary = (mask > 0.5).float()
344
+ coord_out = coord_out * mask_binary + bg_coord * (1 - mask_binary)
345
+ return coord_out, mask_binary
346
+
347
+ def cleanup():
348
+ dist.destroy_process_group()
349
+
350
+ def is_port_in_use(port):
351
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
352
+ return s.connect_ex(('localhost', port)) == 0
353
+
354
+ def setup(rank, world_size, port_no):
355
+ full_address = 'tcp://127.0.0.1:' + str(port_no)
356
+ dist.init_process_group("nccl", init_method=full_address, rank=rank, world_size=world_size)
357
+
358
+ def print_grad(grad, prefix=''):
359
+ print("{} --- Grad Abs Mean, Grad Max, Grad Min: {:.5f} | {:.5f} | {:.5f}".format(prefix, grad.abs().mean().item(), grad.max().item(), grad.min().item()))
360
+
361
+ class AverageMeter(object):
362
+ """Computes and stores the average and current value"""
363
+ def __init__(self):
364
+ self.reset()
365
+
366
+ def reset(self):
367
+ self.val = 0
368
+ self.avg = 0
369
+ self.sum = 0
370
+ self.count = 0
371
+
372
+ def update(self, val, n=1):
373
+ self.val = val
374
+ self.sum += val * n
375
+ self.count += n
376
+ self.avg = self.sum / self.count
377
+
378
+ class EasyDict(dict):
379
+ def __init__(self, d=None, **kwargs):
380
+ if d is None:
381
+ d = {}
382
+ else:
383
+ d = dict(d)
384
+ if kwargs:
385
+ d.update(**kwargs)
386
+ for k, v in d.items():
387
+ setattr(self, k, v)
388
+ # Class attributes
389
+ for k in self.__class__.__dict__.keys():
390
+ if not (k.startswith('__') and k.endswith('__')) and not k in ('update', 'pop'):
391
+ setattr(self, k, getattr(self, k))
392
+
393
+ def __setattr__(self, name, value):
394
+ if isinstance(value, (list, tuple)):
395
+ value = [self.__class__(x)
396
+ if isinstance(x, dict) else x for x in value]
397
+ elif isinstance(value, dict) and not isinstance(value, self.__class__):
398
+ value = self.__class__(value)
399
+ super(EasyDict, self).__setattr__(name, value)
400
+ super(EasyDict, self).__setitem__(name, value)
401
+
402
+ __setitem__ = __setattr__
403
+
404
+ def update(self, e=None, **f):
405
+ d = e or dict()
406
+ d.update(f)
407
+ for k in d:
408
+ setattr(self, k, d[k])
409
+
410
+ def pop(self, k, d=None):
411
+ delattr(self, k)
412
+ return super(EasyDict, self).pop(k, d)
413
+
utils/util_vis.py ADDED
@@ -0,0 +1,511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import torch
4
+ import torchvision
5
+ import torchvision.transforms.functional as torchvision_F
6
+ import matplotlib.pyplot as plt
7
+ import PIL
8
+ import PIL.ImageDraw
9
+ from PIL import Image, ImageFont
10
+ import trimesh
11
+ import pyrender
12
+ import cv2
13
+ import copy
14
+ import base64
15
+ import io
16
+ import imageio
17
+
18
+ os.environ['PYOPENGL_PLATFORM'] = 'egl'
19
+ @torch.no_grad()
20
+ def tb_image(opt, tb, step, group, name, images, masks=None, num_vis=None, from_range=(0, 1), poses=None, cmap="gray", depth=False):
21
+ if not depth:
22
+ images = preprocess_vis_image(opt, images, masks=masks, from_range=from_range, cmap=cmap) # [B, 3, H, W]
23
+ else:
24
+ masks = (masks > 0.5).float()
25
+ images = images * masks + (1 - masks) * ((images * masks).max())
26
+ images = (1 - images).detach().cpu()
27
+ num_H, num_W = num_vis or opt.tb.num_images
28
+ images = images[:num_H*num_W]
29
+ if poses is not None:
30
+ # poses: [B, 3, 4]
31
+ # rots: [max(B, num_images), 3, 3]
32
+ rots = poses[:num_H*num_W, ..., :3]
33
+ images = torch.stack([draw_pose(opt, image, rot, size=20, width=2) for image, rot in zip(images, rots)], dim=0)
34
+ image_grid = torchvision.utils.make_grid(images[:, :3], nrow=num_W, pad_value=1.)
35
+ if images.shape[1]==4:
36
+ mask_grid = torchvision.utils.make_grid(images[:, 3:], nrow=num_W, pad_value=1.)[:1]
37
+ image_grid = torch.cat([image_grid, mask_grid], dim=0)
38
+ tag = "{0}/{1}".format(group, name)
39
+ tb.add_image(tag, image_grid, step)
40
+
41
+ def preprocess_vis_image(opt, images, masks=None, from_range=(0, 1), cmap="gray"):
42
+ min, max = from_range
43
+ images = (images-min)/(max-min)
44
+ if masks is not None:
45
+ # then the mask is directly the transparency channel of png
46
+ images = torch.cat([images, masks], dim=1)
47
+ images = images.clamp(min=0, max=1).cpu()
48
+ if images.shape[1]==1:
49
+ images = get_heatmap(opt, images[:, 0].cpu(), cmap=cmap)
50
+ return images
51
+
52
+ def preprocess_depth_image(opt, depth, mask=None, max_depth=1000):
53
+ if mask is not None: depth = depth * mask + (1 - mask) * max_depth # min of this will leads to minimum of masked regions
54
+ depth = depth - depth.min()
55
+
56
+ if mask is not None: depth = depth * mask # max of this will leads to maximum of masked regions
57
+ depth = depth / depth.max()
58
+ return depth
59
+
60
+ def dump_images(opt, idx, name, images, masks=None, from_range=(0, 1), poses=None, metrics=None, cmap="gray", folder='dump'):
61
+ images = preprocess_vis_image(opt, images, masks=masks, from_range=from_range, cmap=cmap) # [B, 3, H, W]
62
+ if poses is not None:
63
+ rots = poses[..., :3]
64
+ images = torch.stack([draw_pose(opt, image, rot, size=20, width=2) for image, rot in zip(images, rots)], dim=0)
65
+ if metrics is not None:
66
+ images = torch.stack([draw_metric(opt, image, metric.item()) for image, metric in zip(images, metrics)], dim=0)
67
+ images = images.cpu().permute(0, 2, 3, 1).contiguous().numpy() # [B, H, W, 3]
68
+ for i, img in zip(idx, images):
69
+ fname = "{}/{}/{}_{}.png".format(opt.output_path, folder, i, name)
70
+ img = Image.fromarray((img*255).astype(np.uint8))
71
+ img.save(fname)
72
+
73
+ def dump_depths(opt, idx, name, depths, masks=None, rescale=False, folder='dump'):
74
+ if rescale:
75
+ masks = (masks > 0.5).float()
76
+ depths = depths * masks + (1 - masks) * ((depths * masks).max())
77
+ depths = (1 - depths).detach().cpu()
78
+ for i, depth in zip(idx, depths):
79
+ fname = "{}/{}/{}_{}.png".format(opt.output_path, folder, i, name)
80
+ plt.imsave(fname, depth.squeeze(), cmap='viridis')
81
+
82
+ # img_list is a list of length n_views, where each view is a image tensor of [B, 3, H, W]
83
+ def dump_gifs(opt, idx, name, imgs_list, from_range=(0, 1), folder='dump', cmap="gray"):
84
+ for i in range(len(imgs_list)):
85
+ imgs_list[i] = preprocess_vis_image(opt, imgs_list[i], from_range=from_range, cmap=cmap)
86
+ for i in range(len(idx)):
87
+ img_list_np = [imgs[i].cpu().permute(1, 2, 0).contiguous().numpy() for imgs in imgs_list] # list of [H, W, 3], each item is a view of ith sample
88
+ img_list_pil = [Image.fromarray((img*255).astype(np.uint8)).convert('RGB') for img in img_list_np]
89
+ fname = "{}/{}/{}_{}.gif".format(opt.output_path, folder, idx[i], name)
90
+ img_list_pil[0].save(fname, format='GIF', append_images=img_list_pil[1:], save_all=True, duration=100, loop=0)
91
+
92
+ # img_list is a list of length n_views, where each view is a image tensor of [B, 3, H, W]
93
+ def dump_attentions(opt, idx, name, attn_vis, folder='dump'):
94
+ for i in range(len(idx)):
95
+ img_list_pil = [Image.fromarray((img*255).astype(np.uint8)).convert('RGB') for img in attn_vis[i]]
96
+ fname = "{}/{}/{}_{}.gif".format(opt.output_path, folder, idx[i], name)
97
+ img_list_pil[0].save(fname, format='GIF', append_images=img_list_pil[1:], save_all=True, duration=50, loop=0)
98
+
99
+ def get_heatmap(opt, gray, cmap): # [N, H, W]
100
+ color = plt.get_cmap(cmap)(gray.numpy())
101
+ color = torch.from_numpy(color[..., :3]).permute(0, 3, 1, 2).contiguous().float() # [N, 3, H, W]
102
+ return color
103
+
104
+ def dump_meshes(opt, idx, name, meshes, folder='dump'):
105
+ for i, mesh in zip(idx, meshes):
106
+ fname = "{}/{}/{}_{}.ply".format(opt.output_path, folder, i, name)
107
+ try:
108
+ mesh.export(fname)
109
+ except:
110
+ print('Mesh is empty!')
111
+
112
+ def dump_meshes_viz(opt, idx, name, meshes, save_frames=True, folder='dump'):
113
+ for i, mesh in zip(idx, meshes):
114
+ mesh = copy.deepcopy(mesh)
115
+ R = trimesh.transformations.rotation_matrix(np.radians(180), [0,0,1])
116
+ mesh.apply_transform(R)
117
+ R = trimesh.transformations.rotation_matrix(np.radians(180), [0,1,0])
118
+ mesh.apply_transform(R)
119
+ # our marching cubes outputs inverted normals for some reason so this is necessary
120
+ trimesh.repair.fix_inversion(mesh)
121
+
122
+ fname = "{}/{}/{}_{}".format(opt.output_path, folder, i, name)
123
+ try:
124
+ mesh = scale_to_unit_cube(mesh)
125
+ visualize_mesh(mesh, fname, write_frames=save_frames)
126
+ except:
127
+ pass
128
+
129
+ def dump_seen_surface(opt, idx, obj_name, img_name, seen_projs, folder='dump'):
130
+ # seen_proj: [B, H, W, 3]
131
+ for i, seen_proj in zip(idx, seen_projs):
132
+ out_folder = "{}/{}".format(opt.output_path, folder)
133
+ img_fname = "{}_{}.png".format(i, img_name)
134
+ create_seen_surface(i, img_fname, seen_proj, out_folder, obj_name)
135
+
136
+ # https://github.com/princeton-vl/oasis/blob/master/utils/vis_mesh.py
137
+ def create_seen_surface(sample_ID, img_path, XYZ, output_folder, obj_name, connect_thres=0.005):
138
+ height, width = XYZ.shape[:2]
139
+ XYZ_to_idx = {}
140
+ idx = 1
141
+ with open("{}/{}_{}.mtl".format(output_folder, sample_ID, obj_name), "w") as f:
142
+ f.write("newmtl material_0\n")
143
+ f.write("Ka 0.200000 0.200000 0.200000\n")
144
+ f.write("Kd 0.752941 0.752941 0.752941\n")
145
+ f.write("Ks 1.000000 1.000000 1.000000\n")
146
+ f.write("Tr 1.000000\n")
147
+ f.write("illum 2\n")
148
+ f.write("Ns 0.000000\n")
149
+ f.write("map_Ka %s\n" % img_path)
150
+ f.write("map_Kd %s\n" % img_path)
151
+
152
+ with open("{}/{}_{}.obj".format(output_folder, sample_ID, obj_name), "w") as f:
153
+ f.write("mtllib {}_{}.mtl\n".format(sample_ID, obj_name))
154
+ for y in range(height):
155
+ for x in range(width):
156
+ if XYZ[y][x][2] > 0:
157
+ XYZ_to_idx[(y, x)] = idx
158
+ idx += 1
159
+ f.write("v %.4f %.4f %.4f\n" % (XYZ[y][x][0], XYZ[y][x][1], XYZ[y][x][2]))
160
+ f.write("vt %.8f %.8f\n" % ( float(x) / float(width), 1.0 - float(y) / float(height)))
161
+ f.write("usemtl material_0\n")
162
+ for y in range(height-1):
163
+ for x in range(width-1):
164
+ if XYZ[y][x][2] > 0 and XYZ[y][x+1][2] > 0 and XYZ[y+1][x][2] > 0:
165
+ # if close enough, connect vertices to form a face
166
+ if torch.norm(XYZ[y][x] - XYZ[y][x+1]).item() < connect_thres and torch.norm(XYZ[y][x] - XYZ[y+1][x]).item() < connect_thres:
167
+ f.write("f %d/%d %d/%d %d/%d\n" % (XYZ_to_idx[(y, x)], XYZ_to_idx[(y, x)], XYZ_to_idx[(y, x+1)], XYZ_to_idx[(y, x+1)], XYZ_to_idx[(y+1, x)], XYZ_to_idx[(y+1, x)]))
168
+ if XYZ[y][x+1][2] > 0 and XYZ[y+1][x+1][2] > 0 and XYZ[y+1][x][2] > 0:
169
+ if torch.norm(XYZ[y][x+1] - XYZ[y+1][x+1]).item() < connect_thres and torch.norm(XYZ[y][x+1] - XYZ[y+1][x]).item() < connect_thres:
170
+ f.write("f %d/%d %d/%d %d/%d\n" % (XYZ_to_idx[(y, x+1)], XYZ_to_idx[(y, x+1)], XYZ_to_idx[(y+1, x+1)], XYZ_to_idx[(y+1, x+1)], XYZ_to_idx[(y+1, x)], XYZ_to_idx[(y+1, x)]))
171
+
172
+ def dump_pointclouds_compare(opt, idx, name, preds, gts, folder='dump'):
173
+ for i in range(len(idx)):
174
+ pred = preds[i].cpu().numpy() # [N1, 3]
175
+ gt = gts[i].cpu().numpy() # [N2, 3]
176
+ color_pred = np.zeros(pred.shape).astype(np.uint8)
177
+ color_pred[:, 0] = 255
178
+ color_gt = np.zeros(gt.shape).astype(np.uint8)
179
+ color_gt[:, 1] = 255
180
+ pc_vertices = np.vstack([pred, gt])
181
+ colors = np.vstack([color_pred, color_gt])
182
+ pc_color = trimesh.points.PointCloud(vertices=pc_vertices, colors=colors)
183
+ fname = "{}/{}/{}_{}.ply".format(opt.output_path, folder, idx[i], name)
184
+ pc_color.export(fname)
185
+
186
+ def dump_pointclouds(opt, idx, name, pcs, colors, folder='dump', colormap='jet'):
187
+ for i, pc, color in zip(idx, pcs, colors):
188
+ pc = pc.cpu().numpy() # [N, 3]
189
+ color = color.cpu().numpy() # [N, 3] or [N, 1]
190
+ # convert scalar color to rgb with colormap
191
+ if color.shape[1] == 1:
192
+ # single channel color in numpy between [0, 1] to rgb
193
+ color = plt.get_cmap(colormap)(color[:, 0])
194
+ color = (color * 255).astype(np.uint8)
195
+ pc_color = trimesh.points.PointCloud(vertices=pc, colors=color)
196
+ fname = "{}/{}/{}_{}.ply".format(opt.output_path, folder, i, name)
197
+ pc_color.export(fname)
198
+
199
+ @torch.no_grad()
200
+ def vis_pointcloud(opt, vis, step, split, pred, GT=None):
201
+ win_name = "{0}/{1}".format(opt.group, opt.name)
202
+ pred, GT = pred.cpu().numpy(), GT.cpu().numpy()
203
+ for i in range(opt.visdom.num_samples):
204
+ # prediction
205
+ data = [dict(
206
+ type="scatter3d",
207
+ x=[float(n) for n in points[i, :opt.visdom.num_points, 0]],
208
+ y=[float(n) for n in points[i, :opt.visdom.num_points, 1]],
209
+ z=[float(n) for n in points[i, :opt.visdom.num_points, 2]],
210
+ mode="markers",
211
+ marker=dict(
212
+ color=color,
213
+ size=1,
214
+ ),
215
+ ) for points, color in zip([pred, GT], ["blue", "magenta"])]
216
+ vis._send(dict(
217
+ data=data,
218
+ win="{0} #{1}".format(split, i),
219
+ eid="{0}/{1}".format(opt.group, opt.name),
220
+ layout=dict(
221
+ title="{0} #{1} ({2})".format(split, i, step),
222
+ autosize=True,
223
+ margin=dict(l=30, r=30, b=30, t=30, ),
224
+ showlegend=False,
225
+ yaxis=dict(
226
+ scaleanchor="x",
227
+ scaleratio=1,
228
+ )
229
+ ),
230
+ opts=dict(title="{0} #{1} ({2})".format(win_name, i, step), ),
231
+ ))
232
+
233
+ @torch.no_grad()
234
+ def draw_pose(opt, image, rot_mtrx, size=15, width=1):
235
+ # rot_mtrx: [3, 4]
236
+ mode = "RGBA" if image.shape[0]==4 else "RGB"
237
+ image_pil = torchvision_F.to_pil_image(image.cpu()).convert("RGBA")
238
+ draw_pil = PIL.Image.new("RGBA", image_pil.size, (0, 0, 0, 0))
239
+ draw = PIL.ImageDraw.Draw(draw_pil)
240
+ center = (size, size)
241
+ # first column of rotation matrix is the rotated vector of [1, 0, 0]'
242
+ # second column of rotation matrix is the rotated vector of [0, 1, 0]'
243
+ # third column of rotation matrix is the rotated vector of [0, 0, 1]'
244
+ # then always take the first two element of each column is a projection to the 2D plane for visualization
245
+ endpoint = [(size+size*p[0], size+size*p[1]) for p in rot_mtrx.t()]
246
+ draw.line([center, endpoint[0]], fill=(255, 0, 0), width=width)
247
+ draw.line([center, endpoint[1]], fill=(0, 255, 0), width=width)
248
+ draw.line([center, endpoint[2]], fill=(0, 0, 255), width=width)
249
+ image_pil.alpha_composite(draw_pil)
250
+ image_drawn = torchvision_F.to_tensor(image_pil.convert(mode))
251
+ return image_drawn
252
+
253
+ @torch.no_grad()
254
+ def draw_metric(opt, image, metric):
255
+ mode = "RGBA" if image.shape[0]==4 else "RGB"
256
+ image_pil = torchvision_F.to_pil_image(image.cpu()).convert("RGBA")
257
+ draw_pil = PIL.Image.new("RGBA", image_pil.size, (0, 0, 0, 0))
258
+ draw = PIL.ImageDraw.Draw(draw_pil)
259
+ font = ImageFont.truetype("DejaVuSans.ttf", 24)
260
+ position = (image_pil.size[0] - 80, image_pil.size[1] - 35)
261
+ draw.text(position, '{:.3f}'.format(metric), fill="red", font=font)
262
+ image_pil.alpha_composite(draw_pil)
263
+ image_drawn = torchvision_F.to_tensor(image_pil.convert(mode))
264
+ return image_drawn
265
+
266
+ @torch.no_grad()
267
+ def show_att_on_image(img, mask):
268
+ """
269
+ Convert the grayscale attention into heatmap on the image.
270
+ Parameters
271
+ ----------
272
+ img: np.array, [H, W, 3]
273
+ Original colored image in [0, 1].
274
+ mask: np.array, [H, W]
275
+ Attention map in [0, 1].
276
+ Returns
277
+ ----------
278
+ np image with attention applied.
279
+ """
280
+ # check the validity
281
+ assert np.max(img) <= 1
282
+ assert np.max(mask) <= 1
283
+
284
+ # generate heatmap and normalize into [0, 1]
285
+ heatmap = cv2.cvtColor(cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB)
286
+ heatmap = np.float32(heatmap) / 255
287
+
288
+ # add heatmap onto the image
289
+ merged = heatmap + np.float32(img)
290
+
291
+ # re-scale the image
292
+ merged = merged / np.max(merged)
293
+ return merged
294
+
295
+ def look_at(camera_position, camera_target, up_vector):
296
+ vector = camera_position - camera_target
297
+ vector = vector / np.linalg.norm(vector)
298
+
299
+ vector2 = np.cross(up_vector, vector)
300
+ vector2 = vector2 / np.linalg.norm(vector2)
301
+
302
+ vector3 = np.cross(vector, vector2)
303
+ return np.array([
304
+ [vector2[0], vector3[0], vector[0], 0.0],
305
+ [vector2[1], vector3[1], vector[1], 0.0],
306
+ [vector2[2], vector3[2], vector[2], 0.0],
307
+ [-np.dot(vector2, camera_position), -np.dot(vector3, camera_position), np.dot(vector, camera_position), 1.0]
308
+ ])
309
+
310
+ def scale_to_unit_cube(mesh):
311
+ if isinstance(mesh, trimesh.Scene):
312
+ mesh = mesh.dump().sum()
313
+
314
+ vertices = mesh.vertices - mesh.bounding_box.centroid
315
+ vertices *= 2 / np.max(mesh.bounding_box.extents)
316
+ vertices *= 0.5
317
+
318
+ return trimesh.Trimesh(vertices=vertices, faces=mesh.faces)
319
+
320
+ def get_positions_and_rotations(n_frames=180, r=1.5):
321
+ '''
322
+ n_frames: how many frames
323
+ r: how far should the camera be
324
+ '''
325
+ # test case 1
326
+ n_frame_full_circ = n_frames // 3 # frames for a full circle
327
+ n_frame_half_circ = n_frames // 6 # frames for a half circle
328
+
329
+ # full circle in horizontal axes going from 1 to -1 height axis
330
+ pos1 = [np.array([r*np.cos(theta), elev, r*np.sin(theta)])
331
+ for theta, elev in zip(np.linspace(0.5*np.pi,2.5*np.pi, n_frame_full_circ), np.linspace(1,-1,n_frame_full_circ))]
332
+ # half circle in horizontal axes at fixed -1 height
333
+ pos2 = [np.array([r*np.cos(theta), -1, r*np.sin(theta)])
334
+ for theta in np.linspace(2.5*np.pi,3.5*np.pi, n_frame_half_circ)]
335
+ # full circle in horizontal axes going from -1 to 1 height axis
336
+ pos3 = [np.array([r*np.cos(theta), elev, r*np.sin(theta)])
337
+ for theta, elev in zip(np.linspace(3.5*np.pi,5.5*np.pi, n_frame_full_circ), np.linspace(-1,1,n_frame_full_circ))]
338
+ # half circle in horizontal axes at fixed 1 height
339
+ pos4 = [np.array([r*np.cos(theta), 1, r*np.sin(theta)])
340
+ for theta in np.linspace(3.5*np.pi,4.5*np.pi, n_frame_half_circ)]
341
+
342
+ pos = pos1 + pos2 + pos3 + pos4
343
+ target = np.array([0.0, 0.0, 0.0])
344
+ up = np.array([0.0, 1.0, 0.0])
345
+ rot = [look_at(x, target, up) for x in pos]
346
+ return pos, rot
347
+
348
+ def visualize_mesh(mesh, output_path, resolution=(200,200), write_gif=True, write_frames=True, time_per_frame=80, n_frames=180):
349
+ '''
350
+ mesh: Trimesh mesh object
351
+ output_path: absolute path, ".gif" will get added if write_gif, and this will be used as dirname if write_frames is true
352
+ time_per_frame: how many milliseconds to wait for each frame
353
+ n_frames: how many frames in total
354
+ '''
355
+
356
+ # set material
357
+ mat = pyrender.MetallicRoughnessMaterial(
358
+ metallicFactor=0.8,
359
+ roughnessFactor=1.0,
360
+ alphaMode='OPAQUE',
361
+ baseColorFactor=(0.5, 0.5, 0.8, 1.0),
362
+ )
363
+ # define and add scene elements
364
+ mesh = pyrender.Mesh.from_trimesh(mesh, material=mat)
365
+ camera = pyrender.PerspectiveCamera(yfov=np.pi / 3.0, aspectRatio=1.0)
366
+ light = pyrender.SpotLight(color=np.ones(3), intensity=15.0,
367
+ innerConeAngle=np.pi/4.0,
368
+ outerConeAngle=np.pi/4.0)
369
+
370
+ scene = pyrender.Scene()
371
+ obj = scene.add(mesh)
372
+ cam = scene.add(camera)
373
+ light = scene.add(light)
374
+
375
+ positions, rotations = get_positions_and_rotations(n_frames=n_frames)
376
+
377
+ r = pyrender.OffscreenRenderer(*resolution)
378
+
379
+ # move the camera and generate images
380
+ count = 0
381
+ image_list = []
382
+ for pos, rot in zip(positions, rotations):
383
+
384
+ pose = np.eye(4)
385
+ pose[:3, 3] = pos
386
+ pose[:3,:3] = rot[:3,:3]
387
+
388
+ scene.set_pose(cam, pose)
389
+ scene.set_pose(light, pose)
390
+
391
+ color, depth = r.render(scene)
392
+
393
+ img = Image.fromarray(color, mode="RGB")
394
+ image_list.append(img)
395
+
396
+ # save to file
397
+ if write_gif:
398
+ image_list[0].save(f"{output_path}.gif", format='GIF', append_images=image_list[1:], save_all=True, duration=80, loop=0)
399
+
400
+ if write_frames:
401
+ if not os.path.exists(output_path):
402
+ os.makedirs(output_path)
403
+
404
+ for i, img in enumerate(image_list):
405
+ img.save(os.path.join(output_path, f"{i:04d}.jpg"))
406
+
407
+ def get_base64_encoded_image(image_path):
408
+ """
409
+ Returns the base64-encoded image at the given path.
410
+
411
+ Args:
412
+ image_path (str): The path to the image file.
413
+
414
+ Returns:
415
+ str: The base64-encoded image.
416
+ """
417
+ with open(image_path, "rb") as f:
418
+ img = Image.open(f)
419
+ if img.mode == 'RGBA':
420
+ img = img.convert('RGB')
421
+ # Resize the image to reduce its file size
422
+ img.thumbnail((200, 200))
423
+ buffer = io.BytesIO()
424
+ # Convert the image to JPEG format to reduce its file size
425
+ img.save(buffer, format="JPEG", quality=80)
426
+ return base64.b64encode(buffer.getvalue()).decode("utf-8")
427
+
428
+ def get_base64_encoded_gif(gif_path):
429
+ """
430
+ Returns the base64-encoded GIF at the given path.
431
+
432
+ Args:
433
+ gif_path (str): The path to the GIF file.
434
+
435
+ Returns:
436
+ str: The base64-encoded GIF.
437
+ """
438
+ with open(gif_path, "rb") as f:
439
+ frames = imageio.mimread(f)
440
+ # Reduce the number of frames to reduce the file size
441
+ frames = frames[::4]
442
+ buffer = io.BytesIO()
443
+ # compress each image frame to reduce the file size
444
+ frames = [frame[::2, ::2] for frame in frames]
445
+ # Convert the GIF to a subrectangle format to reduce the file size
446
+ imageio.mimsave(buffer, frames, format="GIF", fps=10, subrectangles=True)
447
+ return base64.b64encode(buffer.getvalue()).decode("utf-8")
448
+
449
+ def create_gif_html(folder_path, html_file, skip_every=10):
450
+ """
451
+ Creates an HTML file with a grid of sample visualizations.
452
+
453
+ Args:
454
+ folder_path (str): The path to the folder containing the sample visualizations.
455
+ html_file (str): The name of the HTML file to create.
456
+ """
457
+ # convert path to absolute path
458
+ folder_path = os.path.abspath(folder_path)
459
+
460
+ # Get a list of all the sample IDs
461
+ ids = []
462
+ count = 0
463
+ all_files = sorted(os.listdir(folder_path), key=lambda x: int(x.split("_")[0]))
464
+ for filename in all_files:
465
+ if filename.endswith("_image_input.png"):
466
+ if count % skip_every == 0:
467
+ ids.append(filename.split("_")[0])
468
+ count += 1
469
+
470
+ # Write the HTML file
471
+ with open(html_file, "w") as f:
472
+ # Write the HTML header and CSS style
473
+ f.write("<html>\n")
474
+ f.write("<head>\n")
475
+ f.write("<style>\n")
476
+ f.write(".sample-container {\n")
477
+ f.write(" display: inline-block;\n")
478
+ f.write(" margin: 10px;\n")
479
+ f.write(" width: 350px;\n")
480
+ f.write(" height: 150px;\n")
481
+ f.write(" text-align: center;\n")
482
+ f.write("}\n")
483
+ f.write(".sample-container:nth-child(6n+1) {\n")
484
+ f.write(" clear: left;\n")
485
+ f.write("}\n")
486
+ f.write(".image-container, .gif-container {\n")
487
+ f.write(" display: inline-block;\n")
488
+ f.write(" margin: 10px;\n")
489
+ f.write(" width: 90px;\n")
490
+ f.write(" height: 90px;\n")
491
+ f.write(" object-fit: cover;\n")
492
+ f.write("}\n")
493
+ f.write("</style>\n")
494
+ f.write("</head>\n")
495
+ f.write("<body>\n")
496
+
497
+ # Write the sample visualizations to the HTML file
498
+ for sample_id in ids:
499
+ try:
500
+ f.write("<div class=\"sample-container\">\n")
501
+ f.write(f"<div class=\"sample-id\"><p>{sample_id}</p></div>\n")
502
+ f.write(f"<div class=\"image-container\"><img src=\"data:image/png;base64,{get_base64_encoded_image(os.path.join(folder_path, sample_id + '_image_input.png'))}\" width=\"90\" height=\"90\"></div>\n")
503
+ f.write(f"<div class=\"image-container\"><img src=\"data:image/png;base64,{get_base64_encoded_image(os.path.join(folder_path, sample_id + '_depth_est.png'))}\" width=\"90\" height=\"90\"></div>\n")
504
+ f.write(f"<div class=\"gif-container\"><img src=\"data:image/gif;base64,{get_base64_encoded_gif(os.path.join(folder_path, sample_id + '_mesh_viz.gif'))}\" width=\"90\" height=\"90\"></div>\n")
505
+ f.write("</div>\n")
506
+ except:
507
+ pass
508
+
509
+ # Write the HTML footer
510
+ f.write("</body>\n")
511
+ f.write("</html>\n")
weights/.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Ignore everything in this directory
2
+ *
3
+ # Except this file
4
+ !.gitignore