hugoycj commited on
Commit
cacb27a
·
1 Parent(s): c0f6cb5

Initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ *venv
2
+ flagged
3
+ *examples
4
+ __pycache__
app.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+
4
+ import numpy as np
5
+ import cv2
6
+ from tqdm import tqdm
7
+
8
+ import torch
9
+ from pytorch3d.io.obj_io import load_obj
10
+ import tempfile
11
+ import main_mcc
12
+ import mcc_model
13
+ import util.misc as misc
14
+ from engine_mcc import prepare_data
15
+ from plyfile import PlyData, PlyElement
16
+
17
+ def run_inference(model, samples, device, temperature, args):
18
+ model.eval()
19
+
20
+ seen_xyz, valid_seen_xyz, unseen_xyz, unseen_rgb, labels, seen_images = prepare_data(
21
+ samples, device, is_train=False, args=args, is_viz=True
22
+ )
23
+ pred_occupy = []
24
+ pred_colors = []
25
+
26
+ max_n_unseen_fwd = 2000
27
+
28
+ model.cached_enc_feat = None
29
+ num_passes = int(np.ceil(unseen_xyz.shape[1] / max_n_unseen_fwd))
30
+ for p_idx in range(num_passes):
31
+ p_start = p_idx * max_n_unseen_fwd
32
+ p_end = (p_idx + 1) * max_n_unseen_fwd
33
+ cur_unseen_xyz = unseen_xyz[:, p_start:p_end]
34
+ cur_unseen_rgb = unseen_rgb[:, p_start:p_end].zero_()
35
+ cur_labels = labels[:, p_start:p_end].zero_()
36
+
37
+ with torch.no_grad():
38
+ _, pred = model(
39
+ seen_images=seen_images,
40
+ seen_xyz=seen_xyz,
41
+ unseen_xyz=cur_unseen_xyz,
42
+ unseen_rgb=cur_unseen_rgb,
43
+ unseen_occupy=cur_labels,
44
+ cache_enc=True,
45
+ valid_seen_xyz=valid_seen_xyz,
46
+ )
47
+ if device == "cuda":
48
+ pred_occupy.append(pred[..., 0].cuda())
49
+ else:
50
+ pred_occupy.append(pred[..., 0].cpu())
51
+ if args.regress_color:
52
+ pred_colors.append(pred[..., 1:].reshape((-1, 3)))
53
+ else:
54
+ pred_colors.append(
55
+ (
56
+ torch.nn.Softmax(dim=2)(
57
+ pred[..., 1:].reshape((-1, 3, 256)) / temperature
58
+ ) * torch.linspace(0, 1, 256, device=pred.device)
59
+ ).sum(axis=2)
60
+ )
61
+
62
+ pred_occupy = torch.cat(pred_occupy, dim=1)
63
+ pred_occupy = torch.nn.Sigmoid()(pred_occupy)
64
+ return torch.cat(pred_colors, dim=0).cpu().numpy(), pred_occupy.cpu().numpy(), unseen_xyz.cpu().numpy()
65
+
66
+ def pad_image(im, value):
67
+ if im.shape[0] > im.shape[1]:
68
+ diff = im.shape[0] - im.shape[1]
69
+ return torch.cat([im, (torch.zeros((im.shape[0], diff, im.shape[2])) + value)], dim=1)
70
+ else:
71
+ diff = im.shape[1] - im.shape[0]
72
+ return torch.cat([im, (torch.zeros((diff, im.shape[1], im.shape[2])) + value)], dim=0)
73
+
74
+
75
+ def normalize(seen_xyz):
76
+ seen_xyz = seen_xyz / (seen_xyz[torch.isfinite(seen_xyz.sum(dim=-1))].var(dim=0) ** 0.5).mean()
77
+ seen_xyz = seen_xyz - seen_xyz[torch.isfinite(seen_xyz.sum(dim=-1))].mean(axis=0)
78
+ return seen_xyz
79
+
80
+ def infer(
81
+ image,
82
+ point_cloud,
83
+ seg,
84
+ granularity,
85
+ temperature,
86
+ ):
87
+
88
+ score_thresholds = [0.1, 0.2, 0.3, 0.4, 0.5]
89
+
90
+ device = "cuda" if torch.cuda.is_available() else "cpu"
91
+
92
+ parser = main_mcc.get_args_parser()
93
+ parser.set_defaults(eval=True)
94
+
95
+ args = parser.parse_args()
96
+
97
+ model = mcc_model.get_mcc_model(
98
+ occupancy_weight=1.0,
99
+ rgb_weight=0.01,
100
+ args=args,
101
+ )
102
+
103
+ if device == "cuda":
104
+ model = model.cuda()
105
+
106
+ misc.load_model(args=args, model_without_ddp=model, optimizer=None, loss_scaler=None)
107
+
108
+ rgb = image
109
+ obj = load_obj(point_cloud.name)
110
+
111
+ seen_rgb = (torch.tensor(rgb).float() / 255)[..., [2, 1, 0]]
112
+ H, W = seen_rgb.shape[:2]
113
+ seen_rgb = torch.nn.functional.interpolate(
114
+ seen_rgb.permute(2, 0, 1)[None],
115
+ size=[H, W],
116
+ mode="bilinear",
117
+ align_corners=False,
118
+ )[0].permute(1, 2, 0)
119
+
120
+ seen_xyz = obj[0].reshape(H, W, 3)
121
+ seg = cv2.imread(seg.name, cv2.IMREAD_UNCHANGED)
122
+ mask = torch.tensor(cv2.resize(seg, (W, H))).bool()
123
+ seen_xyz[~mask] = float('inf')
124
+
125
+ seen_xyz = normalize(seen_xyz)
126
+
127
+ bottom, right = mask.nonzero().max(dim=0)[0]
128
+ top, left = mask.nonzero().min(dim=0)[0]
129
+
130
+ bottom = bottom + 40
131
+ right = right + 40
132
+ top = max(top - 40, 0)
133
+ left = max(left - 40, 0)
134
+
135
+ seen_xyz = seen_xyz[top:bottom+1, left:right+1]
136
+ seen_rgb = seen_rgb[top:bottom+1, left:right+1]
137
+
138
+ seen_xyz = pad_image(seen_xyz, float('inf'))
139
+ seen_rgb = pad_image(seen_rgb, 0)
140
+
141
+ seen_rgb = torch.nn.functional.interpolate(
142
+ seen_rgb.permute(2, 0, 1)[None],
143
+ size=[800, 800],
144
+ mode="bilinear",
145
+ align_corners=False,
146
+ )
147
+
148
+ seen_xyz = torch.nn.functional.interpolate(
149
+ seen_xyz.permute(2, 0, 1)[None],
150
+ size=[112, 112],
151
+ mode="bilinear",
152
+ align_corners=False,
153
+ ).permute(0, 2, 3, 1)
154
+
155
+ samples = [
156
+ [seen_xyz, seen_rgb],
157
+ [torch.zeros((20000, 3)), torch.zeros((20000, 3))],
158
+ ]
159
+
160
+ pred_colors, pred_occupy, unseen_xyz = run_inference(model, samples, device, temperature, args)
161
+ _masks = pred_occupy > 0.1
162
+ unseen_xyz = unseen_xyz[_masks]
163
+ pred_colors = pred_colors[None, ...][_masks] * 255
164
+
165
+ # Prepare data for PlyElement
166
+ vertex = np.core.records.fromarrays(np.hstack((unseen_xyz, pred_colors)).transpose(),
167
+ names='x, y, z, red, green, blue',
168
+ formats='f8, f8, f8, u1, u1, u1')
169
+
170
+
171
+ # Create PlyElement
172
+ element = PlyElement.describe(vertex, 'vertex')
173
+
174
+ # Save point cloud data to a temporary file
175
+ with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as f:
176
+ PlyData([element], text=True).write(f)
177
+ temp_file_name = f.name
178
+
179
+ return temp_file_name
180
+
181
+
182
+ demo = gr.Interface(fn=infer,
183
+ inputs=[gr.Image(label="Input Image"),
184
+ gr.File(label="Pointcloud File"),
185
+ gr.File(label="Segmentation File"),
186
+ gr.Slider(minimum=0.05, maximum=0.5, step=0.05, value=0.2, label="Granularity"),
187
+ gr.Slider(minimum=0, maximum=1.0, step=0.1, value=0.1, label="Temperature")
188
+ ],
189
+ outputs=[gr.outputs.File(label="Point Cloud Json")],
190
+ examples=[["demo/quest2.jpg", "demo/quest2.obj", "demo/quest2_seg.png", 0.2, 0.1]],
191
+ cache_examples=True)
192
+ demo.launch(server_name="0.0.0.0", server_port=7860)
engine_mcc.py ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # DeiT: https://github.com/facebookresearch/deit
9
+ # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10
+ # MAE: https://github.com/facebookresearch/mae
11
+ # --------------------------------------------------------
12
+ import math
13
+ from typing import Iterable
14
+ import os
15
+ import matplotlib.pyplot as plt
16
+ import random
17
+ import torch
18
+ import numpy as np
19
+ import time
20
+ import base64
21
+ from io import BytesIO
22
+
23
+ import util.misc as misc
24
+ import util.lr_sched as lr_sched
25
+
26
+ from pytorch3d.structures import Pointclouds
27
+ from pytorch3d.vis.plotly_vis import plot_scene
28
+ from pytorch3d.transforms import RotateAxisAngle
29
+ from pytorch3d.io import IO
30
+
31
+
32
+ def evaluate_points(predicted_xyz, gt_xyz, dist_thres):
33
+ if predicted_xyz.shape[0] == 0:
34
+ return 0.0, 0.0, 0.0
35
+ slice_size = 1000
36
+ precision = 0.0
37
+ for i in range(int(np.ceil(predicted_xyz.shape[0] / slice_size))):
38
+ start = slice_size * i
39
+ end = slice_size * (i + 1)
40
+ dist = ((predicted_xyz[start:end, None] - gt_xyz[None]) ** 2.0).sum(axis=-1) ** 0.5
41
+ precision += ((dist < dist_thres).sum(axis=1) > 0).sum()
42
+ precision /= predicted_xyz.shape[0]
43
+
44
+ recall = 0.0
45
+ for i in range(int(np.ceil(predicted_xyz.shape[0] / slice_size))):
46
+ start = slice_size * i
47
+ end = slice_size * (i + 1)
48
+ dist = ((predicted_xyz[:, None] - gt_xyz[None, start:end]) ** 2.0).sum(axis=-1) ** 0.5
49
+ recall += ((dist < dist_thres).sum(axis=0) > 0).sum()
50
+ recall /= gt_xyz.shape[0]
51
+ return precision, recall, get_f1(precision, recall)
52
+
53
+ def aug_xyz(seen_xyz, unseen_xyz, args, is_train):
54
+ degree_x = 0
55
+ degree_y = 0
56
+ degree_z = 0
57
+ if is_train:
58
+ r_delta = args.random_scale_delta
59
+ scale = torch.tensor([
60
+ random.uniform(1.0 - r_delta, 1.0 + r_delta),
61
+ random.uniform(1.0 - r_delta, 1.0 + r_delta),
62
+ random.uniform(1.0 - r_delta, 1.0 + r_delta),
63
+ ], device=seen_xyz.device)
64
+
65
+ if args.use_hypersim:
66
+ shift = 0
67
+ else:
68
+ degree_x = random.randrange(-args.random_rotate_degree, args.random_rotate_degree + 1)
69
+ degree_y = random.randrange(-args.random_rotate_degree, args.random_rotate_degree + 1)
70
+ degree_z = random.randrange(-args.random_rotate_degree, args.random_rotate_degree + 1)
71
+
72
+ r_shift = args.random_shift
73
+ shift = torch.tensor([[[
74
+ random.uniform(-r_shift, r_shift),
75
+ random.uniform(-r_shift, r_shift),
76
+ random.uniform(-r_shift, r_shift),
77
+ ]]], device=seen_xyz.device)
78
+ seen_xyz = seen_xyz * scale + shift
79
+ unseen_xyz = unseen_xyz * scale + shift
80
+
81
+ B, H, W, _ = seen_xyz.shape
82
+ return [
83
+ rotate(seen_xyz.reshape((B, -1, 3)), degree_x, degree_y, degree_z).reshape((B, H, W, 3)),
84
+ rotate(unseen_xyz, degree_x, degree_y, degree_z),
85
+ ]
86
+
87
+
88
+ def rotate(sample, degree_x, degree_y, degree_z):
89
+ for degree, axis in [(degree_x, "X"), (degree_y, "Y"), (degree_z, "Z")]:
90
+ if degree != 0:
91
+ sample = RotateAxisAngle(degree, axis=axis).to(sample.device).transform_points(sample)
92
+ return sample
93
+
94
+
95
+ def get_grid(B, device, co3d_world_size, granularity):
96
+ N = int(np.ceil(2 * co3d_world_size / granularity))
97
+ grid_unseen_xyz = torch.zeros((N, N, N, 3), device=device)
98
+ for i in range(N):
99
+ grid_unseen_xyz[i, :, :, 0] = i
100
+ for j in range(N):
101
+ grid_unseen_xyz[:, j, :, 1] = j
102
+ for k in range(N):
103
+ grid_unseen_xyz[:, :, k, 2] = k
104
+ grid_unseen_xyz -= (N / 2.0)
105
+ grid_unseen_xyz /= (N / 2.0) / co3d_world_size
106
+ grid_unseen_xyz = grid_unseen_xyz.reshape((1, -1, 3)).repeat(B, 1, 1)
107
+ return grid_unseen_xyz
108
+
109
+
110
+ def run_viz(model, data_loader, device, args, epoch):
111
+ epoch_start_time = time.time()
112
+ model.eval()
113
+ os.system(f'mkdir {args.job_dir}/viz')
114
+
115
+ print('Visualization data_loader length:', len(data_loader))
116
+ dataset = data_loader.dataset
117
+ for sample_idx, samples in enumerate(data_loader):
118
+ if sample_idx >= args.max_n_viz_obj:
119
+ break
120
+ seen_xyz, valid_seen_xyz, unseen_xyz, unseen_rgb, labels, seen_images = prepare_data(samples, device, is_train=False, args=args, is_viz=True)
121
+
122
+ pred_occupy = []
123
+ pred_colors = []
124
+ (model.module if hasattr(model, "module") else model).clear_cache()
125
+
126
+ # don't forward all at once to avoid oom
127
+ max_n_queries_fwd = 2000
128
+
129
+ total_n_passes = int(np.ceil(unseen_xyz.shape[1] / max_n_queries_fwd))
130
+ for p_idx in range(total_n_passes):
131
+ p_start = p_idx * max_n_queries_fwd
132
+ p_end = (p_idx + 1) * max_n_queries_fwd
133
+ cur_unseen_xyz = unseen_xyz[:, p_start:p_end]
134
+ cur_unseen_rgb = unseen_rgb[:, p_start:p_end].zero_()
135
+ cur_labels = labels[:, p_start:p_end].zero_()
136
+
137
+ with torch.no_grad():
138
+ _, pred, = model(
139
+ seen_images=seen_images,
140
+ seen_xyz=seen_xyz,
141
+ unseen_xyz=cur_unseen_xyz,
142
+ unseen_rgb=cur_unseen_rgb,
143
+ unseen_occupy=cur_labels,
144
+ cache_enc=args.run_viz,
145
+ valid_seen_xyz=valid_seen_xyz,
146
+ )
147
+
148
+ cur_occupy_out = pred[..., 0]
149
+
150
+ if args.regress_color:
151
+ cur_color_out = pred[..., 1:].reshape((-1, 3))
152
+ else:
153
+ cur_color_out = pred[..., 1:].reshape((-1, 3, 256)).max(dim=2)[1] / 255.0
154
+ pred_occupy.append(cur_occupy_out)
155
+ pred_colors.append(cur_color_out)
156
+
157
+ rank = misc.get_rank()
158
+ prefix = f'{args.job_dir}/viz/' + dataset.dataset_split + f'_ep{epoch}_rank{rank}_i{sample_idx}'
159
+
160
+ img = (seen_images[0].permute(1, 2, 0) * 255).cpu().numpy().copy().astype(np.uint8)
161
+
162
+ gt_xyz = samples[1][0].to(device).reshape(-1, 3)
163
+ gt_rgb = samples[1][1].to(device).reshape(-1, 3)
164
+ mesh_xyz = samples[2].to(device).reshape(-1, 3) if args.use_hypersim else None
165
+
166
+ with open(prefix + '.html', 'a') as f:
167
+ generate_html(
168
+ img,
169
+ seen_xyz, seen_images,
170
+ torch.cat(pred_occupy, dim=1),
171
+ torch.cat(pred_colors, dim=0),
172
+ unseen_xyz,
173
+ f,
174
+ gt_xyz=gt_xyz,
175
+ gt_rgb=gt_rgb,
176
+ mesh_xyz=mesh_xyz,
177
+ )
178
+ print("Visualization epoch time:", time.time() - epoch_start_time)
179
+
180
+
181
+ def get_f1(precision, recall):
182
+ if (precision + recall) == 0:
183
+ return 0.0
184
+ return 2.0 * precision * recall / (precision + recall)
185
+
186
+
187
+ def generate_plot(img, seen_xyz, seen_rgb, pred_occ, pred_rgb, unseen_xyz,
188
+ gt_xyz=None, gt_rgb=None, mesh_xyz=None, score_thresholds=[0.1, 0.3, 0.5, 0.7, 0.9],
189
+ pointcloud_marker_size=2,
190
+ ):
191
+ # if img is not None:
192
+ # fig = plt.figure()
193
+ # plt.imshow(img)
194
+ # tmpfile = BytesIO()
195
+ # fig.savefig(tmpfile, format='jpg')
196
+ # encoded = base64.b64encode(tmpfile.getvalue()).decode('utf-8')
197
+
198
+ # html = '<img src=\'data:image/png;base64,{}\'>'.format(encoded)
199
+ # f.write(html)
200
+ # plt.close()
201
+
202
+ clouds = {"MCC Output": {}}
203
+ # Seen
204
+ if seen_xyz is not None:
205
+ seen_xyz = seen_xyz.reshape((-1, 3)).cpu()
206
+ seen_rgb = torch.nn.functional.interpolate(seen_rgb, (112, 112)).permute(0, 2, 3, 1).reshape((-1, 3)).cpu()
207
+ good_seen = seen_xyz[:, 0] != -100
208
+
209
+ seen_pc = Pointclouds(
210
+ points=seen_xyz[good_seen][None],
211
+ features=seen_rgb[good_seen][None],
212
+ )
213
+ clouds["MCC Output"]["seen"] = seen_pc
214
+
215
+ # GT points
216
+ if gt_xyz is not None:
217
+ subset_gt = random.sample(range(gt_xyz.shape[0]), 10000)
218
+ gt_pc = Pointclouds(
219
+ points=gt_xyz[subset_gt][None],
220
+ features=gt_rgb[subset_gt][None],
221
+ )
222
+ clouds["MCC Output"]["GT points"] = gt_pc
223
+
224
+ # GT meshes
225
+ if mesh_xyz is not None:
226
+ subset_mesh = random.sample(range(mesh_xyz.shape[0]), 10000)
227
+ mesh_pc = Pointclouds(
228
+ points=mesh_xyz[subset_mesh][None],
229
+ )
230
+ clouds["MCC Output"]["GT mesh"] = mesh_pc
231
+
232
+ pred_occ = torch.nn.Sigmoid()(pred_occ).cpu()
233
+ for t in score_thresholds:
234
+ pos = pred_occ > t
235
+
236
+ points = unseen_xyz[pos].reshape((-1, 3))
237
+ features = pred_rgb[None][pos].reshape((-1, 3))
238
+ good_points = points[:, 0] != -100
239
+
240
+ if good_points.sum() == 0:
241
+ continue
242
+
243
+ pc = Pointclouds(
244
+ points=points[good_points][None].cpu(),
245
+ features=features[good_points][None].cpu(),
246
+ )
247
+
248
+ clouds["MCC Output"][f"pred_{t}"] = pc
249
+ IO().save_pointcloud(pc, "output_pointcloud.ply")
250
+
251
+ plt.figure()
252
+ try:
253
+ fig = plot_scene(clouds, pointcloud_marker_size=pointcloud_marker_size, pointcloud_max_points=20000 * 2)
254
+ fig.update_layout(height=1000, width=1000)
255
+ return fig
256
+ except Exception as e:
257
+ print('writing failed', e)
258
+ try:
259
+ plt.close()
260
+ except:
261
+ pass
262
+
263
+
264
+ def generate_html(img, seen_xyz, seen_rgb, pred_occ, pred_rgb, unseen_xyz, f,
265
+ gt_xyz=None, gt_rgb=None, mesh_xyz=None, score_thresholds=[0.1, 0.3, 0.5, 0.7, 0.9],
266
+ pointcloud_marker_size=2,
267
+ ):
268
+ if img is not None:
269
+ fig = plt.figure()
270
+ plt.imshow(img)
271
+ tmpfile = BytesIO()
272
+ fig.savefig(tmpfile, format='jpg')
273
+ encoded = base64.b64encode(tmpfile.getvalue()).decode('utf-8')
274
+
275
+ html = '<img src=\'data:image/png;base64,{}\'>'.format(encoded)
276
+ f.write(html)
277
+ plt.close()
278
+
279
+ clouds = {"MCC Output": {}}
280
+ # Seen
281
+ if seen_xyz is not None:
282
+ seen_xyz = seen_xyz.reshape((-1, 3)).cpu()
283
+ seen_rgb = torch.nn.functional.interpolate(seen_rgb, (112, 112)).permute(0, 2, 3, 1).reshape((-1, 3)).cpu()
284
+ good_seen = seen_xyz[:, 0] != -100
285
+
286
+ seen_pc = Pointclouds(
287
+ points=seen_xyz[good_seen][None],
288
+ features=seen_rgb[good_seen][None],
289
+ )
290
+ clouds["MCC Output"]["seen"] = seen_pc
291
+
292
+ # GT points
293
+ if gt_xyz is not None:
294
+ subset_gt = random.sample(range(gt_xyz.shape[0]), 10000)
295
+ gt_pc = Pointclouds(
296
+ points=gt_xyz[subset_gt][None],
297
+ features=gt_rgb[subset_gt][None],
298
+ )
299
+ clouds["MCC Output"]["GT points"] = gt_pc
300
+
301
+ # GT meshes
302
+ if mesh_xyz is not None:
303
+ subset_mesh = random.sample(range(mesh_xyz.shape[0]), 10000)
304
+ mesh_pc = Pointclouds(
305
+ points=mesh_xyz[subset_mesh][None],
306
+ )
307
+ clouds["MCC Output"]["GT mesh"] = mesh_pc
308
+
309
+ pred_occ = torch.nn.Sigmoid()(pred_occ).cpu()
310
+ for t in score_thresholds:
311
+ pos = pred_occ > t
312
+
313
+ points = unseen_xyz[pos].reshape((-1, 3))
314
+ features = pred_rgb[None][pos].reshape((-1, 3))
315
+ good_points = points[:, 0] != -100
316
+
317
+ if good_points.sum() == 0:
318
+ continue
319
+
320
+ pc = Pointclouds(
321
+ points=points[good_points][None].cpu(),
322
+ features=features[good_points][None].cpu(),
323
+ )
324
+
325
+ clouds["MCC Output"][f"pred_{t}"] = pc
326
+
327
+ plt.figure()
328
+ try:
329
+ fig = plot_scene(clouds, pointcloud_marker_size=pointcloud_marker_size, pointcloud_max_points=20000 * 2)
330
+ fig.update_layout(height=1000, width=1000)
331
+ html_string = fig.to_html(full_html=False, include_plotlyjs="cnd")
332
+ f.write(html_string)
333
+ return fig, plt
334
+ except Exception as e:
335
+ print('writing failed', e)
336
+ try:
337
+ plt.close()
338
+ except:
339
+ pass
340
+
341
+
342
+ def train_one_epoch(model: torch.nn.Module,
343
+ data_loader: Iterable, optimizer: torch.optim.Optimizer,
344
+ device: torch.device, epoch: int, loss_scaler,
345
+ args=None):
346
+ epoch_start_time = time.time()
347
+ model.train(True)
348
+ metric_logger = misc.MetricLogger(delimiter=" ")
349
+ metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
350
+
351
+ accum_iter = args.accum_iter
352
+
353
+ optimizer.zero_grad()
354
+
355
+ print('Training data_loader length:', len(data_loader))
356
+ for data_iter_step, samples in enumerate(data_loader):
357
+ # we use a per iteration (instead of per epoch) lr scheduler
358
+ if data_iter_step % accum_iter == 0:
359
+ lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
360
+ seen_xyz, valid_seen_xyz, unseen_xyz, unseen_rgb, labels, seen_images = prepare_data(samples, device, is_train=True, args=args)
361
+
362
+ with torch.cuda.amp.autocast():
363
+ loss, _ = model(
364
+ seen_images=seen_images,
365
+ seen_xyz=seen_xyz,
366
+ unseen_xyz=unseen_xyz,
367
+ unseen_rgb=unseen_rgb,
368
+ unseen_occupy=labels,
369
+ valid_seen_xyz=valid_seen_xyz,
370
+ )
371
+
372
+ loss_value = loss.item()
373
+ if not math.isfinite(loss_value):
374
+ print("Warning: Loss is {}".format(loss_value))
375
+ loss *= 0.0
376
+ loss_value = 100.0
377
+
378
+ loss /= accum_iter
379
+ loss_scaler(loss, optimizer, parameters=model.parameters(),
380
+ clip_grad=args.clip_grad,
381
+ update_grad=(data_iter_step + 1) % accum_iter == 0,
382
+ verbose=(data_iter_step % 100) == 0)
383
+
384
+ if (data_iter_step + 1) % accum_iter == 0:
385
+ optimizer.zero_grad()
386
+
387
+ torch.cuda.synchronize()
388
+
389
+ metric_logger.update(loss=loss_value)
390
+
391
+ lr = optimizer.param_groups[0]["lr"]
392
+ metric_logger.update(lr=lr)
393
+
394
+ if data_iter_step == 30:
395
+ os.system('nvidia-smi')
396
+ os.system('free -g')
397
+ if args.debug and data_iter_step == 5:
398
+ break
399
+
400
+ # gather the stats from all processes
401
+ metric_logger.synchronize_between_processes()
402
+ print("Averaged stats:", metric_logger)
403
+ print("Training epoch time:", time.time() - epoch_start_time)
404
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
405
+
406
+
407
+ def eval_one_epoch(
408
+ model: torch.nn.Module,
409
+ data_loader: Iterable,
410
+ device: torch.device,
411
+ args=None
412
+ ):
413
+ epoch_start_time = time.time()
414
+ model.train(False)
415
+
416
+ metric_logger = misc.MetricLogger(delimiter=" ")
417
+
418
+ print('Eval len(data_loader):', len(data_loader))
419
+
420
+ for data_iter_step, samples in enumerate(data_loader):
421
+ seen_xyz, valid_seen_xyz, unseen_xyz, unseen_rgb, labels, seen_images = prepare_data(samples, device, is_train=False, args=args)
422
+
423
+ # don't forward all at once to avoid oom
424
+ max_n_queries_fwd = 5000
425
+ all_loss, all_preds = [], []
426
+ for p_idx in range(int(np.ceil(unseen_xyz.shape[1] / max_n_queries_fwd))):
427
+ p_start = p_idx * max_n_queries_fwd
428
+ p_end = (p_idx + 1) * max_n_queries_fwd
429
+ cur_unseen_xyz = unseen_xyz[:, p_start:p_end]
430
+ cur_unseen_rgb = unseen_rgb[:, p_start:p_end]
431
+ cur_labels = labels[:, p_start:p_end]
432
+
433
+ with torch.no_grad():
434
+ loss, pred = model(
435
+ seen_images=seen_images,
436
+ seen_xyz=seen_xyz,
437
+ unseen_xyz=cur_unseen_xyz,
438
+ unseen_rgb=cur_unseen_rgb,
439
+ unseen_occupy=cur_labels,
440
+ valid_seen_xyz=valid_seen_xyz,
441
+ )
442
+ all_loss.append(loss)
443
+ all_preds.append(pred)
444
+
445
+ loss = sum(all_loss) / len(all_loss)
446
+ pred = torch.cat(all_preds, dim=1)
447
+
448
+ B = pred.shape[0]
449
+
450
+ gt_xyz = samples[1][0].to(device).reshape((B, -1, 3))
451
+ if args.use_hypersim:
452
+ mesh_xyz = samples[2].to(device).reshape((B, -1, 3))
453
+
454
+ s_thres = args.eval_score_threshold
455
+ d_thres = args.eval_dist_threshold
456
+
457
+ for b_idx in range(B):
458
+ geometry_metrics = {}
459
+ predicted_idx = torch.nn.Sigmoid()(pred[b_idx, :, 0]) > s_thres
460
+ predicted_xyz = unseen_xyz[b_idx, predicted_idx]
461
+
462
+ precision, recall, f1 = evaluate_points(predicted_xyz, gt_xyz[b_idx], d_thres)
463
+ geometry_metrics[f'd{d_thres}_s{s_thres}_point_pr'] = precision
464
+ geometry_metrics[f'd{d_thres}_s{s_thres}_point_rc'] = recall
465
+ geometry_metrics[f'd{d_thres}_s{s_thres}_point_f1'] = f1
466
+
467
+ if args.use_hypersim:
468
+ precision, recall, f1 = evaluate_points(predicted_xyz, mesh_xyz[b_idx], d_thres)
469
+ geometry_metrics[f'd{d_thres}_s{s_thres}_mesh_pr'] = precision
470
+ geometry_metrics[f'd{d_thres}_s{s_thres}_mesh_rc'] = recall
471
+ geometry_metrics[f'd{d_thres}_s{s_thres}_mesh_f1'] = f1
472
+
473
+ metric_logger.update(**geometry_metrics)
474
+
475
+ loss_value = loss.item()
476
+
477
+ torch.cuda.synchronize()
478
+ metric_logger.update(loss=loss_value)
479
+
480
+ if args.debug and data_iter_step == 5:
481
+ break
482
+
483
+ metric_logger.synchronize_between_processes()
484
+ print("Validation averaged stats:", metric_logger)
485
+ print("Val epoch time:", time.time() - epoch_start_time)
486
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
487
+
488
+
489
+ def sample_uniform_semisphere(B, N, semisphere_size, device):
490
+ for _ in range(100):
491
+ points = torch.empty(B * N * 3, 3, device=device).uniform_(-semisphere_size, semisphere_size)
492
+ points[..., 2] = points[..., 2].abs()
493
+ dist = (points ** 2.0).sum(axis=-1) ** 0.5
494
+ if (dist < semisphere_size).sum() >= B * N:
495
+ return points[dist < semisphere_size][:B * N].reshape((B, N, 3))
496
+ else:
497
+ print('resampling sphere')
498
+
499
+
500
+ def get_grid_semisphere(B, granularity, semisphere_size, device):
501
+ n_grid_pts = int(semisphere_size / granularity) * 2 + 1
502
+ grid_unseen_xyz = torch.zeros((n_grid_pts, n_grid_pts, n_grid_pts // 2 + 1, 3), device=device)
503
+ for i in range(n_grid_pts):
504
+ grid_unseen_xyz[i, :, :, 0] = i
505
+ grid_unseen_xyz[:, i, :, 1] = i
506
+ for i in range(n_grid_pts // 2 + 1):
507
+ grid_unseen_xyz[:, :, i, 2] = i
508
+ grid_unseen_xyz[..., :2] -= (n_grid_pts // 2.0)
509
+ grid_unseen_xyz *= granularity
510
+ dist = (grid_unseen_xyz ** 2.0).sum(axis=-1) ** 0.5
511
+ grid_unseen_xyz = grid_unseen_xyz[dist <= semisphere_size]
512
+ return grid_unseen_xyz[None].repeat(B, 1, 1)
513
+
514
+
515
+ def get_min_dist(a, b, slice_size=1000):
516
+ all_min, all_idx = [], []
517
+ for i in range(int(np.ceil(a.shape[1] / slice_size))):
518
+ start = slice_size * i
519
+ end = slice_size * (i + 1)
520
+ # B, n_queries, n_gt
521
+ dist = ((a[:, start:end] - b) ** 2.0).sum(axis=-1) ** 0.5
522
+ # B, n_queries
523
+ cur_min, cur_idx = dist.min(axis=2)
524
+ all_min.append(cur_min)
525
+ all_idx.append(cur_idx)
526
+ return torch.cat(all_min, dim=1), torch.cat(all_idx, dim=1)
527
+
528
+
529
+ def construct_uniform_semisphere(gt_xyz, gt_rgb, semisphere_size, n_queries, dist_threshold, is_train, granularity):
530
+ B = gt_xyz.shape[0]
531
+ device = gt_xyz.device
532
+ if is_train:
533
+ unseen_xyz = sample_uniform_semisphere(B, n_queries, semisphere_size, device)
534
+ else:
535
+ unseen_xyz = get_grid_semisphere(B, granularity, semisphere_size, device)
536
+ dist, idx_to_gt = get_min_dist(unseen_xyz[:, :, None], gt_xyz[:, None])
537
+ labels = dist < dist_threshold
538
+ unseen_rgb = torch.zeros_like(unseen_xyz)
539
+ unseen_rgb[labels] = torch.gather(gt_rgb, 1, idx_to_gt.unsqueeze(-1).repeat(1, 1, 3))[labels]
540
+ return unseen_xyz, unseen_rgb, labels.float()
541
+
542
+
543
+ def construct_uniform_grid(gt_xyz, gt_rgb, co3d_world_size, n_queries, dist_threshold, is_train, granularity):
544
+ B = gt_xyz.shape[0]
545
+ device = gt_xyz.device
546
+ if is_train:
547
+ unseen_xyz = torch.empty((B, n_queries, 3), device=device).uniform_(-co3d_world_size, co3d_world_size)
548
+ else:
549
+ unseen_xyz = get_grid(B, device, co3d_world_size, granularity)
550
+ dist, idx_to_gt = get_min_dist(unseen_xyz[:, :, None], gt_xyz[:, None])
551
+ labels = dist < dist_threshold
552
+ unseen_rgb = torch.zeros_like(unseen_xyz)
553
+ unseen_rgb[labels] = torch.gather(gt_rgb, 1, idx_to_gt.unsqueeze(-1).repeat(1, 1, 3))[labels]
554
+ return unseen_xyz, unseen_rgb, labels.float()
555
+
556
+
557
+ def prepare_data(samples, device, is_train, args, is_viz=False):
558
+ # Seen
559
+ seen_xyz, seen_rgb = samples[0][0].to(device), samples[0][1].to(device)
560
+ valid_seen_xyz = torch.isfinite(seen_xyz.sum(axis=-1))
561
+ seen_xyz[~valid_seen_xyz] = -100
562
+ B = seen_xyz.shape[0]
563
+ # Gt
564
+ gt_xyz, gt_rgb = samples[1][0].to(device).reshape(B, -1, 3), samples[1][1].to(device).reshape(B, -1, 3)
565
+
566
+ sampling_func = construct_uniform_semisphere if args.use_hypersim else construct_uniform_grid
567
+ unseen_xyz, unseen_rgb, labels = sampling_func(
568
+ gt_xyz, gt_rgb,
569
+ args.semisphere_size if args.use_hypersim else args.co3d_world_size,
570
+ args.n_queries,
571
+ args.train_dist_threshold,
572
+ is_train,
573
+ args.viz_granularity if is_viz else args.eval_granularity,
574
+ )
575
+
576
+ if is_train:
577
+ seen_xyz, unseen_xyz = aug_xyz(seen_xyz, unseen_xyz, args, is_train=is_train)
578
+
579
+ # Random Flip
580
+ if random.random() < 0.5:
581
+ seen_xyz[..., 0] *= -1
582
+ unseen_xyz[..., 0] *= -1
583
+ seen_xyz = torch.flip(seen_xyz, [2])
584
+ valid_seen_xyz = torch.flip(valid_seen_xyz, [2])
585
+ seen_rgb = torch.flip(seen_rgb, [3])
586
+
587
+ return seen_xyz, valid_seen_xyz, unseen_xyz, unseen_rgb, labels, seen_rgb
main_mcc.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # DeiT: https://github.com/facebookresearch/deit
9
+ # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10
+ # MAE: https://github.com/facebookresearch/mae
11
+ # --------------------------------------------------------
12
+ import argparse
13
+ import datetime
14
+ import json
15
+ import numpy as np
16
+ import os
17
+ import time
18
+ from pathlib import Path
19
+
20
+ import torch
21
+ import torch.backends.cudnn as cudnn
22
+ import timm.optim.optim_factory as optim_factory
23
+
24
+ import util.misc as misc
25
+ import mcc_model
26
+ from util.misc import NativeScalerWithGradNormCount as NativeScaler
27
+ from util.hypersim_dataset import HyperSimDataset, hypersim_collate_fn
28
+ from util.co3d_dataset import CO3DV2Dataset, co3dv2_collate_fn
29
+ from engine_mcc import train_one_epoch, run_viz, eval_one_epoch
30
+ from util.co3d_utils import get_all_dataset_maps
31
+
32
+
33
+ def get_args_parser():
34
+ parser = argparse.ArgumentParser('MCC', add_help=False)
35
+
36
+ # Model
37
+ parser.add_argument('--input_size', default=224, type=int,
38
+ help='Images input size')
39
+ parser.add_argument('--occupancy_weight', default=1.0, type=float,
40
+ help='A constant to weight the occupancy loss')
41
+ parser.add_argument('--rgb_weight', default=0.01, type=float,
42
+ help='A constant to weight the color prediction loss')
43
+ parser.add_argument('--n_queries', default=550, type=int,
44
+ help='Number of queries used in decoder.')
45
+ parser.add_argument('--drop_path', default=0.1, type=float,
46
+ help='drop_path probability')
47
+ parser.add_argument('--regress_color', action='store_true',
48
+ help='If true, regress color with MSE. Otherwise, 256-way classification for each channel.')
49
+
50
+ # Training
51
+ parser.add_argument('--batch_size', default=16, type=int,
52
+ help='Batch size per GPU for training (effective batch size is batch_size * accum_iter * # gpus')
53
+ parser.add_argument('--eval_batch_size', default=2, type=int,
54
+ help='Batch size per GPU for evaluation (effective batch size is batch_size * accum_iter * # gpus')
55
+ parser.add_argument('--epochs', default=100, type=int)
56
+ parser.add_argument('--accum_iter', default=1, type=int,
57
+ help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
58
+ parser.add_argument('--weight_decay', type=float, default=0.05,
59
+ help='Weight decay (default: 0.05)')
60
+ parser.add_argument('--lr', type=float, default=None, metavar='LR',
61
+ help='Learning rate (absolute lr)')
62
+ parser.add_argument('--blr', type=float, default=1e-4, metavar='LR',
63
+ help='Base learning rate: absolute_lr = base_lr * total_batch_size / 512')
64
+ parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
65
+ help='Lower lr bound for cyclic schedulers that hit 0')
66
+ parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',
67
+ help='Epochs to warmup LR')
68
+ parser.add_argument('--clip_grad', type=float, default=1.0,
69
+ help='Clip gradient at the specified norm')
70
+
71
+ # Job
72
+ parser.add_argument('--job_dir', default='',
73
+ help='Path to where to save, empty for no saving')
74
+ parser.add_argument('--output_dir', default='./output_dir',
75
+ help='Path to where to save, empty for no saving')
76
+ parser.add_argument('--device', default='cuda',
77
+ help='Device to use for training / testing')
78
+ parser.add_argument('--seed', default=0, type=int,
79
+ help='Random seed.')
80
+ parser.add_argument('--resume', default='weights/co3dv2_all_categories.pth',
81
+ help='Resume from checkpoint')
82
+
83
+ parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
84
+ help='Start epoch')
85
+ parser.add_argument('--num_workers', default=4, type=int,
86
+ help='Number of workers for training data loader')
87
+ parser.add_argument('--num_eval_workers', default=4, type=int,
88
+ help='Number of workers for evaluation data loader')
89
+ parser.add_argument('--pin_mem', action='store_true',
90
+ help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
91
+ parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
92
+ parser.set_defaults(pin_mem=True)
93
+
94
+ # Distributed training
95
+ parser.add_argument('--world_size', default=1, type=int,
96
+ help='Number of distributed processes')
97
+ parser.add_argument('--local_rank', default=-1, type=int)
98
+ parser.add_argument('--dist_on_itp', action='store_true')
99
+ parser.add_argument('--dist_url', default='env://',
100
+ help='Url used to set up distributed training')
101
+
102
+ # Experiments
103
+ parser.add_argument('--debug', action='store_true')
104
+ parser.add_argument('--run_viz', action='store_true',
105
+ help='Specify to run only the visualization/inference given a trained model.')
106
+ parser.add_argument('--max_n_viz_obj', default=64, type=int,
107
+ help='Max number of objects to visualize during training.')
108
+
109
+ # Data
110
+ parser.add_argument('--train_epoch_len_multiplier', default=32, type=int,
111
+ help='# examples per training epoch is # objects * train_epoch_len_multiplier')
112
+ parser.add_argument('--eval_epoch_len_multiplier', default=1, type=int,
113
+ help='# examples per eval epoch is # objects * eval_epoch_len_multiplier')
114
+
115
+ # CO3D
116
+ parser.add_argument('--co3d_path', type=str, default='co3d_data',
117
+ help='Path to CO3D v2 data.')
118
+ parser.add_argument('--holdout_categories', action='store_true',
119
+ help='If true, hold out 10 categories and train on only the remaining 41 categories.')
120
+ parser.add_argument('--co3d_world_size', default=3.0, type=float,
121
+ help='The world space we consider is \in [-co3d_world_size, co3d_world_size] in each dimension.')
122
+
123
+ # Hypersim
124
+ parser.add_argument('--use_hypersim', action='store_true',
125
+ help='If true, use hypersim, else, co3d.')
126
+ parser.add_argument('--hypersim_path', default="hypersim_data", type=str,
127
+ help="Path to Hypersim data.")
128
+
129
+ # Data aug
130
+ parser.add_argument('--random_scale_delta', default=0.2, type=float,
131
+ help='Random scaling each example by a scaler \in [1 - random_scale_delta, 1 + random_scale_delta].')
132
+ parser.add_argument('--random_shift', default=1.0, type=float,
133
+ help='Random shifting an example in each axis by an amount \in [-random_shift, random_shift]')
134
+ parser.add_argument('--random_rotate_degree', default=180, type=int,
135
+ help='Random rotation degrees.')
136
+
137
+ # Smapling, evaluation, and coordinate system
138
+ parser.add_argument('--shrink_threshold', default=10.0, type=float,
139
+ help='Any points with distance beyond this value will be shrunk.')
140
+ parser.add_argument('--semisphere_size', default=6.0, type=float,
141
+ help='The Hypersim task predicts points in a semisphere in front of the camera.'
142
+ 'This value specifies the size of the semisphere.')
143
+ parser.add_argument('--eval_granularity', default=0.1, type=float,
144
+ help='Granularity of the evaluation points.')
145
+ parser.add_argument('--viz_granularity', default=0.1, type=float,
146
+ help='Granularity of points in visaulizatoin.')
147
+
148
+ parser.add_argument('--eval_score_threshold', default=0.1, type=float,
149
+ help='Score threshold for evaluation.')
150
+ parser.add_argument('--eval_dist_threshold', default=0.1, type=float,
151
+ help='Points closer than this amount to a groud-truth is considered correct.')
152
+ parser.add_argument('--train_dist_threshold', default=0.1, type=float,
153
+ help='Points closer than this amount is considered positive in training.')
154
+ return parser
155
+
156
+
157
+ def build_loader(args, num_tasks, global_rank, is_train, dataset_type, collate_fn, dataset_maps):
158
+ '''Build data loader'''
159
+ dataset = dataset_type(args, is_train=is_train, dataset_maps=dataset_maps)
160
+
161
+ sampler_train = torch.utils.data.DistributedSampler(
162
+ dataset, num_replicas=num_tasks, rank=global_rank, shuffle=is_train
163
+ )
164
+
165
+ data_loader = torch.utils.data.DataLoader(
166
+ dataset, batch_size=args.batch_size if is_train else args.eval_batch_size,
167
+ sampler=sampler_train,
168
+ num_workers=args.num_workers if is_train else args.num_eval_workers,
169
+ pin_memory=args.pin_mem,
170
+ collate_fn=collate_fn,
171
+ )
172
+ return data_loader
173
+
174
+
175
+ def main(args):
176
+ misc.init_distributed_mode(args)
177
+
178
+ print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
179
+ print("{}".format(args).replace(', ', ',\n'))
180
+
181
+ device = torch.device(args.device)
182
+
183
+ # fix the seed for reproducibility
184
+ seed = args.seed + misc.get_rank()
185
+ torch.manual_seed(seed)
186
+ np.random.seed(seed)
187
+
188
+ cudnn.benchmark = True
189
+ num_tasks = misc.get_world_size()
190
+ global_rank = misc.get_rank()
191
+
192
+ # define the model
193
+ model = mcc_model.get_mcc_model(
194
+ rgb_weight=args.rgb_weight,
195
+ occupancy_weight=args.occupancy_weight,
196
+ args=args,
197
+ )
198
+
199
+ model.to(device)
200
+
201
+ model_without_ddp = model
202
+ print("Model = %s" % str(model_without_ddp))
203
+
204
+ eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
205
+ if args.lr is None: # only base_lr is specified
206
+ args.lr = args.blr * eff_batch_size / 512
207
+
208
+ print("base lr: %.2e" % (args.blr))
209
+ print("actual lr: %.2e" % args.lr)
210
+
211
+ print("accumulate grad iterations: %d" % args.accum_iter)
212
+ print("effective batch size: %d" % eff_batch_size)
213
+
214
+ if args.distributed:
215
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
216
+ model_without_ddp = model.module
217
+
218
+ # following timm: set wd as 0 for bias and norm layers
219
+ param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay)
220
+ optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
221
+ print(optimizer)
222
+ loss_scaler = NativeScaler()
223
+
224
+ misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
225
+
226
+ if args.use_hypersim:
227
+ dataset_type = HyperSimDataset
228
+ collate_fn = hypersim_collate_fn
229
+ dataset_maps = None
230
+ else:
231
+ dataset_type = CO3DV2Dataset
232
+ collate_fn = co3dv2_collate_fn
233
+ dataset_maps = get_all_dataset_maps(
234
+ args.co3d_path, args.holdout_categories,
235
+ )
236
+
237
+ dataset_viz = dataset_type(args, is_train=False, is_viz=True, dataset_maps=dataset_maps)
238
+ sampler_viz = torch.utils.data.DistributedSampler(
239
+ dataset_viz, num_replicas=num_tasks, rank=global_rank, shuffle=False
240
+ )
241
+
242
+ data_loader_viz = torch.utils.data.DataLoader(
243
+ dataset_viz, batch_size=1,
244
+ sampler=sampler_viz,
245
+ num_workers=args.num_eval_workers,
246
+ pin_memory=args.pin_mem,
247
+ collate_fn=collate_fn,
248
+ )
249
+
250
+ if args.run_viz:
251
+ run_viz(
252
+ model, data_loader_viz,
253
+ device, args=args, epoch=0,
254
+ )
255
+ exit()
256
+
257
+ data_loader_train, data_loader_val = [
258
+ build_loader(
259
+ args, num_tasks, global_rank,
260
+ is_train=is_train,
261
+ dataset_type=dataset_type, collate_fn=collate_fn, dataset_maps=dataset_maps
262
+ ) for is_train in [True, False]
263
+ ]
264
+
265
+ print(f"Start training for {args.epochs} epochs")
266
+ start_time = time.time()
267
+ for epoch in range(args.start_epoch, args.epochs):
268
+ print(f'Epoch {epoch}:')
269
+ if args.distributed:
270
+ data_loader_train.sampler.set_epoch(epoch)
271
+ train_stats = train_one_epoch(
272
+ model, data_loader_train,
273
+ optimizer, device, epoch, loss_scaler,
274
+ args=args,
275
+ )
276
+
277
+ val_stats = {}
278
+ if (epoch % 5 == 4 or epoch + 1 == args.epochs) or args.debug:
279
+ val_stats = eval_one_epoch(
280
+ model, data_loader_val,
281
+ device, args=args,
282
+ )
283
+
284
+ if ((epoch % 10 == 9 or epoch + 1 == args.epochs) or args.debug):
285
+ run_viz(
286
+ model, data_loader_viz,
287
+ device, args=args, epoch=epoch,
288
+ )
289
+
290
+ if args.output_dir and (epoch % 10 == 9 or epoch + 1 == args.epochs):
291
+ misc.save_model(
292
+ args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
293
+ loss_scaler=loss_scaler, epoch=epoch)
294
+
295
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
296
+ **{f'val_{k}': v for k, v in val_stats.items()},
297
+ 'epoch': epoch,}
298
+
299
+ if args.output_dir and misc.is_main_process():
300
+ with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
301
+ f.write(json.dumps(log_stats) + "\n")
302
+
303
+ run_viz(
304
+ model, data_loader_viz,
305
+ device, args=args, epoch=-1,
306
+ )
307
+
308
+ total_time = time.time() - start_time
309
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
310
+ print('Training time {}'.format(total_time_str))
311
+
312
+
313
+ if __name__ == '__main__':
314
+
315
+ args = get_args_parser()
316
+ args = args.parse_args()
317
+
318
+ if args.output_dir:
319
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
320
+
321
+ main(args)
322
+
mcc_model.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ # DeiT: https://github.com/facebookresearch/deit
10
+ # MAE: https://github.com/facebookresearch/mae
11
+ # --------------------------------------------------------
12
+
13
+ from functools import partial
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+
19
+ from timm.models.vision_transformer import PatchEmbed, Block, Mlp, DropPath
20
+
21
+ from util.pos_embed import get_2d_sincos_pos_embed
22
+
23
+ class MCCDecoderAttention(nn.Module):
24
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., args=None):
25
+ super().__init__()
26
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
27
+ self.num_heads = num_heads
28
+ head_dim = dim // num_heads
29
+ self.scale = head_dim ** -0.5
30
+
31
+ self.args = args
32
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
33
+ self.attn_drop = nn.Dropout(attn_drop)
34
+ self.proj = nn.Linear(dim, dim)
35
+ self.proj_drop = nn.Dropout(proj_drop)
36
+
37
+ def forward(self, x, unseen_size):
38
+ B, N, C = x.shape
39
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
40
+ q, k, v = qkv.unbind(0)
41
+ attn = (q @ k.transpose(-2, -1)) * self.scale
42
+
43
+ mask = torch.zeros((1, 1, N, N), device=attn.device)
44
+ mask[:, :, :, -unseen_size:] = float('-inf')
45
+ for i in range(unseen_size):
46
+ mask[:, :, -(i + 1), -(i + 1)] = 0
47
+ attn = attn + mask
48
+ attn = attn.softmax(dim=-1)
49
+
50
+ attn = self.attn_drop(attn)
51
+
52
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
53
+ x = self.proj(x)
54
+ x = self.proj_drop(x)
55
+ return x
56
+
57
+ class MCCDecoderBlock(nn.Module):
58
+
59
+ def __init__(
60
+ self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
61
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, args=None):
62
+ super().__init__()
63
+ self.args = args
64
+ self.norm1 = norm_layer(dim)
65
+ self.attn = MCCDecoderAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, args=args)
66
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
67
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
68
+
69
+ self.norm2 = norm_layer(dim)
70
+ mlp_hidden_dim = int(dim * mlp_ratio)
71
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
72
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
73
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
74
+
75
+ def forward(self, x, unseen_size):
76
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), unseen_size)))
77
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
78
+ return x
79
+
80
+
81
+ class XYZPosEmbed(nn.Module):
82
+ """ Masked Autoencoder with VisionTransformer backbone
83
+ """
84
+ def __init__(self, embed_dim):
85
+ super().__init__()
86
+ self.embed_dim = embed_dim
87
+
88
+ self.two_d_pos_embed = nn.Parameter(
89
+ torch.zeros(1, 64 + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
90
+
91
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
92
+ self.win_size = 8
93
+
94
+ self.pos_embed = nn.Linear(3, embed_dim)
95
+
96
+ self.blocks = nn.ModuleList([
97
+ Block(embed_dim, num_heads=12, mlp_ratio=2.0, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6))
98
+ for _ in range(1)
99
+ ])
100
+
101
+ self.invalid_xyz_token = nn.Parameter(torch.zeros(embed_dim,))
102
+
103
+ self.initialize_weights()
104
+
105
+ def initialize_weights(self):
106
+ torch.nn.init.normal_(self.cls_token, std=.02)
107
+
108
+ two_d_pos_embed = get_2d_sincos_pos_embed(self.two_d_pos_embed.shape[-1], 8, cls_token=True)
109
+ self.two_d_pos_embed.data.copy_(torch.from_numpy(two_d_pos_embed).float().unsqueeze(0))
110
+
111
+ torch.nn.init.normal_(self.invalid_xyz_token, std=.02)
112
+
113
+ def forward(self, seen_xyz, valid_seen_xyz):
114
+ emb = self.pos_embed(seen_xyz)
115
+
116
+ emb[~valid_seen_xyz] = 0.0
117
+ emb[~valid_seen_xyz] += self.invalid_xyz_token
118
+
119
+ B, H, W, C = emb.shape
120
+ emb = emb.view(B, H // self.win_size, self.win_size, W // self.win_size, self.win_size, C)
121
+ emb = emb.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.win_size * self.win_size, C)
122
+
123
+ emb = emb + self.two_d_pos_embed[:, 1:, :]
124
+ cls_token = self.cls_token + self.two_d_pos_embed[:, :1, :]
125
+
126
+ cls_tokens = cls_token.expand(emb.shape[0], -1, -1)
127
+ emb = torch.cat((cls_tokens, emb), dim=1)
128
+ for _, blk in enumerate(self.blocks):
129
+ emb = blk(emb)
130
+ return emb[:, 0].view(B, (H // self.win_size) * (W // self.win_size), -1)
131
+
132
+
133
+ class DecodeXYZPosEmbed(nn.Module):
134
+ """ Masked Autoencoder with VisionTransformer backbone
135
+ """
136
+ def __init__(self, embed_dim):
137
+ super().__init__()
138
+ self.embed_dim = embed_dim
139
+ self.pos_embed = nn.Linear(3, embed_dim)
140
+
141
+ def forward(self, unseen_xyz):
142
+ return self.pos_embed(unseen_xyz)
143
+
144
+
145
+ class MCC(nn.Module):
146
+ """ Masked Autoencoder with VisionTransformer backbone
147
+ """
148
+ def __init__(self,
149
+ img_size=224, patch_size=16, in_chans=3,
150
+ embed_dim=1024, depth=24, num_heads=16,
151
+ decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
152
+ mlp_ratio=4., norm_layer=nn.LayerNorm,
153
+ rgb_weight=1.0, occupancy_weight=1.0, args=None):
154
+ super().__init__()
155
+
156
+ self.rgb_weight = rgb_weight
157
+ self.occupancy_weight = occupancy_weight
158
+ self.args = args
159
+
160
+ # --------------------------------------------------------------------------
161
+ # encoder specifics
162
+ self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
163
+ num_patches = self.patch_embed.num_patches
164
+
165
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
166
+ self.cls_token_xyz = nn.Parameter(torch.zeros(1, 1, embed_dim))
167
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
168
+
169
+ self.xyz_pos_embed = XYZPosEmbed(embed_dim)
170
+
171
+ self.blocks = nn.ModuleList([
172
+ Block(
173
+ embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
174
+ drop_path=args.drop_path
175
+ ) for i in range(depth)])
176
+
177
+ self.blocks_xyz = nn.ModuleList([
178
+ Block(
179
+ embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
180
+ drop_path=args.drop_path
181
+ ) for i in range(depth)])
182
+
183
+ self.norm = norm_layer(embed_dim)
184
+ self.norm_xyz = norm_layer(embed_dim)
185
+ self.cached_enc_feat = None
186
+
187
+ # --------------------------------------------------------------------------
188
+ # decoder specifics
189
+ self.decoder_embed = nn.Linear(
190
+ embed_dim * 2,
191
+ decoder_embed_dim,
192
+ bias=True
193
+ )
194
+
195
+ self.decoder_xyz_pos_embed = DecodeXYZPosEmbed(decoder_embed_dim)
196
+ self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
197
+
198
+ self.decoder_blocks = nn.ModuleList([
199
+ MCCDecoderBlock(
200
+ decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
201
+ drop_path=args.drop_path,
202
+ args=args,
203
+ ) for i in range(decoder_depth)])
204
+
205
+ self.decoder_norm = norm_layer(decoder_embed_dim)
206
+ if self.args.regress_color:
207
+ self.decoder_pred = nn.Linear(decoder_embed_dim, 3 + 1, bias=True) # decoder to patch
208
+ else:
209
+ self.decoder_pred = nn.Linear(decoder_embed_dim, 256 * 3 + 1, bias=True) # decoder to patch
210
+
211
+ self.loss_occupy = nn.BCEWithLogitsLoss()
212
+ if self.args.regress_color:
213
+ self.loss_rgb = nn.MSELoss()
214
+ else:
215
+ self.loss_rgb = nn.CrossEntropyLoss()
216
+
217
+ self.initialize_weights()
218
+
219
+ def initialize_weights(self):
220
+
221
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
222
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
223
+
224
+ decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
225
+ self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
226
+
227
+ # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
228
+ w = self.patch_embed.proj.weight.data
229
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
230
+
231
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
232
+ torch.nn.init.normal_(self.cls_token, std=.02)
233
+ torch.nn.init.normal_(self.cls_token_xyz, std=.02)
234
+
235
+ # initialize nn.Linear and nn.LayerNorm
236
+ self.apply(self._init_weights)
237
+
238
+ def _init_weights(self, m):
239
+ if isinstance(m, nn.Linear):
240
+ # we use xavier_uniform following official JAX ViT:
241
+ torch.nn.init.xavier_uniform_(m.weight)
242
+ if isinstance(m, nn.Linear) and m.bias is not None:
243
+ nn.init.constant_(m.bias, 0)
244
+ elif isinstance(m, nn.LayerNorm):
245
+ nn.init.constant_(m.bias, 0)
246
+ nn.init.constant_(m.weight, 1.0)
247
+
248
+
249
+ def forward_encoder(self, x, seen_xyz, valid_seen_xyz):
250
+
251
+ # get tokens
252
+ x = self.patch_embed(x)
253
+ x = x + self.pos_embed[:, 1:, :]
254
+ y = self.xyz_pos_embed(seen_xyz, valid_seen_xyz)
255
+
256
+ ##### forward E_XYZ #####
257
+ # append cls token
258
+ cls_token_xyz = self.cls_token_xyz
259
+ cls_tokens_xyz = cls_token_xyz.expand(y.shape[0], -1, -1)
260
+
261
+ y = torch.cat((cls_tokens_xyz, y), dim=1)
262
+ # apply Transformer blocks
263
+ for blk in self.blocks_xyz:
264
+ y = blk(y)
265
+ y = self.norm_xyz(y)
266
+
267
+ ##### forward E_RGB #####
268
+ # append cls token
269
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
270
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
271
+
272
+ x = torch.cat((cls_tokens, x), dim=1)
273
+ # apply Transformer blocks
274
+ for blk in self.blocks:
275
+ x = blk(x)
276
+ x = self.norm(x)
277
+
278
+ # combine encodings
279
+ x = torch.cat([x, y], dim=2)
280
+ return x
281
+
282
+ def forward_decoder(self, x, unseen_xyz):
283
+ # embed tokens
284
+ x = self.decoder_embed(x)
285
+ x = x + self.decoder_pos_embed
286
+
287
+ # 3D pos embed
288
+ unseen_xyz = self.decoder_xyz_pos_embed(unseen_xyz)
289
+ x = torch.cat([x, unseen_xyz], dim=1)
290
+
291
+ # apply Transformer blocks
292
+ for blk in self.decoder_blocks:
293
+ x = blk(x, unseen_xyz.shape[1])
294
+
295
+ x = self.decoder_norm(x)
296
+
297
+ # predictor projection
298
+ pred = self.decoder_pred(x)
299
+ # remove cls & seen token
300
+ pred = pred[:, -unseen_xyz.shape[1]:, :]
301
+
302
+ return pred
303
+
304
+ def forward_loss(self, pred, unseen_occupy, unseen_rgb):
305
+ loss = self.loss_occupy(
306
+ pred[:, :, :1].reshape((-1, 1)),
307
+ unseen_occupy.reshape((-1, 1)).float()
308
+ ) * self.occupancy_weight
309
+
310
+ if unseen_occupy.sum() > 0:
311
+ if self.args.regress_color:
312
+ pred_rgb = pred[:, :, 1:][unseen_occupy.bool()]
313
+ gt_rgb = unseen_rgb[unseen_occupy.bool()]
314
+ else:
315
+ pred_rgb = pred[:, :, 1:][unseen_occupy.bool()].reshape((-1, 256))
316
+ gt_rgb = torch.round(unseen_rgb[unseen_occupy.bool()] * 255).long().reshape((-1,))
317
+
318
+ rgb_loss = self.loss_rgb(pred_rgb, gt_rgb) * self.rgb_weight
319
+ loss = loss + rgb_loss
320
+ return loss
321
+
322
+
323
+ def clear_cache(self):
324
+ self.cached_enc_feat = None
325
+
326
+ def forward(self, seen_images, seen_xyz, unseen_xyz, unseen_rgb, unseen_occupy, valid_seen_xyz,
327
+ cache_enc=False):
328
+
329
+ unseen_xyz = shrink_points_beyond_threshold(unseen_xyz, self.args.shrink_threshold)
330
+
331
+ if self.cached_enc_feat is None:
332
+ seen_images = preprocess_img(seen_images)
333
+ seen_xyz = shrink_points_beyond_threshold(seen_xyz, self.args.shrink_threshold)
334
+ latent = self.forward_encoder(seen_images, seen_xyz, valid_seen_xyz)
335
+
336
+ if cache_enc:
337
+ if self.cached_enc_feat is None:
338
+ self.cached_enc_feat = latent
339
+ else:
340
+ latent = self.cached_enc_feat
341
+
342
+ pred = self.forward_decoder(latent, unseen_xyz)
343
+ loss = self.forward_loss(pred, unseen_occupy, unseen_rgb)
344
+ return loss, pred
345
+
346
+
347
+ def get_mcc_model(**kwargs):
348
+ return MCC(
349
+ embed_dim=768, depth=12, num_heads=12,
350
+ decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
351
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs
352
+ )
353
+
354
+
355
+ def shrink_points_beyond_threshold(xyz, threshold):
356
+ xyz = xyz.clone().detach()
357
+ dist = (xyz ** 2.0).sum(axis=-1) ** 0.5
358
+ affected = (dist > threshold) * torch.isfinite(dist)
359
+ xyz[affected] = xyz[affected] * (
360
+ threshold * (2.0 - threshold / dist[affected]) / dist[affected]
361
+ )[..., None]
362
+ return xyz
363
+
364
+
365
+ def preprocess_img(x):
366
+ if x.shape[2] != 224:
367
+ assert x.shape[2] == 800
368
+ x = F.interpolate(
369
+ x,
370
+ scale_factor=224./800.,
371
+ mode="bilinear",
372
+ )
373
+ resnet_mean = torch.tensor([0.485, 0.456, 0.406], device=x.device).reshape((1, 3, 1, 1))
374
+ resnet_std = torch.tensor([0.229, 0.224, 0.225], device=x.device).reshape((1, 3, 1, 1))
375
+ imgs_normed = (x - resnet_mean) / resnet_std
376
+ return imgs_normed
377
+
378
+
379
+ class LayerScale(nn.Module):
380
+ def __init__(self, dim, init_values=1e-5, inplace=False):
381
+ super().__init__()
382
+ self.inplace = inplace
383
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
384
+
385
+ def forward(self, x):
386
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
packages.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ffmpeg
2
+ libsm6
3
+ libxext6
pre-requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch==1.13.0
2
+ torchvision==0.14.0
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h5py
2
+ omegaconf
3
+ submitit
4
+ timm==0.4.5
5
+ opencv-python
6
+ matplotlib
7
+ plotly
8
+ gradio
9
+ gradio_client==0.2.7
10
+ plyfile
11
+ git+https://github.com/facebookresearch/pytorch3d.git
util/co3d_dataset.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import random
8
+ from typing import cast
9
+
10
+ import torch
11
+ from pytorch3d.implicitron.dataset.dataset_base import FrameData
12
+
13
+ import util.co3d_utils as co3d_utils
14
+
15
+
16
+ def co3dv2_collate_fn(batch):
17
+ assert len(batch[0]) == 4
18
+ return (
19
+ FrameData.collate([x[0] for x in batch]),
20
+ FrameData.collate([x[1] for x in batch]),
21
+ [x[2] for x in batch],
22
+ [x[3] for x in batch],
23
+ )
24
+
25
+
26
+ def pad_point_cloud(pc, N):
27
+ cur_N = pc._points_list[0].shape[0]
28
+ if cur_N == N:
29
+ return pc
30
+
31
+ assert cur_N > 0
32
+
33
+ n_pad = N - cur_N
34
+ indices = random.choices(list(range(cur_N)), k=n_pad)
35
+ pc._features_list[0] = torch.cat([pc._features_list[0], pc._features_list[0][indices]], dim=0)
36
+ pc._points_list[0] = torch.cat([pc._points_list[0], pc._points_list[0][indices]], dim=0)
37
+ return pc
38
+
39
+
40
+ class CO3DV2Dataset(torch.utils.data.Dataset):
41
+ def __init__(self, args, is_train, is_viz=False, dataset_maps=None):
42
+
43
+ self.args = args
44
+ self.is_train = is_train
45
+ self.is_viz = is_viz
46
+
47
+ self.dataset_split = 'train' if is_train else 'val'
48
+ self.all_datasets = dataset_maps[0 if is_train else 1]
49
+ print(len(self.all_datasets), 'categories loaded')
50
+
51
+ self.all_example_names = self.get_all_example_names()
52
+ print('containing', len(self.all_example_names), 'examples')
53
+
54
+ def get_all_example_names(self):
55
+ all_example_names = []
56
+ for category in self.all_datasets.keys():
57
+ for sequence_name in self.all_datasets[category].seq_name2idx.keys():
58
+ all_example_names.append((category, sequence_name))
59
+ return all_example_names
60
+
61
+ def __getitem__(self, index):
62
+ for retry in range(1000):
63
+ try:
64
+ if retry > 9:
65
+ index = random.choice(range(len(self)))
66
+ print('retry', retry, 'new index:', index)
67
+ gap = 1 if self.is_train else len(self.all_example_names) // len(self)
68
+ assert gap >= 1
69
+ category, sequence_name = self.all_example_names[(index * gap) % len(self.all_example_names)]
70
+
71
+ cat_dataset = self.all_datasets[category]
72
+
73
+ frame_data = cat_dataset.__getitem__(
74
+ random.choice(cat_dataset.seq_name2idx[sequence_name])
75
+ if self.is_train
76
+ else cat_dataset.seq_name2idx[sequence_name][
77
+ hash(sequence_name) % len(cat_dataset.seq_name2idx[sequence_name])
78
+ ]
79
+ )
80
+ test_frame = None
81
+ seen_idx = None
82
+
83
+ frame_data = cat_dataset.frame_data_type.collate([frame_data])
84
+ mask = (
85
+ (cast(torch.Tensor, frame_data.fg_probability) > 0.5).float()
86
+ if frame_data.fg_probability is not None
87
+ else None
88
+ )
89
+ seen_rgb = frame_data.image_rgb.clone().detach()
90
+
91
+ # 112, 112, 3
92
+ seen_xyz = co3d_utils.get_rgbd_points(
93
+ 112, 112,
94
+ frame_data.camera,
95
+ frame_data.depth_map,
96
+ mask,
97
+ )
98
+
99
+ full_point_cloud = co3d_utils._load_pointcloud(f'{self.args.co3d_path}/{category}/{sequence_name}/pointcloud.ply', max_points=20000)
100
+ full_point_cloud = pad_point_cloud(full_point_cloud, 20000)
101
+ break
102
+ except Exception as e:
103
+ print(category, sequence_name, 'sampling failed', retry, e)
104
+
105
+ seen_rgb = seen_rgb.squeeze(0)
106
+ full_rgb = full_point_cloud._features_list[0]
107
+
108
+ return (
109
+ (seen_xyz, seen_rgb),
110
+ (full_point_cloud._points_list[0], full_rgb),
111
+ test_frame,
112
+ (category, sequence_name, seen_idx),
113
+ )
114
+
115
+ def __len__(self) -> int:
116
+ n_objs = sum([len(cat_dataset.seq_name2idx.keys()) for cat_dataset in self.all_datasets.values()])
117
+ if self.is_train:
118
+ return int(n_objs * self.args.train_epoch_len_multiplier)
119
+ elif self.is_viz:
120
+ return n_objs
121
+ else:
122
+ return int(n_objs * self.args.eval_epoch_len_multiplier)
util/co3d_utils.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import glob
8
+ from omegaconf import DictConfig
9
+ from typing import Optional
10
+
11
+ import torch
12
+
13
+ from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
14
+ from pytorch3d.implicitron.dataset.json_index_dataset_map_provider_v2 import (
15
+ JsonIndexDatasetMapProviderV2
16
+ )
17
+ from pytorch3d.implicitron.tools.config import expand_args_fields
18
+ from pytorch3d.io import IO
19
+ from pytorch3d.renderer import (
20
+ NDCMultinomialRaysampler,
21
+ ray_bundle_to_ray_points,
22
+ )
23
+ from pytorch3d.renderer.cameras import CamerasBase
24
+ from pytorch3d.structures import Pointclouds
25
+
26
+
27
+ HOLDOUT_CATEGORIES = set([
28
+ 'apple',
29
+ 'baseballglove',
30
+ 'cup',
31
+ 'ball',
32
+ 'toyplane',
33
+ 'handbag',
34
+ 'book',
35
+ 'carrot',
36
+ 'suitcase',
37
+ 'bowl',
38
+ ])
39
+
40
+ def get_dataset_map(
41
+ dataset_root: str,
42
+ category: str,
43
+ subset_name: str,
44
+ ) -> DatasetMap:
45
+ """
46
+ Obtain the dataset map that contains the train/val/test dataset objects.
47
+ """
48
+ expand_args_fields(JsonIndexDatasetMapProviderV2)
49
+ dataset_map_provider = JsonIndexDatasetMapProviderV2(
50
+ category=category,
51
+ subset_name=subset_name,
52
+ dataset_root=dataset_root,
53
+ test_on_train=False,
54
+ only_test_set=False,
55
+ load_eval_batches=True,
56
+ dataset_JsonIndexDataset_args=DictConfig({"remove_empty_masks": False, "load_point_clouds": False}),
57
+ )
58
+ return dataset_map_provider.get_dataset_map()
59
+
60
+
61
+ def _load_pointcloud(pcl_path, max_points):
62
+ pcl = IO().load_pointcloud(pcl_path)
63
+ if max_points > 0:
64
+ pcl = pcl.subsample(max_points)
65
+
66
+ return pcl
67
+
68
+
69
+ def get_all_dataset_maps(co3d_path, holdout_categories):
70
+ all_categories = [c.split('/')[-1] for c in list(glob.glob(co3d_path + '/*')) if not c.endswith('.json')]
71
+ all_categories = sorted(all_categories, key=lambda x: hash(x))
72
+
73
+ # Obtain the CO3Dv2 dataset map
74
+ train_dataset_maps = {}
75
+ val_dataset_maps = {}
76
+ for category in all_categories:
77
+
78
+ print(f'Loading dataset map ({category})')
79
+ dataset_map = {
80
+ 'train': torch.load(f'dataset_cache/{category}_train.pt'),
81
+ 'val': torch.load(f'dataset_cache/{category}_val.pt')
82
+ }
83
+ if not holdout_categories or category not in HOLDOUT_CATEGORIES:
84
+ train_dataset_maps[category] = dataset_map['train']
85
+ if not holdout_categories or category in HOLDOUT_CATEGORIES:
86
+ val_dataset_maps[category] = dataset_map['val']
87
+
88
+ print('Loaded', len(train_dataset_maps), 'categores for train')
89
+ print('Loaded', len(val_dataset_maps), 'categores for val')
90
+ return train_dataset_maps, val_dataset_maps
91
+
92
+
93
+ def get_rgbd_points(
94
+ imh, imw,
95
+ camera: CamerasBase,
96
+ depth_map: torch.Tensor,
97
+ mask: Optional[torch.Tensor] = None,
98
+ mask_thr: float = 0.5,
99
+ ) -> Pointclouds:
100
+ """
101
+ Given a batch of images, depths, masks and cameras, generate a colored
102
+ point cloud by unprojecting depth maps to the and coloring with the source
103
+ pixel colors.
104
+ """
105
+ depth_map = torch.nn.functional.interpolate(
106
+ depth_map,
107
+ size=[imh, imw],
108
+ mode="bilinear",
109
+ align_corners=False,
110
+ )
111
+ # convert the depth maps to point clouds using the grid ray sampler
112
+ pts_3d = ray_bundle_to_ray_points(
113
+ NDCMultinomialRaysampler(
114
+ image_width=imw,
115
+ image_height=imh,
116
+ n_pts_per_ray=1,
117
+ min_depth=1.0,
118
+ max_depth=1.0,
119
+ )(camera)._replace(lengths=depth_map[:, 0, ..., None])
120
+ ).squeeze(3)[None]
121
+
122
+ pts_mask = depth_map > 0.0
123
+ if mask is not None:
124
+ mask = torch.nn.functional.interpolate(
125
+ mask,
126
+ size=[imh, imw],
127
+ mode="bilinear",
128
+ align_corners=False,
129
+ )
130
+ pts_mask *= mask > mask_thr
131
+ pts_3d[~pts_mask] = float('inf')
132
+ return pts_3d.squeeze(0).squeeze(0)
133
+
util/crop.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+ import torch
10
+
11
+ from torchvision import transforms
12
+ from torchvision.transforms import functional as F
13
+
14
+
15
+ class RandomResizedCrop(transforms.RandomResizedCrop):
16
+ """
17
+ RandomResizedCrop for matching TF/TPU implementation: no for-loop is used.
18
+ This may lead to results different with torchvision's version.
19
+ Following BYOL's TF code:
20
+ https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206
21
+ """
22
+ @staticmethod
23
+ def get_params(img, scale, ratio):
24
+ width, height = F._get_image_size(img)
25
+ area = height * width
26
+
27
+ target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
28
+ log_ratio = torch.log(torch.tensor(ratio))
29
+ aspect_ratio = torch.exp(
30
+ torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
31
+ ).item()
32
+
33
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
34
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
35
+
36
+ w = min(w, width)
37
+ h = min(h, height)
38
+
39
+ i = torch.randint(0, height - h + 1, size=(1,)).item()
40
+ j = torch.randint(0, width - w + 1, size=(1,)).item()
41
+
42
+ return i, j, h, w
util/hypersim_dataset.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import random
9
+ import glob
10
+
11
+ import torch
12
+ from pytorch3d.implicitron.dataset.dataset_base import FrameData
13
+ from pytorch3d.ops import sample_points_from_meshes
14
+
15
+ from util.hypersim_utils import read_h5py, read_img
16
+
17
+
18
+ def hypersim_collate_fn(batch):
19
+ assert len(batch[0]) == 4
20
+ return (
21
+ FrameData.collate([x[0] for x in batch]),
22
+ FrameData.collate([x[1] for x in batch]),
23
+ FrameData.collate([x[2] for x in batch]),
24
+ [x[2] for x in batch]
25
+ )
26
+
27
+
28
+ def is_good_xyz(xyz):
29
+ assert len(xyz.shape) == 3
30
+ return (torch.isfinite(xyz.sum(axis=2))).sum() > 2000
31
+
32
+
33
+ def get_camera_pos_file_name_from_frame_name(frame_name):
34
+ tmp = frame_name.split('/')
35
+ tmp[-3] = '_detail'
36
+ tmp[-2] = 'cam_' + tmp[-2].split('_')[2]
37
+ tmp[-1] = 'camera_keyframe_positions.hdf5'
38
+ return '/'.join(tmp)
39
+
40
+
41
+ def get_camera_look_at_file_name_from_frame_name(frame_name):
42
+ tmp = frame_name.split('/')
43
+ tmp[-3] = '_detail'
44
+ tmp[-2] = 'cam_' + tmp[-2].split('_')[2]
45
+ tmp[-1] = 'camera_keyframe_look_at_positions.hdf5'
46
+ return '/'.join(tmp)
47
+
48
+
49
+ def get_camera_orientation_file_name_from_frame_name(frame_name):
50
+ tmp = frame_name.split('/')
51
+ tmp[-3] = '_detail'
52
+ tmp[-2] = 'cam_' + tmp[-2].split('_')[2]
53
+ tmp[-1] = 'camera_keyframe_orientations.hdf5'
54
+ return '/'.join(tmp)
55
+
56
+
57
+ def read_scale_from_frame_name(frame_name):
58
+ tmp = frame_name.split('/')
59
+ with open('/'.join(tmp[:-3] + ['_detail', 'metadata_scene.csv'])) as f:
60
+ for line in f:
61
+ items = line.split(',')
62
+ return float(items[1])
63
+
64
+
65
+ def random_crop(xyz, img, is_train=True):
66
+ assert xyz.shape[0] == img.shape[0]
67
+ assert xyz.shape[1] == img.shape[1]
68
+
69
+ width, height = img.shape[0], img.shape[1]
70
+ w = h = min(width, height)
71
+ if is_train:
72
+ i = torch.randint(0, width - w + 1, size=(1,)).item()
73
+ j = torch.randint(0, height - h + 1, size=(1,)).item()
74
+ else:
75
+ i = (width - w) // 2
76
+ j = (height - h) // 2
77
+ xyz = xyz[i:i+w, j:j+h]
78
+ img = img[i:i+w, j:j+h]
79
+ xyz = torch.nn.functional.interpolate(
80
+ xyz[None].permute(0, 3, 1, 2), (112, 112),
81
+ mode='bilinear',
82
+ ).permute(0, 2, 3, 1)[0]
83
+ img = torch.nn.functional.interpolate(
84
+ img[None].permute(0, 3, 1, 2), (224, 224),
85
+ mode='bilinear',
86
+ ).permute(0, 2, 3, 1)[0]
87
+ return xyz, img
88
+
89
+
90
+ class HyperSimDataset(torch.utils.data.Dataset):
91
+ def __init__(self, args, is_train, is_viz=False, **kwargs):
92
+
93
+ self.args = args
94
+ self.is_train = is_train
95
+ self.is_viz = is_viz
96
+
97
+ self.dataset_split = 'train' if is_train else 'val'
98
+ self.scene_names = self.load_scene_names(is_train)
99
+
100
+ if not is_train:
101
+ self.meshes = self.load_meshes()
102
+
103
+ self.hypersim_gt = self.load_hypersim_gt()
104
+
105
+
106
+ def load_hypersim_gt(self):
107
+ gt_filename = 'hypersim_gt_train.pt' if self.dataset_split == 'train' else 'hypersim_gt_val.pt'
108
+ print('loading GT file from', gt_filename)
109
+ gt = torch.load(gt_filename)
110
+ for scene_name in gt.keys():
111
+ good = torch.isfinite(gt[scene_name][0].sum(axis=1)) & torch.isfinite(gt[scene_name][1].sum(axis=1))
112
+
113
+ # Subsample GT to reduce memory usage.
114
+ if self.is_train:
115
+ good = good & (torch.rand(good.shape) < 0.5)
116
+ else:
117
+ good = good & (torch.rand(good.shape) < 0.1)
118
+ gt[scene_name] = [gt[scene_name][0][good], gt[scene_name][1][good]]
119
+ return gt
120
+
121
+ def load_meshes(self):
122
+ return torch.load('all_hypersim_val_meshes.pt')
123
+
124
+ def load_scene_names(self, is_train):
125
+ split = 'train' if is_train else 'test'
126
+ scene_names = []
127
+ with open(os.path.join(
128
+ self.args.hypersim_path,
129
+ 'evermotion_dataset/analysis/metadata_images_split_scene_v1.csv'),'r') as f:
130
+ for line in f:
131
+ items = line.split(',')
132
+ if items[-1].strip() == split:
133
+ scene_names.append(items[0])
134
+ scene_names = sorted(list(set(scene_names)))
135
+ print(len(scene_names), 'scenes loaded:', scene_names)
136
+ return scene_names
137
+
138
+ def is_corrupted_frame(self, frame):
139
+ return (
140
+ ('ai_003_001' in frame and 'cam_00' in frame)
141
+ or ('ai_004_009' in frame and 'cam_01' in frame)
142
+ )
143
+
144
+ def get_hypersim_data(self, index):
145
+ for retry in range(1000):
146
+ try:
147
+ if retry < 10:
148
+ scene_name = self.scene_names[index % len(self.scene_names)]
149
+ else:
150
+ scene_name = random.choice(self.scene_names)
151
+
152
+ frames = glob.glob(os.path.join(self.args.hypersim_path, scene_name, 'images/scene_cam_*_final_preview/*tonemap*'))
153
+ seen_frame = random.choice(frames)
154
+
155
+ if self.is_corrupted_frame(seen_frame):
156
+ continue
157
+
158
+ seen_data = self.load_frame_data(seen_frame)
159
+ if not is_good_xyz(seen_data[0]):
160
+ continue
161
+
162
+ cur_gt = self.hypersim_gt[scene_name]
163
+ gt_data = [cur_gt[0], cur_gt[1]]
164
+
165
+ if self.is_train:
166
+ mesh_points = torch.zeros((1,))
167
+ else:
168
+ mesh_points = sample_points_from_meshes(self.meshes[scene_name], 1000000)
169
+
170
+ # get camera positions
171
+ camera_positions = read_h5py(get_camera_pos_file_name_from_frame_name(seen_frame))
172
+ camera_position = camera_positions[int(seen_frame.split('.')[-3])]
173
+
174
+ # get camera orientations
175
+ cam_orientations = read_h5py(get_camera_orientation_file_name_from_frame_name(seen_frame))
176
+ cam_orientation = cam_orientations[int(seen_frame.split('.')[-3])]
177
+ cam_orientation = cam_orientation * (-1.0)
178
+
179
+ # rotate to camera direction
180
+ seen_data[0] = torch.matmul(seen_data[0], cam_orientation)
181
+ gt_data[0] = torch.matmul(gt_data[0], cam_orientation)
182
+
183
+ # shift to camera center
184
+ camera_position = torch.matmul(camera_position, cam_orientation)
185
+ seen_data[0] -= camera_position
186
+ gt_data[0] -= camera_position
187
+ # to meter
188
+ asset_to_meter_scale = read_scale_from_frame_name(seen_frame)
189
+ seen_data[0] = seen_data[0] * asset_to_meter_scale
190
+ gt_data[0] = gt_data[0] * asset_to_meter_scale
191
+
192
+ # get points GT
193
+ n_gt = 30000
194
+ in_front_of_cam = (gt_data[0][..., 2] > 0)
195
+ if in_front_of_cam.sum() < 1000:
196
+ print('Warning! Not enough in front of cam.', in_front_of_cam.sum())
197
+ continue
198
+ gt_data = [gt_data[0][in_front_of_cam], gt_data[1][in_front_of_cam]]
199
+
200
+ if in_front_of_cam.sum() < n_gt:
201
+ selected = random.choices(range(gt_data[0].shape[0]), k=n_gt)
202
+ else:
203
+ selected = random.sample(range(gt_data[0].shape[0]), n_gt)
204
+ gt_data = [gt_data[0][selected][None], gt_data[1][selected][None], torch.zeros((1,))]
205
+
206
+ if not self.is_train:
207
+ mesh_points = torch.matmul(mesh_points, cam_orientation)
208
+ mesh_points -= camera_position * asset_to_meter_scale
209
+ in_front_of_cam = (mesh_points[..., 2] > 0)
210
+ if in_front_of_cam.sum() < 1000:
211
+ print('Warning! Not enough mesh in front of cam.', in_front_of_cam.sum())
212
+ continue
213
+ mesh_points = mesh_points[in_front_of_cam]
214
+ if in_front_of_cam.sum() < n_gt:
215
+ selected = random.choices(range(mesh_points.shape[0]), k=n_gt)
216
+ else:
217
+ selected = random.sample(range(mesh_points.shape[0]), n_gt)
218
+ mesh_points = mesh_points[selected][None]
219
+ mesh_points[..., 0] *= -1
220
+
221
+ seen_data[0][..., 0] *= -1
222
+ gt_data[0][..., 0] *= -1
223
+
224
+ seen_data[1] = seen_data[1].permute(2, 0, 1)
225
+
226
+ return seen_data, gt_data, mesh_points, scene_name
227
+ except Exception as e:
228
+ print(scene_name, 'loading failed', retry, e)
229
+
230
+
231
+ def __getitem__(self, index):
232
+
233
+ seen_data, gt_data, mesh_points, scene_name = self.get_hypersim_data(index)
234
+
235
+ # normalize the data
236
+ example_std = get_example_std(seen_data[0])
237
+ seen_data[0] = seen_data[0] / example_std
238
+ gt_data[0] = gt_data[0] / example_std
239
+ mesh_points = mesh_points / example_std
240
+
241
+ return (
242
+ seen_data,
243
+ gt_data,
244
+ mesh_points,
245
+ scene_name,
246
+ )
247
+
248
+ def load_frame_data(self, frame_path):
249
+ frame_xyz_path = frame_path.replace('final_preview/', 'geometry_hdf5/').replace('.tonemap.jpg', '.position.hdf5')
250
+ xyz = read_h5py(frame_xyz_path)
251
+ img = read_img(frame_path)
252
+
253
+ xyz, img = random_crop(
254
+ xyz, img,
255
+ is_train=self.is_train,
256
+ )
257
+ return [xyz, img]
258
+
259
+ def __len__(self) -> int:
260
+ if self.is_train:
261
+ return int(len(self.scene_names) * self.args.train_epoch_len_multiplier)
262
+ elif self.is_viz:
263
+ return len(self.scene_names)
264
+ else:
265
+ return int(len(self.scene_names) * self.args.eval_epoch_len_multiplier)
266
+
267
+
268
+ def get_example_std(x):
269
+ x = x.reshape(-1, 3)
270
+ x = x[torch.isfinite(x.sum(dim=1))]
271
+ return x.std(dim=0).mean().detach()
util/hypersim_utils.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import cv2
8
+ import h5py
9
+
10
+ import torch
11
+
12
+
13
+ def read_h5py(filename):
14
+ with h5py.File(filename, "r") as f:
15
+ data = torch.tensor(f['dataset'][:], dtype=torch.float32)
16
+ return data
17
+
18
+
19
+ def read_img(frame_path):
20
+ for retry in range(100):
21
+ img = cv2.imread(frame_path)
22
+ if img is not None:
23
+ return torch.tensor(img / 255.0, dtype=torch.float32)[..., [2, 1, 0]]
24
+ print('retry loading', retry, frame_path)
util/lars.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # LARS optimizer, implementation from MoCo v3:
8
+ # https://github.com/facebookresearch/moco-v3
9
+ # --------------------------------------------------------
10
+
11
+ import torch
12
+
13
+
14
+ class LARS(torch.optim.Optimizer):
15
+ """
16
+ LARS optimizer, no rate scaling or weight decay for parameters <= 1D.
17
+ """
18
+ def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001):
19
+ defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient)
20
+ super().__init__(params, defaults)
21
+
22
+ @torch.no_grad()
23
+ def step(self):
24
+ for g in self.param_groups:
25
+ for p in g['params']:
26
+ dp = p.grad
27
+
28
+ if dp is None:
29
+ continue
30
+
31
+ if p.ndim > 1: # if not normalization gamma/beta or bias
32
+ dp = dp.add(p, alpha=g['weight_decay'])
33
+ param_norm = torch.norm(p)
34
+ update_norm = torch.norm(dp)
35
+ one = torch.ones_like(param_norm)
36
+ q = torch.where(param_norm > 0.,
37
+ torch.where(update_norm > 0,
38
+ (g['trust_coefficient'] * param_norm / update_norm), one),
39
+ one)
40
+ dp = dp.mul(q)
41
+
42
+ param_state = self.state[p]
43
+ if 'mu' not in param_state:
44
+ param_state['mu'] = torch.zeros_like(p)
45
+ mu = param_state['mu']
46
+ mu.mul_(g['momentum']).add_(dp)
47
+ p.add_(mu, alpha=-g['lr'])
util/lr_decay.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # ELECTRA https://github.com/google-research/electra
9
+ # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10
+ # --------------------------------------------------------
11
+
12
+ import json
13
+
14
+
15
+ def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75):
16
+ """
17
+ Parameter groups for layer-wise lr decay
18
+ Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
19
+ """
20
+ param_group_names = {}
21
+ param_groups = {}
22
+
23
+ num_layers = len(model.blocks) + 1
24
+
25
+ layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
26
+
27
+ for n, p in model.named_parameters():
28
+ if not p.requires_grad:
29
+ continue
30
+
31
+ # no decay: all 1D parameters and model specific ones
32
+ if p.ndim == 1 or n in no_weight_decay_list:
33
+ g_decay = "no_decay"
34
+ this_decay = 0.
35
+ else:
36
+ g_decay = "decay"
37
+ this_decay = weight_decay
38
+
39
+ layer_id = get_layer_id_for_vit(n, num_layers)
40
+ group_name = "layer_%d_%s" % (layer_id, g_decay)
41
+
42
+ if group_name not in param_group_names:
43
+ this_scale = layer_scales[layer_id]
44
+
45
+ param_group_names[group_name] = {
46
+ "lr_scale": this_scale,
47
+ "weight_decay": this_decay,
48
+ "params": [],
49
+ }
50
+ param_groups[group_name] = {
51
+ "lr_scale": this_scale,
52
+ "weight_decay": this_decay,
53
+ "params": [],
54
+ }
55
+
56
+ param_group_names[group_name]["params"].append(n)
57
+ param_groups[group_name]["params"].append(p)
58
+
59
+ # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
60
+
61
+ return list(param_groups.values())
62
+
63
+
64
+ def get_layer_id_for_vit(name, num_layers):
65
+ """
66
+ Assign a parameter with its layer id
67
+ Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
68
+ """
69
+ if name in ['cls_token', 'pos_embed']:
70
+ return 0
71
+ elif name.startswith('patch_embed'):
72
+ return 0
73
+ elif name.startswith('blocks'):
74
+ return int(name.split('.')[1]) + 1
75
+ else:
76
+ return num_layers
util/lr_sched.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+ def adjust_learning_rate(optimizer, epoch, args):
10
+ """Decay the learning rate with half-cycle cosine after warmup"""
11
+ if epoch < args.warmup_epochs:
12
+ lr = args.lr * epoch / args.warmup_epochs
13
+ else:
14
+ lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
15
+ (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
16
+ for param_group in optimizer.param_groups:
17
+ if "lr_scale" in param_group:
18
+ param_group["lr"] = lr * param_group["lr_scale"]
19
+ else:
20
+ param_group["lr"] = lr
21
+ return lr
util/misc.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # DeiT: https://github.com/facebookresearch/deit
9
+ # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10
+ # MAE: https://github.com/facebookresearch/mae
11
+ # --------------------------------------------------------
12
+
13
+ import builtins
14
+ import datetime
15
+ import os
16
+ import time
17
+ from collections import defaultdict, deque
18
+ from pathlib import Path
19
+
20
+ import torch
21
+ import torch.distributed as dist
22
+ from torch._six import inf
23
+
24
+
25
+ class SmoothedValue(object):
26
+ """Track a series of values and provide access to smoothed values over a
27
+ window or the global series average.
28
+ """
29
+
30
+ def __init__(self, window_size=20, fmt=None):
31
+ if fmt is None:
32
+ fmt = "{median:.4f} ({global_avg:.4f})"
33
+ self.deque = deque(maxlen=window_size)
34
+ self.total = 0.0
35
+ self.count = 0
36
+ self.fmt = fmt
37
+
38
+ def update(self, value, n=1):
39
+ self.deque.append(value)
40
+ self.count += n
41
+ self.total += value * n
42
+
43
+ def synchronize_between_processes(self):
44
+ """
45
+ Warning: does not synchronize the deque!
46
+ """
47
+ if not is_dist_avail_and_initialized():
48
+ return
49
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
50
+ dist.barrier()
51
+ dist.all_reduce(t)
52
+ t = t.tolist()
53
+ self.count = int(t[0])
54
+ self.total = t[1]
55
+
56
+ @property
57
+ def median(self):
58
+ d = torch.tensor(list(self.deque))
59
+ return d.median().item()
60
+
61
+ @property
62
+ def avg(self):
63
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
64
+ return d.mean().item()
65
+
66
+ @property
67
+ def global_avg(self):
68
+ return self.total / self.count
69
+
70
+ @property
71
+ def max(self):
72
+ return max(self.deque)
73
+
74
+ @property
75
+ def value(self):
76
+ return self.deque[-1]
77
+
78
+ def __str__(self):
79
+ return self.fmt.format(
80
+ median=self.median,
81
+ avg=self.avg,
82
+ global_avg=self.global_avg,
83
+ max=self.max,
84
+ value=self.value)
85
+
86
+
87
+ class MetricLogger(object):
88
+ def __init__(self, delimiter="\t"):
89
+ self.meters = defaultdict(SmoothedValue)
90
+ self.delimiter = delimiter
91
+
92
+ def update(self, **kwargs):
93
+ for k, v in kwargs.items():
94
+ if v is None:
95
+ continue
96
+ if isinstance(v, torch.Tensor):
97
+ v = v.item()
98
+ assert isinstance(v, (float, int))
99
+ self.meters[k].update(v)
100
+
101
+ def __getattr__(self, attr):
102
+ if attr in self.meters:
103
+ return self.meters[attr]
104
+ if attr in self.__dict__:
105
+ return self.__dict__[attr]
106
+ raise AttributeError("'{}' object has no attribute '{}'".format(
107
+ type(self).__name__, attr))
108
+
109
+ def __str__(self):
110
+ loss_str = []
111
+ for name, meter in self.meters.items():
112
+ loss_str.append(
113
+ "{}: {}".format(name, str(meter))
114
+ )
115
+ return self.delimiter.join(loss_str)
116
+
117
+ def synchronize_between_processes(self):
118
+ for meter in self.meters.values():
119
+ meter.synchronize_between_processes()
120
+
121
+ def add_meter(self, name, meter):
122
+ self.meters[name] = meter
123
+
124
+ def log_every(self, iterable, print_freq, header=None):
125
+ i = 0
126
+ if not header:
127
+ header = ''
128
+ start_time = time.time()
129
+ end = time.time()
130
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
131
+ data_time = SmoothedValue(fmt='{avg:.4f}')
132
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
133
+ log_msg = [
134
+ header,
135
+ '[{0' + space_fmt + '}/{1}]',
136
+ 'eta: {eta}',
137
+ '{meters}',
138
+ 'time: {time}',
139
+ 'data: {data}'
140
+ ]
141
+ if torch.cuda.is_available():
142
+ log_msg.append('max mem: {memory:.0f}')
143
+ log_msg = self.delimiter.join(log_msg)
144
+ MB = 1024.0 * 1024.0
145
+ for obj in iterable:
146
+ data_time.update(time.time() - end)
147
+ yield obj
148
+ iter_time.update(time.time() - end)
149
+ if i % print_freq == 0 or i == len(iterable) - 1:
150
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
151
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
152
+ if torch.cuda.is_available():
153
+ print(log_msg.format(
154
+ i, len(iterable), eta=eta_string,
155
+ meters=str(self),
156
+ time=str(iter_time), data=str(data_time),
157
+ memory=torch.cuda.max_memory_allocated() / MB))
158
+ else:
159
+ print(log_msg.format(
160
+ i, len(iterable), eta=eta_string,
161
+ meters=str(self),
162
+ time=str(iter_time), data=str(data_time)))
163
+ i += 1
164
+ end = time.time()
165
+ total_time = time.time() - start_time
166
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
167
+ print('{} Total time: {} ({:.4f} s / it)'.format(
168
+ header, total_time_str, total_time / len(iterable)))
169
+
170
+
171
+ def setup_for_distributed(is_master):
172
+ """
173
+ This function disables printing when not in master process
174
+ """
175
+ builtin_print = builtins.print
176
+
177
+ def print(*args, **kwargs):
178
+ force = kwargs.pop('force', False)
179
+ force = force or (get_world_size() > 8)
180
+ if is_master or force:
181
+ now = datetime.datetime.now().time()
182
+ builtin_print('[{}] '.format(now), end='') # print with time stamp
183
+ builtin_print(*args, **kwargs)
184
+
185
+ builtins.print = print
186
+
187
+
188
+ def is_dist_avail_and_initialized():
189
+ if not dist.is_available():
190
+ return False
191
+ if not dist.is_initialized():
192
+ return False
193
+ return True
194
+
195
+
196
+ def get_world_size():
197
+ if not is_dist_avail_and_initialized():
198
+ return 1
199
+ return dist.get_world_size()
200
+
201
+
202
+ def get_rank():
203
+ if not is_dist_avail_and_initialized():
204
+ return 0
205
+ return dist.get_rank()
206
+
207
+
208
+ def is_main_process():
209
+ return get_rank() == 0
210
+
211
+
212
+ def save_on_master(*args, **kwargs):
213
+ if is_main_process():
214
+ torch.save(*args, **kwargs)
215
+
216
+
217
+ def init_distributed_mode(args):
218
+ if args.dist_on_itp:
219
+ args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
220
+ args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
221
+ args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
222
+ args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
223
+ os.environ['LOCAL_RANK'] = str(args.gpu)
224
+ os.environ['RANK'] = str(args.rank)
225
+ os.environ['WORLD_SIZE'] = str(args.world_size)
226
+ # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
227
+ elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
228
+ args.rank = int(os.environ["RANK"])
229
+ args.world_size = int(os.environ['WORLD_SIZE'])
230
+ args.gpu = int(os.environ['LOCAL_RANK'])
231
+ elif 'SLURM_PROCID' in os.environ:
232
+ args.rank = int(os.environ['SLURM_PROCID'])
233
+ args.gpu = args.rank % torch.cuda.device_count()
234
+ else:
235
+ print('Not using distributed mode')
236
+ setup_for_distributed(is_master=True) # hack
237
+ args.distributed = False
238
+ return
239
+
240
+ args.distributed = True
241
+
242
+ torch.cuda.set_device(args.gpu)
243
+ args.dist_backend = 'nccl'
244
+ print('| distributed init (rank {}): {}, gpu {}'.format(
245
+ args.rank, args.dist_url, args.gpu), flush=True)
246
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
247
+ world_size=args.world_size, rank=args.rank)
248
+ torch.distributed.barrier()
249
+ setup_for_distributed(args.rank == 0)
250
+
251
+
252
+ class NativeScalerWithGradNormCount:
253
+ state_dict_key = "amp_scaler"
254
+
255
+ def __init__(self):
256
+ self._scaler = torch.cuda.amp.GradScaler()
257
+
258
+ def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True, verbose=False):
259
+ self._scaler.scale(loss).backward(create_graph=create_graph)
260
+ if update_grad:
261
+ if clip_grad is not None:
262
+ assert parameters is not None
263
+ self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
264
+ norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
265
+ else:
266
+ self._scaler.unscale_(optimizer)
267
+ norm = get_grad_norm_(parameters)
268
+ self._scaler.step(optimizer)
269
+ self._scaler.update()
270
+ else:
271
+ norm = None
272
+ if verbose:
273
+ print('norm:', norm, 'clip:', clip_grad)
274
+ return norm
275
+
276
+ def state_dict(self):
277
+ return self._scaler.state_dict()
278
+
279
+ def load_state_dict(self, state_dict):
280
+ self._scaler.load_state_dict(state_dict)
281
+
282
+
283
+ def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
284
+ if isinstance(parameters, torch.Tensor):
285
+ parameters = [parameters]
286
+ parameters = [p for p in parameters if p.grad is not None]
287
+ norm_type = float(norm_type)
288
+ if len(parameters) == 0:
289
+ return torch.tensor(0.)
290
+ device = parameters[0].grad.device
291
+ if norm_type == inf:
292
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
293
+ else:
294
+ total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
295
+ return total_norm
296
+
297
+
298
+ def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler):
299
+ output_dir = Path(args.output_dir)
300
+ epoch_name = f'{epoch:05d}'
301
+ if loss_scaler is not None:
302
+ checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
303
+ for checkpoint_path in checkpoint_paths:
304
+ to_save = {
305
+ 'model': model_without_ddp.state_dict(),
306
+ 'optimizer': optimizer.state_dict(),
307
+ 'epoch': epoch,
308
+ 'scaler': loss_scaler.state_dict(),
309
+ 'args': args,
310
+ }
311
+
312
+ save_on_master(to_save, checkpoint_path)
313
+ else:
314
+ client_state = {'epoch': epoch}
315
+ model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
316
+
317
+
318
+ def load_model(args, model_without_ddp, optimizer, loss_scaler):
319
+ if args.resume:
320
+ if args.resume.startswith('https'):
321
+ checkpoint = torch.hub.load_state_dict_from_url(
322
+ args.resume, map_location='cpu', check_hash=True)
323
+ else:
324
+ checkpoint = torch.load(args.resume, map_location='cpu')
325
+ print("Resume checkpoint %s" % args.resume)
326
+ print(model_without_ddp.load_state_dict(checkpoint['model'], strict=False))
327
+ if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
328
+ optimizer.load_state_dict(checkpoint['optimizer'])
329
+ args.start_epoch = checkpoint['epoch'] + 1
330
+ if 'scaler' in checkpoint:
331
+ print(loss_scaler.load_state_dict(checkpoint['scaler']))
332
+ print("With optim & sched!")
333
+ print("start epoch:", args.start_epoch)
334
+
335
+
336
+ def all_reduce_mean(x):
337
+ world_size = get_world_size()
338
+ if world_size > 1:
339
+ x_reduce = torch.tensor(x).cuda()
340
+ dist.all_reduce(x_reduce)
341
+ x_reduce /= world_size
342
+ return x_reduce.item()
343
+ else:
344
+ return x
345
+
346
+
347
+ import torch.distributed as dist
348
+
349
+ def get_world_size():
350
+ """
351
+ Get the size of the world.
352
+ """
353
+ if not dist.is_available():
354
+ return 1
355
+ if not dist.is_initialized():
356
+ return 1
357
+ return dist.get_world_size()
358
+
359
+
360
+ # def all_gather_unaligned(data):
361
+ # """
362
+ # Run all_gather on arbitrary picklable data (not necessarily tensors).
363
+ # Args:
364
+ # data: any picklable object
365
+ # group: a torch process group. By default, will use a group which
366
+ # contains all ranks on gloo backend.
367
+ # Returns:
368
+ # list[data]: list of data gathered from each rank
369
+ # """
370
+ # print('world', get_world_size())
371
+ # if get_world_size() == 1:
372
+ # return [data]
373
+
374
+ # # receiving Tensor from all ranks
375
+ # tensor_list = [
376
+ # torch.zeros_like(data) for _ in range(get_world_size())
377
+ # ]
378
+ # dist.all_gather(tensor_list, data)
379
+ # for tl in tensor_list:
380
+ # print(tl)
381
+ # print(tl.shape)
382
+ # return tensor_list
383
+
384
+ import pickle
385
+ def _serialize_to_tensor(data, group):
386
+ """
387
+ Seriialize the tensor to ByteTensor. Note that only `gloo` and `nccl`
388
+ backend is supported.
389
+ Args:
390
+ data (data): data to be serialized.
391
+ group (group): pytorch dist group.
392
+ Returns:
393
+ tensor (ByteTensor): tensor that serialized.
394
+ """
395
+
396
+ backend = dist.get_backend(group)
397
+ assert backend in ["gloo", "nccl"]
398
+ device = torch.device("cpu" if backend == "gloo" else "cuda")
399
+
400
+ buffer = pickle.dumps(data)
401
+ if len(buffer) > 1024 ** 3:
402
+ print(
403
+ "Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
404
+ get_rank(), len(buffer) / (1024 ** 3), device
405
+ )
406
+ )
407
+ storage = torch.ByteStorage.from_buffer(buffer)
408
+ tensor = torch.ByteTensor(storage).to(device=device)
409
+ return tensor
410
+
411
+ import functools
412
+ @functools.lru_cache()
413
+ def _get_global_gloo_group():
414
+ """
415
+ Return a process group based on gloo backend, containing all the ranks
416
+ The result is cached.
417
+ Returns:
418
+ (group): pytorch dist group.
419
+ """
420
+ if dist.get_backend() == "nccl":
421
+ return dist.new_group(backend="gloo")
422
+ else:
423
+ return dist.group.WORLD
424
+
425
+
426
+ def _pad_to_largest_tensor(tensor, group):
427
+ """
428
+ Padding all the tensors from different GPUs to the largest ones.
429
+ Args:
430
+ tensor (tensor): tensor to pad.
431
+ group (group): pytorch dist group.
432
+ Returns:
433
+ list[int]: size of the tensor, on each rank
434
+ Tensor: padded tensor that has the max size
435
+ """
436
+ world_size = dist.get_world_size(group=group)
437
+ assert (
438
+ world_size >= 1
439
+ ), "comm.gather/all_gather must be called from ranks within the given group!"
440
+ local_size = torch.tensor(
441
+ [tensor.numel()], dtype=torch.int64, device=tensor.device
442
+ )
443
+ size_list = [
444
+ torch.zeros([1], dtype=torch.int64, device=tensor.device)
445
+ for _ in range(world_size)
446
+ ]
447
+ dist.all_gather(size_list, local_size, group=group)
448
+ size_list = [int(size.item()) for size in size_list]
449
+
450
+ max_size = max(size_list)
451
+
452
+ # we pad the tensor because torch all_gather does not support
453
+ # gathering tensors of different shapes
454
+ if local_size != max_size:
455
+ padding = torch.zeros(
456
+ (max_size - local_size,), dtype=torch.uint8, device=tensor.device
457
+ )
458
+ tensor = torch.cat((tensor, padding), dim=0)
459
+ return size_list, tensor
460
+
461
+ def all_gather_unaligned(data, group=None):
462
+ """
463
+ Run all_gather on arbitrary picklable data (not necessarily tensors).
464
+ Args:
465
+ data: any picklable object
466
+ group: a torch process group. By default, will use a group which
467
+ contains all ranks on gloo backend.
468
+ Returns:
469
+ list[data]: list of data gathered from each rank
470
+ """
471
+ if get_world_size() == 1:
472
+ return [data]
473
+ if group is None:
474
+ group = _get_global_gloo_group()
475
+ if dist.get_world_size(group) == 1:
476
+ return [data]
477
+
478
+ tensor = _serialize_to_tensor(data, group)
479
+
480
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
481
+ max_size = max(size_list)
482
+
483
+ # receiving Tensor from all ranks
484
+ tensor_list = [
485
+ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device)
486
+ for _ in size_list
487
+ ]
488
+ dist.all_gather(tensor_list, tensor, group=group)
489
+
490
+ data_list = []
491
+ for size, tensor in zip(size_list, tensor_list):
492
+ buffer = tensor.cpu().numpy().tobytes()[:size]
493
+ data_list.append(pickle.loads(buffer).to(data.device))
494
+
495
+ return data_list
496
+
util/pos_embed.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # Position embedding utils
8
+ # --------------------------------------------------------
9
+
10
+ import numpy as np
11
+
12
+ import torch
13
+
14
+ # --------------------------------------------------------
15
+ # 2D sine-cosine position embedding
16
+ # References:
17
+ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
18
+ # MoCo v3: https://github.com/facebookresearch/moco-v3
19
+ # --------------------------------------------------------
20
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
21
+ """
22
+ grid_size: int of the grid height and width
23
+ return:
24
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
25
+ """
26
+ grid_h = np.arange(grid_size, dtype=np.float32)
27
+ grid_w = np.arange(grid_size, dtype=np.float32)
28
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
29
+ grid = np.stack(grid, axis=0)
30
+
31
+ grid = grid.reshape([2, 1, grid_size, grid_size])
32
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
33
+ if cls_token:
34
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
35
+ return pos_embed
36
+
37
+
38
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
39
+ assert embed_dim % 2 == 0
40
+
41
+ # use half of dimensions to encode grid_h
42
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
43
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
44
+
45
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
46
+ return emb
47
+
48
+
49
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
50
+ """
51
+ embed_dim: output dimension for each position
52
+ pos: a list of positions to be encoded: size (M,)
53
+ out: (M, D)
54
+ """
55
+ assert embed_dim % 2 == 0
56
+ omega = np.arange(embed_dim // 2, dtype=np.float)
57
+ omega /= embed_dim / 2.
58
+ omega = 1. / 10000**omega # (D/2,)
59
+
60
+ pos = pos.reshape(-1) # (M,)
61
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
62
+
63
+ emb_sin = np.sin(out) # (M, D/2)
64
+ emb_cos = np.cos(out) # (M, D/2)
65
+
66
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
67
+ return emb
68
+
69
+ def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos):
70
+ """
71
+ embed_dim: output dimension for each position
72
+ pos: a list of positions to be encoded: size (M,)
73
+ out: (M, D)
74
+ """
75
+ assert embed_dim % 2 == 0
76
+ omega = torch.arange(embed_dim // 2, device=pos.device).float()
77
+ omega /= embed_dim / 2.
78
+ omega = 1. / 10000**omega # (D/2,)
79
+
80
+ pos = pos.reshape(-1) # (M,)
81
+ out = torch.einsum('m,d->md', pos, omega) # (M, D/2), outer product
82
+
83
+ emb_sin = torch.sin(out) # (M, D/2)
84
+ emb_cos = torch.cos(out) # (M, D/2)
85
+
86
+ emb = torch.cat([emb_sin, emb_cos], axis=1) # (M, D)
87
+ return emb
88
+
89
+
90
+
91
+ # --------------------------------------------------------
92
+ # Interpolate position embeddings for high-resolution
93
+ # References:
94
+ # DeiT: https://github.com/facebookresearch/deit
95
+ # --------------------------------------------------------
96
+ def interpolate_pos_embed(model, checkpoint_model):
97
+ if 'pos_embed' in checkpoint_model:
98
+ pos_embed_checkpoint = checkpoint_model['pos_embed']
99
+ embedding_size = pos_embed_checkpoint.shape[-1]
100
+ num_patches = model.patch_embed.num_patches
101
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
102
+ # height (== width) for the checkpoint position embedding
103
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
104
+ # height (== width) for the new position embedding
105
+ new_size = int(num_patches ** 0.5)
106
+ # class_token and dist_token are kept unchanged
107
+ if orig_size != new_size:
108
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
109
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
110
+ # only the position tokens are interpolated
111
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
112
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
113
+ pos_tokens = torch.nn.functional.interpolate(
114
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
115
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
116
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
117
+ checkpoint_model['pos_embed'] = new_pos_embed
weights/co3dv2_all_categories.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca861bee4c2cb27acc6855da34227ce7026cf9eb275171da3c5a33976b3d86bd
3
+ size 2423688373