hugoycj
commited on
Commit
·
cacb27a
1
Parent(s):
c0f6cb5
Initial commit
Browse files- .gitignore +4 -0
- app.py +192 -0
- engine_mcc.py +587 -0
- main_mcc.py +322 -0
- mcc_model.py +386 -0
- packages.txt +3 -0
- pre-requirements.txt +2 -0
- requirements.txt +11 -0
- util/co3d_dataset.py +122 -0
- util/co3d_utils.py +133 -0
- util/crop.py +42 -0
- util/hypersim_dataset.py +271 -0
- util/hypersim_utils.py +24 -0
- util/lars.py +47 -0
- util/lr_decay.py +76 -0
- util/lr_sched.py +21 -0
- util/misc.py +496 -0
- util/pos_embed.py +117 -0
- weights/co3dv2_all_categories.pth +3 -0
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*venv
|
2 |
+
flagged
|
3 |
+
*examples
|
4 |
+
__pycache__
|
app.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from pytorch3d.io.obj_io import load_obj
|
10 |
+
import tempfile
|
11 |
+
import main_mcc
|
12 |
+
import mcc_model
|
13 |
+
import util.misc as misc
|
14 |
+
from engine_mcc import prepare_data
|
15 |
+
from plyfile import PlyData, PlyElement
|
16 |
+
|
17 |
+
def run_inference(model, samples, device, temperature, args):
|
18 |
+
model.eval()
|
19 |
+
|
20 |
+
seen_xyz, valid_seen_xyz, unseen_xyz, unseen_rgb, labels, seen_images = prepare_data(
|
21 |
+
samples, device, is_train=False, args=args, is_viz=True
|
22 |
+
)
|
23 |
+
pred_occupy = []
|
24 |
+
pred_colors = []
|
25 |
+
|
26 |
+
max_n_unseen_fwd = 2000
|
27 |
+
|
28 |
+
model.cached_enc_feat = None
|
29 |
+
num_passes = int(np.ceil(unseen_xyz.shape[1] / max_n_unseen_fwd))
|
30 |
+
for p_idx in range(num_passes):
|
31 |
+
p_start = p_idx * max_n_unseen_fwd
|
32 |
+
p_end = (p_idx + 1) * max_n_unseen_fwd
|
33 |
+
cur_unseen_xyz = unseen_xyz[:, p_start:p_end]
|
34 |
+
cur_unseen_rgb = unseen_rgb[:, p_start:p_end].zero_()
|
35 |
+
cur_labels = labels[:, p_start:p_end].zero_()
|
36 |
+
|
37 |
+
with torch.no_grad():
|
38 |
+
_, pred = model(
|
39 |
+
seen_images=seen_images,
|
40 |
+
seen_xyz=seen_xyz,
|
41 |
+
unseen_xyz=cur_unseen_xyz,
|
42 |
+
unseen_rgb=cur_unseen_rgb,
|
43 |
+
unseen_occupy=cur_labels,
|
44 |
+
cache_enc=True,
|
45 |
+
valid_seen_xyz=valid_seen_xyz,
|
46 |
+
)
|
47 |
+
if device == "cuda":
|
48 |
+
pred_occupy.append(pred[..., 0].cuda())
|
49 |
+
else:
|
50 |
+
pred_occupy.append(pred[..., 0].cpu())
|
51 |
+
if args.regress_color:
|
52 |
+
pred_colors.append(pred[..., 1:].reshape((-1, 3)))
|
53 |
+
else:
|
54 |
+
pred_colors.append(
|
55 |
+
(
|
56 |
+
torch.nn.Softmax(dim=2)(
|
57 |
+
pred[..., 1:].reshape((-1, 3, 256)) / temperature
|
58 |
+
) * torch.linspace(0, 1, 256, device=pred.device)
|
59 |
+
).sum(axis=2)
|
60 |
+
)
|
61 |
+
|
62 |
+
pred_occupy = torch.cat(pred_occupy, dim=1)
|
63 |
+
pred_occupy = torch.nn.Sigmoid()(pred_occupy)
|
64 |
+
return torch.cat(pred_colors, dim=0).cpu().numpy(), pred_occupy.cpu().numpy(), unseen_xyz.cpu().numpy()
|
65 |
+
|
66 |
+
def pad_image(im, value):
|
67 |
+
if im.shape[0] > im.shape[1]:
|
68 |
+
diff = im.shape[0] - im.shape[1]
|
69 |
+
return torch.cat([im, (torch.zeros((im.shape[0], diff, im.shape[2])) + value)], dim=1)
|
70 |
+
else:
|
71 |
+
diff = im.shape[1] - im.shape[0]
|
72 |
+
return torch.cat([im, (torch.zeros((diff, im.shape[1], im.shape[2])) + value)], dim=0)
|
73 |
+
|
74 |
+
|
75 |
+
def normalize(seen_xyz):
|
76 |
+
seen_xyz = seen_xyz / (seen_xyz[torch.isfinite(seen_xyz.sum(dim=-1))].var(dim=0) ** 0.5).mean()
|
77 |
+
seen_xyz = seen_xyz - seen_xyz[torch.isfinite(seen_xyz.sum(dim=-1))].mean(axis=0)
|
78 |
+
return seen_xyz
|
79 |
+
|
80 |
+
def infer(
|
81 |
+
image,
|
82 |
+
point_cloud,
|
83 |
+
seg,
|
84 |
+
granularity,
|
85 |
+
temperature,
|
86 |
+
):
|
87 |
+
|
88 |
+
score_thresholds = [0.1, 0.2, 0.3, 0.4, 0.5]
|
89 |
+
|
90 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
91 |
+
|
92 |
+
parser = main_mcc.get_args_parser()
|
93 |
+
parser.set_defaults(eval=True)
|
94 |
+
|
95 |
+
args = parser.parse_args()
|
96 |
+
|
97 |
+
model = mcc_model.get_mcc_model(
|
98 |
+
occupancy_weight=1.0,
|
99 |
+
rgb_weight=0.01,
|
100 |
+
args=args,
|
101 |
+
)
|
102 |
+
|
103 |
+
if device == "cuda":
|
104 |
+
model = model.cuda()
|
105 |
+
|
106 |
+
misc.load_model(args=args, model_without_ddp=model, optimizer=None, loss_scaler=None)
|
107 |
+
|
108 |
+
rgb = image
|
109 |
+
obj = load_obj(point_cloud.name)
|
110 |
+
|
111 |
+
seen_rgb = (torch.tensor(rgb).float() / 255)[..., [2, 1, 0]]
|
112 |
+
H, W = seen_rgb.shape[:2]
|
113 |
+
seen_rgb = torch.nn.functional.interpolate(
|
114 |
+
seen_rgb.permute(2, 0, 1)[None],
|
115 |
+
size=[H, W],
|
116 |
+
mode="bilinear",
|
117 |
+
align_corners=False,
|
118 |
+
)[0].permute(1, 2, 0)
|
119 |
+
|
120 |
+
seen_xyz = obj[0].reshape(H, W, 3)
|
121 |
+
seg = cv2.imread(seg.name, cv2.IMREAD_UNCHANGED)
|
122 |
+
mask = torch.tensor(cv2.resize(seg, (W, H))).bool()
|
123 |
+
seen_xyz[~mask] = float('inf')
|
124 |
+
|
125 |
+
seen_xyz = normalize(seen_xyz)
|
126 |
+
|
127 |
+
bottom, right = mask.nonzero().max(dim=0)[0]
|
128 |
+
top, left = mask.nonzero().min(dim=0)[0]
|
129 |
+
|
130 |
+
bottom = bottom + 40
|
131 |
+
right = right + 40
|
132 |
+
top = max(top - 40, 0)
|
133 |
+
left = max(left - 40, 0)
|
134 |
+
|
135 |
+
seen_xyz = seen_xyz[top:bottom+1, left:right+1]
|
136 |
+
seen_rgb = seen_rgb[top:bottom+1, left:right+1]
|
137 |
+
|
138 |
+
seen_xyz = pad_image(seen_xyz, float('inf'))
|
139 |
+
seen_rgb = pad_image(seen_rgb, 0)
|
140 |
+
|
141 |
+
seen_rgb = torch.nn.functional.interpolate(
|
142 |
+
seen_rgb.permute(2, 0, 1)[None],
|
143 |
+
size=[800, 800],
|
144 |
+
mode="bilinear",
|
145 |
+
align_corners=False,
|
146 |
+
)
|
147 |
+
|
148 |
+
seen_xyz = torch.nn.functional.interpolate(
|
149 |
+
seen_xyz.permute(2, 0, 1)[None],
|
150 |
+
size=[112, 112],
|
151 |
+
mode="bilinear",
|
152 |
+
align_corners=False,
|
153 |
+
).permute(0, 2, 3, 1)
|
154 |
+
|
155 |
+
samples = [
|
156 |
+
[seen_xyz, seen_rgb],
|
157 |
+
[torch.zeros((20000, 3)), torch.zeros((20000, 3))],
|
158 |
+
]
|
159 |
+
|
160 |
+
pred_colors, pred_occupy, unseen_xyz = run_inference(model, samples, device, temperature, args)
|
161 |
+
_masks = pred_occupy > 0.1
|
162 |
+
unseen_xyz = unseen_xyz[_masks]
|
163 |
+
pred_colors = pred_colors[None, ...][_masks] * 255
|
164 |
+
|
165 |
+
# Prepare data for PlyElement
|
166 |
+
vertex = np.core.records.fromarrays(np.hstack((unseen_xyz, pred_colors)).transpose(),
|
167 |
+
names='x, y, z, red, green, blue',
|
168 |
+
formats='f8, f8, f8, u1, u1, u1')
|
169 |
+
|
170 |
+
|
171 |
+
# Create PlyElement
|
172 |
+
element = PlyElement.describe(vertex, 'vertex')
|
173 |
+
|
174 |
+
# Save point cloud data to a temporary file
|
175 |
+
with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as f:
|
176 |
+
PlyData([element], text=True).write(f)
|
177 |
+
temp_file_name = f.name
|
178 |
+
|
179 |
+
return temp_file_name
|
180 |
+
|
181 |
+
|
182 |
+
demo = gr.Interface(fn=infer,
|
183 |
+
inputs=[gr.Image(label="Input Image"),
|
184 |
+
gr.File(label="Pointcloud File"),
|
185 |
+
gr.File(label="Segmentation File"),
|
186 |
+
gr.Slider(minimum=0.05, maximum=0.5, step=0.05, value=0.2, label="Granularity"),
|
187 |
+
gr.Slider(minimum=0, maximum=1.0, step=0.1, value=0.1, label="Temperature")
|
188 |
+
],
|
189 |
+
outputs=[gr.outputs.File(label="Point Cloud Json")],
|
190 |
+
examples=[["demo/quest2.jpg", "demo/quest2.obj", "demo/quest2_seg.png", 0.2, 0.1]],
|
191 |
+
cache_examples=True)
|
192 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
engine_mcc.py
ADDED
@@ -0,0 +1,587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# References:
|
8 |
+
# DeiT: https://github.com/facebookresearch/deit
|
9 |
+
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
|
10 |
+
# MAE: https://github.com/facebookresearch/mae
|
11 |
+
# --------------------------------------------------------
|
12 |
+
import math
|
13 |
+
from typing import Iterable
|
14 |
+
import os
|
15 |
+
import matplotlib.pyplot as plt
|
16 |
+
import random
|
17 |
+
import torch
|
18 |
+
import numpy as np
|
19 |
+
import time
|
20 |
+
import base64
|
21 |
+
from io import BytesIO
|
22 |
+
|
23 |
+
import util.misc as misc
|
24 |
+
import util.lr_sched as lr_sched
|
25 |
+
|
26 |
+
from pytorch3d.structures import Pointclouds
|
27 |
+
from pytorch3d.vis.plotly_vis import plot_scene
|
28 |
+
from pytorch3d.transforms import RotateAxisAngle
|
29 |
+
from pytorch3d.io import IO
|
30 |
+
|
31 |
+
|
32 |
+
def evaluate_points(predicted_xyz, gt_xyz, dist_thres):
|
33 |
+
if predicted_xyz.shape[0] == 0:
|
34 |
+
return 0.0, 0.0, 0.0
|
35 |
+
slice_size = 1000
|
36 |
+
precision = 0.0
|
37 |
+
for i in range(int(np.ceil(predicted_xyz.shape[0] / slice_size))):
|
38 |
+
start = slice_size * i
|
39 |
+
end = slice_size * (i + 1)
|
40 |
+
dist = ((predicted_xyz[start:end, None] - gt_xyz[None]) ** 2.0).sum(axis=-1) ** 0.5
|
41 |
+
precision += ((dist < dist_thres).sum(axis=1) > 0).sum()
|
42 |
+
precision /= predicted_xyz.shape[0]
|
43 |
+
|
44 |
+
recall = 0.0
|
45 |
+
for i in range(int(np.ceil(predicted_xyz.shape[0] / slice_size))):
|
46 |
+
start = slice_size * i
|
47 |
+
end = slice_size * (i + 1)
|
48 |
+
dist = ((predicted_xyz[:, None] - gt_xyz[None, start:end]) ** 2.0).sum(axis=-1) ** 0.5
|
49 |
+
recall += ((dist < dist_thres).sum(axis=0) > 0).sum()
|
50 |
+
recall /= gt_xyz.shape[0]
|
51 |
+
return precision, recall, get_f1(precision, recall)
|
52 |
+
|
53 |
+
def aug_xyz(seen_xyz, unseen_xyz, args, is_train):
|
54 |
+
degree_x = 0
|
55 |
+
degree_y = 0
|
56 |
+
degree_z = 0
|
57 |
+
if is_train:
|
58 |
+
r_delta = args.random_scale_delta
|
59 |
+
scale = torch.tensor([
|
60 |
+
random.uniform(1.0 - r_delta, 1.0 + r_delta),
|
61 |
+
random.uniform(1.0 - r_delta, 1.0 + r_delta),
|
62 |
+
random.uniform(1.0 - r_delta, 1.0 + r_delta),
|
63 |
+
], device=seen_xyz.device)
|
64 |
+
|
65 |
+
if args.use_hypersim:
|
66 |
+
shift = 0
|
67 |
+
else:
|
68 |
+
degree_x = random.randrange(-args.random_rotate_degree, args.random_rotate_degree + 1)
|
69 |
+
degree_y = random.randrange(-args.random_rotate_degree, args.random_rotate_degree + 1)
|
70 |
+
degree_z = random.randrange(-args.random_rotate_degree, args.random_rotate_degree + 1)
|
71 |
+
|
72 |
+
r_shift = args.random_shift
|
73 |
+
shift = torch.tensor([[[
|
74 |
+
random.uniform(-r_shift, r_shift),
|
75 |
+
random.uniform(-r_shift, r_shift),
|
76 |
+
random.uniform(-r_shift, r_shift),
|
77 |
+
]]], device=seen_xyz.device)
|
78 |
+
seen_xyz = seen_xyz * scale + shift
|
79 |
+
unseen_xyz = unseen_xyz * scale + shift
|
80 |
+
|
81 |
+
B, H, W, _ = seen_xyz.shape
|
82 |
+
return [
|
83 |
+
rotate(seen_xyz.reshape((B, -1, 3)), degree_x, degree_y, degree_z).reshape((B, H, W, 3)),
|
84 |
+
rotate(unseen_xyz, degree_x, degree_y, degree_z),
|
85 |
+
]
|
86 |
+
|
87 |
+
|
88 |
+
def rotate(sample, degree_x, degree_y, degree_z):
|
89 |
+
for degree, axis in [(degree_x, "X"), (degree_y, "Y"), (degree_z, "Z")]:
|
90 |
+
if degree != 0:
|
91 |
+
sample = RotateAxisAngle(degree, axis=axis).to(sample.device).transform_points(sample)
|
92 |
+
return sample
|
93 |
+
|
94 |
+
|
95 |
+
def get_grid(B, device, co3d_world_size, granularity):
|
96 |
+
N = int(np.ceil(2 * co3d_world_size / granularity))
|
97 |
+
grid_unseen_xyz = torch.zeros((N, N, N, 3), device=device)
|
98 |
+
for i in range(N):
|
99 |
+
grid_unseen_xyz[i, :, :, 0] = i
|
100 |
+
for j in range(N):
|
101 |
+
grid_unseen_xyz[:, j, :, 1] = j
|
102 |
+
for k in range(N):
|
103 |
+
grid_unseen_xyz[:, :, k, 2] = k
|
104 |
+
grid_unseen_xyz -= (N / 2.0)
|
105 |
+
grid_unseen_xyz /= (N / 2.0) / co3d_world_size
|
106 |
+
grid_unseen_xyz = grid_unseen_xyz.reshape((1, -1, 3)).repeat(B, 1, 1)
|
107 |
+
return grid_unseen_xyz
|
108 |
+
|
109 |
+
|
110 |
+
def run_viz(model, data_loader, device, args, epoch):
|
111 |
+
epoch_start_time = time.time()
|
112 |
+
model.eval()
|
113 |
+
os.system(f'mkdir {args.job_dir}/viz')
|
114 |
+
|
115 |
+
print('Visualization data_loader length:', len(data_loader))
|
116 |
+
dataset = data_loader.dataset
|
117 |
+
for sample_idx, samples in enumerate(data_loader):
|
118 |
+
if sample_idx >= args.max_n_viz_obj:
|
119 |
+
break
|
120 |
+
seen_xyz, valid_seen_xyz, unseen_xyz, unseen_rgb, labels, seen_images = prepare_data(samples, device, is_train=False, args=args, is_viz=True)
|
121 |
+
|
122 |
+
pred_occupy = []
|
123 |
+
pred_colors = []
|
124 |
+
(model.module if hasattr(model, "module") else model).clear_cache()
|
125 |
+
|
126 |
+
# don't forward all at once to avoid oom
|
127 |
+
max_n_queries_fwd = 2000
|
128 |
+
|
129 |
+
total_n_passes = int(np.ceil(unseen_xyz.shape[1] / max_n_queries_fwd))
|
130 |
+
for p_idx in range(total_n_passes):
|
131 |
+
p_start = p_idx * max_n_queries_fwd
|
132 |
+
p_end = (p_idx + 1) * max_n_queries_fwd
|
133 |
+
cur_unseen_xyz = unseen_xyz[:, p_start:p_end]
|
134 |
+
cur_unseen_rgb = unseen_rgb[:, p_start:p_end].zero_()
|
135 |
+
cur_labels = labels[:, p_start:p_end].zero_()
|
136 |
+
|
137 |
+
with torch.no_grad():
|
138 |
+
_, pred, = model(
|
139 |
+
seen_images=seen_images,
|
140 |
+
seen_xyz=seen_xyz,
|
141 |
+
unseen_xyz=cur_unseen_xyz,
|
142 |
+
unseen_rgb=cur_unseen_rgb,
|
143 |
+
unseen_occupy=cur_labels,
|
144 |
+
cache_enc=args.run_viz,
|
145 |
+
valid_seen_xyz=valid_seen_xyz,
|
146 |
+
)
|
147 |
+
|
148 |
+
cur_occupy_out = pred[..., 0]
|
149 |
+
|
150 |
+
if args.regress_color:
|
151 |
+
cur_color_out = pred[..., 1:].reshape((-1, 3))
|
152 |
+
else:
|
153 |
+
cur_color_out = pred[..., 1:].reshape((-1, 3, 256)).max(dim=2)[1] / 255.0
|
154 |
+
pred_occupy.append(cur_occupy_out)
|
155 |
+
pred_colors.append(cur_color_out)
|
156 |
+
|
157 |
+
rank = misc.get_rank()
|
158 |
+
prefix = f'{args.job_dir}/viz/' + dataset.dataset_split + f'_ep{epoch}_rank{rank}_i{sample_idx}'
|
159 |
+
|
160 |
+
img = (seen_images[0].permute(1, 2, 0) * 255).cpu().numpy().copy().astype(np.uint8)
|
161 |
+
|
162 |
+
gt_xyz = samples[1][0].to(device).reshape(-1, 3)
|
163 |
+
gt_rgb = samples[1][1].to(device).reshape(-1, 3)
|
164 |
+
mesh_xyz = samples[2].to(device).reshape(-1, 3) if args.use_hypersim else None
|
165 |
+
|
166 |
+
with open(prefix + '.html', 'a') as f:
|
167 |
+
generate_html(
|
168 |
+
img,
|
169 |
+
seen_xyz, seen_images,
|
170 |
+
torch.cat(pred_occupy, dim=1),
|
171 |
+
torch.cat(pred_colors, dim=0),
|
172 |
+
unseen_xyz,
|
173 |
+
f,
|
174 |
+
gt_xyz=gt_xyz,
|
175 |
+
gt_rgb=gt_rgb,
|
176 |
+
mesh_xyz=mesh_xyz,
|
177 |
+
)
|
178 |
+
print("Visualization epoch time:", time.time() - epoch_start_time)
|
179 |
+
|
180 |
+
|
181 |
+
def get_f1(precision, recall):
|
182 |
+
if (precision + recall) == 0:
|
183 |
+
return 0.0
|
184 |
+
return 2.0 * precision * recall / (precision + recall)
|
185 |
+
|
186 |
+
|
187 |
+
def generate_plot(img, seen_xyz, seen_rgb, pred_occ, pred_rgb, unseen_xyz,
|
188 |
+
gt_xyz=None, gt_rgb=None, mesh_xyz=None, score_thresholds=[0.1, 0.3, 0.5, 0.7, 0.9],
|
189 |
+
pointcloud_marker_size=2,
|
190 |
+
):
|
191 |
+
# if img is not None:
|
192 |
+
# fig = plt.figure()
|
193 |
+
# plt.imshow(img)
|
194 |
+
# tmpfile = BytesIO()
|
195 |
+
# fig.savefig(tmpfile, format='jpg')
|
196 |
+
# encoded = base64.b64encode(tmpfile.getvalue()).decode('utf-8')
|
197 |
+
|
198 |
+
# html = '<img src=\'data:image/png;base64,{}\'>'.format(encoded)
|
199 |
+
# f.write(html)
|
200 |
+
# plt.close()
|
201 |
+
|
202 |
+
clouds = {"MCC Output": {}}
|
203 |
+
# Seen
|
204 |
+
if seen_xyz is not None:
|
205 |
+
seen_xyz = seen_xyz.reshape((-1, 3)).cpu()
|
206 |
+
seen_rgb = torch.nn.functional.interpolate(seen_rgb, (112, 112)).permute(0, 2, 3, 1).reshape((-1, 3)).cpu()
|
207 |
+
good_seen = seen_xyz[:, 0] != -100
|
208 |
+
|
209 |
+
seen_pc = Pointclouds(
|
210 |
+
points=seen_xyz[good_seen][None],
|
211 |
+
features=seen_rgb[good_seen][None],
|
212 |
+
)
|
213 |
+
clouds["MCC Output"]["seen"] = seen_pc
|
214 |
+
|
215 |
+
# GT points
|
216 |
+
if gt_xyz is not None:
|
217 |
+
subset_gt = random.sample(range(gt_xyz.shape[0]), 10000)
|
218 |
+
gt_pc = Pointclouds(
|
219 |
+
points=gt_xyz[subset_gt][None],
|
220 |
+
features=gt_rgb[subset_gt][None],
|
221 |
+
)
|
222 |
+
clouds["MCC Output"]["GT points"] = gt_pc
|
223 |
+
|
224 |
+
# GT meshes
|
225 |
+
if mesh_xyz is not None:
|
226 |
+
subset_mesh = random.sample(range(mesh_xyz.shape[0]), 10000)
|
227 |
+
mesh_pc = Pointclouds(
|
228 |
+
points=mesh_xyz[subset_mesh][None],
|
229 |
+
)
|
230 |
+
clouds["MCC Output"]["GT mesh"] = mesh_pc
|
231 |
+
|
232 |
+
pred_occ = torch.nn.Sigmoid()(pred_occ).cpu()
|
233 |
+
for t in score_thresholds:
|
234 |
+
pos = pred_occ > t
|
235 |
+
|
236 |
+
points = unseen_xyz[pos].reshape((-1, 3))
|
237 |
+
features = pred_rgb[None][pos].reshape((-1, 3))
|
238 |
+
good_points = points[:, 0] != -100
|
239 |
+
|
240 |
+
if good_points.sum() == 0:
|
241 |
+
continue
|
242 |
+
|
243 |
+
pc = Pointclouds(
|
244 |
+
points=points[good_points][None].cpu(),
|
245 |
+
features=features[good_points][None].cpu(),
|
246 |
+
)
|
247 |
+
|
248 |
+
clouds["MCC Output"][f"pred_{t}"] = pc
|
249 |
+
IO().save_pointcloud(pc, "output_pointcloud.ply")
|
250 |
+
|
251 |
+
plt.figure()
|
252 |
+
try:
|
253 |
+
fig = plot_scene(clouds, pointcloud_marker_size=pointcloud_marker_size, pointcloud_max_points=20000 * 2)
|
254 |
+
fig.update_layout(height=1000, width=1000)
|
255 |
+
return fig
|
256 |
+
except Exception as e:
|
257 |
+
print('writing failed', e)
|
258 |
+
try:
|
259 |
+
plt.close()
|
260 |
+
except:
|
261 |
+
pass
|
262 |
+
|
263 |
+
|
264 |
+
def generate_html(img, seen_xyz, seen_rgb, pred_occ, pred_rgb, unseen_xyz, f,
|
265 |
+
gt_xyz=None, gt_rgb=None, mesh_xyz=None, score_thresholds=[0.1, 0.3, 0.5, 0.7, 0.9],
|
266 |
+
pointcloud_marker_size=2,
|
267 |
+
):
|
268 |
+
if img is not None:
|
269 |
+
fig = plt.figure()
|
270 |
+
plt.imshow(img)
|
271 |
+
tmpfile = BytesIO()
|
272 |
+
fig.savefig(tmpfile, format='jpg')
|
273 |
+
encoded = base64.b64encode(tmpfile.getvalue()).decode('utf-8')
|
274 |
+
|
275 |
+
html = '<img src=\'data:image/png;base64,{}\'>'.format(encoded)
|
276 |
+
f.write(html)
|
277 |
+
plt.close()
|
278 |
+
|
279 |
+
clouds = {"MCC Output": {}}
|
280 |
+
# Seen
|
281 |
+
if seen_xyz is not None:
|
282 |
+
seen_xyz = seen_xyz.reshape((-1, 3)).cpu()
|
283 |
+
seen_rgb = torch.nn.functional.interpolate(seen_rgb, (112, 112)).permute(0, 2, 3, 1).reshape((-1, 3)).cpu()
|
284 |
+
good_seen = seen_xyz[:, 0] != -100
|
285 |
+
|
286 |
+
seen_pc = Pointclouds(
|
287 |
+
points=seen_xyz[good_seen][None],
|
288 |
+
features=seen_rgb[good_seen][None],
|
289 |
+
)
|
290 |
+
clouds["MCC Output"]["seen"] = seen_pc
|
291 |
+
|
292 |
+
# GT points
|
293 |
+
if gt_xyz is not None:
|
294 |
+
subset_gt = random.sample(range(gt_xyz.shape[0]), 10000)
|
295 |
+
gt_pc = Pointclouds(
|
296 |
+
points=gt_xyz[subset_gt][None],
|
297 |
+
features=gt_rgb[subset_gt][None],
|
298 |
+
)
|
299 |
+
clouds["MCC Output"]["GT points"] = gt_pc
|
300 |
+
|
301 |
+
# GT meshes
|
302 |
+
if mesh_xyz is not None:
|
303 |
+
subset_mesh = random.sample(range(mesh_xyz.shape[0]), 10000)
|
304 |
+
mesh_pc = Pointclouds(
|
305 |
+
points=mesh_xyz[subset_mesh][None],
|
306 |
+
)
|
307 |
+
clouds["MCC Output"]["GT mesh"] = mesh_pc
|
308 |
+
|
309 |
+
pred_occ = torch.nn.Sigmoid()(pred_occ).cpu()
|
310 |
+
for t in score_thresholds:
|
311 |
+
pos = pred_occ > t
|
312 |
+
|
313 |
+
points = unseen_xyz[pos].reshape((-1, 3))
|
314 |
+
features = pred_rgb[None][pos].reshape((-1, 3))
|
315 |
+
good_points = points[:, 0] != -100
|
316 |
+
|
317 |
+
if good_points.sum() == 0:
|
318 |
+
continue
|
319 |
+
|
320 |
+
pc = Pointclouds(
|
321 |
+
points=points[good_points][None].cpu(),
|
322 |
+
features=features[good_points][None].cpu(),
|
323 |
+
)
|
324 |
+
|
325 |
+
clouds["MCC Output"][f"pred_{t}"] = pc
|
326 |
+
|
327 |
+
plt.figure()
|
328 |
+
try:
|
329 |
+
fig = plot_scene(clouds, pointcloud_marker_size=pointcloud_marker_size, pointcloud_max_points=20000 * 2)
|
330 |
+
fig.update_layout(height=1000, width=1000)
|
331 |
+
html_string = fig.to_html(full_html=False, include_plotlyjs="cnd")
|
332 |
+
f.write(html_string)
|
333 |
+
return fig, plt
|
334 |
+
except Exception as e:
|
335 |
+
print('writing failed', e)
|
336 |
+
try:
|
337 |
+
plt.close()
|
338 |
+
except:
|
339 |
+
pass
|
340 |
+
|
341 |
+
|
342 |
+
def train_one_epoch(model: torch.nn.Module,
|
343 |
+
data_loader: Iterable, optimizer: torch.optim.Optimizer,
|
344 |
+
device: torch.device, epoch: int, loss_scaler,
|
345 |
+
args=None):
|
346 |
+
epoch_start_time = time.time()
|
347 |
+
model.train(True)
|
348 |
+
metric_logger = misc.MetricLogger(delimiter=" ")
|
349 |
+
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
350 |
+
|
351 |
+
accum_iter = args.accum_iter
|
352 |
+
|
353 |
+
optimizer.zero_grad()
|
354 |
+
|
355 |
+
print('Training data_loader length:', len(data_loader))
|
356 |
+
for data_iter_step, samples in enumerate(data_loader):
|
357 |
+
# we use a per iteration (instead of per epoch) lr scheduler
|
358 |
+
if data_iter_step % accum_iter == 0:
|
359 |
+
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
|
360 |
+
seen_xyz, valid_seen_xyz, unseen_xyz, unseen_rgb, labels, seen_images = prepare_data(samples, device, is_train=True, args=args)
|
361 |
+
|
362 |
+
with torch.cuda.amp.autocast():
|
363 |
+
loss, _ = model(
|
364 |
+
seen_images=seen_images,
|
365 |
+
seen_xyz=seen_xyz,
|
366 |
+
unseen_xyz=unseen_xyz,
|
367 |
+
unseen_rgb=unseen_rgb,
|
368 |
+
unseen_occupy=labels,
|
369 |
+
valid_seen_xyz=valid_seen_xyz,
|
370 |
+
)
|
371 |
+
|
372 |
+
loss_value = loss.item()
|
373 |
+
if not math.isfinite(loss_value):
|
374 |
+
print("Warning: Loss is {}".format(loss_value))
|
375 |
+
loss *= 0.0
|
376 |
+
loss_value = 100.0
|
377 |
+
|
378 |
+
loss /= accum_iter
|
379 |
+
loss_scaler(loss, optimizer, parameters=model.parameters(),
|
380 |
+
clip_grad=args.clip_grad,
|
381 |
+
update_grad=(data_iter_step + 1) % accum_iter == 0,
|
382 |
+
verbose=(data_iter_step % 100) == 0)
|
383 |
+
|
384 |
+
if (data_iter_step + 1) % accum_iter == 0:
|
385 |
+
optimizer.zero_grad()
|
386 |
+
|
387 |
+
torch.cuda.synchronize()
|
388 |
+
|
389 |
+
metric_logger.update(loss=loss_value)
|
390 |
+
|
391 |
+
lr = optimizer.param_groups[0]["lr"]
|
392 |
+
metric_logger.update(lr=lr)
|
393 |
+
|
394 |
+
if data_iter_step == 30:
|
395 |
+
os.system('nvidia-smi')
|
396 |
+
os.system('free -g')
|
397 |
+
if args.debug and data_iter_step == 5:
|
398 |
+
break
|
399 |
+
|
400 |
+
# gather the stats from all processes
|
401 |
+
metric_logger.synchronize_between_processes()
|
402 |
+
print("Averaged stats:", metric_logger)
|
403 |
+
print("Training epoch time:", time.time() - epoch_start_time)
|
404 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
405 |
+
|
406 |
+
|
407 |
+
def eval_one_epoch(
|
408 |
+
model: torch.nn.Module,
|
409 |
+
data_loader: Iterable,
|
410 |
+
device: torch.device,
|
411 |
+
args=None
|
412 |
+
):
|
413 |
+
epoch_start_time = time.time()
|
414 |
+
model.train(False)
|
415 |
+
|
416 |
+
metric_logger = misc.MetricLogger(delimiter=" ")
|
417 |
+
|
418 |
+
print('Eval len(data_loader):', len(data_loader))
|
419 |
+
|
420 |
+
for data_iter_step, samples in enumerate(data_loader):
|
421 |
+
seen_xyz, valid_seen_xyz, unseen_xyz, unseen_rgb, labels, seen_images = prepare_data(samples, device, is_train=False, args=args)
|
422 |
+
|
423 |
+
# don't forward all at once to avoid oom
|
424 |
+
max_n_queries_fwd = 5000
|
425 |
+
all_loss, all_preds = [], []
|
426 |
+
for p_idx in range(int(np.ceil(unseen_xyz.shape[1] / max_n_queries_fwd))):
|
427 |
+
p_start = p_idx * max_n_queries_fwd
|
428 |
+
p_end = (p_idx + 1) * max_n_queries_fwd
|
429 |
+
cur_unseen_xyz = unseen_xyz[:, p_start:p_end]
|
430 |
+
cur_unseen_rgb = unseen_rgb[:, p_start:p_end]
|
431 |
+
cur_labels = labels[:, p_start:p_end]
|
432 |
+
|
433 |
+
with torch.no_grad():
|
434 |
+
loss, pred = model(
|
435 |
+
seen_images=seen_images,
|
436 |
+
seen_xyz=seen_xyz,
|
437 |
+
unseen_xyz=cur_unseen_xyz,
|
438 |
+
unseen_rgb=cur_unseen_rgb,
|
439 |
+
unseen_occupy=cur_labels,
|
440 |
+
valid_seen_xyz=valid_seen_xyz,
|
441 |
+
)
|
442 |
+
all_loss.append(loss)
|
443 |
+
all_preds.append(pred)
|
444 |
+
|
445 |
+
loss = sum(all_loss) / len(all_loss)
|
446 |
+
pred = torch.cat(all_preds, dim=1)
|
447 |
+
|
448 |
+
B = pred.shape[0]
|
449 |
+
|
450 |
+
gt_xyz = samples[1][0].to(device).reshape((B, -1, 3))
|
451 |
+
if args.use_hypersim:
|
452 |
+
mesh_xyz = samples[2].to(device).reshape((B, -1, 3))
|
453 |
+
|
454 |
+
s_thres = args.eval_score_threshold
|
455 |
+
d_thres = args.eval_dist_threshold
|
456 |
+
|
457 |
+
for b_idx in range(B):
|
458 |
+
geometry_metrics = {}
|
459 |
+
predicted_idx = torch.nn.Sigmoid()(pred[b_idx, :, 0]) > s_thres
|
460 |
+
predicted_xyz = unseen_xyz[b_idx, predicted_idx]
|
461 |
+
|
462 |
+
precision, recall, f1 = evaluate_points(predicted_xyz, gt_xyz[b_idx], d_thres)
|
463 |
+
geometry_metrics[f'd{d_thres}_s{s_thres}_point_pr'] = precision
|
464 |
+
geometry_metrics[f'd{d_thres}_s{s_thres}_point_rc'] = recall
|
465 |
+
geometry_metrics[f'd{d_thres}_s{s_thres}_point_f1'] = f1
|
466 |
+
|
467 |
+
if args.use_hypersim:
|
468 |
+
precision, recall, f1 = evaluate_points(predicted_xyz, mesh_xyz[b_idx], d_thres)
|
469 |
+
geometry_metrics[f'd{d_thres}_s{s_thres}_mesh_pr'] = precision
|
470 |
+
geometry_metrics[f'd{d_thres}_s{s_thres}_mesh_rc'] = recall
|
471 |
+
geometry_metrics[f'd{d_thres}_s{s_thres}_mesh_f1'] = f1
|
472 |
+
|
473 |
+
metric_logger.update(**geometry_metrics)
|
474 |
+
|
475 |
+
loss_value = loss.item()
|
476 |
+
|
477 |
+
torch.cuda.synchronize()
|
478 |
+
metric_logger.update(loss=loss_value)
|
479 |
+
|
480 |
+
if args.debug and data_iter_step == 5:
|
481 |
+
break
|
482 |
+
|
483 |
+
metric_logger.synchronize_between_processes()
|
484 |
+
print("Validation averaged stats:", metric_logger)
|
485 |
+
print("Val epoch time:", time.time() - epoch_start_time)
|
486 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
487 |
+
|
488 |
+
|
489 |
+
def sample_uniform_semisphere(B, N, semisphere_size, device):
|
490 |
+
for _ in range(100):
|
491 |
+
points = torch.empty(B * N * 3, 3, device=device).uniform_(-semisphere_size, semisphere_size)
|
492 |
+
points[..., 2] = points[..., 2].abs()
|
493 |
+
dist = (points ** 2.0).sum(axis=-1) ** 0.5
|
494 |
+
if (dist < semisphere_size).sum() >= B * N:
|
495 |
+
return points[dist < semisphere_size][:B * N].reshape((B, N, 3))
|
496 |
+
else:
|
497 |
+
print('resampling sphere')
|
498 |
+
|
499 |
+
|
500 |
+
def get_grid_semisphere(B, granularity, semisphere_size, device):
|
501 |
+
n_grid_pts = int(semisphere_size / granularity) * 2 + 1
|
502 |
+
grid_unseen_xyz = torch.zeros((n_grid_pts, n_grid_pts, n_grid_pts // 2 + 1, 3), device=device)
|
503 |
+
for i in range(n_grid_pts):
|
504 |
+
grid_unseen_xyz[i, :, :, 0] = i
|
505 |
+
grid_unseen_xyz[:, i, :, 1] = i
|
506 |
+
for i in range(n_grid_pts // 2 + 1):
|
507 |
+
grid_unseen_xyz[:, :, i, 2] = i
|
508 |
+
grid_unseen_xyz[..., :2] -= (n_grid_pts // 2.0)
|
509 |
+
grid_unseen_xyz *= granularity
|
510 |
+
dist = (grid_unseen_xyz ** 2.0).sum(axis=-1) ** 0.5
|
511 |
+
grid_unseen_xyz = grid_unseen_xyz[dist <= semisphere_size]
|
512 |
+
return grid_unseen_xyz[None].repeat(B, 1, 1)
|
513 |
+
|
514 |
+
|
515 |
+
def get_min_dist(a, b, slice_size=1000):
|
516 |
+
all_min, all_idx = [], []
|
517 |
+
for i in range(int(np.ceil(a.shape[1] / slice_size))):
|
518 |
+
start = slice_size * i
|
519 |
+
end = slice_size * (i + 1)
|
520 |
+
# B, n_queries, n_gt
|
521 |
+
dist = ((a[:, start:end] - b) ** 2.0).sum(axis=-1) ** 0.5
|
522 |
+
# B, n_queries
|
523 |
+
cur_min, cur_idx = dist.min(axis=2)
|
524 |
+
all_min.append(cur_min)
|
525 |
+
all_idx.append(cur_idx)
|
526 |
+
return torch.cat(all_min, dim=1), torch.cat(all_idx, dim=1)
|
527 |
+
|
528 |
+
|
529 |
+
def construct_uniform_semisphere(gt_xyz, gt_rgb, semisphere_size, n_queries, dist_threshold, is_train, granularity):
|
530 |
+
B = gt_xyz.shape[0]
|
531 |
+
device = gt_xyz.device
|
532 |
+
if is_train:
|
533 |
+
unseen_xyz = sample_uniform_semisphere(B, n_queries, semisphere_size, device)
|
534 |
+
else:
|
535 |
+
unseen_xyz = get_grid_semisphere(B, granularity, semisphere_size, device)
|
536 |
+
dist, idx_to_gt = get_min_dist(unseen_xyz[:, :, None], gt_xyz[:, None])
|
537 |
+
labels = dist < dist_threshold
|
538 |
+
unseen_rgb = torch.zeros_like(unseen_xyz)
|
539 |
+
unseen_rgb[labels] = torch.gather(gt_rgb, 1, idx_to_gt.unsqueeze(-1).repeat(1, 1, 3))[labels]
|
540 |
+
return unseen_xyz, unseen_rgb, labels.float()
|
541 |
+
|
542 |
+
|
543 |
+
def construct_uniform_grid(gt_xyz, gt_rgb, co3d_world_size, n_queries, dist_threshold, is_train, granularity):
|
544 |
+
B = gt_xyz.shape[0]
|
545 |
+
device = gt_xyz.device
|
546 |
+
if is_train:
|
547 |
+
unseen_xyz = torch.empty((B, n_queries, 3), device=device).uniform_(-co3d_world_size, co3d_world_size)
|
548 |
+
else:
|
549 |
+
unseen_xyz = get_grid(B, device, co3d_world_size, granularity)
|
550 |
+
dist, idx_to_gt = get_min_dist(unseen_xyz[:, :, None], gt_xyz[:, None])
|
551 |
+
labels = dist < dist_threshold
|
552 |
+
unseen_rgb = torch.zeros_like(unseen_xyz)
|
553 |
+
unseen_rgb[labels] = torch.gather(gt_rgb, 1, idx_to_gt.unsqueeze(-1).repeat(1, 1, 3))[labels]
|
554 |
+
return unseen_xyz, unseen_rgb, labels.float()
|
555 |
+
|
556 |
+
|
557 |
+
def prepare_data(samples, device, is_train, args, is_viz=False):
|
558 |
+
# Seen
|
559 |
+
seen_xyz, seen_rgb = samples[0][0].to(device), samples[0][1].to(device)
|
560 |
+
valid_seen_xyz = torch.isfinite(seen_xyz.sum(axis=-1))
|
561 |
+
seen_xyz[~valid_seen_xyz] = -100
|
562 |
+
B = seen_xyz.shape[0]
|
563 |
+
# Gt
|
564 |
+
gt_xyz, gt_rgb = samples[1][0].to(device).reshape(B, -1, 3), samples[1][1].to(device).reshape(B, -1, 3)
|
565 |
+
|
566 |
+
sampling_func = construct_uniform_semisphere if args.use_hypersim else construct_uniform_grid
|
567 |
+
unseen_xyz, unseen_rgb, labels = sampling_func(
|
568 |
+
gt_xyz, gt_rgb,
|
569 |
+
args.semisphere_size if args.use_hypersim else args.co3d_world_size,
|
570 |
+
args.n_queries,
|
571 |
+
args.train_dist_threshold,
|
572 |
+
is_train,
|
573 |
+
args.viz_granularity if is_viz else args.eval_granularity,
|
574 |
+
)
|
575 |
+
|
576 |
+
if is_train:
|
577 |
+
seen_xyz, unseen_xyz = aug_xyz(seen_xyz, unseen_xyz, args, is_train=is_train)
|
578 |
+
|
579 |
+
# Random Flip
|
580 |
+
if random.random() < 0.5:
|
581 |
+
seen_xyz[..., 0] *= -1
|
582 |
+
unseen_xyz[..., 0] *= -1
|
583 |
+
seen_xyz = torch.flip(seen_xyz, [2])
|
584 |
+
valid_seen_xyz = torch.flip(valid_seen_xyz, [2])
|
585 |
+
seen_rgb = torch.flip(seen_rgb, [3])
|
586 |
+
|
587 |
+
return seen_xyz, valid_seen_xyz, unseen_xyz, unseen_rgb, labels, seen_rgb
|
main_mcc.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# References:
|
8 |
+
# DeiT: https://github.com/facebookresearch/deit
|
9 |
+
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
|
10 |
+
# MAE: https://github.com/facebookresearch/mae
|
11 |
+
# --------------------------------------------------------
|
12 |
+
import argparse
|
13 |
+
import datetime
|
14 |
+
import json
|
15 |
+
import numpy as np
|
16 |
+
import os
|
17 |
+
import time
|
18 |
+
from pathlib import Path
|
19 |
+
|
20 |
+
import torch
|
21 |
+
import torch.backends.cudnn as cudnn
|
22 |
+
import timm.optim.optim_factory as optim_factory
|
23 |
+
|
24 |
+
import util.misc as misc
|
25 |
+
import mcc_model
|
26 |
+
from util.misc import NativeScalerWithGradNormCount as NativeScaler
|
27 |
+
from util.hypersim_dataset import HyperSimDataset, hypersim_collate_fn
|
28 |
+
from util.co3d_dataset import CO3DV2Dataset, co3dv2_collate_fn
|
29 |
+
from engine_mcc import train_one_epoch, run_viz, eval_one_epoch
|
30 |
+
from util.co3d_utils import get_all_dataset_maps
|
31 |
+
|
32 |
+
|
33 |
+
def get_args_parser():
|
34 |
+
parser = argparse.ArgumentParser('MCC', add_help=False)
|
35 |
+
|
36 |
+
# Model
|
37 |
+
parser.add_argument('--input_size', default=224, type=int,
|
38 |
+
help='Images input size')
|
39 |
+
parser.add_argument('--occupancy_weight', default=1.0, type=float,
|
40 |
+
help='A constant to weight the occupancy loss')
|
41 |
+
parser.add_argument('--rgb_weight', default=0.01, type=float,
|
42 |
+
help='A constant to weight the color prediction loss')
|
43 |
+
parser.add_argument('--n_queries', default=550, type=int,
|
44 |
+
help='Number of queries used in decoder.')
|
45 |
+
parser.add_argument('--drop_path', default=0.1, type=float,
|
46 |
+
help='drop_path probability')
|
47 |
+
parser.add_argument('--regress_color', action='store_true',
|
48 |
+
help='If true, regress color with MSE. Otherwise, 256-way classification for each channel.')
|
49 |
+
|
50 |
+
# Training
|
51 |
+
parser.add_argument('--batch_size', default=16, type=int,
|
52 |
+
help='Batch size per GPU for training (effective batch size is batch_size * accum_iter * # gpus')
|
53 |
+
parser.add_argument('--eval_batch_size', default=2, type=int,
|
54 |
+
help='Batch size per GPU for evaluation (effective batch size is batch_size * accum_iter * # gpus')
|
55 |
+
parser.add_argument('--epochs', default=100, type=int)
|
56 |
+
parser.add_argument('--accum_iter', default=1, type=int,
|
57 |
+
help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
|
58 |
+
parser.add_argument('--weight_decay', type=float, default=0.05,
|
59 |
+
help='Weight decay (default: 0.05)')
|
60 |
+
parser.add_argument('--lr', type=float, default=None, metavar='LR',
|
61 |
+
help='Learning rate (absolute lr)')
|
62 |
+
parser.add_argument('--blr', type=float, default=1e-4, metavar='LR',
|
63 |
+
help='Base learning rate: absolute_lr = base_lr * total_batch_size / 512')
|
64 |
+
parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
|
65 |
+
help='Lower lr bound for cyclic schedulers that hit 0')
|
66 |
+
parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',
|
67 |
+
help='Epochs to warmup LR')
|
68 |
+
parser.add_argument('--clip_grad', type=float, default=1.0,
|
69 |
+
help='Clip gradient at the specified norm')
|
70 |
+
|
71 |
+
# Job
|
72 |
+
parser.add_argument('--job_dir', default='',
|
73 |
+
help='Path to where to save, empty for no saving')
|
74 |
+
parser.add_argument('--output_dir', default='./output_dir',
|
75 |
+
help='Path to where to save, empty for no saving')
|
76 |
+
parser.add_argument('--device', default='cuda',
|
77 |
+
help='Device to use for training / testing')
|
78 |
+
parser.add_argument('--seed', default=0, type=int,
|
79 |
+
help='Random seed.')
|
80 |
+
parser.add_argument('--resume', default='weights/co3dv2_all_categories.pth',
|
81 |
+
help='Resume from checkpoint')
|
82 |
+
|
83 |
+
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
|
84 |
+
help='Start epoch')
|
85 |
+
parser.add_argument('--num_workers', default=4, type=int,
|
86 |
+
help='Number of workers for training data loader')
|
87 |
+
parser.add_argument('--num_eval_workers', default=4, type=int,
|
88 |
+
help='Number of workers for evaluation data loader')
|
89 |
+
parser.add_argument('--pin_mem', action='store_true',
|
90 |
+
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
91 |
+
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
|
92 |
+
parser.set_defaults(pin_mem=True)
|
93 |
+
|
94 |
+
# Distributed training
|
95 |
+
parser.add_argument('--world_size', default=1, type=int,
|
96 |
+
help='Number of distributed processes')
|
97 |
+
parser.add_argument('--local_rank', default=-1, type=int)
|
98 |
+
parser.add_argument('--dist_on_itp', action='store_true')
|
99 |
+
parser.add_argument('--dist_url', default='env://',
|
100 |
+
help='Url used to set up distributed training')
|
101 |
+
|
102 |
+
# Experiments
|
103 |
+
parser.add_argument('--debug', action='store_true')
|
104 |
+
parser.add_argument('--run_viz', action='store_true',
|
105 |
+
help='Specify to run only the visualization/inference given a trained model.')
|
106 |
+
parser.add_argument('--max_n_viz_obj', default=64, type=int,
|
107 |
+
help='Max number of objects to visualize during training.')
|
108 |
+
|
109 |
+
# Data
|
110 |
+
parser.add_argument('--train_epoch_len_multiplier', default=32, type=int,
|
111 |
+
help='# examples per training epoch is # objects * train_epoch_len_multiplier')
|
112 |
+
parser.add_argument('--eval_epoch_len_multiplier', default=1, type=int,
|
113 |
+
help='# examples per eval epoch is # objects * eval_epoch_len_multiplier')
|
114 |
+
|
115 |
+
# CO3D
|
116 |
+
parser.add_argument('--co3d_path', type=str, default='co3d_data',
|
117 |
+
help='Path to CO3D v2 data.')
|
118 |
+
parser.add_argument('--holdout_categories', action='store_true',
|
119 |
+
help='If true, hold out 10 categories and train on only the remaining 41 categories.')
|
120 |
+
parser.add_argument('--co3d_world_size', default=3.0, type=float,
|
121 |
+
help='The world space we consider is \in [-co3d_world_size, co3d_world_size] in each dimension.')
|
122 |
+
|
123 |
+
# Hypersim
|
124 |
+
parser.add_argument('--use_hypersim', action='store_true',
|
125 |
+
help='If true, use hypersim, else, co3d.')
|
126 |
+
parser.add_argument('--hypersim_path', default="hypersim_data", type=str,
|
127 |
+
help="Path to Hypersim data.")
|
128 |
+
|
129 |
+
# Data aug
|
130 |
+
parser.add_argument('--random_scale_delta', default=0.2, type=float,
|
131 |
+
help='Random scaling each example by a scaler \in [1 - random_scale_delta, 1 + random_scale_delta].')
|
132 |
+
parser.add_argument('--random_shift', default=1.0, type=float,
|
133 |
+
help='Random shifting an example in each axis by an amount \in [-random_shift, random_shift]')
|
134 |
+
parser.add_argument('--random_rotate_degree', default=180, type=int,
|
135 |
+
help='Random rotation degrees.')
|
136 |
+
|
137 |
+
# Smapling, evaluation, and coordinate system
|
138 |
+
parser.add_argument('--shrink_threshold', default=10.0, type=float,
|
139 |
+
help='Any points with distance beyond this value will be shrunk.')
|
140 |
+
parser.add_argument('--semisphere_size', default=6.0, type=float,
|
141 |
+
help='The Hypersim task predicts points in a semisphere in front of the camera.'
|
142 |
+
'This value specifies the size of the semisphere.')
|
143 |
+
parser.add_argument('--eval_granularity', default=0.1, type=float,
|
144 |
+
help='Granularity of the evaluation points.')
|
145 |
+
parser.add_argument('--viz_granularity', default=0.1, type=float,
|
146 |
+
help='Granularity of points in visaulizatoin.')
|
147 |
+
|
148 |
+
parser.add_argument('--eval_score_threshold', default=0.1, type=float,
|
149 |
+
help='Score threshold for evaluation.')
|
150 |
+
parser.add_argument('--eval_dist_threshold', default=0.1, type=float,
|
151 |
+
help='Points closer than this amount to a groud-truth is considered correct.')
|
152 |
+
parser.add_argument('--train_dist_threshold', default=0.1, type=float,
|
153 |
+
help='Points closer than this amount is considered positive in training.')
|
154 |
+
return parser
|
155 |
+
|
156 |
+
|
157 |
+
def build_loader(args, num_tasks, global_rank, is_train, dataset_type, collate_fn, dataset_maps):
|
158 |
+
'''Build data loader'''
|
159 |
+
dataset = dataset_type(args, is_train=is_train, dataset_maps=dataset_maps)
|
160 |
+
|
161 |
+
sampler_train = torch.utils.data.DistributedSampler(
|
162 |
+
dataset, num_replicas=num_tasks, rank=global_rank, shuffle=is_train
|
163 |
+
)
|
164 |
+
|
165 |
+
data_loader = torch.utils.data.DataLoader(
|
166 |
+
dataset, batch_size=args.batch_size if is_train else args.eval_batch_size,
|
167 |
+
sampler=sampler_train,
|
168 |
+
num_workers=args.num_workers if is_train else args.num_eval_workers,
|
169 |
+
pin_memory=args.pin_mem,
|
170 |
+
collate_fn=collate_fn,
|
171 |
+
)
|
172 |
+
return data_loader
|
173 |
+
|
174 |
+
|
175 |
+
def main(args):
|
176 |
+
misc.init_distributed_mode(args)
|
177 |
+
|
178 |
+
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
|
179 |
+
print("{}".format(args).replace(', ', ',\n'))
|
180 |
+
|
181 |
+
device = torch.device(args.device)
|
182 |
+
|
183 |
+
# fix the seed for reproducibility
|
184 |
+
seed = args.seed + misc.get_rank()
|
185 |
+
torch.manual_seed(seed)
|
186 |
+
np.random.seed(seed)
|
187 |
+
|
188 |
+
cudnn.benchmark = True
|
189 |
+
num_tasks = misc.get_world_size()
|
190 |
+
global_rank = misc.get_rank()
|
191 |
+
|
192 |
+
# define the model
|
193 |
+
model = mcc_model.get_mcc_model(
|
194 |
+
rgb_weight=args.rgb_weight,
|
195 |
+
occupancy_weight=args.occupancy_weight,
|
196 |
+
args=args,
|
197 |
+
)
|
198 |
+
|
199 |
+
model.to(device)
|
200 |
+
|
201 |
+
model_without_ddp = model
|
202 |
+
print("Model = %s" % str(model_without_ddp))
|
203 |
+
|
204 |
+
eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
|
205 |
+
if args.lr is None: # only base_lr is specified
|
206 |
+
args.lr = args.blr * eff_batch_size / 512
|
207 |
+
|
208 |
+
print("base lr: %.2e" % (args.blr))
|
209 |
+
print("actual lr: %.2e" % args.lr)
|
210 |
+
|
211 |
+
print("accumulate grad iterations: %d" % args.accum_iter)
|
212 |
+
print("effective batch size: %d" % eff_batch_size)
|
213 |
+
|
214 |
+
if args.distributed:
|
215 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
|
216 |
+
model_without_ddp = model.module
|
217 |
+
|
218 |
+
# following timm: set wd as 0 for bias and norm layers
|
219 |
+
param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay)
|
220 |
+
optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
|
221 |
+
print(optimizer)
|
222 |
+
loss_scaler = NativeScaler()
|
223 |
+
|
224 |
+
misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
|
225 |
+
|
226 |
+
if args.use_hypersim:
|
227 |
+
dataset_type = HyperSimDataset
|
228 |
+
collate_fn = hypersim_collate_fn
|
229 |
+
dataset_maps = None
|
230 |
+
else:
|
231 |
+
dataset_type = CO3DV2Dataset
|
232 |
+
collate_fn = co3dv2_collate_fn
|
233 |
+
dataset_maps = get_all_dataset_maps(
|
234 |
+
args.co3d_path, args.holdout_categories,
|
235 |
+
)
|
236 |
+
|
237 |
+
dataset_viz = dataset_type(args, is_train=False, is_viz=True, dataset_maps=dataset_maps)
|
238 |
+
sampler_viz = torch.utils.data.DistributedSampler(
|
239 |
+
dataset_viz, num_replicas=num_tasks, rank=global_rank, shuffle=False
|
240 |
+
)
|
241 |
+
|
242 |
+
data_loader_viz = torch.utils.data.DataLoader(
|
243 |
+
dataset_viz, batch_size=1,
|
244 |
+
sampler=sampler_viz,
|
245 |
+
num_workers=args.num_eval_workers,
|
246 |
+
pin_memory=args.pin_mem,
|
247 |
+
collate_fn=collate_fn,
|
248 |
+
)
|
249 |
+
|
250 |
+
if args.run_viz:
|
251 |
+
run_viz(
|
252 |
+
model, data_loader_viz,
|
253 |
+
device, args=args, epoch=0,
|
254 |
+
)
|
255 |
+
exit()
|
256 |
+
|
257 |
+
data_loader_train, data_loader_val = [
|
258 |
+
build_loader(
|
259 |
+
args, num_tasks, global_rank,
|
260 |
+
is_train=is_train,
|
261 |
+
dataset_type=dataset_type, collate_fn=collate_fn, dataset_maps=dataset_maps
|
262 |
+
) for is_train in [True, False]
|
263 |
+
]
|
264 |
+
|
265 |
+
print(f"Start training for {args.epochs} epochs")
|
266 |
+
start_time = time.time()
|
267 |
+
for epoch in range(args.start_epoch, args.epochs):
|
268 |
+
print(f'Epoch {epoch}:')
|
269 |
+
if args.distributed:
|
270 |
+
data_loader_train.sampler.set_epoch(epoch)
|
271 |
+
train_stats = train_one_epoch(
|
272 |
+
model, data_loader_train,
|
273 |
+
optimizer, device, epoch, loss_scaler,
|
274 |
+
args=args,
|
275 |
+
)
|
276 |
+
|
277 |
+
val_stats = {}
|
278 |
+
if (epoch % 5 == 4 or epoch + 1 == args.epochs) or args.debug:
|
279 |
+
val_stats = eval_one_epoch(
|
280 |
+
model, data_loader_val,
|
281 |
+
device, args=args,
|
282 |
+
)
|
283 |
+
|
284 |
+
if ((epoch % 10 == 9 or epoch + 1 == args.epochs) or args.debug):
|
285 |
+
run_viz(
|
286 |
+
model, data_loader_viz,
|
287 |
+
device, args=args, epoch=epoch,
|
288 |
+
)
|
289 |
+
|
290 |
+
if args.output_dir and (epoch % 10 == 9 or epoch + 1 == args.epochs):
|
291 |
+
misc.save_model(
|
292 |
+
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
|
293 |
+
loss_scaler=loss_scaler, epoch=epoch)
|
294 |
+
|
295 |
+
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
296 |
+
**{f'val_{k}': v for k, v in val_stats.items()},
|
297 |
+
'epoch': epoch,}
|
298 |
+
|
299 |
+
if args.output_dir and misc.is_main_process():
|
300 |
+
with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
|
301 |
+
f.write(json.dumps(log_stats) + "\n")
|
302 |
+
|
303 |
+
run_viz(
|
304 |
+
model, data_loader_viz,
|
305 |
+
device, args=args, epoch=-1,
|
306 |
+
)
|
307 |
+
|
308 |
+
total_time = time.time() - start_time
|
309 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
310 |
+
print('Training time {}'.format(total_time_str))
|
311 |
+
|
312 |
+
|
313 |
+
if __name__ == '__main__':
|
314 |
+
|
315 |
+
args = get_args_parser()
|
316 |
+
args = args.parse_args()
|
317 |
+
|
318 |
+
if args.output_dir:
|
319 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
320 |
+
|
321 |
+
main(args)
|
322 |
+
|
mcc_model.py
ADDED
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# References:
|
8 |
+
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
9 |
+
# DeiT: https://github.com/facebookresearch/deit
|
10 |
+
# MAE: https://github.com/facebookresearch/mae
|
11 |
+
# --------------------------------------------------------
|
12 |
+
|
13 |
+
from functools import partial
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
|
19 |
+
from timm.models.vision_transformer import PatchEmbed, Block, Mlp, DropPath
|
20 |
+
|
21 |
+
from util.pos_embed import get_2d_sincos_pos_embed
|
22 |
+
|
23 |
+
class MCCDecoderAttention(nn.Module):
|
24 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., args=None):
|
25 |
+
super().__init__()
|
26 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
27 |
+
self.num_heads = num_heads
|
28 |
+
head_dim = dim // num_heads
|
29 |
+
self.scale = head_dim ** -0.5
|
30 |
+
|
31 |
+
self.args = args
|
32 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
33 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
34 |
+
self.proj = nn.Linear(dim, dim)
|
35 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
36 |
+
|
37 |
+
def forward(self, x, unseen_size):
|
38 |
+
B, N, C = x.shape
|
39 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
40 |
+
q, k, v = qkv.unbind(0)
|
41 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
42 |
+
|
43 |
+
mask = torch.zeros((1, 1, N, N), device=attn.device)
|
44 |
+
mask[:, :, :, -unseen_size:] = float('-inf')
|
45 |
+
for i in range(unseen_size):
|
46 |
+
mask[:, :, -(i + 1), -(i + 1)] = 0
|
47 |
+
attn = attn + mask
|
48 |
+
attn = attn.softmax(dim=-1)
|
49 |
+
|
50 |
+
attn = self.attn_drop(attn)
|
51 |
+
|
52 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
53 |
+
x = self.proj(x)
|
54 |
+
x = self.proj_drop(x)
|
55 |
+
return x
|
56 |
+
|
57 |
+
class MCCDecoderBlock(nn.Module):
|
58 |
+
|
59 |
+
def __init__(
|
60 |
+
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
|
61 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, args=None):
|
62 |
+
super().__init__()
|
63 |
+
self.args = args
|
64 |
+
self.norm1 = norm_layer(dim)
|
65 |
+
self.attn = MCCDecoderAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, args=args)
|
66 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
67 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
68 |
+
|
69 |
+
self.norm2 = norm_layer(dim)
|
70 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
71 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
72 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
73 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
74 |
+
|
75 |
+
def forward(self, x, unseen_size):
|
76 |
+
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), unseen_size)))
|
77 |
+
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
78 |
+
return x
|
79 |
+
|
80 |
+
|
81 |
+
class XYZPosEmbed(nn.Module):
|
82 |
+
""" Masked Autoencoder with VisionTransformer backbone
|
83 |
+
"""
|
84 |
+
def __init__(self, embed_dim):
|
85 |
+
super().__init__()
|
86 |
+
self.embed_dim = embed_dim
|
87 |
+
|
88 |
+
self.two_d_pos_embed = nn.Parameter(
|
89 |
+
torch.zeros(1, 64 + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
|
90 |
+
|
91 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
92 |
+
self.win_size = 8
|
93 |
+
|
94 |
+
self.pos_embed = nn.Linear(3, embed_dim)
|
95 |
+
|
96 |
+
self.blocks = nn.ModuleList([
|
97 |
+
Block(embed_dim, num_heads=12, mlp_ratio=2.0, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6))
|
98 |
+
for _ in range(1)
|
99 |
+
])
|
100 |
+
|
101 |
+
self.invalid_xyz_token = nn.Parameter(torch.zeros(embed_dim,))
|
102 |
+
|
103 |
+
self.initialize_weights()
|
104 |
+
|
105 |
+
def initialize_weights(self):
|
106 |
+
torch.nn.init.normal_(self.cls_token, std=.02)
|
107 |
+
|
108 |
+
two_d_pos_embed = get_2d_sincos_pos_embed(self.two_d_pos_embed.shape[-1], 8, cls_token=True)
|
109 |
+
self.two_d_pos_embed.data.copy_(torch.from_numpy(two_d_pos_embed).float().unsqueeze(0))
|
110 |
+
|
111 |
+
torch.nn.init.normal_(self.invalid_xyz_token, std=.02)
|
112 |
+
|
113 |
+
def forward(self, seen_xyz, valid_seen_xyz):
|
114 |
+
emb = self.pos_embed(seen_xyz)
|
115 |
+
|
116 |
+
emb[~valid_seen_xyz] = 0.0
|
117 |
+
emb[~valid_seen_xyz] += self.invalid_xyz_token
|
118 |
+
|
119 |
+
B, H, W, C = emb.shape
|
120 |
+
emb = emb.view(B, H // self.win_size, self.win_size, W // self.win_size, self.win_size, C)
|
121 |
+
emb = emb.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.win_size * self.win_size, C)
|
122 |
+
|
123 |
+
emb = emb + self.two_d_pos_embed[:, 1:, :]
|
124 |
+
cls_token = self.cls_token + self.two_d_pos_embed[:, :1, :]
|
125 |
+
|
126 |
+
cls_tokens = cls_token.expand(emb.shape[0], -1, -1)
|
127 |
+
emb = torch.cat((cls_tokens, emb), dim=1)
|
128 |
+
for _, blk in enumerate(self.blocks):
|
129 |
+
emb = blk(emb)
|
130 |
+
return emb[:, 0].view(B, (H // self.win_size) * (W // self.win_size), -1)
|
131 |
+
|
132 |
+
|
133 |
+
class DecodeXYZPosEmbed(nn.Module):
|
134 |
+
""" Masked Autoencoder with VisionTransformer backbone
|
135 |
+
"""
|
136 |
+
def __init__(self, embed_dim):
|
137 |
+
super().__init__()
|
138 |
+
self.embed_dim = embed_dim
|
139 |
+
self.pos_embed = nn.Linear(3, embed_dim)
|
140 |
+
|
141 |
+
def forward(self, unseen_xyz):
|
142 |
+
return self.pos_embed(unseen_xyz)
|
143 |
+
|
144 |
+
|
145 |
+
class MCC(nn.Module):
|
146 |
+
""" Masked Autoencoder with VisionTransformer backbone
|
147 |
+
"""
|
148 |
+
def __init__(self,
|
149 |
+
img_size=224, patch_size=16, in_chans=3,
|
150 |
+
embed_dim=1024, depth=24, num_heads=16,
|
151 |
+
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
|
152 |
+
mlp_ratio=4., norm_layer=nn.LayerNorm,
|
153 |
+
rgb_weight=1.0, occupancy_weight=1.0, args=None):
|
154 |
+
super().__init__()
|
155 |
+
|
156 |
+
self.rgb_weight = rgb_weight
|
157 |
+
self.occupancy_weight = occupancy_weight
|
158 |
+
self.args = args
|
159 |
+
|
160 |
+
# --------------------------------------------------------------------------
|
161 |
+
# encoder specifics
|
162 |
+
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
|
163 |
+
num_patches = self.patch_embed.num_patches
|
164 |
+
|
165 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
166 |
+
self.cls_token_xyz = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
167 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
|
168 |
+
|
169 |
+
self.xyz_pos_embed = XYZPosEmbed(embed_dim)
|
170 |
+
|
171 |
+
self.blocks = nn.ModuleList([
|
172 |
+
Block(
|
173 |
+
embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
|
174 |
+
drop_path=args.drop_path
|
175 |
+
) for i in range(depth)])
|
176 |
+
|
177 |
+
self.blocks_xyz = nn.ModuleList([
|
178 |
+
Block(
|
179 |
+
embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
|
180 |
+
drop_path=args.drop_path
|
181 |
+
) for i in range(depth)])
|
182 |
+
|
183 |
+
self.norm = norm_layer(embed_dim)
|
184 |
+
self.norm_xyz = norm_layer(embed_dim)
|
185 |
+
self.cached_enc_feat = None
|
186 |
+
|
187 |
+
# --------------------------------------------------------------------------
|
188 |
+
# decoder specifics
|
189 |
+
self.decoder_embed = nn.Linear(
|
190 |
+
embed_dim * 2,
|
191 |
+
decoder_embed_dim,
|
192 |
+
bias=True
|
193 |
+
)
|
194 |
+
|
195 |
+
self.decoder_xyz_pos_embed = DecodeXYZPosEmbed(decoder_embed_dim)
|
196 |
+
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
|
197 |
+
|
198 |
+
self.decoder_blocks = nn.ModuleList([
|
199 |
+
MCCDecoderBlock(
|
200 |
+
decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
|
201 |
+
drop_path=args.drop_path,
|
202 |
+
args=args,
|
203 |
+
) for i in range(decoder_depth)])
|
204 |
+
|
205 |
+
self.decoder_norm = norm_layer(decoder_embed_dim)
|
206 |
+
if self.args.regress_color:
|
207 |
+
self.decoder_pred = nn.Linear(decoder_embed_dim, 3 + 1, bias=True) # decoder to patch
|
208 |
+
else:
|
209 |
+
self.decoder_pred = nn.Linear(decoder_embed_dim, 256 * 3 + 1, bias=True) # decoder to patch
|
210 |
+
|
211 |
+
self.loss_occupy = nn.BCEWithLogitsLoss()
|
212 |
+
if self.args.regress_color:
|
213 |
+
self.loss_rgb = nn.MSELoss()
|
214 |
+
else:
|
215 |
+
self.loss_rgb = nn.CrossEntropyLoss()
|
216 |
+
|
217 |
+
self.initialize_weights()
|
218 |
+
|
219 |
+
def initialize_weights(self):
|
220 |
+
|
221 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
|
222 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
223 |
+
|
224 |
+
decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
|
225 |
+
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
|
226 |
+
|
227 |
+
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
|
228 |
+
w = self.patch_embed.proj.weight.data
|
229 |
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
230 |
+
|
231 |
+
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
232 |
+
torch.nn.init.normal_(self.cls_token, std=.02)
|
233 |
+
torch.nn.init.normal_(self.cls_token_xyz, std=.02)
|
234 |
+
|
235 |
+
# initialize nn.Linear and nn.LayerNorm
|
236 |
+
self.apply(self._init_weights)
|
237 |
+
|
238 |
+
def _init_weights(self, m):
|
239 |
+
if isinstance(m, nn.Linear):
|
240 |
+
# we use xavier_uniform following official JAX ViT:
|
241 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
242 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
243 |
+
nn.init.constant_(m.bias, 0)
|
244 |
+
elif isinstance(m, nn.LayerNorm):
|
245 |
+
nn.init.constant_(m.bias, 0)
|
246 |
+
nn.init.constant_(m.weight, 1.0)
|
247 |
+
|
248 |
+
|
249 |
+
def forward_encoder(self, x, seen_xyz, valid_seen_xyz):
|
250 |
+
|
251 |
+
# get tokens
|
252 |
+
x = self.patch_embed(x)
|
253 |
+
x = x + self.pos_embed[:, 1:, :]
|
254 |
+
y = self.xyz_pos_embed(seen_xyz, valid_seen_xyz)
|
255 |
+
|
256 |
+
##### forward E_XYZ #####
|
257 |
+
# append cls token
|
258 |
+
cls_token_xyz = self.cls_token_xyz
|
259 |
+
cls_tokens_xyz = cls_token_xyz.expand(y.shape[0], -1, -1)
|
260 |
+
|
261 |
+
y = torch.cat((cls_tokens_xyz, y), dim=1)
|
262 |
+
# apply Transformer blocks
|
263 |
+
for blk in self.blocks_xyz:
|
264 |
+
y = blk(y)
|
265 |
+
y = self.norm_xyz(y)
|
266 |
+
|
267 |
+
##### forward E_RGB #####
|
268 |
+
# append cls token
|
269 |
+
cls_token = self.cls_token + self.pos_embed[:, :1, :]
|
270 |
+
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
|
271 |
+
|
272 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
273 |
+
# apply Transformer blocks
|
274 |
+
for blk in self.blocks:
|
275 |
+
x = blk(x)
|
276 |
+
x = self.norm(x)
|
277 |
+
|
278 |
+
# combine encodings
|
279 |
+
x = torch.cat([x, y], dim=2)
|
280 |
+
return x
|
281 |
+
|
282 |
+
def forward_decoder(self, x, unseen_xyz):
|
283 |
+
# embed tokens
|
284 |
+
x = self.decoder_embed(x)
|
285 |
+
x = x + self.decoder_pos_embed
|
286 |
+
|
287 |
+
# 3D pos embed
|
288 |
+
unseen_xyz = self.decoder_xyz_pos_embed(unseen_xyz)
|
289 |
+
x = torch.cat([x, unseen_xyz], dim=1)
|
290 |
+
|
291 |
+
# apply Transformer blocks
|
292 |
+
for blk in self.decoder_blocks:
|
293 |
+
x = blk(x, unseen_xyz.shape[1])
|
294 |
+
|
295 |
+
x = self.decoder_norm(x)
|
296 |
+
|
297 |
+
# predictor projection
|
298 |
+
pred = self.decoder_pred(x)
|
299 |
+
# remove cls & seen token
|
300 |
+
pred = pred[:, -unseen_xyz.shape[1]:, :]
|
301 |
+
|
302 |
+
return pred
|
303 |
+
|
304 |
+
def forward_loss(self, pred, unseen_occupy, unseen_rgb):
|
305 |
+
loss = self.loss_occupy(
|
306 |
+
pred[:, :, :1].reshape((-1, 1)),
|
307 |
+
unseen_occupy.reshape((-1, 1)).float()
|
308 |
+
) * self.occupancy_weight
|
309 |
+
|
310 |
+
if unseen_occupy.sum() > 0:
|
311 |
+
if self.args.regress_color:
|
312 |
+
pred_rgb = pred[:, :, 1:][unseen_occupy.bool()]
|
313 |
+
gt_rgb = unseen_rgb[unseen_occupy.bool()]
|
314 |
+
else:
|
315 |
+
pred_rgb = pred[:, :, 1:][unseen_occupy.bool()].reshape((-1, 256))
|
316 |
+
gt_rgb = torch.round(unseen_rgb[unseen_occupy.bool()] * 255).long().reshape((-1,))
|
317 |
+
|
318 |
+
rgb_loss = self.loss_rgb(pred_rgb, gt_rgb) * self.rgb_weight
|
319 |
+
loss = loss + rgb_loss
|
320 |
+
return loss
|
321 |
+
|
322 |
+
|
323 |
+
def clear_cache(self):
|
324 |
+
self.cached_enc_feat = None
|
325 |
+
|
326 |
+
def forward(self, seen_images, seen_xyz, unseen_xyz, unseen_rgb, unseen_occupy, valid_seen_xyz,
|
327 |
+
cache_enc=False):
|
328 |
+
|
329 |
+
unseen_xyz = shrink_points_beyond_threshold(unseen_xyz, self.args.shrink_threshold)
|
330 |
+
|
331 |
+
if self.cached_enc_feat is None:
|
332 |
+
seen_images = preprocess_img(seen_images)
|
333 |
+
seen_xyz = shrink_points_beyond_threshold(seen_xyz, self.args.shrink_threshold)
|
334 |
+
latent = self.forward_encoder(seen_images, seen_xyz, valid_seen_xyz)
|
335 |
+
|
336 |
+
if cache_enc:
|
337 |
+
if self.cached_enc_feat is None:
|
338 |
+
self.cached_enc_feat = latent
|
339 |
+
else:
|
340 |
+
latent = self.cached_enc_feat
|
341 |
+
|
342 |
+
pred = self.forward_decoder(latent, unseen_xyz)
|
343 |
+
loss = self.forward_loss(pred, unseen_occupy, unseen_rgb)
|
344 |
+
return loss, pred
|
345 |
+
|
346 |
+
|
347 |
+
def get_mcc_model(**kwargs):
|
348 |
+
return MCC(
|
349 |
+
embed_dim=768, depth=12, num_heads=12,
|
350 |
+
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
|
351 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs
|
352 |
+
)
|
353 |
+
|
354 |
+
|
355 |
+
def shrink_points_beyond_threshold(xyz, threshold):
|
356 |
+
xyz = xyz.clone().detach()
|
357 |
+
dist = (xyz ** 2.0).sum(axis=-1) ** 0.5
|
358 |
+
affected = (dist > threshold) * torch.isfinite(dist)
|
359 |
+
xyz[affected] = xyz[affected] * (
|
360 |
+
threshold * (2.0 - threshold / dist[affected]) / dist[affected]
|
361 |
+
)[..., None]
|
362 |
+
return xyz
|
363 |
+
|
364 |
+
|
365 |
+
def preprocess_img(x):
|
366 |
+
if x.shape[2] != 224:
|
367 |
+
assert x.shape[2] == 800
|
368 |
+
x = F.interpolate(
|
369 |
+
x,
|
370 |
+
scale_factor=224./800.,
|
371 |
+
mode="bilinear",
|
372 |
+
)
|
373 |
+
resnet_mean = torch.tensor([0.485, 0.456, 0.406], device=x.device).reshape((1, 3, 1, 1))
|
374 |
+
resnet_std = torch.tensor([0.229, 0.224, 0.225], device=x.device).reshape((1, 3, 1, 1))
|
375 |
+
imgs_normed = (x - resnet_mean) / resnet_std
|
376 |
+
return imgs_normed
|
377 |
+
|
378 |
+
|
379 |
+
class LayerScale(nn.Module):
|
380 |
+
def __init__(self, dim, init_values=1e-5, inplace=False):
|
381 |
+
super().__init__()
|
382 |
+
self.inplace = inplace
|
383 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
384 |
+
|
385 |
+
def forward(self, x):
|
386 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
packages.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
ffmpeg
|
2 |
+
libsm6
|
3 |
+
libxext6
|
pre-requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
torch==1.13.0
|
2 |
+
torchvision==0.14.0
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
h5py
|
2 |
+
omegaconf
|
3 |
+
submitit
|
4 |
+
timm==0.4.5
|
5 |
+
opencv-python
|
6 |
+
matplotlib
|
7 |
+
plotly
|
8 |
+
gradio
|
9 |
+
gradio_client==0.2.7
|
10 |
+
plyfile
|
11 |
+
git+https://github.com/facebookresearch/pytorch3d.git
|
util/co3d_dataset.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import random
|
8 |
+
from typing import cast
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
12 |
+
|
13 |
+
import util.co3d_utils as co3d_utils
|
14 |
+
|
15 |
+
|
16 |
+
def co3dv2_collate_fn(batch):
|
17 |
+
assert len(batch[0]) == 4
|
18 |
+
return (
|
19 |
+
FrameData.collate([x[0] for x in batch]),
|
20 |
+
FrameData.collate([x[1] for x in batch]),
|
21 |
+
[x[2] for x in batch],
|
22 |
+
[x[3] for x in batch],
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
def pad_point_cloud(pc, N):
|
27 |
+
cur_N = pc._points_list[0].shape[0]
|
28 |
+
if cur_N == N:
|
29 |
+
return pc
|
30 |
+
|
31 |
+
assert cur_N > 0
|
32 |
+
|
33 |
+
n_pad = N - cur_N
|
34 |
+
indices = random.choices(list(range(cur_N)), k=n_pad)
|
35 |
+
pc._features_list[0] = torch.cat([pc._features_list[0], pc._features_list[0][indices]], dim=0)
|
36 |
+
pc._points_list[0] = torch.cat([pc._points_list[0], pc._points_list[0][indices]], dim=0)
|
37 |
+
return pc
|
38 |
+
|
39 |
+
|
40 |
+
class CO3DV2Dataset(torch.utils.data.Dataset):
|
41 |
+
def __init__(self, args, is_train, is_viz=False, dataset_maps=None):
|
42 |
+
|
43 |
+
self.args = args
|
44 |
+
self.is_train = is_train
|
45 |
+
self.is_viz = is_viz
|
46 |
+
|
47 |
+
self.dataset_split = 'train' if is_train else 'val'
|
48 |
+
self.all_datasets = dataset_maps[0 if is_train else 1]
|
49 |
+
print(len(self.all_datasets), 'categories loaded')
|
50 |
+
|
51 |
+
self.all_example_names = self.get_all_example_names()
|
52 |
+
print('containing', len(self.all_example_names), 'examples')
|
53 |
+
|
54 |
+
def get_all_example_names(self):
|
55 |
+
all_example_names = []
|
56 |
+
for category in self.all_datasets.keys():
|
57 |
+
for sequence_name in self.all_datasets[category].seq_name2idx.keys():
|
58 |
+
all_example_names.append((category, sequence_name))
|
59 |
+
return all_example_names
|
60 |
+
|
61 |
+
def __getitem__(self, index):
|
62 |
+
for retry in range(1000):
|
63 |
+
try:
|
64 |
+
if retry > 9:
|
65 |
+
index = random.choice(range(len(self)))
|
66 |
+
print('retry', retry, 'new index:', index)
|
67 |
+
gap = 1 if self.is_train else len(self.all_example_names) // len(self)
|
68 |
+
assert gap >= 1
|
69 |
+
category, sequence_name = self.all_example_names[(index * gap) % len(self.all_example_names)]
|
70 |
+
|
71 |
+
cat_dataset = self.all_datasets[category]
|
72 |
+
|
73 |
+
frame_data = cat_dataset.__getitem__(
|
74 |
+
random.choice(cat_dataset.seq_name2idx[sequence_name])
|
75 |
+
if self.is_train
|
76 |
+
else cat_dataset.seq_name2idx[sequence_name][
|
77 |
+
hash(sequence_name) % len(cat_dataset.seq_name2idx[sequence_name])
|
78 |
+
]
|
79 |
+
)
|
80 |
+
test_frame = None
|
81 |
+
seen_idx = None
|
82 |
+
|
83 |
+
frame_data = cat_dataset.frame_data_type.collate([frame_data])
|
84 |
+
mask = (
|
85 |
+
(cast(torch.Tensor, frame_data.fg_probability) > 0.5).float()
|
86 |
+
if frame_data.fg_probability is not None
|
87 |
+
else None
|
88 |
+
)
|
89 |
+
seen_rgb = frame_data.image_rgb.clone().detach()
|
90 |
+
|
91 |
+
# 112, 112, 3
|
92 |
+
seen_xyz = co3d_utils.get_rgbd_points(
|
93 |
+
112, 112,
|
94 |
+
frame_data.camera,
|
95 |
+
frame_data.depth_map,
|
96 |
+
mask,
|
97 |
+
)
|
98 |
+
|
99 |
+
full_point_cloud = co3d_utils._load_pointcloud(f'{self.args.co3d_path}/{category}/{sequence_name}/pointcloud.ply', max_points=20000)
|
100 |
+
full_point_cloud = pad_point_cloud(full_point_cloud, 20000)
|
101 |
+
break
|
102 |
+
except Exception as e:
|
103 |
+
print(category, sequence_name, 'sampling failed', retry, e)
|
104 |
+
|
105 |
+
seen_rgb = seen_rgb.squeeze(0)
|
106 |
+
full_rgb = full_point_cloud._features_list[0]
|
107 |
+
|
108 |
+
return (
|
109 |
+
(seen_xyz, seen_rgb),
|
110 |
+
(full_point_cloud._points_list[0], full_rgb),
|
111 |
+
test_frame,
|
112 |
+
(category, sequence_name, seen_idx),
|
113 |
+
)
|
114 |
+
|
115 |
+
def __len__(self) -> int:
|
116 |
+
n_objs = sum([len(cat_dataset.seq_name2idx.keys()) for cat_dataset in self.all_datasets.values()])
|
117 |
+
if self.is_train:
|
118 |
+
return int(n_objs * self.args.train_epoch_len_multiplier)
|
119 |
+
elif self.is_viz:
|
120 |
+
return n_objs
|
121 |
+
else:
|
122 |
+
return int(n_objs * self.args.eval_epoch_len_multiplier)
|
util/co3d_utils.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import glob
|
8 |
+
from omegaconf import DictConfig
|
9 |
+
from typing import Optional
|
10 |
+
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
|
14 |
+
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider_v2 import (
|
15 |
+
JsonIndexDatasetMapProviderV2
|
16 |
+
)
|
17 |
+
from pytorch3d.implicitron.tools.config import expand_args_fields
|
18 |
+
from pytorch3d.io import IO
|
19 |
+
from pytorch3d.renderer import (
|
20 |
+
NDCMultinomialRaysampler,
|
21 |
+
ray_bundle_to_ray_points,
|
22 |
+
)
|
23 |
+
from pytorch3d.renderer.cameras import CamerasBase
|
24 |
+
from pytorch3d.structures import Pointclouds
|
25 |
+
|
26 |
+
|
27 |
+
HOLDOUT_CATEGORIES = set([
|
28 |
+
'apple',
|
29 |
+
'baseballglove',
|
30 |
+
'cup',
|
31 |
+
'ball',
|
32 |
+
'toyplane',
|
33 |
+
'handbag',
|
34 |
+
'book',
|
35 |
+
'carrot',
|
36 |
+
'suitcase',
|
37 |
+
'bowl',
|
38 |
+
])
|
39 |
+
|
40 |
+
def get_dataset_map(
|
41 |
+
dataset_root: str,
|
42 |
+
category: str,
|
43 |
+
subset_name: str,
|
44 |
+
) -> DatasetMap:
|
45 |
+
"""
|
46 |
+
Obtain the dataset map that contains the train/val/test dataset objects.
|
47 |
+
"""
|
48 |
+
expand_args_fields(JsonIndexDatasetMapProviderV2)
|
49 |
+
dataset_map_provider = JsonIndexDatasetMapProviderV2(
|
50 |
+
category=category,
|
51 |
+
subset_name=subset_name,
|
52 |
+
dataset_root=dataset_root,
|
53 |
+
test_on_train=False,
|
54 |
+
only_test_set=False,
|
55 |
+
load_eval_batches=True,
|
56 |
+
dataset_JsonIndexDataset_args=DictConfig({"remove_empty_masks": False, "load_point_clouds": False}),
|
57 |
+
)
|
58 |
+
return dataset_map_provider.get_dataset_map()
|
59 |
+
|
60 |
+
|
61 |
+
def _load_pointcloud(pcl_path, max_points):
|
62 |
+
pcl = IO().load_pointcloud(pcl_path)
|
63 |
+
if max_points > 0:
|
64 |
+
pcl = pcl.subsample(max_points)
|
65 |
+
|
66 |
+
return pcl
|
67 |
+
|
68 |
+
|
69 |
+
def get_all_dataset_maps(co3d_path, holdout_categories):
|
70 |
+
all_categories = [c.split('/')[-1] for c in list(glob.glob(co3d_path + '/*')) if not c.endswith('.json')]
|
71 |
+
all_categories = sorted(all_categories, key=lambda x: hash(x))
|
72 |
+
|
73 |
+
# Obtain the CO3Dv2 dataset map
|
74 |
+
train_dataset_maps = {}
|
75 |
+
val_dataset_maps = {}
|
76 |
+
for category in all_categories:
|
77 |
+
|
78 |
+
print(f'Loading dataset map ({category})')
|
79 |
+
dataset_map = {
|
80 |
+
'train': torch.load(f'dataset_cache/{category}_train.pt'),
|
81 |
+
'val': torch.load(f'dataset_cache/{category}_val.pt')
|
82 |
+
}
|
83 |
+
if not holdout_categories or category not in HOLDOUT_CATEGORIES:
|
84 |
+
train_dataset_maps[category] = dataset_map['train']
|
85 |
+
if not holdout_categories or category in HOLDOUT_CATEGORIES:
|
86 |
+
val_dataset_maps[category] = dataset_map['val']
|
87 |
+
|
88 |
+
print('Loaded', len(train_dataset_maps), 'categores for train')
|
89 |
+
print('Loaded', len(val_dataset_maps), 'categores for val')
|
90 |
+
return train_dataset_maps, val_dataset_maps
|
91 |
+
|
92 |
+
|
93 |
+
def get_rgbd_points(
|
94 |
+
imh, imw,
|
95 |
+
camera: CamerasBase,
|
96 |
+
depth_map: torch.Tensor,
|
97 |
+
mask: Optional[torch.Tensor] = None,
|
98 |
+
mask_thr: float = 0.5,
|
99 |
+
) -> Pointclouds:
|
100 |
+
"""
|
101 |
+
Given a batch of images, depths, masks and cameras, generate a colored
|
102 |
+
point cloud by unprojecting depth maps to the and coloring with the source
|
103 |
+
pixel colors.
|
104 |
+
"""
|
105 |
+
depth_map = torch.nn.functional.interpolate(
|
106 |
+
depth_map,
|
107 |
+
size=[imh, imw],
|
108 |
+
mode="bilinear",
|
109 |
+
align_corners=False,
|
110 |
+
)
|
111 |
+
# convert the depth maps to point clouds using the grid ray sampler
|
112 |
+
pts_3d = ray_bundle_to_ray_points(
|
113 |
+
NDCMultinomialRaysampler(
|
114 |
+
image_width=imw,
|
115 |
+
image_height=imh,
|
116 |
+
n_pts_per_ray=1,
|
117 |
+
min_depth=1.0,
|
118 |
+
max_depth=1.0,
|
119 |
+
)(camera)._replace(lengths=depth_map[:, 0, ..., None])
|
120 |
+
).squeeze(3)[None]
|
121 |
+
|
122 |
+
pts_mask = depth_map > 0.0
|
123 |
+
if mask is not None:
|
124 |
+
mask = torch.nn.functional.interpolate(
|
125 |
+
mask,
|
126 |
+
size=[imh, imw],
|
127 |
+
mode="bilinear",
|
128 |
+
align_corners=False,
|
129 |
+
)
|
130 |
+
pts_mask *= mask > mask_thr
|
131 |
+
pts_3d[~pts_mask] = float('inf')
|
132 |
+
return pts_3d.squeeze(0).squeeze(0)
|
133 |
+
|
util/crop.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from torchvision import transforms
|
12 |
+
from torchvision.transforms import functional as F
|
13 |
+
|
14 |
+
|
15 |
+
class RandomResizedCrop(transforms.RandomResizedCrop):
|
16 |
+
"""
|
17 |
+
RandomResizedCrop for matching TF/TPU implementation: no for-loop is used.
|
18 |
+
This may lead to results different with torchvision's version.
|
19 |
+
Following BYOL's TF code:
|
20 |
+
https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206
|
21 |
+
"""
|
22 |
+
@staticmethod
|
23 |
+
def get_params(img, scale, ratio):
|
24 |
+
width, height = F._get_image_size(img)
|
25 |
+
area = height * width
|
26 |
+
|
27 |
+
target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
|
28 |
+
log_ratio = torch.log(torch.tensor(ratio))
|
29 |
+
aspect_ratio = torch.exp(
|
30 |
+
torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
|
31 |
+
).item()
|
32 |
+
|
33 |
+
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
34 |
+
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
35 |
+
|
36 |
+
w = min(w, width)
|
37 |
+
h = min(h, height)
|
38 |
+
|
39 |
+
i = torch.randint(0, height - h + 1, size=(1,)).item()
|
40 |
+
j = torch.randint(0, width - w + 1, size=(1,)).item()
|
41 |
+
|
42 |
+
return i, j, h, w
|
util/hypersim_dataset.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import os
|
8 |
+
import random
|
9 |
+
import glob
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
13 |
+
from pytorch3d.ops import sample_points_from_meshes
|
14 |
+
|
15 |
+
from util.hypersim_utils import read_h5py, read_img
|
16 |
+
|
17 |
+
|
18 |
+
def hypersim_collate_fn(batch):
|
19 |
+
assert len(batch[0]) == 4
|
20 |
+
return (
|
21 |
+
FrameData.collate([x[0] for x in batch]),
|
22 |
+
FrameData.collate([x[1] for x in batch]),
|
23 |
+
FrameData.collate([x[2] for x in batch]),
|
24 |
+
[x[2] for x in batch]
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
def is_good_xyz(xyz):
|
29 |
+
assert len(xyz.shape) == 3
|
30 |
+
return (torch.isfinite(xyz.sum(axis=2))).sum() > 2000
|
31 |
+
|
32 |
+
|
33 |
+
def get_camera_pos_file_name_from_frame_name(frame_name):
|
34 |
+
tmp = frame_name.split('/')
|
35 |
+
tmp[-3] = '_detail'
|
36 |
+
tmp[-2] = 'cam_' + tmp[-2].split('_')[2]
|
37 |
+
tmp[-1] = 'camera_keyframe_positions.hdf5'
|
38 |
+
return '/'.join(tmp)
|
39 |
+
|
40 |
+
|
41 |
+
def get_camera_look_at_file_name_from_frame_name(frame_name):
|
42 |
+
tmp = frame_name.split('/')
|
43 |
+
tmp[-3] = '_detail'
|
44 |
+
tmp[-2] = 'cam_' + tmp[-2].split('_')[2]
|
45 |
+
tmp[-1] = 'camera_keyframe_look_at_positions.hdf5'
|
46 |
+
return '/'.join(tmp)
|
47 |
+
|
48 |
+
|
49 |
+
def get_camera_orientation_file_name_from_frame_name(frame_name):
|
50 |
+
tmp = frame_name.split('/')
|
51 |
+
tmp[-3] = '_detail'
|
52 |
+
tmp[-2] = 'cam_' + tmp[-2].split('_')[2]
|
53 |
+
tmp[-1] = 'camera_keyframe_orientations.hdf5'
|
54 |
+
return '/'.join(tmp)
|
55 |
+
|
56 |
+
|
57 |
+
def read_scale_from_frame_name(frame_name):
|
58 |
+
tmp = frame_name.split('/')
|
59 |
+
with open('/'.join(tmp[:-3] + ['_detail', 'metadata_scene.csv'])) as f:
|
60 |
+
for line in f:
|
61 |
+
items = line.split(',')
|
62 |
+
return float(items[1])
|
63 |
+
|
64 |
+
|
65 |
+
def random_crop(xyz, img, is_train=True):
|
66 |
+
assert xyz.shape[0] == img.shape[0]
|
67 |
+
assert xyz.shape[1] == img.shape[1]
|
68 |
+
|
69 |
+
width, height = img.shape[0], img.shape[1]
|
70 |
+
w = h = min(width, height)
|
71 |
+
if is_train:
|
72 |
+
i = torch.randint(0, width - w + 1, size=(1,)).item()
|
73 |
+
j = torch.randint(0, height - h + 1, size=(1,)).item()
|
74 |
+
else:
|
75 |
+
i = (width - w) // 2
|
76 |
+
j = (height - h) // 2
|
77 |
+
xyz = xyz[i:i+w, j:j+h]
|
78 |
+
img = img[i:i+w, j:j+h]
|
79 |
+
xyz = torch.nn.functional.interpolate(
|
80 |
+
xyz[None].permute(0, 3, 1, 2), (112, 112),
|
81 |
+
mode='bilinear',
|
82 |
+
).permute(0, 2, 3, 1)[0]
|
83 |
+
img = torch.nn.functional.interpolate(
|
84 |
+
img[None].permute(0, 3, 1, 2), (224, 224),
|
85 |
+
mode='bilinear',
|
86 |
+
).permute(0, 2, 3, 1)[0]
|
87 |
+
return xyz, img
|
88 |
+
|
89 |
+
|
90 |
+
class HyperSimDataset(torch.utils.data.Dataset):
|
91 |
+
def __init__(self, args, is_train, is_viz=False, **kwargs):
|
92 |
+
|
93 |
+
self.args = args
|
94 |
+
self.is_train = is_train
|
95 |
+
self.is_viz = is_viz
|
96 |
+
|
97 |
+
self.dataset_split = 'train' if is_train else 'val'
|
98 |
+
self.scene_names = self.load_scene_names(is_train)
|
99 |
+
|
100 |
+
if not is_train:
|
101 |
+
self.meshes = self.load_meshes()
|
102 |
+
|
103 |
+
self.hypersim_gt = self.load_hypersim_gt()
|
104 |
+
|
105 |
+
|
106 |
+
def load_hypersim_gt(self):
|
107 |
+
gt_filename = 'hypersim_gt_train.pt' if self.dataset_split == 'train' else 'hypersim_gt_val.pt'
|
108 |
+
print('loading GT file from', gt_filename)
|
109 |
+
gt = torch.load(gt_filename)
|
110 |
+
for scene_name in gt.keys():
|
111 |
+
good = torch.isfinite(gt[scene_name][0].sum(axis=1)) & torch.isfinite(gt[scene_name][1].sum(axis=1))
|
112 |
+
|
113 |
+
# Subsample GT to reduce memory usage.
|
114 |
+
if self.is_train:
|
115 |
+
good = good & (torch.rand(good.shape) < 0.5)
|
116 |
+
else:
|
117 |
+
good = good & (torch.rand(good.shape) < 0.1)
|
118 |
+
gt[scene_name] = [gt[scene_name][0][good], gt[scene_name][1][good]]
|
119 |
+
return gt
|
120 |
+
|
121 |
+
def load_meshes(self):
|
122 |
+
return torch.load('all_hypersim_val_meshes.pt')
|
123 |
+
|
124 |
+
def load_scene_names(self, is_train):
|
125 |
+
split = 'train' if is_train else 'test'
|
126 |
+
scene_names = []
|
127 |
+
with open(os.path.join(
|
128 |
+
self.args.hypersim_path,
|
129 |
+
'evermotion_dataset/analysis/metadata_images_split_scene_v1.csv'),'r') as f:
|
130 |
+
for line in f:
|
131 |
+
items = line.split(',')
|
132 |
+
if items[-1].strip() == split:
|
133 |
+
scene_names.append(items[0])
|
134 |
+
scene_names = sorted(list(set(scene_names)))
|
135 |
+
print(len(scene_names), 'scenes loaded:', scene_names)
|
136 |
+
return scene_names
|
137 |
+
|
138 |
+
def is_corrupted_frame(self, frame):
|
139 |
+
return (
|
140 |
+
('ai_003_001' in frame and 'cam_00' in frame)
|
141 |
+
or ('ai_004_009' in frame and 'cam_01' in frame)
|
142 |
+
)
|
143 |
+
|
144 |
+
def get_hypersim_data(self, index):
|
145 |
+
for retry in range(1000):
|
146 |
+
try:
|
147 |
+
if retry < 10:
|
148 |
+
scene_name = self.scene_names[index % len(self.scene_names)]
|
149 |
+
else:
|
150 |
+
scene_name = random.choice(self.scene_names)
|
151 |
+
|
152 |
+
frames = glob.glob(os.path.join(self.args.hypersim_path, scene_name, 'images/scene_cam_*_final_preview/*tonemap*'))
|
153 |
+
seen_frame = random.choice(frames)
|
154 |
+
|
155 |
+
if self.is_corrupted_frame(seen_frame):
|
156 |
+
continue
|
157 |
+
|
158 |
+
seen_data = self.load_frame_data(seen_frame)
|
159 |
+
if not is_good_xyz(seen_data[0]):
|
160 |
+
continue
|
161 |
+
|
162 |
+
cur_gt = self.hypersim_gt[scene_name]
|
163 |
+
gt_data = [cur_gt[0], cur_gt[1]]
|
164 |
+
|
165 |
+
if self.is_train:
|
166 |
+
mesh_points = torch.zeros((1,))
|
167 |
+
else:
|
168 |
+
mesh_points = sample_points_from_meshes(self.meshes[scene_name], 1000000)
|
169 |
+
|
170 |
+
# get camera positions
|
171 |
+
camera_positions = read_h5py(get_camera_pos_file_name_from_frame_name(seen_frame))
|
172 |
+
camera_position = camera_positions[int(seen_frame.split('.')[-3])]
|
173 |
+
|
174 |
+
# get camera orientations
|
175 |
+
cam_orientations = read_h5py(get_camera_orientation_file_name_from_frame_name(seen_frame))
|
176 |
+
cam_orientation = cam_orientations[int(seen_frame.split('.')[-3])]
|
177 |
+
cam_orientation = cam_orientation * (-1.0)
|
178 |
+
|
179 |
+
# rotate to camera direction
|
180 |
+
seen_data[0] = torch.matmul(seen_data[0], cam_orientation)
|
181 |
+
gt_data[0] = torch.matmul(gt_data[0], cam_orientation)
|
182 |
+
|
183 |
+
# shift to camera center
|
184 |
+
camera_position = torch.matmul(camera_position, cam_orientation)
|
185 |
+
seen_data[0] -= camera_position
|
186 |
+
gt_data[0] -= camera_position
|
187 |
+
# to meter
|
188 |
+
asset_to_meter_scale = read_scale_from_frame_name(seen_frame)
|
189 |
+
seen_data[0] = seen_data[0] * asset_to_meter_scale
|
190 |
+
gt_data[0] = gt_data[0] * asset_to_meter_scale
|
191 |
+
|
192 |
+
# get points GT
|
193 |
+
n_gt = 30000
|
194 |
+
in_front_of_cam = (gt_data[0][..., 2] > 0)
|
195 |
+
if in_front_of_cam.sum() < 1000:
|
196 |
+
print('Warning! Not enough in front of cam.', in_front_of_cam.sum())
|
197 |
+
continue
|
198 |
+
gt_data = [gt_data[0][in_front_of_cam], gt_data[1][in_front_of_cam]]
|
199 |
+
|
200 |
+
if in_front_of_cam.sum() < n_gt:
|
201 |
+
selected = random.choices(range(gt_data[0].shape[0]), k=n_gt)
|
202 |
+
else:
|
203 |
+
selected = random.sample(range(gt_data[0].shape[0]), n_gt)
|
204 |
+
gt_data = [gt_data[0][selected][None], gt_data[1][selected][None], torch.zeros((1,))]
|
205 |
+
|
206 |
+
if not self.is_train:
|
207 |
+
mesh_points = torch.matmul(mesh_points, cam_orientation)
|
208 |
+
mesh_points -= camera_position * asset_to_meter_scale
|
209 |
+
in_front_of_cam = (mesh_points[..., 2] > 0)
|
210 |
+
if in_front_of_cam.sum() < 1000:
|
211 |
+
print('Warning! Not enough mesh in front of cam.', in_front_of_cam.sum())
|
212 |
+
continue
|
213 |
+
mesh_points = mesh_points[in_front_of_cam]
|
214 |
+
if in_front_of_cam.sum() < n_gt:
|
215 |
+
selected = random.choices(range(mesh_points.shape[0]), k=n_gt)
|
216 |
+
else:
|
217 |
+
selected = random.sample(range(mesh_points.shape[0]), n_gt)
|
218 |
+
mesh_points = mesh_points[selected][None]
|
219 |
+
mesh_points[..., 0] *= -1
|
220 |
+
|
221 |
+
seen_data[0][..., 0] *= -1
|
222 |
+
gt_data[0][..., 0] *= -1
|
223 |
+
|
224 |
+
seen_data[1] = seen_data[1].permute(2, 0, 1)
|
225 |
+
|
226 |
+
return seen_data, gt_data, mesh_points, scene_name
|
227 |
+
except Exception as e:
|
228 |
+
print(scene_name, 'loading failed', retry, e)
|
229 |
+
|
230 |
+
|
231 |
+
def __getitem__(self, index):
|
232 |
+
|
233 |
+
seen_data, gt_data, mesh_points, scene_name = self.get_hypersim_data(index)
|
234 |
+
|
235 |
+
# normalize the data
|
236 |
+
example_std = get_example_std(seen_data[0])
|
237 |
+
seen_data[0] = seen_data[0] / example_std
|
238 |
+
gt_data[0] = gt_data[0] / example_std
|
239 |
+
mesh_points = mesh_points / example_std
|
240 |
+
|
241 |
+
return (
|
242 |
+
seen_data,
|
243 |
+
gt_data,
|
244 |
+
mesh_points,
|
245 |
+
scene_name,
|
246 |
+
)
|
247 |
+
|
248 |
+
def load_frame_data(self, frame_path):
|
249 |
+
frame_xyz_path = frame_path.replace('final_preview/', 'geometry_hdf5/').replace('.tonemap.jpg', '.position.hdf5')
|
250 |
+
xyz = read_h5py(frame_xyz_path)
|
251 |
+
img = read_img(frame_path)
|
252 |
+
|
253 |
+
xyz, img = random_crop(
|
254 |
+
xyz, img,
|
255 |
+
is_train=self.is_train,
|
256 |
+
)
|
257 |
+
return [xyz, img]
|
258 |
+
|
259 |
+
def __len__(self) -> int:
|
260 |
+
if self.is_train:
|
261 |
+
return int(len(self.scene_names) * self.args.train_epoch_len_multiplier)
|
262 |
+
elif self.is_viz:
|
263 |
+
return len(self.scene_names)
|
264 |
+
else:
|
265 |
+
return int(len(self.scene_names) * self.args.eval_epoch_len_multiplier)
|
266 |
+
|
267 |
+
|
268 |
+
def get_example_std(x):
|
269 |
+
x = x.reshape(-1, 3)
|
270 |
+
x = x[torch.isfinite(x.sum(dim=1))]
|
271 |
+
return x.std(dim=0).mean().detach()
|
util/hypersim_utils.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import h5py
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
|
13 |
+
def read_h5py(filename):
|
14 |
+
with h5py.File(filename, "r") as f:
|
15 |
+
data = torch.tensor(f['dataset'][:], dtype=torch.float32)
|
16 |
+
return data
|
17 |
+
|
18 |
+
|
19 |
+
def read_img(frame_path):
|
20 |
+
for retry in range(100):
|
21 |
+
img = cv2.imread(frame_path)
|
22 |
+
if img is not None:
|
23 |
+
return torch.tensor(img / 255.0, dtype=torch.float32)[..., [2, 1, 0]]
|
24 |
+
print('retry loading', retry, frame_path)
|
util/lars.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# LARS optimizer, implementation from MoCo v3:
|
8 |
+
# https://github.com/facebookresearch/moco-v3
|
9 |
+
# --------------------------------------------------------
|
10 |
+
|
11 |
+
import torch
|
12 |
+
|
13 |
+
|
14 |
+
class LARS(torch.optim.Optimizer):
|
15 |
+
"""
|
16 |
+
LARS optimizer, no rate scaling or weight decay for parameters <= 1D.
|
17 |
+
"""
|
18 |
+
def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001):
|
19 |
+
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient)
|
20 |
+
super().__init__(params, defaults)
|
21 |
+
|
22 |
+
@torch.no_grad()
|
23 |
+
def step(self):
|
24 |
+
for g in self.param_groups:
|
25 |
+
for p in g['params']:
|
26 |
+
dp = p.grad
|
27 |
+
|
28 |
+
if dp is None:
|
29 |
+
continue
|
30 |
+
|
31 |
+
if p.ndim > 1: # if not normalization gamma/beta or bias
|
32 |
+
dp = dp.add(p, alpha=g['weight_decay'])
|
33 |
+
param_norm = torch.norm(p)
|
34 |
+
update_norm = torch.norm(dp)
|
35 |
+
one = torch.ones_like(param_norm)
|
36 |
+
q = torch.where(param_norm > 0.,
|
37 |
+
torch.where(update_norm > 0,
|
38 |
+
(g['trust_coefficient'] * param_norm / update_norm), one),
|
39 |
+
one)
|
40 |
+
dp = dp.mul(q)
|
41 |
+
|
42 |
+
param_state = self.state[p]
|
43 |
+
if 'mu' not in param_state:
|
44 |
+
param_state['mu'] = torch.zeros_like(p)
|
45 |
+
mu = param_state['mu']
|
46 |
+
mu.mul_(g['momentum']).add_(dp)
|
47 |
+
p.add_(mu, alpha=-g['lr'])
|
util/lr_decay.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# References:
|
8 |
+
# ELECTRA https://github.com/google-research/electra
|
9 |
+
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
|
10 |
+
# --------------------------------------------------------
|
11 |
+
|
12 |
+
import json
|
13 |
+
|
14 |
+
|
15 |
+
def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75):
|
16 |
+
"""
|
17 |
+
Parameter groups for layer-wise lr decay
|
18 |
+
Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
|
19 |
+
"""
|
20 |
+
param_group_names = {}
|
21 |
+
param_groups = {}
|
22 |
+
|
23 |
+
num_layers = len(model.blocks) + 1
|
24 |
+
|
25 |
+
layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
|
26 |
+
|
27 |
+
for n, p in model.named_parameters():
|
28 |
+
if not p.requires_grad:
|
29 |
+
continue
|
30 |
+
|
31 |
+
# no decay: all 1D parameters and model specific ones
|
32 |
+
if p.ndim == 1 or n in no_weight_decay_list:
|
33 |
+
g_decay = "no_decay"
|
34 |
+
this_decay = 0.
|
35 |
+
else:
|
36 |
+
g_decay = "decay"
|
37 |
+
this_decay = weight_decay
|
38 |
+
|
39 |
+
layer_id = get_layer_id_for_vit(n, num_layers)
|
40 |
+
group_name = "layer_%d_%s" % (layer_id, g_decay)
|
41 |
+
|
42 |
+
if group_name not in param_group_names:
|
43 |
+
this_scale = layer_scales[layer_id]
|
44 |
+
|
45 |
+
param_group_names[group_name] = {
|
46 |
+
"lr_scale": this_scale,
|
47 |
+
"weight_decay": this_decay,
|
48 |
+
"params": [],
|
49 |
+
}
|
50 |
+
param_groups[group_name] = {
|
51 |
+
"lr_scale": this_scale,
|
52 |
+
"weight_decay": this_decay,
|
53 |
+
"params": [],
|
54 |
+
}
|
55 |
+
|
56 |
+
param_group_names[group_name]["params"].append(n)
|
57 |
+
param_groups[group_name]["params"].append(p)
|
58 |
+
|
59 |
+
# print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
|
60 |
+
|
61 |
+
return list(param_groups.values())
|
62 |
+
|
63 |
+
|
64 |
+
def get_layer_id_for_vit(name, num_layers):
|
65 |
+
"""
|
66 |
+
Assign a parameter with its layer id
|
67 |
+
Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
|
68 |
+
"""
|
69 |
+
if name in ['cls_token', 'pos_embed']:
|
70 |
+
return 0
|
71 |
+
elif name.startswith('patch_embed'):
|
72 |
+
return 0
|
73 |
+
elif name.startswith('blocks'):
|
74 |
+
return int(name.split('.')[1]) + 1
|
75 |
+
else:
|
76 |
+
return num_layers
|
util/lr_sched.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
|
9 |
+
def adjust_learning_rate(optimizer, epoch, args):
|
10 |
+
"""Decay the learning rate with half-cycle cosine after warmup"""
|
11 |
+
if epoch < args.warmup_epochs:
|
12 |
+
lr = args.lr * epoch / args.warmup_epochs
|
13 |
+
else:
|
14 |
+
lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
|
15 |
+
(1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
|
16 |
+
for param_group in optimizer.param_groups:
|
17 |
+
if "lr_scale" in param_group:
|
18 |
+
param_group["lr"] = lr * param_group["lr_scale"]
|
19 |
+
else:
|
20 |
+
param_group["lr"] = lr
|
21 |
+
return lr
|
util/misc.py
ADDED
@@ -0,0 +1,496 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# References:
|
8 |
+
# DeiT: https://github.com/facebookresearch/deit
|
9 |
+
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
|
10 |
+
# MAE: https://github.com/facebookresearch/mae
|
11 |
+
# --------------------------------------------------------
|
12 |
+
|
13 |
+
import builtins
|
14 |
+
import datetime
|
15 |
+
import os
|
16 |
+
import time
|
17 |
+
from collections import defaultdict, deque
|
18 |
+
from pathlib import Path
|
19 |
+
|
20 |
+
import torch
|
21 |
+
import torch.distributed as dist
|
22 |
+
from torch._six import inf
|
23 |
+
|
24 |
+
|
25 |
+
class SmoothedValue(object):
|
26 |
+
"""Track a series of values and provide access to smoothed values over a
|
27 |
+
window or the global series average.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self, window_size=20, fmt=None):
|
31 |
+
if fmt is None:
|
32 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
33 |
+
self.deque = deque(maxlen=window_size)
|
34 |
+
self.total = 0.0
|
35 |
+
self.count = 0
|
36 |
+
self.fmt = fmt
|
37 |
+
|
38 |
+
def update(self, value, n=1):
|
39 |
+
self.deque.append(value)
|
40 |
+
self.count += n
|
41 |
+
self.total += value * n
|
42 |
+
|
43 |
+
def synchronize_between_processes(self):
|
44 |
+
"""
|
45 |
+
Warning: does not synchronize the deque!
|
46 |
+
"""
|
47 |
+
if not is_dist_avail_and_initialized():
|
48 |
+
return
|
49 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
50 |
+
dist.barrier()
|
51 |
+
dist.all_reduce(t)
|
52 |
+
t = t.tolist()
|
53 |
+
self.count = int(t[0])
|
54 |
+
self.total = t[1]
|
55 |
+
|
56 |
+
@property
|
57 |
+
def median(self):
|
58 |
+
d = torch.tensor(list(self.deque))
|
59 |
+
return d.median().item()
|
60 |
+
|
61 |
+
@property
|
62 |
+
def avg(self):
|
63 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
64 |
+
return d.mean().item()
|
65 |
+
|
66 |
+
@property
|
67 |
+
def global_avg(self):
|
68 |
+
return self.total / self.count
|
69 |
+
|
70 |
+
@property
|
71 |
+
def max(self):
|
72 |
+
return max(self.deque)
|
73 |
+
|
74 |
+
@property
|
75 |
+
def value(self):
|
76 |
+
return self.deque[-1]
|
77 |
+
|
78 |
+
def __str__(self):
|
79 |
+
return self.fmt.format(
|
80 |
+
median=self.median,
|
81 |
+
avg=self.avg,
|
82 |
+
global_avg=self.global_avg,
|
83 |
+
max=self.max,
|
84 |
+
value=self.value)
|
85 |
+
|
86 |
+
|
87 |
+
class MetricLogger(object):
|
88 |
+
def __init__(self, delimiter="\t"):
|
89 |
+
self.meters = defaultdict(SmoothedValue)
|
90 |
+
self.delimiter = delimiter
|
91 |
+
|
92 |
+
def update(self, **kwargs):
|
93 |
+
for k, v in kwargs.items():
|
94 |
+
if v is None:
|
95 |
+
continue
|
96 |
+
if isinstance(v, torch.Tensor):
|
97 |
+
v = v.item()
|
98 |
+
assert isinstance(v, (float, int))
|
99 |
+
self.meters[k].update(v)
|
100 |
+
|
101 |
+
def __getattr__(self, attr):
|
102 |
+
if attr in self.meters:
|
103 |
+
return self.meters[attr]
|
104 |
+
if attr in self.__dict__:
|
105 |
+
return self.__dict__[attr]
|
106 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
107 |
+
type(self).__name__, attr))
|
108 |
+
|
109 |
+
def __str__(self):
|
110 |
+
loss_str = []
|
111 |
+
for name, meter in self.meters.items():
|
112 |
+
loss_str.append(
|
113 |
+
"{}: {}".format(name, str(meter))
|
114 |
+
)
|
115 |
+
return self.delimiter.join(loss_str)
|
116 |
+
|
117 |
+
def synchronize_between_processes(self):
|
118 |
+
for meter in self.meters.values():
|
119 |
+
meter.synchronize_between_processes()
|
120 |
+
|
121 |
+
def add_meter(self, name, meter):
|
122 |
+
self.meters[name] = meter
|
123 |
+
|
124 |
+
def log_every(self, iterable, print_freq, header=None):
|
125 |
+
i = 0
|
126 |
+
if not header:
|
127 |
+
header = ''
|
128 |
+
start_time = time.time()
|
129 |
+
end = time.time()
|
130 |
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
131 |
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
132 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
133 |
+
log_msg = [
|
134 |
+
header,
|
135 |
+
'[{0' + space_fmt + '}/{1}]',
|
136 |
+
'eta: {eta}',
|
137 |
+
'{meters}',
|
138 |
+
'time: {time}',
|
139 |
+
'data: {data}'
|
140 |
+
]
|
141 |
+
if torch.cuda.is_available():
|
142 |
+
log_msg.append('max mem: {memory:.0f}')
|
143 |
+
log_msg = self.delimiter.join(log_msg)
|
144 |
+
MB = 1024.0 * 1024.0
|
145 |
+
for obj in iterable:
|
146 |
+
data_time.update(time.time() - end)
|
147 |
+
yield obj
|
148 |
+
iter_time.update(time.time() - end)
|
149 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
150 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
151 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
152 |
+
if torch.cuda.is_available():
|
153 |
+
print(log_msg.format(
|
154 |
+
i, len(iterable), eta=eta_string,
|
155 |
+
meters=str(self),
|
156 |
+
time=str(iter_time), data=str(data_time),
|
157 |
+
memory=torch.cuda.max_memory_allocated() / MB))
|
158 |
+
else:
|
159 |
+
print(log_msg.format(
|
160 |
+
i, len(iterable), eta=eta_string,
|
161 |
+
meters=str(self),
|
162 |
+
time=str(iter_time), data=str(data_time)))
|
163 |
+
i += 1
|
164 |
+
end = time.time()
|
165 |
+
total_time = time.time() - start_time
|
166 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
167 |
+
print('{} Total time: {} ({:.4f} s / it)'.format(
|
168 |
+
header, total_time_str, total_time / len(iterable)))
|
169 |
+
|
170 |
+
|
171 |
+
def setup_for_distributed(is_master):
|
172 |
+
"""
|
173 |
+
This function disables printing when not in master process
|
174 |
+
"""
|
175 |
+
builtin_print = builtins.print
|
176 |
+
|
177 |
+
def print(*args, **kwargs):
|
178 |
+
force = kwargs.pop('force', False)
|
179 |
+
force = force or (get_world_size() > 8)
|
180 |
+
if is_master or force:
|
181 |
+
now = datetime.datetime.now().time()
|
182 |
+
builtin_print('[{}] '.format(now), end='') # print with time stamp
|
183 |
+
builtin_print(*args, **kwargs)
|
184 |
+
|
185 |
+
builtins.print = print
|
186 |
+
|
187 |
+
|
188 |
+
def is_dist_avail_and_initialized():
|
189 |
+
if not dist.is_available():
|
190 |
+
return False
|
191 |
+
if not dist.is_initialized():
|
192 |
+
return False
|
193 |
+
return True
|
194 |
+
|
195 |
+
|
196 |
+
def get_world_size():
|
197 |
+
if not is_dist_avail_and_initialized():
|
198 |
+
return 1
|
199 |
+
return dist.get_world_size()
|
200 |
+
|
201 |
+
|
202 |
+
def get_rank():
|
203 |
+
if not is_dist_avail_and_initialized():
|
204 |
+
return 0
|
205 |
+
return dist.get_rank()
|
206 |
+
|
207 |
+
|
208 |
+
def is_main_process():
|
209 |
+
return get_rank() == 0
|
210 |
+
|
211 |
+
|
212 |
+
def save_on_master(*args, **kwargs):
|
213 |
+
if is_main_process():
|
214 |
+
torch.save(*args, **kwargs)
|
215 |
+
|
216 |
+
|
217 |
+
def init_distributed_mode(args):
|
218 |
+
if args.dist_on_itp:
|
219 |
+
args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
220 |
+
args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
|
221 |
+
args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
|
222 |
+
args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
|
223 |
+
os.environ['LOCAL_RANK'] = str(args.gpu)
|
224 |
+
os.environ['RANK'] = str(args.rank)
|
225 |
+
os.environ['WORLD_SIZE'] = str(args.world_size)
|
226 |
+
# ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
|
227 |
+
elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
228 |
+
args.rank = int(os.environ["RANK"])
|
229 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
230 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
231 |
+
elif 'SLURM_PROCID' in os.environ:
|
232 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
233 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
234 |
+
else:
|
235 |
+
print('Not using distributed mode')
|
236 |
+
setup_for_distributed(is_master=True) # hack
|
237 |
+
args.distributed = False
|
238 |
+
return
|
239 |
+
|
240 |
+
args.distributed = True
|
241 |
+
|
242 |
+
torch.cuda.set_device(args.gpu)
|
243 |
+
args.dist_backend = 'nccl'
|
244 |
+
print('| distributed init (rank {}): {}, gpu {}'.format(
|
245 |
+
args.rank, args.dist_url, args.gpu), flush=True)
|
246 |
+
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
247 |
+
world_size=args.world_size, rank=args.rank)
|
248 |
+
torch.distributed.barrier()
|
249 |
+
setup_for_distributed(args.rank == 0)
|
250 |
+
|
251 |
+
|
252 |
+
class NativeScalerWithGradNormCount:
|
253 |
+
state_dict_key = "amp_scaler"
|
254 |
+
|
255 |
+
def __init__(self):
|
256 |
+
self._scaler = torch.cuda.amp.GradScaler()
|
257 |
+
|
258 |
+
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True, verbose=False):
|
259 |
+
self._scaler.scale(loss).backward(create_graph=create_graph)
|
260 |
+
if update_grad:
|
261 |
+
if clip_grad is not None:
|
262 |
+
assert parameters is not None
|
263 |
+
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
|
264 |
+
norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
|
265 |
+
else:
|
266 |
+
self._scaler.unscale_(optimizer)
|
267 |
+
norm = get_grad_norm_(parameters)
|
268 |
+
self._scaler.step(optimizer)
|
269 |
+
self._scaler.update()
|
270 |
+
else:
|
271 |
+
norm = None
|
272 |
+
if verbose:
|
273 |
+
print('norm:', norm, 'clip:', clip_grad)
|
274 |
+
return norm
|
275 |
+
|
276 |
+
def state_dict(self):
|
277 |
+
return self._scaler.state_dict()
|
278 |
+
|
279 |
+
def load_state_dict(self, state_dict):
|
280 |
+
self._scaler.load_state_dict(state_dict)
|
281 |
+
|
282 |
+
|
283 |
+
def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
|
284 |
+
if isinstance(parameters, torch.Tensor):
|
285 |
+
parameters = [parameters]
|
286 |
+
parameters = [p for p in parameters if p.grad is not None]
|
287 |
+
norm_type = float(norm_type)
|
288 |
+
if len(parameters) == 0:
|
289 |
+
return torch.tensor(0.)
|
290 |
+
device = parameters[0].grad.device
|
291 |
+
if norm_type == inf:
|
292 |
+
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
|
293 |
+
else:
|
294 |
+
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
|
295 |
+
return total_norm
|
296 |
+
|
297 |
+
|
298 |
+
def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler):
|
299 |
+
output_dir = Path(args.output_dir)
|
300 |
+
epoch_name = f'{epoch:05d}'
|
301 |
+
if loss_scaler is not None:
|
302 |
+
checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
|
303 |
+
for checkpoint_path in checkpoint_paths:
|
304 |
+
to_save = {
|
305 |
+
'model': model_without_ddp.state_dict(),
|
306 |
+
'optimizer': optimizer.state_dict(),
|
307 |
+
'epoch': epoch,
|
308 |
+
'scaler': loss_scaler.state_dict(),
|
309 |
+
'args': args,
|
310 |
+
}
|
311 |
+
|
312 |
+
save_on_master(to_save, checkpoint_path)
|
313 |
+
else:
|
314 |
+
client_state = {'epoch': epoch}
|
315 |
+
model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
|
316 |
+
|
317 |
+
|
318 |
+
def load_model(args, model_without_ddp, optimizer, loss_scaler):
|
319 |
+
if args.resume:
|
320 |
+
if args.resume.startswith('https'):
|
321 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
322 |
+
args.resume, map_location='cpu', check_hash=True)
|
323 |
+
else:
|
324 |
+
checkpoint = torch.load(args.resume, map_location='cpu')
|
325 |
+
print("Resume checkpoint %s" % args.resume)
|
326 |
+
print(model_without_ddp.load_state_dict(checkpoint['model'], strict=False))
|
327 |
+
if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
|
328 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
329 |
+
args.start_epoch = checkpoint['epoch'] + 1
|
330 |
+
if 'scaler' in checkpoint:
|
331 |
+
print(loss_scaler.load_state_dict(checkpoint['scaler']))
|
332 |
+
print("With optim & sched!")
|
333 |
+
print("start epoch:", args.start_epoch)
|
334 |
+
|
335 |
+
|
336 |
+
def all_reduce_mean(x):
|
337 |
+
world_size = get_world_size()
|
338 |
+
if world_size > 1:
|
339 |
+
x_reduce = torch.tensor(x).cuda()
|
340 |
+
dist.all_reduce(x_reduce)
|
341 |
+
x_reduce /= world_size
|
342 |
+
return x_reduce.item()
|
343 |
+
else:
|
344 |
+
return x
|
345 |
+
|
346 |
+
|
347 |
+
import torch.distributed as dist
|
348 |
+
|
349 |
+
def get_world_size():
|
350 |
+
"""
|
351 |
+
Get the size of the world.
|
352 |
+
"""
|
353 |
+
if not dist.is_available():
|
354 |
+
return 1
|
355 |
+
if not dist.is_initialized():
|
356 |
+
return 1
|
357 |
+
return dist.get_world_size()
|
358 |
+
|
359 |
+
|
360 |
+
# def all_gather_unaligned(data):
|
361 |
+
# """
|
362 |
+
# Run all_gather on arbitrary picklable data (not necessarily tensors).
|
363 |
+
# Args:
|
364 |
+
# data: any picklable object
|
365 |
+
# group: a torch process group. By default, will use a group which
|
366 |
+
# contains all ranks on gloo backend.
|
367 |
+
# Returns:
|
368 |
+
# list[data]: list of data gathered from each rank
|
369 |
+
# """
|
370 |
+
# print('world', get_world_size())
|
371 |
+
# if get_world_size() == 1:
|
372 |
+
# return [data]
|
373 |
+
|
374 |
+
# # receiving Tensor from all ranks
|
375 |
+
# tensor_list = [
|
376 |
+
# torch.zeros_like(data) for _ in range(get_world_size())
|
377 |
+
# ]
|
378 |
+
# dist.all_gather(tensor_list, data)
|
379 |
+
# for tl in tensor_list:
|
380 |
+
# print(tl)
|
381 |
+
# print(tl.shape)
|
382 |
+
# return tensor_list
|
383 |
+
|
384 |
+
import pickle
|
385 |
+
def _serialize_to_tensor(data, group):
|
386 |
+
"""
|
387 |
+
Seriialize the tensor to ByteTensor. Note that only `gloo` and `nccl`
|
388 |
+
backend is supported.
|
389 |
+
Args:
|
390 |
+
data (data): data to be serialized.
|
391 |
+
group (group): pytorch dist group.
|
392 |
+
Returns:
|
393 |
+
tensor (ByteTensor): tensor that serialized.
|
394 |
+
"""
|
395 |
+
|
396 |
+
backend = dist.get_backend(group)
|
397 |
+
assert backend in ["gloo", "nccl"]
|
398 |
+
device = torch.device("cpu" if backend == "gloo" else "cuda")
|
399 |
+
|
400 |
+
buffer = pickle.dumps(data)
|
401 |
+
if len(buffer) > 1024 ** 3:
|
402 |
+
print(
|
403 |
+
"Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
|
404 |
+
get_rank(), len(buffer) / (1024 ** 3), device
|
405 |
+
)
|
406 |
+
)
|
407 |
+
storage = torch.ByteStorage.from_buffer(buffer)
|
408 |
+
tensor = torch.ByteTensor(storage).to(device=device)
|
409 |
+
return tensor
|
410 |
+
|
411 |
+
import functools
|
412 |
+
@functools.lru_cache()
|
413 |
+
def _get_global_gloo_group():
|
414 |
+
"""
|
415 |
+
Return a process group based on gloo backend, containing all the ranks
|
416 |
+
The result is cached.
|
417 |
+
Returns:
|
418 |
+
(group): pytorch dist group.
|
419 |
+
"""
|
420 |
+
if dist.get_backend() == "nccl":
|
421 |
+
return dist.new_group(backend="gloo")
|
422 |
+
else:
|
423 |
+
return dist.group.WORLD
|
424 |
+
|
425 |
+
|
426 |
+
def _pad_to_largest_tensor(tensor, group):
|
427 |
+
"""
|
428 |
+
Padding all the tensors from different GPUs to the largest ones.
|
429 |
+
Args:
|
430 |
+
tensor (tensor): tensor to pad.
|
431 |
+
group (group): pytorch dist group.
|
432 |
+
Returns:
|
433 |
+
list[int]: size of the tensor, on each rank
|
434 |
+
Tensor: padded tensor that has the max size
|
435 |
+
"""
|
436 |
+
world_size = dist.get_world_size(group=group)
|
437 |
+
assert (
|
438 |
+
world_size >= 1
|
439 |
+
), "comm.gather/all_gather must be called from ranks within the given group!"
|
440 |
+
local_size = torch.tensor(
|
441 |
+
[tensor.numel()], dtype=torch.int64, device=tensor.device
|
442 |
+
)
|
443 |
+
size_list = [
|
444 |
+
torch.zeros([1], dtype=torch.int64, device=tensor.device)
|
445 |
+
for _ in range(world_size)
|
446 |
+
]
|
447 |
+
dist.all_gather(size_list, local_size, group=group)
|
448 |
+
size_list = [int(size.item()) for size in size_list]
|
449 |
+
|
450 |
+
max_size = max(size_list)
|
451 |
+
|
452 |
+
# we pad the tensor because torch all_gather does not support
|
453 |
+
# gathering tensors of different shapes
|
454 |
+
if local_size != max_size:
|
455 |
+
padding = torch.zeros(
|
456 |
+
(max_size - local_size,), dtype=torch.uint8, device=tensor.device
|
457 |
+
)
|
458 |
+
tensor = torch.cat((tensor, padding), dim=0)
|
459 |
+
return size_list, tensor
|
460 |
+
|
461 |
+
def all_gather_unaligned(data, group=None):
|
462 |
+
"""
|
463 |
+
Run all_gather on arbitrary picklable data (not necessarily tensors).
|
464 |
+
Args:
|
465 |
+
data: any picklable object
|
466 |
+
group: a torch process group. By default, will use a group which
|
467 |
+
contains all ranks on gloo backend.
|
468 |
+
Returns:
|
469 |
+
list[data]: list of data gathered from each rank
|
470 |
+
"""
|
471 |
+
if get_world_size() == 1:
|
472 |
+
return [data]
|
473 |
+
if group is None:
|
474 |
+
group = _get_global_gloo_group()
|
475 |
+
if dist.get_world_size(group) == 1:
|
476 |
+
return [data]
|
477 |
+
|
478 |
+
tensor = _serialize_to_tensor(data, group)
|
479 |
+
|
480 |
+
size_list, tensor = _pad_to_largest_tensor(tensor, group)
|
481 |
+
max_size = max(size_list)
|
482 |
+
|
483 |
+
# receiving Tensor from all ranks
|
484 |
+
tensor_list = [
|
485 |
+
torch.empty((max_size,), dtype=torch.uint8, device=tensor.device)
|
486 |
+
for _ in size_list
|
487 |
+
]
|
488 |
+
dist.all_gather(tensor_list, tensor, group=group)
|
489 |
+
|
490 |
+
data_list = []
|
491 |
+
for size, tensor in zip(size_list, tensor_list):
|
492 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
493 |
+
data_list.append(pickle.loads(buffer).to(data.device))
|
494 |
+
|
495 |
+
return data_list
|
496 |
+
|
util/pos_embed.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# Position embedding utils
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
import torch
|
13 |
+
|
14 |
+
# --------------------------------------------------------
|
15 |
+
# 2D sine-cosine position embedding
|
16 |
+
# References:
|
17 |
+
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
|
18 |
+
# MoCo v3: https://github.com/facebookresearch/moco-v3
|
19 |
+
# --------------------------------------------------------
|
20 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
21 |
+
"""
|
22 |
+
grid_size: int of the grid height and width
|
23 |
+
return:
|
24 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
25 |
+
"""
|
26 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
27 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
28 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
29 |
+
grid = np.stack(grid, axis=0)
|
30 |
+
|
31 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
32 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
33 |
+
if cls_token:
|
34 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
35 |
+
return pos_embed
|
36 |
+
|
37 |
+
|
38 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
39 |
+
assert embed_dim % 2 == 0
|
40 |
+
|
41 |
+
# use half of dimensions to encode grid_h
|
42 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
43 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
44 |
+
|
45 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
46 |
+
return emb
|
47 |
+
|
48 |
+
|
49 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
50 |
+
"""
|
51 |
+
embed_dim: output dimension for each position
|
52 |
+
pos: a list of positions to be encoded: size (M,)
|
53 |
+
out: (M, D)
|
54 |
+
"""
|
55 |
+
assert embed_dim % 2 == 0
|
56 |
+
omega = np.arange(embed_dim // 2, dtype=np.float)
|
57 |
+
omega /= embed_dim / 2.
|
58 |
+
omega = 1. / 10000**omega # (D/2,)
|
59 |
+
|
60 |
+
pos = pos.reshape(-1) # (M,)
|
61 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
62 |
+
|
63 |
+
emb_sin = np.sin(out) # (M, D/2)
|
64 |
+
emb_cos = np.cos(out) # (M, D/2)
|
65 |
+
|
66 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
67 |
+
return emb
|
68 |
+
|
69 |
+
def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos):
|
70 |
+
"""
|
71 |
+
embed_dim: output dimension for each position
|
72 |
+
pos: a list of positions to be encoded: size (M,)
|
73 |
+
out: (M, D)
|
74 |
+
"""
|
75 |
+
assert embed_dim % 2 == 0
|
76 |
+
omega = torch.arange(embed_dim // 2, device=pos.device).float()
|
77 |
+
omega /= embed_dim / 2.
|
78 |
+
omega = 1. / 10000**omega # (D/2,)
|
79 |
+
|
80 |
+
pos = pos.reshape(-1) # (M,)
|
81 |
+
out = torch.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
82 |
+
|
83 |
+
emb_sin = torch.sin(out) # (M, D/2)
|
84 |
+
emb_cos = torch.cos(out) # (M, D/2)
|
85 |
+
|
86 |
+
emb = torch.cat([emb_sin, emb_cos], axis=1) # (M, D)
|
87 |
+
return emb
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
# --------------------------------------------------------
|
92 |
+
# Interpolate position embeddings for high-resolution
|
93 |
+
# References:
|
94 |
+
# DeiT: https://github.com/facebookresearch/deit
|
95 |
+
# --------------------------------------------------------
|
96 |
+
def interpolate_pos_embed(model, checkpoint_model):
|
97 |
+
if 'pos_embed' in checkpoint_model:
|
98 |
+
pos_embed_checkpoint = checkpoint_model['pos_embed']
|
99 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
100 |
+
num_patches = model.patch_embed.num_patches
|
101 |
+
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
102 |
+
# height (== width) for the checkpoint position embedding
|
103 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
104 |
+
# height (== width) for the new position embedding
|
105 |
+
new_size = int(num_patches ** 0.5)
|
106 |
+
# class_token and dist_token are kept unchanged
|
107 |
+
if orig_size != new_size:
|
108 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
109 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
110 |
+
# only the position tokens are interpolated
|
111 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
112 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
113 |
+
pos_tokens = torch.nn.functional.interpolate(
|
114 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
115 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
116 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
117 |
+
checkpoint_model['pos_embed'] = new_pos_embed
|
weights/co3dv2_all_categories.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ca861bee4c2cb27acc6855da34227ce7026cf9eb275171da3c5a33976b3d86bd
|
3 |
+
size 2423688373
|