Spaces:
Build error
Build error
zxhuang1698
commited on
Commit
•
414b431
1
Parent(s):
1a618eb
initial commit
Browse files- README.md +5 -6
- app.py +303 -0
- examples/images/armchair.png +0 -0
- examples/images/bolt.png +0 -0
- examples/images/bucket.png +0 -0
- examples/images/case.png +0 -0
- examples/images/dispenser.png +0 -0
- examples/images/hat.png +0 -0
- examples/images/teddy_bear.png +0 -0
- examples/images/tiger.png +0 -0
- examples/images/toy.png +0 -0
- examples/images/wedding_cake.png +0 -0
- examples/masks/armchair.png +0 -0
- examples/masks/bolt.png +0 -0
- examples/masks/bucket.png +0 -0
- examples/masks/case.png +0 -0
- examples/masks/dispenser.png +0 -0
- examples/masks/hat.png +0 -0
- examples/masks/teddy_bear.png +0 -0
- examples/masks/tiger.png +0 -0
- examples/masks/toy.png +0 -0
- examples/masks/wedding_cake.png +0 -0
- model/compute_graph/graph_depth.py +106 -0
- model/compute_graph/graph_shape.py +202 -0
- model/depth/__init__.py +0 -0
- model/depth/base_model.py +17 -0
- model/depth/blocks.py +343 -0
- model/depth/dpt_depth.py +123 -0
- model/depth/midas_loss.py +185 -0
- model/depth/vit.py +492 -0
- model/depth_engine.py +445 -0
- model/shape/implicit.py +288 -0
- model/shape/rgb_enc.py +137 -0
- model/shape/seen_coord_enc.py +195 -0
- model/shape_engine.py +598 -0
- options/depth.yaml +72 -0
- options/shape.yaml +110 -0
- requirements.txt +95 -0
- utils/camera.py +230 -0
- utils/eval_3D.py +133 -0
- utils/eval_depth.py +110 -0
- utils/layers.py +147 -0
- utils/loss.py +42 -0
- utils/options.py +127 -0
- utils/pos_embed.py +118 -0
- utils/util.py +413 -0
- utils/util_vis.py +511 -0
- weights/.gitignore +4 -0
README.md
CHANGED
@@ -1,13 +1,12 @@
|
|
1 |
---
|
2 |
title: ZeroShape
|
3 |
-
emoji:
|
4 |
colorFrom: green
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.
|
|
|
8 |
app_file: app.py
|
9 |
-
pinned:
|
10 |
license: mit
|
11 |
-
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: ZeroShape
|
3 |
+
emoji: 🔥
|
4 |
colorFrom: green
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.5.0
|
8 |
+
python_version: 3.10
|
9 |
app_file: app.py
|
10 |
+
pinned: true
|
11 |
license: mit
|
12 |
+
---
|
|
|
|
app.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import torchvision.transforms.functional as torchvision_F
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
import shutil
|
7 |
+
import importlib
|
8 |
+
import trimesh
|
9 |
+
import tempfile
|
10 |
+
import subprocess
|
11 |
+
import utils.options as options
|
12 |
+
import shlex
|
13 |
+
import time
|
14 |
+
import rembg
|
15 |
+
|
16 |
+
from utils.util import EasyDict as edict
|
17 |
+
from PIL import Image
|
18 |
+
from utils.eval_3D import get_dense_3D_grid, compute_level_grid, convert_to_explicit
|
19 |
+
|
20 |
+
def get_1d_bounds(arr):
|
21 |
+
nz = np.flatnonzero(arr)
|
22 |
+
return nz[0], nz[-1]
|
23 |
+
|
24 |
+
def get_bbox_from_mask(mask, thr):
|
25 |
+
masks_for_box = (mask > thr).astype(np.float32)
|
26 |
+
assert masks_for_box.sum() > 0, "Empty mask!"
|
27 |
+
x0, x1 = get_1d_bounds(masks_for_box.sum(axis=-2))
|
28 |
+
y0, y1 = get_1d_bounds(masks_for_box.sum(axis=-1))
|
29 |
+
|
30 |
+
return x0, y0, x1, y1
|
31 |
+
|
32 |
+
def square_crop(image, bbox, crop_ratio=1.):
|
33 |
+
x1, y1, x2, y2 = bbox
|
34 |
+
h, w = y2-y1, x2-x1
|
35 |
+
yc, xc = (y1+y2)/2, (x1+x2)/2
|
36 |
+
S = max(h, w)*1.2
|
37 |
+
scale = S*crop_ratio
|
38 |
+
image = torchvision_F.crop(image, top=int(yc-scale/2), left=int(xc-scale/2), height=int(scale), width=int(scale))
|
39 |
+
return image
|
40 |
+
|
41 |
+
def preprocess_image(opt, image, bbox):
|
42 |
+
image = square_crop(image, bbox=bbox)
|
43 |
+
if image.size[0] != opt.W or image.size[1] != opt.H:
|
44 |
+
image = image.resize((opt.W, opt.H))
|
45 |
+
image = torchvision_F.to_tensor(image)
|
46 |
+
rgb, mask = image[:3], image[3:]
|
47 |
+
if opt.data.bgcolor is not None:
|
48 |
+
# replace background color using mask
|
49 |
+
rgb = rgb * mask + opt.data.bgcolor * (1 - mask)
|
50 |
+
mask = (mask > 0.5).float()
|
51 |
+
return rgb, mask
|
52 |
+
|
53 |
+
def get_image(opt, image_fname, mask_fname):
|
54 |
+
image = Image.open(image_fname).convert("RGB")
|
55 |
+
mask = Image.open(mask_fname).convert("L")
|
56 |
+
mask_np = np.array(mask)
|
57 |
+
|
58 |
+
#binarize
|
59 |
+
mask_np[mask_np <= 127] = 0
|
60 |
+
mask_np[mask_np >= 127] = 1.0
|
61 |
+
|
62 |
+
image = Image.merge("RGBA", (*image.split(), mask))
|
63 |
+
bbox = get_bbox_from_mask(mask_np, 0.5)
|
64 |
+
rgb_input_map, mask_input_map = preprocess_image(opt, image, bbox=bbox)
|
65 |
+
return rgb_input_map, mask_input_map
|
66 |
+
|
67 |
+
def get_intr(opt):
|
68 |
+
# load camera
|
69 |
+
f = 1.3875
|
70 |
+
K = torch.tensor([[f*opt.W, 0, opt.W/2],
|
71 |
+
[0, f*opt.H, opt.H/2],
|
72 |
+
[0, 0, 1]]).float()
|
73 |
+
return K
|
74 |
+
|
75 |
+
def get_pixel_grid(H, W, device='cuda'):
|
76 |
+
y_range = torch.arange(H, dtype=torch.float32).to(device)
|
77 |
+
x_range = torch.arange(W, dtype=torch.float32).to(device)
|
78 |
+
Y, X = torch.meshgrid(y_range, x_range, indexing='ij')
|
79 |
+
Z = torch.ones_like(Y).to(device)
|
80 |
+
xyz_grid = torch.stack([X, Y, Z],dim=-1).view(-1,3)
|
81 |
+
return xyz_grid
|
82 |
+
|
83 |
+
def unproj_depth(depth, intr):
|
84 |
+
'''
|
85 |
+
depth: [B, H, W]
|
86 |
+
intr: [B, 3, 3]
|
87 |
+
'''
|
88 |
+
batch_size, H, W = depth.shape
|
89 |
+
intr = intr.to(depth.device)
|
90 |
+
|
91 |
+
# [B, 3, 3]
|
92 |
+
K_inv = torch.linalg.inv(intr).float()
|
93 |
+
# [1, H*W,3]
|
94 |
+
pixel_grid = get_pixel_grid(H, W, depth.device).unsqueeze(0)
|
95 |
+
# [B, H*W,3]
|
96 |
+
pixel_grid = pixel_grid.repeat(batch_size, 1, 1)
|
97 |
+
# [B, 3, H*W]
|
98 |
+
ray_dirs = K_inv @ pixel_grid.permute(0, 2, 1).contiguous()
|
99 |
+
# [B, H*W, 3], in camera coordinates
|
100 |
+
seen_points = ray_dirs.permute(0, 2, 1).contiguous() * depth.view(batch_size, H*W, 1)
|
101 |
+
# [B, H, W, 3]
|
102 |
+
seen_points = seen_points.view(batch_size, H, W, 3)
|
103 |
+
return seen_points
|
104 |
+
|
105 |
+
def prepare_data(opt, image_path, mask_path):
|
106 |
+
var = edict()
|
107 |
+
rgb_input_map, mask_input_map = get_image(opt, image_path, mask_path)
|
108 |
+
intr = get_intr(opt)
|
109 |
+
var.rgb_input_map = rgb_input_map.unsqueeze(0).to(opt.device)
|
110 |
+
var.mask_input_map = mask_input_map.unsqueeze(0).to(opt.device)
|
111 |
+
var.intr = intr.unsqueeze(0).to(opt.device)
|
112 |
+
var.idx = torch.tensor([0]).to(opt.device).long()
|
113 |
+
var.pose_gt = False
|
114 |
+
return var
|
115 |
+
|
116 |
+
@torch.no_grad()
|
117 |
+
def marching_cubes(opt, var, impl_network, visualize_attn=False):
|
118 |
+
points_3D = get_dense_3D_grid(opt, var) # [B, N, N, N, 3]
|
119 |
+
level_vox, attn_vis = compute_level_grid(opt, impl_network, var.latent_depth, var.latent_semantic,
|
120 |
+
points_3D, var.rgb_input_map, visualize_attn)
|
121 |
+
if attn_vis: var.attn_vis = attn_vis
|
122 |
+
# occ_grids: a list of length B, each is [N, N, N]
|
123 |
+
*level_grids, = level_vox.cpu().numpy()
|
124 |
+
meshes = convert_to_explicit(opt, level_grids, isoval=0.5, to_pointcloud=False)
|
125 |
+
var.mesh_pred = meshes
|
126 |
+
return var
|
127 |
+
|
128 |
+
@torch.no_grad()
|
129 |
+
def infer_sample(opt, var, graph):
|
130 |
+
var = graph.forward(opt, var, training=False, get_loss=False)
|
131 |
+
var = marching_cubes(opt, var, graph.impl_network, visualize_attn=True)
|
132 |
+
return var.mesh_pred[0]
|
133 |
+
|
134 |
+
def infer(input_image_path, input_mask_path):
|
135 |
+
opt_cmd = options.parse_arguments(["--yaml=options/shape.yaml", "--datadir=examples", "--eval.vox_res=128", "--ckpt=weights/shape.ckpt"])
|
136 |
+
opt = options.set(opt_cmd=opt_cmd, safe_check=False)
|
137 |
+
opt.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
138 |
+
|
139 |
+
# build model
|
140 |
+
print("Building model...")
|
141 |
+
opt.pretrain.depth = None
|
142 |
+
opt.arch.depth.pretrained = None
|
143 |
+
module = importlib.import_module("model.compute_graph.graph_shape")
|
144 |
+
graph = module.Graph(opt).to(opt.device)
|
145 |
+
|
146 |
+
# download checkpoint
|
147 |
+
if not os.path.isfile(opt.ckpt):
|
148 |
+
print("Downloading checkpoint...")
|
149 |
+
subprocess.run(
|
150 |
+
shlex.split(
|
151 |
+
"wget -q -O weights/shape.ckpt https://www.dropbox.com/scl/fi/hv3w9z59dqytievwviko4/shape.ckpt?rlkey=a2gut89kavrldmnt8b3df92oi&dl=0"
|
152 |
+
)
|
153 |
+
)
|
154 |
+
|
155 |
+
# wait if the checkpoint is still downloading
|
156 |
+
while not os.path.isfile(opt.ckpt):
|
157 |
+
time.sleep(1)
|
158 |
+
|
159 |
+
# load checkpoint
|
160 |
+
print("Loading checkpoint...")
|
161 |
+
checkpoint = torch.load(opt.ckpt, map_location=torch.device(opt.device))
|
162 |
+
graph.load_state_dict(checkpoint["graph"], strict=True)
|
163 |
+
graph.eval()
|
164 |
+
|
165 |
+
# load the data
|
166 |
+
print("Loading data...")
|
167 |
+
var = prepare_data(opt, input_image_path, input_mask_path)
|
168 |
+
|
169 |
+
# create the save dir
|
170 |
+
save_folder = os.path.join(opt.datadir, 'preds')
|
171 |
+
if os.path.isdir(save_folder):
|
172 |
+
shutil.rmtree(save_folder)
|
173 |
+
os.makedirs(save_folder)
|
174 |
+
opt.output_path = opt.datadir
|
175 |
+
|
176 |
+
# inference the model and save the results
|
177 |
+
print("Inferencing...")
|
178 |
+
mesh_pred = infer_sample(opt, var, graph)
|
179 |
+
# rotate the mesh upside down
|
180 |
+
mesh_pred.apply_transform(trimesh.transformations.rotation_matrix(np.pi, [1, 0, 0]))
|
181 |
+
mesh_path = tempfile.NamedTemporaryFile(suffix=".glb", delete=False)
|
182 |
+
mesh_pred.export(mesh_path.name, file_type="glb")
|
183 |
+
return mesh_path.name
|
184 |
+
|
185 |
+
def infer_wrapper_mask(input_image_path, input_mask_path):
|
186 |
+
return infer(input_image_path, input_mask_path)
|
187 |
+
|
188 |
+
def infer_wrapper_nomask(input_image_path):
|
189 |
+
input = Image.open(input_image_path)
|
190 |
+
segmented = rembg.remove(input)
|
191 |
+
mask = segmented.split()[-1]
|
192 |
+
mask_path = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
193 |
+
mask.save(mask_path.name)
|
194 |
+
return infer(input_image_path, mask_path.name), mask_path.name
|
195 |
+
|
196 |
+
|
197 |
+
def assert_input_image(input_image):
|
198 |
+
if input_image is None:
|
199 |
+
raise gr.Error("No image selected or uploaded!")
|
200 |
+
|
201 |
+
def assert_mask_image(input_mask):
|
202 |
+
if input_mask is None:
|
203 |
+
raise gr.Error("No mask selected or uploaded! Please check the box if you do not have the mask.")
|
204 |
+
|
205 |
+
def demo_gradio():
|
206 |
+
with gr.Blocks(analytics_enabled=False) as demo_ui:
|
207 |
+
|
208 |
+
# HEADERS
|
209 |
+
with gr.Row():
|
210 |
+
with gr.Column(scale=1):
|
211 |
+
gr.Markdown('# ZeroShape: Regression-based Zero-shot Shape Reconstruction')
|
212 |
+
gr.Markdown("[\[Arxiv\]](https://arxiv.org/pdf/2312.14198.pdf) | [\[Project\]](https://zixuanh.com/projects/zeroshape.html) | [\[GitHub\]](https://github.com/zxhuang1698/ZeroShape)")
|
213 |
+
gr.Markdown("Please switch to the \"Estimated Mask\" tab if you do not have the foreground mask. The demo will try to estimate the mask for you.")
|
214 |
+
|
215 |
+
# with mask
|
216 |
+
with gr.Tab("Groundtruth Mask"):
|
217 |
+
with gr.Row():
|
218 |
+
input_image_tab1 = gr.Image(label="Input Image", image_mode="RGB", sources="upload", type="filepath", elem_id="content_image", width=300)
|
219 |
+
mask_tab1 = gr.Image(label="Foreground Mask", image_mode="RGB", sources="upload", type="filepath", elem_id="content_image", width=300)
|
220 |
+
output_mesh_tab1 = gr.Model3D(label="Output Mesh")
|
221 |
+
with gr.Row():
|
222 |
+
submit_tab1 = gr.Button('Reconstruct', elem_id="recon_button_tab1", variant='primary')
|
223 |
+
# examples
|
224 |
+
with gr.Row():
|
225 |
+
examples_tab1 = [
|
226 |
+
['examples/images/armchair.png', 'examples/masks/armchair.png'],
|
227 |
+
['examples/images/bolt.png', 'examples/masks/bolt.png'],
|
228 |
+
['examples/images/bucket.png', 'examples/masks/bucket.png'],
|
229 |
+
['examples/images/case.png', 'examples/masks/case.png'],
|
230 |
+
['examples/images/dispenser.png', 'examples/masks/dispenser.png'],
|
231 |
+
['examples/images/hat.png', 'examples/masks/hat.png'],
|
232 |
+
['examples/images/teddy_bear.png', 'examples/masks/teddy_bear.png'],
|
233 |
+
['examples/images/tiger.png', 'examples/masks/tiger.png'],
|
234 |
+
['examples/images/toy.png', 'examples/masks/toy.png'],
|
235 |
+
['examples/images/wedding_cake.png', 'examples/masks/wedding_cake.png'],
|
236 |
+
]
|
237 |
+
gr.Examples(
|
238 |
+
examples=examples_tab1,
|
239 |
+
inputs=[input_image_tab1, mask_tab1],
|
240 |
+
outputs=[output_mesh_tab1],
|
241 |
+
fn=infer_wrapper_mask,
|
242 |
+
cache_examples=False#os.getenv('SYSTEM') == 'spaces',
|
243 |
+
)
|
244 |
+
# without mask
|
245 |
+
with gr.Tab("Estimated Mask"):
|
246 |
+
with gr.Row():
|
247 |
+
input_image_tab2 = gr.Image(label="Input Image", image_mode="RGB", sources="upload", type="filepath", elem_id="content_image", width=300)
|
248 |
+
mask_tab2 = gr.Image(label="Foreground Mask", image_mode="RGB", sources="upload", type="filepath", elem_id="content_image", width=300)
|
249 |
+
output_mesh_tab2 = gr.Model3D(label="Output Mesh")
|
250 |
+
with gr.Row():
|
251 |
+
submit_tab2 = gr.Button('Reconstruct', elem_id="recon_button_tab2", variant='primary')
|
252 |
+
# examples
|
253 |
+
with gr.Row():
|
254 |
+
examples_tab2 = [
|
255 |
+
['examples/images/armchair.png'],
|
256 |
+
['examples/images/bolt.png'],
|
257 |
+
['examples/images/bucket.png'],
|
258 |
+
['examples/images/case.png'],
|
259 |
+
['examples/images/dispenser.png'],
|
260 |
+
['examples/images/hat.png'],
|
261 |
+
['examples/images/teddy_bear.png'],
|
262 |
+
['examples/images/tiger.png'],
|
263 |
+
['examples/images/toy.png'],
|
264 |
+
['examples/images/wedding_cake.png'],
|
265 |
+
]
|
266 |
+
gr.Examples(
|
267 |
+
examples=examples_tab2,
|
268 |
+
inputs=[input_image_tab2],
|
269 |
+
outputs=[output_mesh_tab2, mask_tab2],
|
270 |
+
fn=infer_wrapper_nomask,
|
271 |
+
cache_examples=False#os.getenv('SYSTEM') == 'spaces',
|
272 |
+
)
|
273 |
+
|
274 |
+
submit_tab1.click(
|
275 |
+
fn=assert_input_image,
|
276 |
+
inputs=[input_image_tab1],
|
277 |
+
queue=False
|
278 |
+
).success(
|
279 |
+
fn=assert_mask_image,
|
280 |
+
inputs=[mask_tab1],
|
281 |
+
queue=False
|
282 |
+
).success(
|
283 |
+
fn=infer_wrapper_mask,
|
284 |
+
inputs=[input_image_tab1, mask_tab1],
|
285 |
+
outputs=[output_mesh_tab1],
|
286 |
+
)
|
287 |
+
|
288 |
+
submit_tab2.click(
|
289 |
+
fn=assert_input_image,
|
290 |
+
inputs=[input_image_tab2],
|
291 |
+
queue=False
|
292 |
+
).success(
|
293 |
+
fn=infer_wrapper_nomask,
|
294 |
+
inputs=[input_image_tab2],
|
295 |
+
outputs=[output_mesh_tab2, mask_tab2],
|
296 |
+
)
|
297 |
+
|
298 |
+
return demo_ui
|
299 |
+
|
300 |
+
if __name__ == "__main__":
|
301 |
+
demo_ui = demo_gradio()
|
302 |
+
demo_ui.queue(max_size=10)
|
303 |
+
demo_ui.launch()
|
examples/images/armchair.png
ADDED
examples/images/bolt.png
ADDED
examples/images/bucket.png
ADDED
examples/images/case.png
ADDED
examples/images/dispenser.png
ADDED
examples/images/hat.png
ADDED
examples/images/teddy_bear.png
ADDED
examples/images/tiger.png
ADDED
examples/images/toy.png
ADDED
examples/images/wedding_cake.png
ADDED
examples/masks/armchair.png
ADDED
examples/masks/bolt.png
ADDED
examples/masks/bucket.png
ADDED
examples/masks/case.png
ADDED
examples/masks/dispenser.png
ADDED
examples/masks/hat.png
ADDED
examples/masks/teddy_bear.png
ADDED
examples/masks/tiger.png
ADDED
examples/masks/toy.png
ADDED
examples/masks/wedding_cake.png
ADDED
model/compute_graph/graph_depth.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from utils.util import EasyDict as edict
|
5 |
+
from utils.loss import Loss
|
6 |
+
from model.depth.dpt_depth import DPTDepthModel
|
7 |
+
from utils.layers import Bottleneck_Conv
|
8 |
+
from utils.camera import unproj_depth, valid_norm_fac
|
9 |
+
|
10 |
+
class Graph(nn.Module):
|
11 |
+
|
12 |
+
def __init__(self, opt):
|
13 |
+
super().__init__()
|
14 |
+
# define the depth pred model based on omnidata
|
15 |
+
self.dpt_depth = DPTDepthModel(backbone='vitb_rn50_384')
|
16 |
+
if opt.arch.depth.pretrained is not None:
|
17 |
+
checkpoint = torch.load(opt.arch.depth.pretrained, map_location="cuda:{}".format(opt.device))
|
18 |
+
state_dict = checkpoint['model_state_dict']
|
19 |
+
self.dpt_depth.load_state_dict(state_dict)
|
20 |
+
|
21 |
+
if opt.loss_weight.intr is not None:
|
22 |
+
self.intr_feat_channels = 768
|
23 |
+
self.intr_head = nn.Sequential(
|
24 |
+
Bottleneck_Conv(self.intr_feat_channels, kernel_size=3),
|
25 |
+
Bottleneck_Conv(self.intr_feat_channels, kernel_size=3),
|
26 |
+
)
|
27 |
+
self.intr_pool = nn.AdaptiveAvgPool2d((1, 1))
|
28 |
+
self.intr_proj = nn.Linear(self.intr_feat_channels, 3)
|
29 |
+
# init the last linear layer so it outputs zeros
|
30 |
+
nn.init.zeros_(self.intr_proj.weight)
|
31 |
+
nn.init.zeros_(self.intr_proj.bias)
|
32 |
+
|
33 |
+
self.loss_fns = Loss(opt)
|
34 |
+
|
35 |
+
def intr_param2mtx(self, opt, intr_params):
|
36 |
+
'''
|
37 |
+
Parameters:
|
38 |
+
opt: config
|
39 |
+
intr_params: [B, 3], [scale_f, delta_cx, delta_cy]
|
40 |
+
Return:
|
41 |
+
intr: [B, 3, 3]
|
42 |
+
'''
|
43 |
+
batch_size = len(intr_params)
|
44 |
+
f = 1.3875
|
45 |
+
intr = torch.zeros(3, 3).float().to(intr_params.device).unsqueeze(0).repeat(batch_size, 1, 1)
|
46 |
+
intr[:, 2, 2] += 1
|
47 |
+
# scale the focal length
|
48 |
+
# range: [-1, 1], symmetric
|
49 |
+
scale_f = torch.tanh(intr_params[:, 0])
|
50 |
+
# range: [1/4, 4], symmetric
|
51 |
+
scale_f = torch.pow(4. , scale_f)
|
52 |
+
intr[:, 0, 0] += f * opt.W * scale_f
|
53 |
+
intr[:, 1, 1] += f * opt.H * scale_f
|
54 |
+
# shift the optic center, (at most to the image border)
|
55 |
+
shift_cx = torch.tanh(intr_params[:, 1]) * opt.W / 2
|
56 |
+
shift_cy = torch.tanh(intr_params[:, 2]) * opt.H / 2
|
57 |
+
intr[:, 0, 2] += opt.W / 2 + shift_cx
|
58 |
+
intr[:, 1, 2] += opt.H / 2 + shift_cy
|
59 |
+
return intr
|
60 |
+
|
61 |
+
def forward(self, opt, var, training=False, get_loss=True):
|
62 |
+
batch_size = len(var.idx)
|
63 |
+
|
64 |
+
# predict the depth map and feature maps if needed
|
65 |
+
if opt.loss_weight.intr is None:
|
66 |
+
var.depth_pred = self.dpt_depth(var.rgb_input_map)
|
67 |
+
else:
|
68 |
+
var.depth_pred, intr_feat = self.dpt_depth(var.rgb_input_map, get_feat=True)
|
69 |
+
# predict the intrinsics
|
70 |
+
intr_feat = self.intr_head(intr_feat)
|
71 |
+
intr_feat = self.intr_pool(intr_feat).squeeze(-1).squeeze(-1)
|
72 |
+
intr_params = self.intr_proj(intr_feat)
|
73 |
+
# [B, 3, 3]
|
74 |
+
var.intr_pred = self.intr_param2mtx(opt, intr_params)
|
75 |
+
|
76 |
+
# project the predicted depth map to 3D points and normalize, [B, H*W, 3]
|
77 |
+
seen_points_3D_pred = unproj_depth(opt, var.depth_pred, var.intr_pred)
|
78 |
+
seen_points_mean_pred, seen_points_scale_pred = valid_norm_fac(seen_points_3D_pred, var.mask_input_map > 0.5)
|
79 |
+
var.seen_points_pred = (seen_points_3D_pred - seen_points_mean_pred.unsqueeze(1)) / seen_points_scale_pred.unsqueeze(-1).unsqueeze(-1)
|
80 |
+
var.seen_points_pred[(var.mask_input_map<=0.5).view(batch_size, -1)] = 0
|
81 |
+
|
82 |
+
if 'depth_input_map' in var or training:
|
83 |
+
# project the ground truth depth map to 3D points and normalize, [B, H*W, 3]
|
84 |
+
seen_points_3D_gt = unproj_depth(opt, var.depth_input_map, var.intr)
|
85 |
+
seen_points_mean_gt, seen_points_scale_gt = valid_norm_fac(seen_points_3D_gt, var.mask_input_map > 0.5)
|
86 |
+
var.seen_points_gt = (seen_points_3D_gt - seen_points_mean_gt.unsqueeze(1)) / seen_points_scale_gt.unsqueeze(-1).unsqueeze(-1)
|
87 |
+
var.seen_points_gt[(var.mask_input_map<=0.5).view(batch_size, -1)] = 0
|
88 |
+
|
89 |
+
# record the validity mask, [B, H*W]
|
90 |
+
var.validity_mask = (var.mask_input_map>0.5).float().view(batch_size, -1)
|
91 |
+
|
92 |
+
# calculate the loss if needed
|
93 |
+
if get_loss:
|
94 |
+
loss = self.compute_loss(opt, var, training)
|
95 |
+
return var, loss
|
96 |
+
|
97 |
+
return var
|
98 |
+
|
99 |
+
def compute_loss(self, opt, var, training=False):
|
100 |
+
loss = edict()
|
101 |
+
if opt.loss_weight.depth is not None:
|
102 |
+
loss.depth = self.loss_fns.depth_loss(var.depth_pred, var.depth_input_map, var.mask_input_map)
|
103 |
+
if opt.loss_weight.intr is not None:
|
104 |
+
loss.intr = self.loss_fns.intr_loss(var.seen_points_pred, var.seen_points_gt, var.validity_mask)
|
105 |
+
return loss
|
106 |
+
|
model/compute_graph/graph_shape.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from utils.util import EasyDict as edict
|
5 |
+
from utils.loss import Loss
|
6 |
+
from model.shape.implicit import Implicit
|
7 |
+
from model.shape.seen_coord_enc import CoordEncAtt, CoordEncRes
|
8 |
+
from model.shape.rgb_enc import RGBEncAtt, RGBEncRes
|
9 |
+
from model.depth.dpt_depth import DPTDepthModel
|
10 |
+
from utils.util import toggle_grad, interpolate_coordmap, get_child_state_dict
|
11 |
+
from utils.camera import unproj_depth, valid_norm_fac
|
12 |
+
from utils.layers import Bottleneck_Conv
|
13 |
+
|
14 |
+
class Graph(nn.Module):
|
15 |
+
|
16 |
+
def __init__(self, opt):
|
17 |
+
super().__init__()
|
18 |
+
# define the intrinsics head
|
19 |
+
self.intr_feat_channels = 768
|
20 |
+
self.intr_head = nn.Sequential(
|
21 |
+
Bottleneck_Conv(self.intr_feat_channels, kernel_size=3),
|
22 |
+
Bottleneck_Conv(self.intr_feat_channels, kernel_size=3),
|
23 |
+
)
|
24 |
+
self.intr_pool = nn.AdaptiveAvgPool2d((1, 1))
|
25 |
+
self.intr_proj = nn.Linear(self.intr_feat_channels, 3)
|
26 |
+
# init the last linear layer so it outputs zeros
|
27 |
+
nn.init.zeros_(self.intr_proj.weight)
|
28 |
+
nn.init.zeros_(self.intr_proj.bias)
|
29 |
+
|
30 |
+
# define the depth pred model based on omnidata
|
31 |
+
self.dpt_depth = DPTDepthModel(backbone='vitb_rn50_384')
|
32 |
+
# load the pretrained depth model
|
33 |
+
# when intrinsics need to be predicted we need to load that part as well
|
34 |
+
self.load_pretrained_depth(opt)
|
35 |
+
if opt.optim.fix_dpt:
|
36 |
+
toggle_grad(self.dpt_depth, False)
|
37 |
+
toggle_grad(self.intr_head, False)
|
38 |
+
toggle_grad(self.intr_proj, False)
|
39 |
+
|
40 |
+
# encoder that encode seen surface to impl conditioning vec
|
41 |
+
if opt.arch.depth.encoder == 'resnet':
|
42 |
+
opt.arch.depth.dsp = 1
|
43 |
+
self.coord_encoder = CoordEncRes(opt)
|
44 |
+
else:
|
45 |
+
self.coord_encoder = CoordEncAtt(embed_dim=opt.arch.latent_dim, n_blocks=opt.arch.depth.n_blocks,
|
46 |
+
num_heads=opt.arch.num_heads, win_size=opt.arch.win_size//opt.arch.depth.dsp)
|
47 |
+
|
48 |
+
# rgb branch (not used in final model, keep here for extension)
|
49 |
+
if opt.arch.rgb.encoder == 'resnet':
|
50 |
+
self.rgb_encoder = RGBEncRes(opt)
|
51 |
+
elif opt.arch.rgb.encoder == 'transformer':
|
52 |
+
self.rgb_encoder = RGBEncAtt(img_size=opt.H, embed_dim=opt.arch.latent_dim, n_blocks=opt.arch.rgb.n_blocks,
|
53 |
+
num_heads=opt.arch.num_heads, win_size=opt.arch.win_size)
|
54 |
+
else:
|
55 |
+
self.rgb_encoder = None
|
56 |
+
|
57 |
+
# implicit function
|
58 |
+
feat_res = opt.H // opt.arch.win_size
|
59 |
+
self.impl_network = Implicit(feat_res**2, latent_dim=opt.arch.latent_dim*2 if self.rgb_encoder else opt.arch.latent_dim,
|
60 |
+
semantic=self.rgb_encoder is not None, n_channels=opt.arch.impl.n_channels,
|
61 |
+
n_blocks_attn=opt.arch.impl.att_blocks, n_layers_mlp=opt.arch.impl.mlp_layers,
|
62 |
+
num_heads=opt.arch.num_heads, posenc_3D=opt.arch.impl.posenc_3D,
|
63 |
+
mlp_ratio=opt.arch.impl.mlp_ratio, skip_in=opt.arch.impl.skip_in,
|
64 |
+
pos_perlayer=opt.arch.impl.posenc_perlayer)
|
65 |
+
|
66 |
+
# loss functions
|
67 |
+
self.loss_fns = Loss(opt)
|
68 |
+
|
69 |
+
def load_pretrained_depth(self, opt):
|
70 |
+
if opt.pretrain.depth:
|
71 |
+
# loading from our pretrained depth and intr model
|
72 |
+
if opt.device == 0:
|
73 |
+
print("loading dpt depth from {}...".format(opt.pretrain.depth))
|
74 |
+
checkpoint = torch.load(opt.pretrain.depth, map_location="cuda:{}".format(opt.device))
|
75 |
+
self.dpt_depth.load_state_dict(get_child_state_dict(checkpoint["graph"], "dpt_depth"))
|
76 |
+
# load the intr head
|
77 |
+
if opt.device == 0:
|
78 |
+
print("loading pretrained intr from {}...".format(opt.pretrain.depth))
|
79 |
+
self.intr_head.load_state_dict(get_child_state_dict(checkpoint["graph"], "intr_head"))
|
80 |
+
self.intr_proj.load_state_dict(get_child_state_dict(checkpoint["graph"], "intr_proj"))
|
81 |
+
elif opt.arch.depth.pretrained:
|
82 |
+
# loading from omnidata weights
|
83 |
+
if opt.device == 0:
|
84 |
+
print("loading dpt depth from {}...".format(opt.arch.depth.pretrained))
|
85 |
+
checkpoint = torch.load(opt.arch.depth.pretrained, map_location="cuda:{}".format(opt.device))
|
86 |
+
state_dict = checkpoint['model_state_dict']
|
87 |
+
self.dpt_depth.load_state_dict(state_dict)
|
88 |
+
|
89 |
+
def intr_param2mtx(self, opt, intr_params):
|
90 |
+
'''
|
91 |
+
Parameters:
|
92 |
+
opt: config
|
93 |
+
intr_params: [B, 3], [scale_f, delta_cx, delta_cy]
|
94 |
+
Return:
|
95 |
+
intr: [B, 3, 3]
|
96 |
+
'''
|
97 |
+
batch_size = len(intr_params)
|
98 |
+
f = 1.3875
|
99 |
+
intr = torch.zeros(3, 3).float().to(intr_params.device).unsqueeze(0).repeat(batch_size, 1, 1)
|
100 |
+
intr[:, 2, 2] += 1
|
101 |
+
# scale the focal length
|
102 |
+
# range: [-1, 1], symmetric
|
103 |
+
scale_f = torch.tanh(intr_params[:, 0])
|
104 |
+
# range: [1/4, 4], symmetric
|
105 |
+
scale_f = torch.pow(4. , scale_f)
|
106 |
+
intr[:, 0, 0] += f * opt.W * scale_f
|
107 |
+
intr[:, 1, 1] += f * opt.H * scale_f
|
108 |
+
# shift the optic center, (at most to the image border)
|
109 |
+
shift_cx = torch.tanh(intr_params[:, 1]) * opt.W / 2
|
110 |
+
shift_cy = torch.tanh(intr_params[:, 2]) * opt.H / 2
|
111 |
+
intr[:, 0, 2] += opt.W / 2 + shift_cx
|
112 |
+
intr[:, 1, 2] += opt.H / 2 + shift_cy
|
113 |
+
return intr
|
114 |
+
|
115 |
+
def forward(self, opt, var, training=False, get_loss=True):
|
116 |
+
batch_size = len(var.idx)
|
117 |
+
|
118 |
+
# encode the rgb, [B, 3, H, W] -> [B, 1+H/(ws)*W/(ws), C], not used in our final model
|
119 |
+
var.latent_semantic = self.rgb_encoder(var.rgb_input_map) if self.rgb_encoder else None
|
120 |
+
|
121 |
+
# predict the depth map and intrinsics
|
122 |
+
var.depth_pred, intr_feat = self.dpt_depth(var.rgb_input_map, get_feat=True)
|
123 |
+
depth_map = var.depth_pred
|
124 |
+
# predict the intrinsics
|
125 |
+
intr_feat = self.intr_head(intr_feat)
|
126 |
+
intr_feat = self.intr_pool(intr_feat).squeeze(-1).squeeze(-1)
|
127 |
+
intr_params = self.intr_proj(intr_feat)
|
128 |
+
# [B, 3, 3]
|
129 |
+
var.intr_pred = self.intr_param2mtx(opt, intr_params)
|
130 |
+
intr_forward = var.intr_pred
|
131 |
+
# record the validity mask, [B, H*W]
|
132 |
+
var.validity_mask = (var.mask_input_map>0.5).float().view(batch_size, -1)
|
133 |
+
|
134 |
+
# project the depth to 3D points in view-centric frame
|
135 |
+
# [B, H*W, 3], in camera coordinates
|
136 |
+
seen_points_3D_pred = unproj_depth(opt, depth_map, intr_forward)
|
137 |
+
# [B, H*W, 3], [B, 1, H, W] (boolean) -> [B, 3], [B]
|
138 |
+
seen_points_mean_pred, seen_points_scale_pred = valid_norm_fac(seen_points_3D_pred, var.mask_input_map > 0.5)
|
139 |
+
# normalize the seen surface, [B, H*W, 3]
|
140 |
+
var.seen_points = (seen_points_3D_pred - seen_points_mean_pred.unsqueeze(1)) / seen_points_scale_pred.unsqueeze(-1).unsqueeze(-1)
|
141 |
+
var.seen_points[(var.mask_input_map<=0.5).view(batch_size, -1)] = 0
|
142 |
+
# [B, 3, H, W]
|
143 |
+
seen_3D_map = var.seen_points.view(batch_size, opt.H, opt.W, 3).permute(0, 3, 1, 2).contiguous()
|
144 |
+
seen_3D_dsp, mask_dsp = interpolate_coordmap(seen_3D_map, var.mask_input_map, (opt.H//opt.arch.depth.dsp, opt.W//opt.arch.depth.dsp))
|
145 |
+
|
146 |
+
# encode the depth, [B, 1, H/k, W/k] -> [B, 1+H/(ws)*W/(ws), C]
|
147 |
+
if opt.arch.depth.encoder == 'resnet':
|
148 |
+
var.latent_depth = self.coord_encoder(seen_3D_dsp, mask_dsp)
|
149 |
+
else:
|
150 |
+
var.latent_depth = self.coord_encoder(seen_3D_dsp.permute(0, 2, 3, 1).contiguous(), mask_dsp.squeeze(1)>0.5)
|
151 |
+
|
152 |
+
|
153 |
+
var.pose = var.pose_gt
|
154 |
+
# forward for loss calculation (only during training)
|
155 |
+
if 'gt_sample_points' in var and 'gt_sample_sdf' in var:
|
156 |
+
with torch.no_grad():
|
157 |
+
# get the normalizing fac based on the GT seen surface
|
158 |
+
# project the GT depth to 3D points in view-centric frame
|
159 |
+
# [B, H*W, 3], in camera coordinates
|
160 |
+
seen_points_3D_gt = unproj_depth(opt, var.depth_input_map, var.intr)
|
161 |
+
# [B, H*W, 3], [B, 1, H, W] (boolean) -> [B, 3], [B]
|
162 |
+
seen_points_mean_gt, seen_points_scale_gt = valid_norm_fac(seen_points_3D_gt, var.mask_input_map > 0.5)
|
163 |
+
var.seen_points_gt = (seen_points_3D_gt - seen_points_mean_gt.unsqueeze(1)) / seen_points_scale_gt.unsqueeze(-1).unsqueeze(-1)
|
164 |
+
var.seen_points_gt[(var.mask_input_map<=0.5).view(batch_size, -1)] = 0
|
165 |
+
|
166 |
+
# transform the GT points accordingly
|
167 |
+
# [B, 3, 3]
|
168 |
+
R_gt = var.pose_gt[:, :, :3]
|
169 |
+
# [B, 3, 1]
|
170 |
+
T_gt = var.pose_gt[:, :, 3:]
|
171 |
+
# [B, 3, N]
|
172 |
+
gt_sample_points_transposed = var.gt_sample_points.permute(0, 2, 1).contiguous()
|
173 |
+
# camera coordinates, [B, N, 3]
|
174 |
+
gt_sample_points_cam = (R_gt @ gt_sample_points_transposed + T_gt).permute(0, 2, 1).contiguous()
|
175 |
+
# normalize with seen std and mean, [B, N, 3]
|
176 |
+
var.gt_points_cam = (gt_sample_points_cam - seen_points_mean_gt.unsqueeze(1)) / seen_points_scale_gt.unsqueeze(-1).unsqueeze(-1)
|
177 |
+
|
178 |
+
# get near-surface points for visualization
|
179 |
+
# [B, 100, 3]
|
180 |
+
close_surf_idx = torch.topk(var.gt_sample_sdf.abs(), k=100, dim=1, largest=False)[1].unsqueeze(-1).repeat(1, 1, 3)
|
181 |
+
# [B, 100, 3]
|
182 |
+
var.gt_surf_points = torch.gather(var.gt_points_cam, dim=1, index=close_surf_idx)
|
183 |
+
|
184 |
+
# [B, N], [B, N, 1+feat_res**2], inference the impl_network for 3D loss
|
185 |
+
var.pred_sample_occ, attn = self.impl_network(var.latent_depth, var.latent_semantic, var.gt_points_cam)
|
186 |
+
|
187 |
+
# calculate the loss if needed
|
188 |
+
if get_loss:
|
189 |
+
loss = self.compute_loss(opt, var, training)
|
190 |
+
return var, loss
|
191 |
+
|
192 |
+
return var
|
193 |
+
|
194 |
+
def compute_loss(self, opt, var, training=False):
|
195 |
+
loss = edict()
|
196 |
+
if opt.loss_weight.depth is not None:
|
197 |
+
loss.depth = self.loss_fns.depth_loss(var.depth_pred, var.depth_input_map, var.mask_input_map)
|
198 |
+
if opt.loss_weight.intr is not None and training:
|
199 |
+
loss.intr = self.loss_fns.intr_loss(var.seen_points, var.seen_points_gt, var.validity_mask)
|
200 |
+
if opt.loss_weight.shape is not None and training:
|
201 |
+
loss.shape = self.loss_fns.shape_loss(var.pred_sample_occ, var.gt_sample_sdf)
|
202 |
+
return loss
|
model/depth/__init__.py
ADDED
File without changes
|
model/depth/base_model.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/isl-org/DPT
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
class BaseModel(torch.nn.Module):
|
6 |
+
def load(self, path):
|
7 |
+
"""Load model from file.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
path (str): file path
|
11 |
+
"""
|
12 |
+
parameters = torch.load(path, map_location=torch.device('cpu'))
|
13 |
+
|
14 |
+
if "optimizer" in parameters:
|
15 |
+
parameters = parameters["model"]
|
16 |
+
|
17 |
+
self.load_state_dict(parameters)
|
model/depth/blocks.py
ADDED
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/isl-org/DPT
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from .vit import (
|
6 |
+
_make_pretrained_vitb_rn50_384,
|
7 |
+
_make_pretrained_vitl16_384,
|
8 |
+
_make_pretrained_vitb16_384,
|
9 |
+
forward_vit,
|
10 |
+
)
|
11 |
+
|
12 |
+
def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
|
13 |
+
if backbone == "vitl16_384":
|
14 |
+
pretrained = _make_pretrained_vitl16_384(
|
15 |
+
use_pretrained, hooks=hooks, use_readout=use_readout
|
16 |
+
)
|
17 |
+
scratch = _make_scratch(
|
18 |
+
[256, 512, 1024, 1024], features, groups=groups, expand=expand
|
19 |
+
) # ViT-L/16 - 85.0% Top1 (backbone)
|
20 |
+
elif backbone == "vitb_rn50_384":
|
21 |
+
pretrained = _make_pretrained_vitb_rn50_384(
|
22 |
+
use_pretrained,
|
23 |
+
hooks=hooks,
|
24 |
+
use_vit_only=use_vit_only,
|
25 |
+
use_readout=use_readout,
|
26 |
+
)
|
27 |
+
scratch = _make_scratch(
|
28 |
+
[256, 512, 768, 768], features, groups=groups, expand=expand
|
29 |
+
) # ViT-H/16 - 85.0% Top1 (backbone)
|
30 |
+
elif backbone == "vitb16_384":
|
31 |
+
pretrained = _make_pretrained_vitb16_384(
|
32 |
+
use_pretrained, hooks=hooks, use_readout=use_readout
|
33 |
+
)
|
34 |
+
scratch = _make_scratch(
|
35 |
+
[96, 192, 384, 768], features, groups=groups, expand=expand
|
36 |
+
) # ViT-B/16 - 84.6% Top1 (backbone)
|
37 |
+
elif backbone == "resnext101_wsl":
|
38 |
+
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
|
39 |
+
scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
|
40 |
+
elif backbone == "efficientnet_lite3":
|
41 |
+
pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
|
42 |
+
scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
|
43 |
+
else:
|
44 |
+
print(f"Backbone '{backbone}' not implemented")
|
45 |
+
assert False
|
46 |
+
|
47 |
+
return pretrained, scratch
|
48 |
+
|
49 |
+
|
50 |
+
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
51 |
+
scratch = nn.Module()
|
52 |
+
|
53 |
+
out_shape1 = out_shape
|
54 |
+
out_shape2 = out_shape
|
55 |
+
out_shape3 = out_shape
|
56 |
+
out_shape4 = out_shape
|
57 |
+
if expand==True:
|
58 |
+
out_shape1 = out_shape
|
59 |
+
out_shape2 = out_shape*2
|
60 |
+
out_shape3 = out_shape*4
|
61 |
+
out_shape4 = out_shape*8
|
62 |
+
|
63 |
+
scratch.layer1_rn = nn.Conv2d(
|
64 |
+
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
65 |
+
)
|
66 |
+
scratch.layer2_rn = nn.Conv2d(
|
67 |
+
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
68 |
+
)
|
69 |
+
scratch.layer3_rn = nn.Conv2d(
|
70 |
+
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
71 |
+
)
|
72 |
+
scratch.layer4_rn = nn.Conv2d(
|
73 |
+
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
74 |
+
)
|
75 |
+
|
76 |
+
return scratch
|
77 |
+
|
78 |
+
|
79 |
+
def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
|
80 |
+
efficientnet = torch.hub.load(
|
81 |
+
"rwightman/gen-efficientnet-pytorch",
|
82 |
+
"tf_efficientnet_lite3",
|
83 |
+
pretrained=use_pretrained,
|
84 |
+
exportable=exportable
|
85 |
+
)
|
86 |
+
return _make_efficientnet_backbone(efficientnet)
|
87 |
+
|
88 |
+
|
89 |
+
def _make_efficientnet_backbone(effnet):
|
90 |
+
pretrained = nn.Module()
|
91 |
+
|
92 |
+
pretrained.layer1 = nn.Sequential(
|
93 |
+
effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
|
94 |
+
)
|
95 |
+
pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
|
96 |
+
pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
|
97 |
+
pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
|
98 |
+
|
99 |
+
return pretrained
|
100 |
+
|
101 |
+
|
102 |
+
def _make_resnet_backbone(resnet):
|
103 |
+
pretrained = nn.Module()
|
104 |
+
pretrained.layer1 = nn.Sequential(
|
105 |
+
resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
|
106 |
+
)
|
107 |
+
|
108 |
+
pretrained.layer2 = resnet.layer2
|
109 |
+
pretrained.layer3 = resnet.layer3
|
110 |
+
pretrained.layer4 = resnet.layer4
|
111 |
+
|
112 |
+
return pretrained
|
113 |
+
|
114 |
+
|
115 |
+
def _make_pretrained_resnext101_wsl(use_pretrained):
|
116 |
+
resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
|
117 |
+
return _make_resnet_backbone(resnet)
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
+
class Interpolate(nn.Module):
|
122 |
+
"""Interpolation module.
|
123 |
+
"""
|
124 |
+
|
125 |
+
def __init__(self, scale_factor, mode, align_corners=False):
|
126 |
+
"""Init.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
scale_factor (float): scaling
|
130 |
+
mode (str): interpolation mode
|
131 |
+
"""
|
132 |
+
super(Interpolate, self).__init__()
|
133 |
+
|
134 |
+
self.interp = nn.functional.interpolate
|
135 |
+
self.scale_factor = scale_factor
|
136 |
+
self.mode = mode
|
137 |
+
self.align_corners = align_corners
|
138 |
+
|
139 |
+
def forward(self, x):
|
140 |
+
"""Forward pass.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
x (tensor): input
|
144 |
+
|
145 |
+
Returns:
|
146 |
+
tensor: interpolated data
|
147 |
+
"""
|
148 |
+
|
149 |
+
x = self.interp(
|
150 |
+
x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
|
151 |
+
)
|
152 |
+
|
153 |
+
return x
|
154 |
+
|
155 |
+
|
156 |
+
class ResidualConvUnit(nn.Module):
|
157 |
+
"""Residual convolution module.
|
158 |
+
"""
|
159 |
+
|
160 |
+
def __init__(self, features):
|
161 |
+
"""Init.
|
162 |
+
|
163 |
+
Args:
|
164 |
+
features (int): number of features
|
165 |
+
"""
|
166 |
+
super().__init__()
|
167 |
+
|
168 |
+
self.conv1 = nn.Conv2d(
|
169 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
170 |
+
)
|
171 |
+
|
172 |
+
self.conv2 = nn.Conv2d(
|
173 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
174 |
+
)
|
175 |
+
|
176 |
+
self.relu = nn.ReLU(inplace=True)
|
177 |
+
|
178 |
+
def forward(self, x):
|
179 |
+
"""Forward pass.
|
180 |
+
|
181 |
+
Args:
|
182 |
+
x (tensor): input
|
183 |
+
|
184 |
+
Returns:
|
185 |
+
tensor: output
|
186 |
+
"""
|
187 |
+
out = self.relu(x)
|
188 |
+
out = self.conv1(out)
|
189 |
+
out = self.relu(out)
|
190 |
+
out = self.conv2(out)
|
191 |
+
|
192 |
+
return out + x
|
193 |
+
|
194 |
+
|
195 |
+
class FeatureFusionBlock(nn.Module):
|
196 |
+
"""Feature fusion block.
|
197 |
+
"""
|
198 |
+
|
199 |
+
def __init__(self, features):
|
200 |
+
"""Init.
|
201 |
+
|
202 |
+
Args:
|
203 |
+
features (int): number of features
|
204 |
+
"""
|
205 |
+
super(FeatureFusionBlock, self).__init__()
|
206 |
+
|
207 |
+
self.resConfUnit1 = ResidualConvUnit(features)
|
208 |
+
self.resConfUnit2 = ResidualConvUnit(features)
|
209 |
+
|
210 |
+
def forward(self, *xs):
|
211 |
+
"""Forward pass.
|
212 |
+
|
213 |
+
Returns:
|
214 |
+
tensor: output
|
215 |
+
"""
|
216 |
+
output = xs[0]
|
217 |
+
|
218 |
+
if len(xs) == 2:
|
219 |
+
output += self.resConfUnit1(xs[1])
|
220 |
+
|
221 |
+
output = self.resConfUnit2(output)
|
222 |
+
|
223 |
+
output = nn.functional.interpolate(
|
224 |
+
output, scale_factor=2, mode="bilinear", align_corners=True
|
225 |
+
)
|
226 |
+
|
227 |
+
return output
|
228 |
+
|
229 |
+
|
230 |
+
|
231 |
+
|
232 |
+
class ResidualConvUnit_custom(nn.Module):
|
233 |
+
"""Residual convolution module.
|
234 |
+
"""
|
235 |
+
|
236 |
+
def __init__(self, features, activation, bn):
|
237 |
+
"""Init.
|
238 |
+
|
239 |
+
Args:
|
240 |
+
features (int): number of features
|
241 |
+
"""
|
242 |
+
super().__init__()
|
243 |
+
|
244 |
+
self.bn = bn
|
245 |
+
|
246 |
+
self.groups=1
|
247 |
+
|
248 |
+
self.conv1 = nn.Conv2d(
|
249 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
250 |
+
)
|
251 |
+
|
252 |
+
self.conv2 = nn.Conv2d(
|
253 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
254 |
+
)
|
255 |
+
|
256 |
+
if self.bn==True:
|
257 |
+
self.bn1 = nn.BatchNorm2d(features)
|
258 |
+
self.bn2 = nn.BatchNorm2d(features)
|
259 |
+
|
260 |
+
self.activation = activation
|
261 |
+
|
262 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
263 |
+
|
264 |
+
def forward(self, x):
|
265 |
+
"""Forward pass.
|
266 |
+
|
267 |
+
Args:
|
268 |
+
x (tensor): input
|
269 |
+
|
270 |
+
Returns:
|
271 |
+
tensor: output
|
272 |
+
"""
|
273 |
+
|
274 |
+
out = self.activation(x)
|
275 |
+
out = self.conv1(out)
|
276 |
+
if self.bn==True:
|
277 |
+
out = self.bn1(out)
|
278 |
+
|
279 |
+
out = self.activation(out)
|
280 |
+
out = self.conv2(out)
|
281 |
+
if self.bn==True:
|
282 |
+
out = self.bn2(out)
|
283 |
+
|
284 |
+
if self.groups > 1:
|
285 |
+
out = self.conv_merge(out)
|
286 |
+
|
287 |
+
return self.skip_add.add(out, x)
|
288 |
+
|
289 |
+
# return out + x
|
290 |
+
|
291 |
+
|
292 |
+
class FeatureFusionBlock_custom(nn.Module):
|
293 |
+
"""Feature fusion block.
|
294 |
+
"""
|
295 |
+
|
296 |
+
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
|
297 |
+
"""Init.
|
298 |
+
|
299 |
+
Args:
|
300 |
+
features (int): number of features
|
301 |
+
"""
|
302 |
+
super(FeatureFusionBlock_custom, self).__init__()
|
303 |
+
|
304 |
+
self.deconv = deconv
|
305 |
+
self.align_corners = align_corners
|
306 |
+
|
307 |
+
self.groups=1
|
308 |
+
|
309 |
+
self.expand = expand
|
310 |
+
out_features = features
|
311 |
+
if self.expand==True:
|
312 |
+
out_features = features//2
|
313 |
+
|
314 |
+
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
315 |
+
|
316 |
+
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
|
317 |
+
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
|
318 |
+
|
319 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
320 |
+
|
321 |
+
def forward(self, *xs):
|
322 |
+
"""Forward pass.
|
323 |
+
|
324 |
+
Returns:
|
325 |
+
tensor: output
|
326 |
+
"""
|
327 |
+
output = xs[0]
|
328 |
+
|
329 |
+
if len(xs) == 2:
|
330 |
+
res = self.resConfUnit1(xs[1])
|
331 |
+
output = self.skip_add.add(output, res)
|
332 |
+
# output += res
|
333 |
+
|
334 |
+
output = self.resConfUnit2(output)
|
335 |
+
|
336 |
+
output = nn.functional.interpolate(
|
337 |
+
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
|
338 |
+
)
|
339 |
+
|
340 |
+
output = self.out_conv(output)
|
341 |
+
|
342 |
+
return output
|
343 |
+
|
model/depth/dpt_depth.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/isl-org/DPT
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from .base_model import BaseModel
|
7 |
+
from .blocks import (
|
8 |
+
FeatureFusionBlock,
|
9 |
+
FeatureFusionBlock_custom,
|
10 |
+
Interpolate,
|
11 |
+
_make_encoder,
|
12 |
+
forward_vit,
|
13 |
+
)
|
14 |
+
|
15 |
+
|
16 |
+
def _make_fusion_block(features, use_bn):
|
17 |
+
return FeatureFusionBlock_custom(
|
18 |
+
features,
|
19 |
+
nn.ReLU(False),
|
20 |
+
deconv=False,
|
21 |
+
bn=use_bn,
|
22 |
+
expand=False,
|
23 |
+
align_corners=True,
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
class DPT(BaseModel):
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
head,
|
31 |
+
features=256,
|
32 |
+
backbone="vitb_rn50_384",
|
33 |
+
readout="project",
|
34 |
+
channels_last=False,
|
35 |
+
use_bn=False,
|
36 |
+
):
|
37 |
+
|
38 |
+
super(DPT, self).__init__()
|
39 |
+
|
40 |
+
self.channels_last = channels_last
|
41 |
+
|
42 |
+
hooks = {
|
43 |
+
"vitb_rn50_384": [0, 1, 8, 11],
|
44 |
+
"vitb16_384": [2, 5, 8, 11],
|
45 |
+
"vitl16_384": [5, 11, 17, 23],
|
46 |
+
}
|
47 |
+
|
48 |
+
# Instantiate backbone and reassemble blocks
|
49 |
+
self.pretrained, self.scratch = _make_encoder(
|
50 |
+
backbone,
|
51 |
+
features,
|
52 |
+
True, # Set to true of you want to train from scratch, uses ImageNet weights
|
53 |
+
groups=1,
|
54 |
+
expand=False,
|
55 |
+
exportable=False,
|
56 |
+
hooks=hooks[backbone],
|
57 |
+
use_readout=readout,
|
58 |
+
)
|
59 |
+
|
60 |
+
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
61 |
+
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
62 |
+
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
63 |
+
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
64 |
+
|
65 |
+
self.scratch.output_conv = head
|
66 |
+
|
67 |
+
|
68 |
+
def forward(self, x, get_feat=False):
|
69 |
+
if self.channels_last == True:
|
70 |
+
x.contiguous(memory_format=torch.channels_last)
|
71 |
+
|
72 |
+
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
|
73 |
+
|
74 |
+
# res 8x -> 4x -> 2x -> 1x base_size
|
75 |
+
# base_size = H / 32
|
76 |
+
# all n_channels same (256 by default) after these
|
77 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
78 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
79 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
80 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
81 |
+
|
82 |
+
# upsample by two with out changing n_channels each time, conv-sum for fusing
|
83 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
84 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
85 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
86 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
87 |
+
|
88 |
+
out = self.scratch.output_conv(path_1)
|
89 |
+
|
90 |
+
# save the feat if required
|
91 |
+
if get_feat:
|
92 |
+
return out, layer_4
|
93 |
+
|
94 |
+
return out
|
95 |
+
|
96 |
+
class DPTDepthModel(DPT):
|
97 |
+
def __init__(self, path=None, non_negative=True, num_channels=1, **kwargs):
|
98 |
+
features = kwargs["features"] if "features" in kwargs else 256
|
99 |
+
|
100 |
+
head = nn.Sequential(
|
101 |
+
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
|
102 |
+
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
103 |
+
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
104 |
+
nn.ReLU(True),
|
105 |
+
nn.Conv2d(32, num_channels, kernel_size=1, stride=1, padding=0),
|
106 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
107 |
+
nn.Identity(),
|
108 |
+
)
|
109 |
+
nn.init.constant_(head[-3].bias, 0.05)
|
110 |
+
super().__init__(head, **kwargs)
|
111 |
+
|
112 |
+
if path is not None:
|
113 |
+
self.load(path)
|
114 |
+
|
115 |
+
def forward(self, image, get_feat=False):
|
116 |
+
x = image * 2 - 1
|
117 |
+
if get_feat:
|
118 |
+
output, feat = super().forward(x, get_feat=get_feat)
|
119 |
+
output = output.clamp(min=0, max=1)
|
120 |
+
return output, feat
|
121 |
+
else:
|
122 |
+
output = super().forward(x, get_feat=get_feat).clamp(min=0, max=1)
|
123 |
+
return output
|
model/depth/midas_loss.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/EPFL-VILAB/omnidata
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
def masked_l1_loss(preds, target, mask_valid):
|
7 |
+
element_wise_loss = abs(preds - target)
|
8 |
+
element_wise_loss[~mask_valid] = 0
|
9 |
+
return element_wise_loss.sum() / (mask_valid.sum() + 1.e-6)
|
10 |
+
|
11 |
+
def compute_scale_and_shift(prediction, target, mask):
|
12 |
+
# system matrix: A = [[a_00, a_01], [a_10, a_11]]
|
13 |
+
a_00 = torch.sum(mask * prediction * prediction, (1, 2))
|
14 |
+
a_01 = torch.sum(mask * prediction, (1, 2))
|
15 |
+
a_11 = torch.sum(mask, (1, 2))
|
16 |
+
|
17 |
+
# right hand side: b = [b_0, b_1]
|
18 |
+
b_0 = torch.sum(mask * prediction * target, (1, 2))
|
19 |
+
b_1 = torch.sum(mask * target, (1, 2))
|
20 |
+
|
21 |
+
# solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b
|
22 |
+
x_0 = torch.zeros_like(b_0)
|
23 |
+
x_1 = torch.zeros_like(b_1)
|
24 |
+
|
25 |
+
det = a_00 * a_11 - a_01 * a_01
|
26 |
+
valid = det.nonzero()
|
27 |
+
|
28 |
+
x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / (det[valid] + 1e-6)
|
29 |
+
x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / (det[valid] + 1e-6)
|
30 |
+
|
31 |
+
return x_0, x_1
|
32 |
+
|
33 |
+
|
34 |
+
def masked_shift_and_scale(depth_preds, depth_gt, mask_valid):
|
35 |
+
depth_preds_nan = depth_preds.clone()
|
36 |
+
depth_gt_nan = depth_gt.clone()
|
37 |
+
depth_preds_nan[~mask_valid] = np.nan
|
38 |
+
depth_gt_nan[~mask_valid] = np.nan
|
39 |
+
|
40 |
+
mask_diff = mask_valid.view(mask_valid.size()[:2] + (-1,)).sum(-1, keepdims=True) + 1
|
41 |
+
|
42 |
+
# flatten spatial dimension and take valid median [B, 1, 1, 1]
|
43 |
+
t_gt = depth_gt_nan.view(depth_gt_nan.size()[:2] + (-1,)).nanmedian(-1, keepdims=True)[0].unsqueeze(-1)
|
44 |
+
t_gt[torch.isnan(t_gt)] = 0
|
45 |
+
# subtract median and set invalid position to 0
|
46 |
+
diff_gt = torch.abs(depth_gt - t_gt)
|
47 |
+
diff_gt[~mask_valid] = 0
|
48 |
+
# get the avg abs diff value over valid regions [B, 1, 1, 1]
|
49 |
+
s_gt = (diff_gt.view(diff_gt.size()[:2] + (-1,)).sum(-1, keepdims=True) / mask_diff).unsqueeze(-1)
|
50 |
+
# normalize
|
51 |
+
depth_gt_aligned = (depth_gt - t_gt) / (s_gt + 1e-6)
|
52 |
+
|
53 |
+
# same as gt normalization
|
54 |
+
t_pred = depth_preds_nan.view(depth_preds_nan.size()[:2] + (-1,)).nanmedian(-1, keepdims=True)[0].unsqueeze(-1)
|
55 |
+
t_pred[torch.isnan(t_pred)] = 0
|
56 |
+
diff_pred = torch.abs(depth_preds - t_pred)
|
57 |
+
diff_pred[~mask_valid] = 0
|
58 |
+
s_pred = (diff_pred.view(diff_pred.size()[:2] + (-1,)).sum(-1, keepdims=True) / mask_diff).unsqueeze(-1)
|
59 |
+
depth_pred_aligned = (depth_preds - t_pred) / (s_pred + 1e-6)
|
60 |
+
|
61 |
+
return depth_pred_aligned, depth_gt_aligned
|
62 |
+
|
63 |
+
|
64 |
+
def reduction_batch_based(image_loss, M):
|
65 |
+
# average of all valid pixels of the batch
|
66 |
+
|
67 |
+
# avoid division by 0 (if sum(M) = sum(sum(mask)) = 0: sum(image_loss) = 0)
|
68 |
+
divisor = torch.sum(M)
|
69 |
+
|
70 |
+
if divisor == 0:
|
71 |
+
return 0
|
72 |
+
else:
|
73 |
+
return torch.sum(image_loss) / divisor
|
74 |
+
|
75 |
+
|
76 |
+
def reduction_image_based(image_loss, M):
|
77 |
+
# mean of average of valid pixels of an image
|
78 |
+
|
79 |
+
# avoid division by 0 (if M = sum(mask) = 0: image_loss = 0)
|
80 |
+
valid = M.nonzero()
|
81 |
+
|
82 |
+
image_loss[valid] = image_loss[valid] / M[valid]
|
83 |
+
|
84 |
+
return torch.mean(image_loss)
|
85 |
+
|
86 |
+
|
87 |
+
|
88 |
+
def gradient_loss(prediction, target, mask, reduction=reduction_batch_based):
|
89 |
+
|
90 |
+
M = torch.sum(mask, (1, 2))
|
91 |
+
|
92 |
+
diff = prediction - target
|
93 |
+
diff = torch.mul(mask, diff)
|
94 |
+
|
95 |
+
grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1])
|
96 |
+
mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1])
|
97 |
+
grad_x = torch.mul(mask_x, grad_x)
|
98 |
+
|
99 |
+
grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :])
|
100 |
+
mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :])
|
101 |
+
grad_y = torch.mul(mask_y, grad_y)
|
102 |
+
|
103 |
+
image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2))
|
104 |
+
|
105 |
+
return reduction(image_loss, M)
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
class SSIMAE(nn.Module):
|
110 |
+
def __init__(self):
|
111 |
+
super().__init__()
|
112 |
+
|
113 |
+
def forward(self, depth_preds, depth_gt, mask_valid):
|
114 |
+
depth_pred_aligned, depth_gt_aligned = masked_shift_and_scale(depth_preds, depth_gt, mask_valid)
|
115 |
+
ssi_mae_loss = masked_l1_loss(depth_pred_aligned, depth_gt_aligned, mask_valid)
|
116 |
+
return ssi_mae_loss
|
117 |
+
|
118 |
+
|
119 |
+
class GradientMatchingTerm(nn.Module):
|
120 |
+
def __init__(self, scales=4, reduction='batch-based'):
|
121 |
+
super().__init__()
|
122 |
+
|
123 |
+
if reduction == 'batch-based':
|
124 |
+
self.__reduction = reduction_batch_based
|
125 |
+
else:
|
126 |
+
self.__reduction = reduction_image_based
|
127 |
+
|
128 |
+
self.__scales = scales
|
129 |
+
|
130 |
+
def forward(self, prediction, target, mask):
|
131 |
+
total = 0
|
132 |
+
|
133 |
+
for scale in range(self.__scales):
|
134 |
+
step = pow(2, scale)
|
135 |
+
|
136 |
+
total += gradient_loss(prediction[:, ::step, ::step], target[:, ::step, ::step],
|
137 |
+
mask[:, ::step, ::step], reduction=self.__reduction)
|
138 |
+
|
139 |
+
return total
|
140 |
+
|
141 |
+
|
142 |
+
class MidasLoss(nn.Module):
|
143 |
+
def __init__(self, alpha=0.1, scales=4, reduction='image-based', inverse_depth=True, shrink_mask=False):
|
144 |
+
super().__init__()
|
145 |
+
|
146 |
+
self.__ssi_mae_loss = SSIMAE()
|
147 |
+
self.__gradient_matching_term = GradientMatchingTerm(scales=scales, reduction=reduction)
|
148 |
+
self.__alpha = alpha
|
149 |
+
self.inverse_depth = inverse_depth
|
150 |
+
self.shrink_mask = shrink_mask
|
151 |
+
|
152 |
+
# decrease valid region via min-pooling
|
153 |
+
@torch.no_grad()
|
154 |
+
def erode_mask(self, mask, max_pool_size=4):
|
155 |
+
mask_float = mask.float()
|
156 |
+
h, w = mask_float.shape[2], mask_float.shape[3]
|
157 |
+
mask_float = 1 - mask_float
|
158 |
+
mask_float = torch.nn.functional.max_pool2d(mask_float, kernel_size=max_pool_size)
|
159 |
+
mask_float = torch.nn.functional.interpolate(mask_float, (h, w), mode='nearest')
|
160 |
+
# only if a 4x4 region is all valid then we make that valid
|
161 |
+
mask_valid = mask_float == 0
|
162 |
+
return mask_valid
|
163 |
+
|
164 |
+
def forward(self, prediction_raw, target_raw, mask_raw):
|
165 |
+
if self.shrink_mask:
|
166 |
+
mask = self.erode_mask(mask_raw)
|
167 |
+
else:
|
168 |
+
mask = mask_raw > 0.5
|
169 |
+
ssi_loss = self.__ssi_mae_loss(prediction_raw, target_raw, mask)
|
170 |
+
if self.__alpha <= 0:
|
171 |
+
return ssi_loss
|
172 |
+
|
173 |
+
if self.inverse_depth:
|
174 |
+
prediction = 1 / (prediction_raw.squeeze(1) + 1e-6)
|
175 |
+
target = 1 / (target_raw.squeeze(1) + 1e-6)
|
176 |
+
else:
|
177 |
+
prediction = prediction_raw.squeeze(1)
|
178 |
+
target = target_raw.squeeze(1)
|
179 |
+
# gradient loss
|
180 |
+
scale, shift = compute_scale_and_shift(prediction, target, mask.squeeze(1))
|
181 |
+
prediction_ssi = scale.view(-1, 1, 1) * prediction + shift.view(-1, 1, 1)
|
182 |
+
reg_loss = self.__gradient_matching_term(prediction_ssi, target, mask.squeeze(1))
|
183 |
+
total = ssi_loss + self.__alpha * reg_loss
|
184 |
+
|
185 |
+
return total
|
model/depth/vit.py
ADDED
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/isl-org/DPT
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import timm
|
5 |
+
import types
|
6 |
+
import math
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
|
10 |
+
class Slice(nn.Module):
|
11 |
+
def __init__(self, start_index=1):
|
12 |
+
super(Slice, self).__init__()
|
13 |
+
self.start_index = start_index
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
return x[:, self.start_index :]
|
17 |
+
|
18 |
+
|
19 |
+
class AddReadout(nn.Module):
|
20 |
+
def __init__(self, start_index=1):
|
21 |
+
super(AddReadout, self).__init__()
|
22 |
+
self.start_index = start_index
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
if self.start_index == 2:
|
26 |
+
readout = (x[:, 0] + x[:, 1]) / 2
|
27 |
+
else:
|
28 |
+
readout = x[:, 0]
|
29 |
+
return x[:, self.start_index :] + readout.unsqueeze(1)
|
30 |
+
|
31 |
+
|
32 |
+
class ProjectReadout(nn.Module):
|
33 |
+
def __init__(self, in_features, start_index=1):
|
34 |
+
super(ProjectReadout, self).__init__()
|
35 |
+
self.start_index = start_index
|
36 |
+
|
37 |
+
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
|
41 |
+
features = torch.cat((x[:, self.start_index :], readout), -1)
|
42 |
+
|
43 |
+
return self.project(features)
|
44 |
+
|
45 |
+
|
46 |
+
class Transpose(nn.Module):
|
47 |
+
def __init__(self, dim0, dim1):
|
48 |
+
super(Transpose, self).__init__()
|
49 |
+
self.dim0 = dim0
|
50 |
+
self.dim1 = dim1
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
x = x.transpose(self.dim0, self.dim1).contiguous()
|
54 |
+
return x
|
55 |
+
|
56 |
+
|
57 |
+
def forward_vit(pretrained, x):
|
58 |
+
b, c, h, w = x.shape
|
59 |
+
|
60 |
+
glob = pretrained.model.forward_flex(x)
|
61 |
+
|
62 |
+
layer_1 = pretrained.activations["1"]
|
63 |
+
layer_2 = pretrained.activations["2"]
|
64 |
+
layer_3 = pretrained.activations["3"]
|
65 |
+
layer_4 = pretrained.activations["4"]
|
66 |
+
|
67 |
+
layer_1 = pretrained.act_postprocess1[0:2](layer_1)
|
68 |
+
layer_2 = pretrained.act_postprocess2[0:2](layer_2)
|
69 |
+
layer_3 = pretrained.act_postprocess3[0:2](layer_3)
|
70 |
+
layer_4 = pretrained.act_postprocess4[0:2](layer_4)
|
71 |
+
|
72 |
+
unflatten = nn.Sequential(
|
73 |
+
nn.Unflatten(
|
74 |
+
2,
|
75 |
+
torch.Size(
|
76 |
+
[
|
77 |
+
h // pretrained.model.patch_size[1],
|
78 |
+
w // pretrained.model.patch_size[0],
|
79 |
+
]
|
80 |
+
),
|
81 |
+
)
|
82 |
+
)
|
83 |
+
|
84 |
+
if layer_1.ndim == 3:
|
85 |
+
layer_1 = unflatten(layer_1)
|
86 |
+
if layer_2.ndim == 3:
|
87 |
+
layer_2 = unflatten(layer_2)
|
88 |
+
if layer_3.ndim == 3:
|
89 |
+
layer_3 = unflatten(layer_3)
|
90 |
+
if layer_4.ndim == 3:
|
91 |
+
layer_4 = unflatten(layer_4)
|
92 |
+
|
93 |
+
layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
|
94 |
+
layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
|
95 |
+
layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
|
96 |
+
layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
|
97 |
+
|
98 |
+
return layer_1, layer_2, layer_3, layer_4
|
99 |
+
|
100 |
+
|
101 |
+
def _resize_pos_embed(self, posemb, gs_h, gs_w):
|
102 |
+
posemb_tok, posemb_grid = (
|
103 |
+
posemb[:, : self.start_index],
|
104 |
+
posemb[0, self.start_index :],
|
105 |
+
)
|
106 |
+
|
107 |
+
gs_old = int(math.sqrt(len(posemb_grid)))
|
108 |
+
|
109 |
+
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2).contiguous()
|
110 |
+
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear", align_corners=False)
|
111 |
+
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1).contiguous()
|
112 |
+
|
113 |
+
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
114 |
+
|
115 |
+
return posemb
|
116 |
+
|
117 |
+
|
118 |
+
def forward_flex(self, x):
|
119 |
+
b, c, h, w = x.shape
|
120 |
+
|
121 |
+
pos_embed = self._resize_pos_embed(
|
122 |
+
self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
|
123 |
+
)
|
124 |
+
|
125 |
+
B = x.shape[0]
|
126 |
+
|
127 |
+
if hasattr(self.patch_embed, "backbone"):
|
128 |
+
x = self.patch_embed.backbone(x)
|
129 |
+
if isinstance(x, (list, tuple)):
|
130 |
+
x = x[-1] # last feature if backbone outputs list/tuple of features
|
131 |
+
|
132 |
+
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2).contiguous()
|
133 |
+
|
134 |
+
if getattr(self, "dist_token", None) is not None:
|
135 |
+
cls_tokens = self.cls_token.expand(
|
136 |
+
B, -1, -1
|
137 |
+
) # stole cls_tokens impl from Phil Wang, thanks
|
138 |
+
dist_token = self.dist_token.expand(B, -1, -1)
|
139 |
+
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
140 |
+
else:
|
141 |
+
cls_tokens = self.cls_token.expand(
|
142 |
+
B, -1, -1
|
143 |
+
) # stole cls_tokens impl from Phil Wang, thanks
|
144 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
145 |
+
|
146 |
+
x = x + pos_embed
|
147 |
+
x = self.pos_drop(x)
|
148 |
+
|
149 |
+
for blk in self.blocks:
|
150 |
+
x = blk(x)
|
151 |
+
|
152 |
+
x = self.norm(x)
|
153 |
+
|
154 |
+
return x
|
155 |
+
|
156 |
+
|
157 |
+
activations = {}
|
158 |
+
|
159 |
+
|
160 |
+
def get_activation(name):
|
161 |
+
def hook(model, input, output):
|
162 |
+
activations[name] = output
|
163 |
+
|
164 |
+
return hook
|
165 |
+
|
166 |
+
|
167 |
+
def get_readout_oper(vit_features, features, use_readout, start_index=1):
|
168 |
+
if use_readout == "ignore":
|
169 |
+
readout_oper = [Slice(start_index)] * len(features)
|
170 |
+
elif use_readout == "add":
|
171 |
+
readout_oper = [AddReadout(start_index)] * len(features)
|
172 |
+
elif use_readout == "project":
|
173 |
+
readout_oper = [
|
174 |
+
ProjectReadout(vit_features, start_index) for out_feat in features
|
175 |
+
]
|
176 |
+
else:
|
177 |
+
assert (
|
178 |
+
False
|
179 |
+
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
|
180 |
+
|
181 |
+
return readout_oper
|
182 |
+
|
183 |
+
|
184 |
+
def _make_vit_b16_backbone(
|
185 |
+
model,
|
186 |
+
features=[96, 192, 384, 768],
|
187 |
+
size=[384, 384],
|
188 |
+
hooks=[2, 5, 8, 11],
|
189 |
+
vit_features=768,
|
190 |
+
use_readout="ignore",
|
191 |
+
start_index=1,
|
192 |
+
):
|
193 |
+
pretrained = nn.Module()
|
194 |
+
|
195 |
+
pretrained.model = model
|
196 |
+
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
197 |
+
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
198 |
+
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
199 |
+
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
200 |
+
|
201 |
+
pretrained.activations = activations
|
202 |
+
|
203 |
+
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
204 |
+
|
205 |
+
# 32, 48, 136, 384
|
206 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
207 |
+
readout_oper[0],
|
208 |
+
Transpose(1, 2),
|
209 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
210 |
+
nn.Conv2d(
|
211 |
+
in_channels=vit_features,
|
212 |
+
out_channels=features[0],
|
213 |
+
kernel_size=1,
|
214 |
+
stride=1,
|
215 |
+
padding=0,
|
216 |
+
),
|
217 |
+
nn.ConvTranspose2d(
|
218 |
+
in_channels=features[0],
|
219 |
+
out_channels=features[0],
|
220 |
+
kernel_size=4,
|
221 |
+
stride=4,
|
222 |
+
padding=0,
|
223 |
+
bias=True,
|
224 |
+
dilation=1,
|
225 |
+
groups=1,
|
226 |
+
),
|
227 |
+
)
|
228 |
+
|
229 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
230 |
+
readout_oper[1],
|
231 |
+
Transpose(1, 2),
|
232 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
233 |
+
nn.Conv2d(
|
234 |
+
in_channels=vit_features,
|
235 |
+
out_channels=features[1],
|
236 |
+
kernel_size=1,
|
237 |
+
stride=1,
|
238 |
+
padding=0,
|
239 |
+
),
|
240 |
+
nn.ConvTranspose2d(
|
241 |
+
in_channels=features[1],
|
242 |
+
out_channels=features[1],
|
243 |
+
kernel_size=2,
|
244 |
+
stride=2,
|
245 |
+
padding=0,
|
246 |
+
bias=True,
|
247 |
+
dilation=1,
|
248 |
+
groups=1,
|
249 |
+
),
|
250 |
+
)
|
251 |
+
|
252 |
+
pretrained.act_postprocess3 = nn.Sequential(
|
253 |
+
readout_oper[2],
|
254 |
+
Transpose(1, 2),
|
255 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
256 |
+
nn.Conv2d(
|
257 |
+
in_channels=vit_features,
|
258 |
+
out_channels=features[2],
|
259 |
+
kernel_size=1,
|
260 |
+
stride=1,
|
261 |
+
padding=0,
|
262 |
+
),
|
263 |
+
)
|
264 |
+
|
265 |
+
pretrained.act_postprocess4 = nn.Sequential(
|
266 |
+
readout_oper[3],
|
267 |
+
Transpose(1, 2),
|
268 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
269 |
+
nn.Conv2d(
|
270 |
+
in_channels=vit_features,
|
271 |
+
out_channels=features[3],
|
272 |
+
kernel_size=1,
|
273 |
+
stride=1,
|
274 |
+
padding=0,
|
275 |
+
),
|
276 |
+
nn.Conv2d(
|
277 |
+
in_channels=features[3],
|
278 |
+
out_channels=features[3],
|
279 |
+
kernel_size=3,
|
280 |
+
stride=2,
|
281 |
+
padding=1,
|
282 |
+
),
|
283 |
+
)
|
284 |
+
|
285 |
+
pretrained.model.start_index = start_index
|
286 |
+
pretrained.model.patch_size = [16, 16]
|
287 |
+
|
288 |
+
# We inject this function into the VisionTransformer instances so that
|
289 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
290 |
+
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
291 |
+
pretrained.model._resize_pos_embed = types.MethodType(
|
292 |
+
_resize_pos_embed, pretrained.model
|
293 |
+
)
|
294 |
+
|
295 |
+
return pretrained
|
296 |
+
|
297 |
+
|
298 |
+
def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
|
299 |
+
model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
|
300 |
+
|
301 |
+
hooks = [5, 11, 17, 23] if hooks == None else hooks
|
302 |
+
return _make_vit_b16_backbone(
|
303 |
+
model,
|
304 |
+
features=[256, 512, 1024, 1024],
|
305 |
+
hooks=hooks,
|
306 |
+
vit_features=1024,
|
307 |
+
use_readout=use_readout,
|
308 |
+
)
|
309 |
+
|
310 |
+
|
311 |
+
def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
|
312 |
+
model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
|
313 |
+
|
314 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
315 |
+
return _make_vit_b16_backbone(
|
316 |
+
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
|
317 |
+
)
|
318 |
+
|
319 |
+
|
320 |
+
def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
|
321 |
+
model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
|
322 |
+
|
323 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
324 |
+
return _make_vit_b16_backbone(
|
325 |
+
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
|
326 |
+
)
|
327 |
+
|
328 |
+
|
329 |
+
def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
|
330 |
+
model = timm.create_model(
|
331 |
+
"vit_deit_base_distilled_patch16_384", pretrained=pretrained
|
332 |
+
)
|
333 |
+
|
334 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
335 |
+
return _make_vit_b16_backbone(
|
336 |
+
model,
|
337 |
+
features=[96, 192, 384, 768],
|
338 |
+
hooks=hooks,
|
339 |
+
use_readout=use_readout,
|
340 |
+
start_index=2,
|
341 |
+
)
|
342 |
+
|
343 |
+
|
344 |
+
def _make_vit_b_rn50_backbone(
|
345 |
+
model,
|
346 |
+
features=[256, 512, 768, 768],
|
347 |
+
size=[384, 384],
|
348 |
+
hooks=[0, 1, 8, 11],
|
349 |
+
vit_features=768,
|
350 |
+
use_vit_only=False,
|
351 |
+
use_readout="ignore",
|
352 |
+
start_index=1,
|
353 |
+
):
|
354 |
+
pretrained = nn.Module()
|
355 |
+
|
356 |
+
pretrained.model = model
|
357 |
+
|
358 |
+
if use_vit_only == True:
|
359 |
+
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
360 |
+
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
361 |
+
else:
|
362 |
+
pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
|
363 |
+
get_activation("1")
|
364 |
+
)
|
365 |
+
pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
|
366 |
+
get_activation("2")
|
367 |
+
)
|
368 |
+
|
369 |
+
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
370 |
+
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
371 |
+
|
372 |
+
pretrained.activations = activations
|
373 |
+
|
374 |
+
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
375 |
+
|
376 |
+
if use_vit_only == True:
|
377 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
378 |
+
readout_oper[0],
|
379 |
+
Transpose(1, 2),
|
380 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
381 |
+
nn.Conv2d(
|
382 |
+
in_channels=vit_features,
|
383 |
+
out_channels=features[0],
|
384 |
+
kernel_size=1,
|
385 |
+
stride=1,
|
386 |
+
padding=0,
|
387 |
+
),
|
388 |
+
nn.ConvTranspose2d(
|
389 |
+
in_channels=features[0],
|
390 |
+
out_channels=features[0],
|
391 |
+
kernel_size=4,
|
392 |
+
stride=4,
|
393 |
+
padding=0,
|
394 |
+
bias=True,
|
395 |
+
dilation=1,
|
396 |
+
groups=1,
|
397 |
+
),
|
398 |
+
)
|
399 |
+
|
400 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
401 |
+
readout_oper[1],
|
402 |
+
Transpose(1, 2),
|
403 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
404 |
+
nn.Conv2d(
|
405 |
+
in_channels=vit_features,
|
406 |
+
out_channels=features[1],
|
407 |
+
kernel_size=1,
|
408 |
+
stride=1,
|
409 |
+
padding=0,
|
410 |
+
),
|
411 |
+
nn.ConvTranspose2d(
|
412 |
+
in_channels=features[1],
|
413 |
+
out_channels=features[1],
|
414 |
+
kernel_size=2,
|
415 |
+
stride=2,
|
416 |
+
padding=0,
|
417 |
+
bias=True,
|
418 |
+
dilation=1,
|
419 |
+
groups=1,
|
420 |
+
),
|
421 |
+
)
|
422 |
+
else:
|
423 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
424 |
+
nn.Identity(), nn.Identity(), nn.Identity()
|
425 |
+
)
|
426 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
427 |
+
nn.Identity(), nn.Identity(), nn.Identity()
|
428 |
+
)
|
429 |
+
|
430 |
+
pretrained.act_postprocess3 = nn.Sequential(
|
431 |
+
readout_oper[2],
|
432 |
+
Transpose(1, 2),
|
433 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
434 |
+
nn.Conv2d(
|
435 |
+
in_channels=vit_features,
|
436 |
+
out_channels=features[2],
|
437 |
+
kernel_size=1,
|
438 |
+
stride=1,
|
439 |
+
padding=0,
|
440 |
+
),
|
441 |
+
)
|
442 |
+
|
443 |
+
pretrained.act_postprocess4 = nn.Sequential(
|
444 |
+
readout_oper[3],
|
445 |
+
Transpose(1, 2),
|
446 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
447 |
+
nn.Conv2d(
|
448 |
+
in_channels=vit_features,
|
449 |
+
out_channels=features[3],
|
450 |
+
kernel_size=1,
|
451 |
+
stride=1,
|
452 |
+
padding=0,
|
453 |
+
),
|
454 |
+
nn.Conv2d(
|
455 |
+
in_channels=features[3],
|
456 |
+
out_channels=features[3],
|
457 |
+
kernel_size=3,
|
458 |
+
stride=2,
|
459 |
+
padding=1,
|
460 |
+
),
|
461 |
+
)
|
462 |
+
|
463 |
+
pretrained.model.start_index = start_index
|
464 |
+
pretrained.model.patch_size = [16, 16]
|
465 |
+
|
466 |
+
# We inject this function into the VisionTransformer instances so that
|
467 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
468 |
+
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
469 |
+
|
470 |
+
# We inject this function into the VisionTransformer instances so that
|
471 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
472 |
+
pretrained.model._resize_pos_embed = types.MethodType(
|
473 |
+
_resize_pos_embed, pretrained.model
|
474 |
+
)
|
475 |
+
|
476 |
+
return pretrained
|
477 |
+
|
478 |
+
|
479 |
+
def _make_pretrained_vitb_rn50_384(
|
480 |
+
pretrained, use_readout="ignore", hooks=None, use_vit_only=False
|
481 |
+
):
|
482 |
+
model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
|
483 |
+
|
484 |
+
hooks = [0, 1, 8, 11] if hooks == None else hooks
|
485 |
+
return _make_vit_b_rn50_backbone(
|
486 |
+
model,
|
487 |
+
features=[256, 512, 768, 768],
|
488 |
+
size=[384, 384],
|
489 |
+
hooks=hooks,
|
490 |
+
use_vit_only=use_vit_only,
|
491 |
+
use_readout=use_readout,
|
492 |
+
)
|
model/depth_engine.py
ADDED
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os, time, datetime
|
3 |
+
import torch
|
4 |
+
import torch.utils.tensorboard
|
5 |
+
import importlib
|
6 |
+
import shutil
|
7 |
+
import utils.util as util
|
8 |
+
import utils.util_vis as util_vis
|
9 |
+
|
10 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
11 |
+
from utils.util import print_eval, setup, cleanup
|
12 |
+
from utils.util import EasyDict as edict
|
13 |
+
from utils.eval_depth import DepthMetric
|
14 |
+
from copy import deepcopy
|
15 |
+
from model.compute_graph import graph_depth
|
16 |
+
|
17 |
+
# ============================ main engine for training and evaluation ============================
|
18 |
+
|
19 |
+
class Runner():
|
20 |
+
|
21 |
+
def __init__(self, opt):
|
22 |
+
super().__init__()
|
23 |
+
if os.path.isdir(opt.output_path) and opt.resume == False and opt.device == 0:
|
24 |
+
for filename in os.listdir(opt.output_path):
|
25 |
+
if "tfevents" in filename: os.remove(os.path.join(opt.output_path, filename))
|
26 |
+
if "html" in filename: os.remove(os.path.join(opt.output_path, filename))
|
27 |
+
if "vis" in filename: shutil.rmtree(os.path.join(opt.output_path, filename))
|
28 |
+
if "dump" in filename: shutil.rmtree(os.path.join(opt.output_path, filename))
|
29 |
+
if "embedding" in filename: shutil.rmtree(os.path.join(opt.output_path, filename))
|
30 |
+
if opt.device == 0:
|
31 |
+
os.makedirs(opt.output_path,exist_ok=True)
|
32 |
+
setup(opt.device, opt.world_size, opt.port)
|
33 |
+
opt.batch_size = opt.batch_size // opt.world_size
|
34 |
+
|
35 |
+
def get_viz_data(self, opt):
|
36 |
+
# get data for visualization
|
37 |
+
viz_data_list = []
|
38 |
+
sample_range = len(self.viz_loader)
|
39 |
+
viz_interval = sample_range // opt.eval.n_vis
|
40 |
+
for i in range(sample_range):
|
41 |
+
current_batch = next(self.viz_loader_iter)
|
42 |
+
if i % viz_interval != 0: continue
|
43 |
+
viz_data_list.append(current_batch)
|
44 |
+
return viz_data_list
|
45 |
+
|
46 |
+
def load_dataset(self, opt, eval_split="test"):
|
47 |
+
data_train = importlib.import_module('data.{}'.format(opt.data.dataset_train))
|
48 |
+
data_test = importlib.import_module('data.{}'.format(opt.data.dataset_test))
|
49 |
+
if opt.device == 0: print("loading training data...")
|
50 |
+
self.batch_order = []
|
51 |
+
self.train_data = data_train.Dataset(opt, split="train", load_3D=False)
|
52 |
+
self.train_loader = self.train_data.setup_loader(opt, shuffle=True, use_ddp=True, drop_last=True)
|
53 |
+
self.num_batches = len(self.train_loader)
|
54 |
+
if opt.device == 0: print("loading test data...")
|
55 |
+
self.test_data = data_test.Dataset(opt, split=eval_split, load_3D=False)
|
56 |
+
self.test_loader = self.test_data.setup_loader(opt, shuffle=False, use_ddp=True, drop_last=True, batch_size=opt.eval.batch_size)
|
57 |
+
self.num_batches_test = len(self.test_loader)
|
58 |
+
if len(self.test_loader.sampler) * opt.world_size < len(self.test_data):
|
59 |
+
self.aux_test_dataset = torch.utils.data.Subset(self.test_data,
|
60 |
+
range(len(self.test_loader.sampler) * opt.world_size, len(self.test_data)))
|
61 |
+
self.aux_test_loader = torch.utils.data.DataLoader(
|
62 |
+
self.aux_test_dataset, batch_size=opt.eval.batch_size, shuffle=False, drop_last=False,
|
63 |
+
num_workers=opt.data.num_workers)
|
64 |
+
if opt.device == 0:
|
65 |
+
print("creating data for visualization...")
|
66 |
+
self.viz_loader = self.test_data.setup_loader(opt, shuffle=False, use_ddp=False, drop_last=False, batch_size=1)
|
67 |
+
self.viz_loader_iter = iter(self.viz_loader)
|
68 |
+
self.viz_data = self.get_viz_data(opt)
|
69 |
+
|
70 |
+
def build_networks(self, opt):
|
71 |
+
if opt.device == 0: print("building networks...")
|
72 |
+
self.graph = DDP(graph_depth.Graph(opt).to(opt.device), device_ids=[opt.device], find_unused_parameters=True)
|
73 |
+
self.depth_metric = DepthMetric(thresholds=opt.eval.d_thresholds, depth_cap=opt.eval.depth_cap)
|
74 |
+
|
75 |
+
# =================================================== set up training =========================================================
|
76 |
+
|
77 |
+
def setup_optimizer(self, opt):
|
78 |
+
if opt.device == 0: print("setting up optimizers...")
|
79 |
+
param_nodecay = []
|
80 |
+
param_decay = []
|
81 |
+
for name, param in self.graph.named_parameters():
|
82 |
+
# skip and fixed params
|
83 |
+
if not param.requires_grad:
|
84 |
+
continue
|
85 |
+
if param.ndim <= 1 or name.endswith(".bias"):
|
86 |
+
# print("{} -> finetune_param_nodecay".format(name))
|
87 |
+
param_nodecay.append(param)
|
88 |
+
else:
|
89 |
+
param_decay.append(param)
|
90 |
+
# print("{} -> finetune_param_decay".format(name))
|
91 |
+
# create the optim dictionary
|
92 |
+
optim_dict = [
|
93 |
+
{'params': param_nodecay, 'lr': opt.optim.lr, 'weight_decay': 0.},
|
94 |
+
{'params': param_decay, 'lr': opt.optim.lr, 'weight_decay': opt.optim.weight_decay}
|
95 |
+
]
|
96 |
+
|
97 |
+
self.optim = torch.optim.AdamW(optim_dict, betas=(0.9, 0.95))
|
98 |
+
if opt.optim.sched:
|
99 |
+
self.sched = torch.optim.lr_scheduler.CosineAnnealingLR(self.optim, opt.max_epoch)
|
100 |
+
if opt.optim.amp:
|
101 |
+
self.scaler = torch.cuda.amp.GradScaler()
|
102 |
+
|
103 |
+
def restore_checkpoint(self, opt, best=False, evaluate=False):
|
104 |
+
epoch_start, iter_start = None, None
|
105 |
+
if opt.resume:
|
106 |
+
if opt.device == 0: print("resuming from previous checkpoint...")
|
107 |
+
epoch_start, iter_start, best_val, best_ep = util.restore_checkpoint(opt, self, resume=opt.resume, best=best, evaluate=evaluate)
|
108 |
+
self.best_val = best_val
|
109 |
+
self.best_ep = best_ep
|
110 |
+
elif opt.load is not None:
|
111 |
+
if opt.device == 0: print("loading weights from checkpoint {}...".format(opt.load))
|
112 |
+
epoch_start, iter_start, best_val, best_ep = util.restore_checkpoint(opt, self, load_name=opt.load)
|
113 |
+
else:
|
114 |
+
if opt.device == 0: print("initializing weights from scratch...")
|
115 |
+
self.epoch_start = epoch_start or 0
|
116 |
+
self.iter_start = iter_start or 0
|
117 |
+
|
118 |
+
def setup_visualizer(self, opt, test=False):
|
119 |
+
if opt.device == 0:
|
120 |
+
print("setting up visualizers...")
|
121 |
+
if opt.tb:
|
122 |
+
self.tb = torch.utils.tensorboard.SummaryWriter(log_dir=opt.output_path, flush_secs=10)
|
123 |
+
|
124 |
+
def train(self, opt):
|
125 |
+
# before training
|
126 |
+
torch.cuda.set_device(opt.device)
|
127 |
+
torch.cuda.empty_cache()
|
128 |
+
if opt.device == 0: print("TRAINING START")
|
129 |
+
self.train_metric_logger = util.MetricLogger(delimiter=" ")
|
130 |
+
self.train_metric_logger.add_meter('lr', util.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
131 |
+
self.iter_skip = self.iter_start % len(self.train_loader)
|
132 |
+
self.it = self.iter_start
|
133 |
+
self.skip_dis = False
|
134 |
+
if not opt.resume:
|
135 |
+
self.best_val = np.inf
|
136 |
+
self.best_ep = 1
|
137 |
+
# training
|
138 |
+
if self.iter_start == 0 and not opt.debug: self.evaluate(opt, ep=0, training=True)
|
139 |
+
# if opt.device == 0: self.save_checkpoint(opt, ep=0, it=0, best_val=self.best_val, best_ep=self.best_ep)
|
140 |
+
self.ep = self.epoch_start
|
141 |
+
for self.ep in range(self.epoch_start, opt.max_epoch):
|
142 |
+
self.train_epoch(opt)
|
143 |
+
# after training
|
144 |
+
if opt.device == 0: self.save_checkpoint(opt, ep=self.ep, it=self.it, best_val=self.best_val, best_ep=self.best_ep)
|
145 |
+
if opt.tb and opt.device == 0:
|
146 |
+
self.tb.flush()
|
147 |
+
self.tb.close()
|
148 |
+
if opt.device == 0:
|
149 |
+
print("TRAINING DONE")
|
150 |
+
print("Best val: %.4f @ epoch %d" % (self.best_val, self.best_ep))
|
151 |
+
cleanup()
|
152 |
+
|
153 |
+
def train_epoch(self, opt):
|
154 |
+
# before train epoch
|
155 |
+
self.train_loader.sampler.set_epoch(self.ep)
|
156 |
+
if opt.device == 0:
|
157 |
+
print("training epoch {}".format(self.ep+1))
|
158 |
+
batch_progress = range(self.num_batches)
|
159 |
+
self.graph.train()
|
160 |
+
# train epoch
|
161 |
+
loader = iter(self.train_loader)
|
162 |
+
|
163 |
+
for batch_id in batch_progress:
|
164 |
+
# if resuming from previous checkpoint, skip until the last iteration number is reached
|
165 |
+
if self.iter_skip>0:
|
166 |
+
self.iter_skip -= 1
|
167 |
+
continue
|
168 |
+
batch = next(loader)
|
169 |
+
# train iteration
|
170 |
+
var = edict(batch)
|
171 |
+
opt.H, opt.W = opt.image_size
|
172 |
+
var = util.move_to_device(var, opt.device)
|
173 |
+
loss = self.train_iteration(opt, var, batch_progress)
|
174 |
+
|
175 |
+
# after train epoch
|
176 |
+
lr = self.sched.get_last_lr()[0] if opt.optim.sched else opt.optim.lr
|
177 |
+
if opt.optim.sched: self.sched.step()
|
178 |
+
if (self.ep + 1) % opt.freq.eval == 0:
|
179 |
+
if opt.device == 0: print("validating epoch {}".format(self.ep+1))
|
180 |
+
current_val = self.evaluate(opt, ep=self.ep+1, training=True)
|
181 |
+
if current_val < self.best_val and opt.device == 0:
|
182 |
+
self.best_val = current_val
|
183 |
+
self.best_ep = self.ep + 1
|
184 |
+
self.save_checkpoint(opt, ep=self.ep, it=self.it, best_val=self.best_val, best_ep=self.best_ep, best=True, latest=True)
|
185 |
+
|
186 |
+
def train_iteration(self, opt, var, loader):
|
187 |
+
# before train iteration
|
188 |
+
torch.distributed.barrier()
|
189 |
+
|
190 |
+
# train iteration
|
191 |
+
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=opt.optim.amp):
|
192 |
+
var, loss = self.graph.forward(opt, var, training=True, get_loss=True)
|
193 |
+
loss = self.summarize_loss(opt, var, loss)
|
194 |
+
loss_scaled = loss.all / opt.optim.accum
|
195 |
+
|
196 |
+
# backward
|
197 |
+
if opt.optim.amp:
|
198 |
+
self.scaler.scale(loss_scaled).backward()
|
199 |
+
# skip update if accumulating gradient
|
200 |
+
if (self.it + 1) % opt.optim.accum == 0:
|
201 |
+
self.scaler.unscale_(self.optim)
|
202 |
+
# gradient clipping
|
203 |
+
if opt.optim.clip_norm:
|
204 |
+
norm = torch.nn.utils.clip_grad_norm_(self.graph.parameters(), opt.optim.clip_norm)
|
205 |
+
if opt.debug: print("Grad norm: {}".format(norm))
|
206 |
+
self.scaler.step(self.optim)
|
207 |
+
self.scaler.update()
|
208 |
+
self.optim.zero_grad()
|
209 |
+
else:
|
210 |
+
loss_scaled.backward()
|
211 |
+
if (self.it + 1) % opt.optim.accum == 0:
|
212 |
+
if opt.optim.clip_norm:
|
213 |
+
norm = torch.nn.utils.clip_grad_norm_(self.graph.parameters(), opt.optim.clip_norm)
|
214 |
+
if opt.debug: print("Grad norm: {}".format(norm))
|
215 |
+
self.optim.step()
|
216 |
+
self.optim.zero_grad()
|
217 |
+
|
218 |
+
# after train iteration
|
219 |
+
lr = self.sched.get_last_lr()[0] if opt.optim.sched else opt.optim.lr
|
220 |
+
self.train_metric_logger.update(lr=lr)
|
221 |
+
self.train_metric_logger.update(loss=loss.all)
|
222 |
+
if opt.device == 0:
|
223 |
+
if (self.it) % opt.freq.vis == 0 and not opt.debug:
|
224 |
+
self.visualize(opt, var, step=self.it, split="train")
|
225 |
+
if (self.it+1) % opt.freq.ckpt_latest == 0 and not opt.debug:
|
226 |
+
self.save_checkpoint(opt, ep=self.ep, it=self.it+1, best_val=self.best_val, best_ep=self.best_ep, latest=True)
|
227 |
+
if (self.it) % opt.freq.scalar == 0 and not opt.debug:
|
228 |
+
self.log_scalars(opt, var, loss, step=self.it, split="train")
|
229 |
+
if (self.it) % (opt.freq.save_vis * (self.it//10000*10+1)) == 0 and not opt.debug:
|
230 |
+
self.vis_train_iter(opt)
|
231 |
+
if (self.it) % opt.freq.print == 0:
|
232 |
+
print('[{}] '.format(datetime.datetime.now().time()), end='')
|
233 |
+
print(f'Train Iter {self.it}/{self.num_batches*opt.max_epoch}: {self.train_metric_logger}')
|
234 |
+
self.it += 1
|
235 |
+
return loss
|
236 |
+
|
237 |
+
@torch.no_grad()
|
238 |
+
def vis_train_iter(self, opt):
|
239 |
+
self.graph.eval()
|
240 |
+
for i in range(len(self.viz_data)):
|
241 |
+
var_viz = edict(deepcopy(self.viz_data[i]))
|
242 |
+
var_viz = util.move_to_device(var_viz, opt.device)
|
243 |
+
var_viz = self.graph.module(opt, var_viz, training=False, get_loss=False)
|
244 |
+
vis_folder = "vis_log/iter_{}".format(self.it)
|
245 |
+
os.makedirs("{}/{}".format(opt.output_path, vis_folder), exist_ok=True)
|
246 |
+
util_vis.dump_images(opt, var_viz.idx, "image_input", var_viz.rgb_input_map, masks=None, from_range=(0, 1), folder=vis_folder)
|
247 |
+
util_vis.dump_images(opt, var_viz.idx, "mask_input", var_viz.mask_input_map, folder=vis_folder)
|
248 |
+
util_vis.dump_depths(opt, var_viz.idx, "depth_est", var_viz.depth_pred, var_viz.mask_input_map, rescale=True, folder=vis_folder)
|
249 |
+
util_vis.dump_depths(opt, var_viz.idx, "depth_input", var_viz.depth_input_map, var_viz.mask_input_map, rescale=True, folder=vis_folder)
|
250 |
+
if 'seen_points_pred' in var_viz and 'seen_points_gt' in var_viz:
|
251 |
+
util_vis.dump_pointclouds_compare(opt, var_viz.idx, "seen_surface", var_viz.seen_points_pred, var_viz.seen_points_gt, folder=vis_folder)
|
252 |
+
self.graph.train()
|
253 |
+
|
254 |
+
def summarize_loss(self, opt, var, loss, non_act_loss_key=[]):
|
255 |
+
loss_all = 0.
|
256 |
+
assert("all" not in loss)
|
257 |
+
# weigh losses
|
258 |
+
for key in loss:
|
259 |
+
assert(key in opt.loss_weight)
|
260 |
+
if opt.loss_weight[key] is not None:
|
261 |
+
assert not torch.isinf(loss[key].mean()), "loss {} is Inf".format(key)
|
262 |
+
assert not torch.isnan(loss[key].mean()), "loss {} is NaN".format(key)
|
263 |
+
loss_all += float(opt.loss_weight[key])*loss[key].mean() if key not in non_act_loss_key else 0.0*loss[key].mean()
|
264 |
+
loss.update(all=loss_all)
|
265 |
+
return loss
|
266 |
+
|
267 |
+
# =================================================== set up evaluation =========================================================
|
268 |
+
|
269 |
+
@torch.no_grad()
|
270 |
+
def evaluate(self, opt, ep, training=False):
|
271 |
+
self.graph.eval()
|
272 |
+
loss_eval = edict()
|
273 |
+
|
274 |
+
# metric dictionary
|
275 |
+
metric_eval = {}
|
276 |
+
for metric_key in self.depth_metric.metric_keys:
|
277 |
+
metric_eval[metric_key] = []
|
278 |
+
metric_avg = {}
|
279 |
+
eval_metric_logger = util.MetricLogger(delimiter=" ")
|
280 |
+
|
281 |
+
# dataloader on the test set
|
282 |
+
with torch.cuda.device(opt.device):
|
283 |
+
for it, batch in enumerate(self.test_loader):
|
284 |
+
|
285 |
+
# inference the model
|
286 |
+
var = edict(batch)
|
287 |
+
var = self.evaluate_batch(opt, var, ep, it, single_gpu=False)
|
288 |
+
|
289 |
+
# record foreground mae for evaluation
|
290 |
+
sample_metrics, var.depth_pred_aligned = self.depth_metric.compute_metrics(
|
291 |
+
var.depth_pred, var.depth_input_map, var.mask_eroded if 'mask_eroded' in var else var.mask_input_map)
|
292 |
+
var.rmse = sample_metrics['rmse']
|
293 |
+
curr_metrics = {}
|
294 |
+
for metric_key in metric_eval:
|
295 |
+
metric_eval[metric_key].append(sample_metrics[metric_key])
|
296 |
+
curr_metrics[metric_key] = sample_metrics[metric_key].mean()
|
297 |
+
eval_metric_logger.update(**curr_metrics)
|
298 |
+
# eval_metric_logger.update(metric_key=sample_metrics[metric_key].mean())
|
299 |
+
|
300 |
+
# accumulate the scores
|
301 |
+
if opt.device == 0 and it % opt.freq.print_eval == 0:
|
302 |
+
print('[{}] '.format(datetime.datetime.now().time()), end='')
|
303 |
+
print(f'Eval Iter {it}/{len(self.test_loader)} @ EP {ep}: {eval_metric_logger}')
|
304 |
+
|
305 |
+
# dump the result if in eval mode
|
306 |
+
if not training:
|
307 |
+
self.dump_results(opt, var, ep, write_new=(it == 0))
|
308 |
+
|
309 |
+
# save the visualization
|
310 |
+
if it == 0 and training and opt.device == 0:
|
311 |
+
print("visualizing and saving results...")
|
312 |
+
for i in range(len(self.viz_data)):
|
313 |
+
var_viz = edict(deepcopy(self.viz_data[i]))
|
314 |
+
var_viz = self.evaluate_batch(opt, var_viz, ep, it, single_gpu=True)
|
315 |
+
self.visualize(opt, var_viz, step=ep, split="eval")
|
316 |
+
self.dump_results(opt, var_viz, ep, train=True)
|
317 |
+
|
318 |
+
# collect the eval results into tensors
|
319 |
+
for metric_key in metric_eval:
|
320 |
+
metric_eval[metric_key] = torch.cat(metric_eval[metric_key], dim=0)
|
321 |
+
|
322 |
+
if opt.world_size > 1:
|
323 |
+
metric_gather_dict = {}
|
324 |
+
# empty tensors for gathering
|
325 |
+
for metric_key in metric_eval:
|
326 |
+
metric_gather_dict[metric_key] = [torch.zeros_like(metric_eval[metric_key]).to(opt.device) for _ in range(opt.world_size)]
|
327 |
+
|
328 |
+
# gather the metrics
|
329 |
+
torch.distributed.barrier()
|
330 |
+
for metric_key in metric_eval:
|
331 |
+
torch.distributed.all_gather(metric_gather_dict[metric_key], metric_eval[metric_key])
|
332 |
+
metric_gather_dict[metric_key] = torch.cat(metric_gather_dict[metric_key], dim=0)
|
333 |
+
else:
|
334 |
+
metric_gather_dict = metric_eval
|
335 |
+
|
336 |
+
# handle last batch, if any
|
337 |
+
if len(self.test_loader.sampler) * opt.world_size < len(self.test_data):
|
338 |
+
for metric_key in metric_eval:
|
339 |
+
metric_gather_dict[metric_key] = [metric_gather_dict[metric_key]]
|
340 |
+
for batch in self.aux_test_loader:
|
341 |
+
# inference the model
|
342 |
+
var = edict(batch)
|
343 |
+
var = self.evaluate_batch(opt, var, ep, it, single_gpu=False)
|
344 |
+
|
345 |
+
# record MAE for evaluation
|
346 |
+
sample_metrics, var.depth_pred_aligned = self.depth_metric.compute_metrics(
|
347 |
+
var.depth_pred, var.depth_input_map, var.mask_eroded if 'mask_eroded' in var else var.mask_input_map)
|
348 |
+
var.rmse = sample_metrics['rmse']
|
349 |
+
for metric_key in metric_eval:
|
350 |
+
metric_gather_dict[metric_key].append(sample_metrics[metric_key])
|
351 |
+
|
352 |
+
# dump the result if in eval mode
|
353 |
+
if not training and opt.device == 0:
|
354 |
+
self.dump_results(opt, var, ep, write_new=(it == 0))
|
355 |
+
|
356 |
+
for metric_key in metric_eval:
|
357 |
+
metric_gather_dict[metric_key] = torch.cat(metric_gather_dict[metric_key], dim=0)
|
358 |
+
|
359 |
+
assert metric_gather_dict['l1_err'].shape[0] == len(self.test_data)
|
360 |
+
# compute the mean of the metrics
|
361 |
+
for metric_key in metric_eval:
|
362 |
+
metric_avg[metric_key] = metric_gather_dict[metric_key].mean()
|
363 |
+
|
364 |
+
# printout and save the metrics
|
365 |
+
if opt.device == 0:
|
366 |
+
# print eval info
|
367 |
+
print_eval(opt, depth_metrics=metric_avg)
|
368 |
+
val_metric = metric_avg['l1_err']
|
369 |
+
|
370 |
+
if training:
|
371 |
+
# log/visualize results to tb/vis
|
372 |
+
self.log_scalars(opt, var, loss_eval, metric=metric_avg, step=ep, split="eval")
|
373 |
+
|
374 |
+
if not training:
|
375 |
+
# write to file
|
376 |
+
metrics_file = os.path.join(opt.output_path, 'best_val.txt')
|
377 |
+
with open(metrics_file, "w") as outfile:
|
378 |
+
for metric_key in metric_avg:
|
379 |
+
outfile.write('{}: {:.6f}\n'.format(metric_key, metric_avg[metric_key].item()))
|
380 |
+
|
381 |
+
return val_metric.item()
|
382 |
+
return float('inf')
|
383 |
+
|
384 |
+
def evaluate_batch(self, opt, var, ep=None, it=None, single_gpu=False):
|
385 |
+
var = util.move_to_device(var, opt.device)
|
386 |
+
if single_gpu:
|
387 |
+
var = self.graph.module(opt, var, training=False, get_loss=False)
|
388 |
+
else:
|
389 |
+
var = self.graph(opt, var, training=False, get_loss=False)
|
390 |
+
return var
|
391 |
+
|
392 |
+
@torch.no_grad()
|
393 |
+
def log_scalars(self, opt, var, loss, metric=None, step=0, split="train"):
|
394 |
+
if split=="train":
|
395 |
+
sample_metrics, _ = self.depth_metric.compute_metrics(
|
396 |
+
var.depth_pred, var.depth_input_map, var.mask_eroded if 'mask_eroded' in var else var.mask_input_map)
|
397 |
+
metric = dict(L1_ERR=sample_metrics['l1_err'].mean().item())
|
398 |
+
for key, value in loss.items():
|
399 |
+
if key=="all": continue
|
400 |
+
self.tb.add_scalar("{0}/loss_{1}".format(split, key), value.mean(), step)
|
401 |
+
if metric is not None:
|
402 |
+
for key, value in metric.items():
|
403 |
+
self.tb.add_scalar("{0}/{1}".format(split, key), value, step)
|
404 |
+
|
405 |
+
@torch.no_grad()
|
406 |
+
def visualize(self, opt, var, step=0, split="train"):
|
407 |
+
pass
|
408 |
+
|
409 |
+
@torch.no_grad()
|
410 |
+
def dump_results(self, opt, var, ep, write_new=False, train=False):
|
411 |
+
# create the dir
|
412 |
+
current_folder = "dump" if train == False else "vis_{}".format(ep)
|
413 |
+
os.makedirs("{}/{}/".format(opt.output_path, current_folder), exist_ok=True)
|
414 |
+
|
415 |
+
# save the results
|
416 |
+
util_vis.dump_images(opt, var.idx, "image_input", var.rgb_input_map, masks=None, from_range=(0, 1), folder=current_folder)
|
417 |
+
util_vis.dump_images(opt, var.idx, "mask_input", var.mask_input_map, folder=current_folder)
|
418 |
+
util_vis.dump_depths(opt, var.idx, "depth_pred", var.depth_pred, var.mask_input_map, rescale=True, folder=current_folder)
|
419 |
+
util_vis.dump_depths(opt, var.idx, "depth_input", var.depth_input_map, var.mask_input_map, rescale=True, folder=current_folder)
|
420 |
+
if 'seen_points_pred' in var and 'seen_points_gt' in var:
|
421 |
+
util_vis.dump_pointclouds_compare(opt, var.idx, "seen_surface", var.seen_points_pred, var.seen_points_gt, folder=current_folder)
|
422 |
+
|
423 |
+
if "depth_pred_aligned" in var:
|
424 |
+
# get the max and min for the depth map
|
425 |
+
batch_size = var.depth_input_map.shape[0]
|
426 |
+
mask = var.mask_eroded if 'mask_eroded' in var else var.mask_input_map
|
427 |
+
masked_depth_far_bg = var.depth_input_map * mask + (1 - mask) * 1000
|
428 |
+
depth_min_gt = masked_depth_far_bg.view(batch_size, -1).min(dim=1)[0]
|
429 |
+
masked_depth_invalid_bg = var.depth_input_map * mask + (1 - mask) * 0
|
430 |
+
depth_max_gt = masked_depth_invalid_bg.view(batch_size, -1).max(dim=1)[0]
|
431 |
+
depth_vis_pred = (var.depth_pred_aligned - depth_min_gt.view(batch_size, 1, 1, 1)) / (depth_max_gt - depth_min_gt).view(batch_size, 1, 1, 1)
|
432 |
+
depth_vis_pred = depth_vis_pred * mask + (1 - mask)
|
433 |
+
depth_vis_gt = (var.depth_input_map - depth_min_gt.view(batch_size, 1, 1, 1)) / (depth_max_gt - depth_min_gt).view(batch_size, 1, 1, 1)
|
434 |
+
depth_vis_gt = depth_vis_gt * mask + (1 - mask)
|
435 |
+
util_vis.dump_depths(opt, var.idx, "depth_gt_aligned", depth_vis_gt.clamp(max=1, min=0), None, rescale=False, folder=current_folder)
|
436 |
+
util_vis.dump_depths(opt, var.idx, "depth_pred_aligned", depth_vis_pred.clamp(max=1, min=0), None, rescale=False, folder=current_folder)
|
437 |
+
if "mask_eroded" in var and "rmse" in var:
|
438 |
+
util_vis.dump_images(opt, var.idx, "image_eroded", var.rgb_input_map, masks=var.mask_eroded, metrics=var.rmse, from_range=(0, 1), folder=current_folder)
|
439 |
+
|
440 |
+
def save_checkpoint(self, opt, ep=0, it=0, best_val=np.inf, best_ep=1, latest=False, best=False):
|
441 |
+
util.save_checkpoint(opt, self, ep=ep, it=it, best_val=best_val, best_ep=best_ep, latest=latest, best=best)
|
442 |
+
if not latest:
|
443 |
+
print("checkpoint saved: ({0}) {1}, epoch {2} (iteration {3})".format(opt.group, opt.name, ep, it))
|
444 |
+
if best:
|
445 |
+
print("Saving the current model as the best...")
|
model/shape/implicit.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from functools import partial
|
6 |
+
from utils.layers import get_embedder
|
7 |
+
from utils.layers import LayerScale
|
8 |
+
from timm.models.vision_transformer import Mlp, DropPath
|
9 |
+
from utils.pos_embed import get_2d_sincos_pos_embed
|
10 |
+
|
11 |
+
class ImplFuncAttention(nn.Module):
|
12 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., last_layer=False):
|
13 |
+
super().__init__()
|
14 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
15 |
+
self.num_heads = num_heads
|
16 |
+
head_dim = dim // num_heads
|
17 |
+
self.scale = head_dim ** -0.5
|
18 |
+
|
19 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
20 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
21 |
+
self.proj = nn.Linear(dim, dim)
|
22 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
23 |
+
self.last_layer = last_layer
|
24 |
+
|
25 |
+
def forward(self, x, N_points):
|
26 |
+
|
27 |
+
B, N, C = x.shape
|
28 |
+
N_latent = N - N_points
|
29 |
+
# [3, B, num_heads, N, C/num_heads]
|
30 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
31 |
+
# [B, num_heads, N, C/num_heads]
|
32 |
+
q, k, v = qkv.unbind(0)
|
33 |
+
# [B, num_heads, N_latent, C/num_heads]
|
34 |
+
q_latent, k_latent, v_latent = q[:, :, :-N_points], k[:, :, :-N_points], v[:, :, :-N_points]
|
35 |
+
# [B, num_heads, N_points, C/num_heads]
|
36 |
+
q_points, k_points, v_points = q[:, :, -N_points:], k[:, :, -N_points:], v[:, :, -N_points:]
|
37 |
+
|
38 |
+
# attention weight for each point, it's only connected to the latent and itself
|
39 |
+
# [B, num_heads, N_points, N_latent+1]
|
40 |
+
# get the cross attention, [B, num_heads, N_points, N_latent]
|
41 |
+
attn_cross = (q_points @ k_latent.transpose(-2, -1)) * self.scale
|
42 |
+
# get the attention to self feature, [B, num_heads, N_points, 1]
|
43 |
+
attn_self = torch.sum(q_points * k_points, dim=-1, keepdim=True) * self.scale
|
44 |
+
# get the normalized attention, [B, num_heads, N_points, N_latent+1]
|
45 |
+
attn_joint = torch.cat([attn_cross, attn_self], dim=-1)
|
46 |
+
attn_joint = attn_joint.softmax(dim=-1)
|
47 |
+
attn_joint = self.attn_drop(attn_joint)
|
48 |
+
|
49 |
+
# break it down to weigh and sum the values
|
50 |
+
# [B, num_heads, N_points, N_latent] @ [B, num_heads, N_latent, C/num_heads]
|
51 |
+
# -> [B, num_heads, N_points, C/num_heads] -> [B, N_points, C]
|
52 |
+
sum_cross = (attn_joint[:, :, :, :N_latent] @ v_latent).transpose(1, 2).reshape(B, N_points, C)
|
53 |
+
# [B, num_heads, N_points, 1] * [B, num_heads, N_points, C/num_heads]
|
54 |
+
# -> [B, num_heads, N_points, C/num_heads] -> [B, N_points, C]
|
55 |
+
sum_self = (attn_joint[:, :, :, N_latent:] * v_points).transpose(1, 2).reshape(B, N_points, C)
|
56 |
+
# [B, N_points, C]
|
57 |
+
output_points = sum_cross + sum_self
|
58 |
+
|
59 |
+
if self.last_layer:
|
60 |
+
output = self.proj(output_points)
|
61 |
+
output = self.proj_drop(output)
|
62 |
+
# [B, N_points, C], [B, N_points, N_latent]
|
63 |
+
return output, attn_joint[..., :-1].mean(dim=1)
|
64 |
+
|
65 |
+
# attention weight for the latent vec, it's not connected to the points
|
66 |
+
# [B, num_heads, N_latent, N_latent]
|
67 |
+
attn_latent = (q_latent @ k_latent.transpose(-2, -1)) * self.scale
|
68 |
+
attn_latent = attn_latent.softmax(dim=-1)
|
69 |
+
attn_latent = self.attn_drop(attn_latent)
|
70 |
+
# get the output latent, [B, N_latent, C]
|
71 |
+
output_latent = (attn_latent @ v_latent).transpose(1, 2).reshape(B, N_latent, C)
|
72 |
+
|
73 |
+
# concatenate the output and return
|
74 |
+
output = torch.cat([output_latent, output_points], dim=1)
|
75 |
+
output = self.proj(output)
|
76 |
+
output = self.proj_drop(output)
|
77 |
+
|
78 |
+
# [B, N, C], [B, N_points, N_latent+1]
|
79 |
+
return output, attn_joint[..., :-1].mean(dim=1)
|
80 |
+
|
81 |
+
class ImplFuncBlock(nn.Module):
|
82 |
+
|
83 |
+
def __init__(
|
84 |
+
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
|
85 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, last_layer=False):
|
86 |
+
super().__init__()
|
87 |
+
self.last_layer = last_layer
|
88 |
+
self.norm1 = norm_layer(dim)
|
89 |
+
self.attn = ImplFuncAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, last_layer=last_layer)
|
90 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
91 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
92 |
+
|
93 |
+
self.norm2 = norm_layer(dim)
|
94 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
95 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
96 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
97 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
98 |
+
|
99 |
+
def forward(self, x, unseen_size):
|
100 |
+
if self.last_layer:
|
101 |
+
attn_out, attn_vis = self.attn(self.norm1(x), unseen_size)
|
102 |
+
output = x[:, -unseen_size:] + self.drop_path1(self.ls1(attn_out))
|
103 |
+
output = output + self.drop_path2(self.ls2(self.mlp(self.norm2(output))))
|
104 |
+
return output, attn_vis
|
105 |
+
else:
|
106 |
+
attn_out, attn_vis = self.attn(self.norm1(x), unseen_size)
|
107 |
+
x = x + self.drop_path1(self.ls1(attn_out))
|
108 |
+
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
109 |
+
return x, attn_vis
|
110 |
+
|
111 |
+
class LinearProj3D(nn.Module):
|
112 |
+
"""
|
113 |
+
Linear projection of 3D point into embedding space
|
114 |
+
"""
|
115 |
+
def __init__(self, embed_dim, posenc_res=0):
|
116 |
+
super().__init__()
|
117 |
+
self.embed_dim = embed_dim
|
118 |
+
|
119 |
+
# define positional embedder
|
120 |
+
self.embed_fn = None
|
121 |
+
input_ch = 3
|
122 |
+
if posenc_res > 0:
|
123 |
+
self.embed_fn, input_ch = get_embedder(posenc_res, input_dims=3)
|
124 |
+
|
125 |
+
# linear proj layer
|
126 |
+
self.proj = nn.Linear(input_ch, embed_dim)
|
127 |
+
|
128 |
+
def forward(self, points_3D):
|
129 |
+
if self.embed_fn is not None:
|
130 |
+
points_3D = self.embed_fn(points_3D)
|
131 |
+
return self.proj(points_3D)
|
132 |
+
|
133 |
+
class MLPBlocks(nn.Module):
|
134 |
+
def __init__(self, num_hidden_layers, n_channels, latent_dim,
|
135 |
+
skip_in=[], posenc_res=0):
|
136 |
+
super().__init__()
|
137 |
+
|
138 |
+
# projection to the same number of channels
|
139 |
+
self.dims = [3 + latent_dim] + [n_channels] * num_hidden_layers + [1]
|
140 |
+
self.num_layers = len(self.dims)
|
141 |
+
self.skip_in = skip_in
|
142 |
+
|
143 |
+
# define positional embedder
|
144 |
+
self.embed_fn = None
|
145 |
+
if posenc_res > 0:
|
146 |
+
embed_fn, input_ch = get_embedder(posenc_res, input_dims=3)
|
147 |
+
self.embed_fn = embed_fn
|
148 |
+
self.dims[0] += (input_ch - 3)
|
149 |
+
|
150 |
+
self.layers = nn.ModuleList([])
|
151 |
+
|
152 |
+
for l in range(0, self.num_layers - 1):
|
153 |
+
out_dim = self.dims[l + 1]
|
154 |
+
if l in self.skip_in:
|
155 |
+
in_dim = self.dims[l] + self.dims[0]
|
156 |
+
else:
|
157 |
+
in_dim = self.dims[l]
|
158 |
+
|
159 |
+
lin = nn.Linear(in_dim, out_dim)
|
160 |
+
self.layers.append(lin)
|
161 |
+
|
162 |
+
# register for param init
|
163 |
+
self.posenc_res = posenc_res
|
164 |
+
|
165 |
+
# activation
|
166 |
+
self.softplus = nn.Softplus(beta=100)
|
167 |
+
|
168 |
+
def forward(self, points, proj_latent):
|
169 |
+
|
170 |
+
# positional encoding
|
171 |
+
if self.embed_fn is not None:
|
172 |
+
points = self.embed_fn(points)
|
173 |
+
|
174 |
+
# forward by layer
|
175 |
+
# [B, N, posenc+C]
|
176 |
+
inputs = torch.cat([points, proj_latent], dim=-1)
|
177 |
+
x = inputs
|
178 |
+
for l in range(0, self.num_layers - 1):
|
179 |
+
if l in self.skip_in:
|
180 |
+
x = torch.cat([x, inputs], -1) / np.sqrt(2)
|
181 |
+
x = self.layers[l](x)
|
182 |
+
if l < self.num_layers - 2:
|
183 |
+
x = self.softplus(x)
|
184 |
+
return x
|
185 |
+
|
186 |
+
class Implicit(nn.Module):
|
187 |
+
"""
|
188 |
+
Implicit function conditioned on depth encodings
|
189 |
+
"""
|
190 |
+
def __init__(self,
|
191 |
+
num_patches, latent_dim=768, semantic=False, n_channels=512,
|
192 |
+
n_blocks_attn=2, n_layers_mlp=6, num_heads=16, posenc_3D=0,
|
193 |
+
mlp_ratio=4., norm_layer=partial(nn.LayerNorm, eps=1e-6), drop_path=0.1,
|
194 |
+
skip_in=[], pos_perlayer=True):
|
195 |
+
super().__init__()
|
196 |
+
self.num_patches = num_patches
|
197 |
+
self.pos_perlayer = pos_perlayer
|
198 |
+
self.semantic = semantic
|
199 |
+
|
200 |
+
# projection to the same number of channels, no posenc
|
201 |
+
self.point_proj = LinearProj3D(n_channels)
|
202 |
+
self.latent_proj = nn.Linear(latent_dim, n_channels, bias=True)
|
203 |
+
|
204 |
+
# positional embedding for the depth latent codes
|
205 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, n_channels), requires_grad=False) # fixed sin-cos embedding
|
206 |
+
|
207 |
+
# multi-head attention blocks
|
208 |
+
self.blocks_attn = nn.ModuleList([
|
209 |
+
ImplFuncBlock(
|
210 |
+
n_channels, num_heads, mlp_ratio,
|
211 |
+
qkv_bias=True, norm_layer=norm_layer, drop_path=drop_path
|
212 |
+
) for _ in range(n_blocks_attn-1)])
|
213 |
+
self.blocks_attn.append(
|
214 |
+
ImplFuncBlock(
|
215 |
+
n_channels, num_heads, mlp_ratio,
|
216 |
+
qkv_bias=True, norm_layer=norm_layer, drop_path=drop_path, last_layer=True
|
217 |
+
)
|
218 |
+
)
|
219 |
+
self.norm = norm_layer(n_channels)
|
220 |
+
|
221 |
+
self.impl_mlp = None
|
222 |
+
# define the impl MLP
|
223 |
+
if n_layers_mlp > 0:
|
224 |
+
self.impl_mlp = MLPBlocks(n_layers_mlp, n_channels, n_channels,
|
225 |
+
skip_in=skip_in, posenc_res=posenc_3D)
|
226 |
+
else:
|
227 |
+
# occ and color prediction
|
228 |
+
self.pred_head = nn.Linear(n_channels, 1, bias=True)
|
229 |
+
|
230 |
+
self.initialize_weights()
|
231 |
+
|
232 |
+
def initialize_weights(self):
|
233 |
+
|
234 |
+
# initialize the positional embedding for the depth latent codes
|
235 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.num_patches**.5), cls_token=True)
|
236 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
237 |
+
|
238 |
+
# initialize nn.Linear and nn.LayerNorm
|
239 |
+
self.apply(self._init_weights)
|
240 |
+
|
241 |
+
def _init_weights(self, m):
|
242 |
+
if isinstance(m, nn.Linear):
|
243 |
+
# we use xavier_uniform following official JAX ViT:
|
244 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
245 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
246 |
+
nn.init.constant_(m.bias, 0)
|
247 |
+
elif isinstance(m, nn.LayerNorm):
|
248 |
+
nn.init.constant_(m.bias, 0)
|
249 |
+
nn.init.constant_(m.weight, 1.0)
|
250 |
+
|
251 |
+
def forward(self, latent_depth, latent_semantic, points_3D):
|
252 |
+
# concatenate latent codes if semantic is used
|
253 |
+
latent = torch.cat([latent_depth, latent_semantic], dim=-1) if self.semantic else latent_depth
|
254 |
+
|
255 |
+
# project latent code and add posenc
|
256 |
+
# [B, 1+n_patches, C]
|
257 |
+
latent = self.latent_proj(latent)
|
258 |
+
N_latent = latent.shape[1]
|
259 |
+
|
260 |
+
# project query points
|
261 |
+
# [B, n_points, C_dec]
|
262 |
+
points_feat = self.point_proj(points_3D)
|
263 |
+
|
264 |
+
# concat point feat with latent
|
265 |
+
# [B, 1+n_patches+n_points, C_dec]
|
266 |
+
output = torch.cat([latent, points_feat], dim=1)
|
267 |
+
|
268 |
+
# apply multi-head attention blocks
|
269 |
+
attn_vis = []
|
270 |
+
for l, blk in enumerate(self.blocks_attn):
|
271 |
+
if self.pos_perlayer or l == 0:
|
272 |
+
output[:, :N_latent] = output[:, :N_latent] + self.pos_embed
|
273 |
+
output, attn = blk(output, points_feat.shape[1])
|
274 |
+
attn_vis.append(attn)
|
275 |
+
output = self.norm(output)
|
276 |
+
# average of attention weights across layers, [B, N_points, N_latent+1]
|
277 |
+
attn_vis = torch.stack(attn_vis, dim=-1).mean(dim=-1)
|
278 |
+
|
279 |
+
if self.impl_mlp:
|
280 |
+
# apply mlp blocks
|
281 |
+
output = self.impl_mlp(points_3D, output)
|
282 |
+
else:
|
283 |
+
# predictor projection
|
284 |
+
# [B, n_points, 1]
|
285 |
+
output = self.pred_head(output)
|
286 |
+
|
287 |
+
# return the occ logit of shape [B, n_points] and the attention weights if needed
|
288 |
+
return output.squeeze(-1), attn_vis
|
model/shape/rgb_enc.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This source code is written based on https://github.com/facebookresearch/MCC
|
2 |
+
# The original code base is licensed under the license found in the LICENSE file in the root directory.
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torchvision
|
7 |
+
|
8 |
+
from functools import partial
|
9 |
+
from timm.models.vision_transformer import Block, PatchEmbed
|
10 |
+
from utils.pos_embed import get_2d_sincos_pos_embed
|
11 |
+
from utils.layers import Bottleneck_Conv
|
12 |
+
|
13 |
+
class RGBEncAtt(nn.Module):
|
14 |
+
"""
|
15 |
+
Seen surface encoder based on transformer.
|
16 |
+
"""
|
17 |
+
def __init__(self,
|
18 |
+
img_size=224, embed_dim=768, n_blocks=12, num_heads=12, win_size=16,
|
19 |
+
mlp_ratio=4., norm_layer=partial(nn.LayerNorm, eps=1e-6), drop_path=0.1):
|
20 |
+
super().__init__()
|
21 |
+
|
22 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
23 |
+
self.rgb_embed = PatchEmbed(img_size, win_size, 3, embed_dim)
|
24 |
+
|
25 |
+
num_patches = self.rgb_embed.num_patches
|
26 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False)
|
27 |
+
|
28 |
+
self.blocks = nn.ModuleList([
|
29 |
+
Block(
|
30 |
+
embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
|
31 |
+
drop_path=drop_path
|
32 |
+
) for _ in range(n_blocks)])
|
33 |
+
|
34 |
+
self.norm = norm_layer(embed_dim)
|
35 |
+
|
36 |
+
self.initialize_weights()
|
37 |
+
|
38 |
+
def initialize_weights(self):
|
39 |
+
# initialize the pos enc with fixed cos-sin pattern
|
40 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.rgb_embed.num_patches**.5), cls_token=True)
|
41 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
42 |
+
|
43 |
+
# initialize rgb patch_embed like nn.Linear (instead of nn.Conv2d)
|
44 |
+
w = self.rgb_embed.proj.weight.data
|
45 |
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
46 |
+
|
47 |
+
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
48 |
+
torch.nn.init.normal_(self.cls_token, std=.02)
|
49 |
+
|
50 |
+
# initialize nn.Linear and nn.LayerNorm
|
51 |
+
self.apply(self._init_weights)
|
52 |
+
|
53 |
+
def _init_weights(self, m):
|
54 |
+
if isinstance(m, nn.Linear):
|
55 |
+
# we use xavier_uniform following official JAX ViT:
|
56 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
57 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
58 |
+
nn.init.constant_(m.bias, 0)
|
59 |
+
elif isinstance(m, nn.LayerNorm):
|
60 |
+
nn.init.constant_(m.bias, 0)
|
61 |
+
nn.init.constant_(m.weight, 1.0)
|
62 |
+
|
63 |
+
def forward(self, rgb_obj):
|
64 |
+
|
65 |
+
# [B, H/ws*W/ws, C]
|
66 |
+
rgb_embedding = self.rgb_embed(rgb_obj)
|
67 |
+
rgb_embedding = rgb_embedding + self.pos_embed[:, 1:, :]
|
68 |
+
|
69 |
+
# append cls token
|
70 |
+
# [1, 1, C]
|
71 |
+
cls_token = self.cls_token + self.pos_embed[:, :1, :]
|
72 |
+
# [B, 1, C]
|
73 |
+
cls_tokens = cls_token.expand(rgb_embedding.shape[0], -1, -1)
|
74 |
+
|
75 |
+
# [B, H/ws*W/ws+1, C]
|
76 |
+
rgb_embedding = torch.cat((cls_tokens, rgb_embedding), dim=1)
|
77 |
+
|
78 |
+
# apply Transformer blocks
|
79 |
+
for blk in self.blocks:
|
80 |
+
rgb_embedding = blk(rgb_embedding)
|
81 |
+
rgb_embedding = self.norm(rgb_embedding)
|
82 |
+
|
83 |
+
# [B, H/ws*W/ws+1, C]
|
84 |
+
return rgb_embedding
|
85 |
+
|
86 |
+
class RGBEncRes(nn.Module):
|
87 |
+
"""
|
88 |
+
RGB encoder based on resnet.
|
89 |
+
"""
|
90 |
+
def __init__(self, opt):
|
91 |
+
super().__init__()
|
92 |
+
|
93 |
+
self.encoder = torchvision.models.resnet50(pretrained=True)
|
94 |
+
self.encoder.fc = nn.Sequential(
|
95 |
+
Bottleneck_Conv(2048),
|
96 |
+
Bottleneck_Conv(2048),
|
97 |
+
nn.Linear(2048, opt.arch.latent_dim)
|
98 |
+
)
|
99 |
+
|
100 |
+
# define hooks
|
101 |
+
self.rgb_feature = None
|
102 |
+
def feature_hook(model, input, output):
|
103 |
+
self.rgb_feature = output
|
104 |
+
|
105 |
+
# attach hooks
|
106 |
+
if (opt.arch.win_size) == 16:
|
107 |
+
self.encoder.layer3.register_forward_hook(feature_hook)
|
108 |
+
self.rgb_feat_proj = nn.Sequential(
|
109 |
+
Bottleneck_Conv(1024),
|
110 |
+
Bottleneck_Conv(1024),
|
111 |
+
nn.Conv2d(1024, opt.arch.latent_dim, 1)
|
112 |
+
)
|
113 |
+
elif (opt.arch.win_size) == 32:
|
114 |
+
self.encoder.layer4.register_forward_hook(feature_hook)
|
115 |
+
self.rgb_feat_proj = nn.Sequential(
|
116 |
+
Bottleneck_Conv(2048),
|
117 |
+
Bottleneck_Conv(2048),
|
118 |
+
nn.Conv2d(2048, opt.arch.latent_dim, 1)
|
119 |
+
)
|
120 |
+
else:
|
121 |
+
print('Make sure win_size is 16 or 32 when using resnet backbone!')
|
122 |
+
raise NotImplementedError
|
123 |
+
|
124 |
+
def forward(self, rgb_obj):
|
125 |
+
batch_size = rgb_obj.shape[0]
|
126 |
+
assert len(rgb_obj.shape) == 4
|
127 |
+
|
128 |
+
# [B, 1, C]
|
129 |
+
global_feat = self.encoder(rgb_obj).unsqueeze(1)
|
130 |
+
# [B, C, H/ws*W/ws]
|
131 |
+
local_feat = self.rgb_feat_proj(self.rgb_feature).view(batch_size, global_feat.shape[-1], -1)
|
132 |
+
# [B, H/ws*W/ws, C]
|
133 |
+
local_feat = local_feat.permute(0, 2, 1).contiguous()
|
134 |
+
# [B, 1+H/ws*W/ws, C]
|
135 |
+
rgb_embedding = torch.cat([global_feat, local_feat], dim=1)
|
136 |
+
|
137 |
+
return rgb_embedding
|
model/shape/seen_coord_enc.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This source code is written based on https://github.com/facebookresearch/MCC
|
2 |
+
# The original code base is licensed under the license found in the LICENSE file in the root directory.
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torchvision
|
7 |
+
|
8 |
+
from functools import partial
|
9 |
+
from timm.models.vision_transformer import Block
|
10 |
+
from utils.pos_embed import get_2d_sincos_pos_embed
|
11 |
+
from utils.layers import Bottleneck_Conv
|
12 |
+
|
13 |
+
class CoordEmb(nn.Module):
|
14 |
+
"""
|
15 |
+
Encode the seen coordinate map to a lower resolution feature map
|
16 |
+
Achieved with window-wise attention block by deviding coord map into windows
|
17 |
+
Each window is seperately encoded into a single CLS token with self-attention and posenc
|
18 |
+
"""
|
19 |
+
def __init__(self, embed_dim, win_size=8, num_heads=8):
|
20 |
+
super().__init__()
|
21 |
+
self.embed_dim = embed_dim
|
22 |
+
self.win_size = win_size
|
23 |
+
|
24 |
+
self.two_d_pos_embed = nn.Parameter(
|
25 |
+
torch.zeros(1, self.win_size*self.win_size + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
|
26 |
+
|
27 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
28 |
+
|
29 |
+
self.pos_embed = nn.Linear(3, embed_dim)
|
30 |
+
|
31 |
+
self.blocks = nn.ModuleList([
|
32 |
+
# each block is a residual block with layernorm -> attention -> layernorm -> mlp
|
33 |
+
Block(embed_dim, num_heads=num_heads, mlp_ratio=2.0, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6))
|
34 |
+
for _ in range(1)
|
35 |
+
])
|
36 |
+
|
37 |
+
self.invalid_coord_token = nn.Parameter(torch.zeros(embed_dim,))
|
38 |
+
|
39 |
+
self.initialize_weights()
|
40 |
+
|
41 |
+
def initialize_weights(self):
|
42 |
+
torch.nn.init.normal_(self.cls_token, std=.02)
|
43 |
+
|
44 |
+
two_d_pos_embed = get_2d_sincos_pos_embed(self.two_d_pos_embed.shape[-1], self.win_size, cls_token=True)
|
45 |
+
self.two_d_pos_embed.data.copy_(torch.from_numpy(two_d_pos_embed).float().unsqueeze(0))
|
46 |
+
|
47 |
+
torch.nn.init.normal_(self.invalid_coord_token, std=.02)
|
48 |
+
|
49 |
+
def forward(self, coord_obj, mask_obj):
|
50 |
+
# [B, H, W, C]
|
51 |
+
emb = self.pos_embed(coord_obj)
|
52 |
+
|
53 |
+
emb[~mask_obj] = 0.0
|
54 |
+
emb[~mask_obj] += self.invalid_coord_token
|
55 |
+
|
56 |
+
B, H, W, C = emb.shape
|
57 |
+
# [B, H/ws, 8, W/ws, W, C]
|
58 |
+
emb = emb.view(B, H // self.win_size, self.win_size, W // self.win_size, self.win_size, C)
|
59 |
+
# [B * H/ws * W/ws, 64, C]
|
60 |
+
emb = emb.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.win_size * self.win_size, C)
|
61 |
+
|
62 |
+
# [B * H/ws * W/ws, 64, C], add posenc that is local to each patch
|
63 |
+
emb = emb + self.two_d_pos_embed[:, 1:, :]
|
64 |
+
# [1, 1, C]
|
65 |
+
cls_token = self.cls_token + self.two_d_pos_embed[:, :1, :]
|
66 |
+
|
67 |
+
# [B * H/ws * W/ws, 1, C]
|
68 |
+
cls_tokens = cls_token.expand(emb.shape[0], -1, -1)
|
69 |
+
# [B * H/ws * W/ws, 65, C]
|
70 |
+
emb = torch.cat((cls_tokens, emb), dim=1)
|
71 |
+
|
72 |
+
# transformer (single block) that handle each of the patch seperately
|
73 |
+
# reasoning is done within each batch
|
74 |
+
for _, blk in enumerate(self.blocks):
|
75 |
+
emb = blk(emb)
|
76 |
+
|
77 |
+
# return the cls token of each window, [B, H/ws*W/ws, C]
|
78 |
+
return emb[:, 0].view(B, (H // self.win_size) * (W // self.win_size), -1)
|
79 |
+
|
80 |
+
class CoordEncAtt(nn.Module):
|
81 |
+
"""
|
82 |
+
Seen surface encoder based on transformer.
|
83 |
+
"""
|
84 |
+
def __init__(self,
|
85 |
+
embed_dim=768, n_blocks=12, num_heads=12, win_size=8,
|
86 |
+
mlp_ratio=4., norm_layer=partial(nn.LayerNorm, eps=1e-6), drop_path=0.1):
|
87 |
+
super().__init__()
|
88 |
+
|
89 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
90 |
+
self.coord_embed = CoordEmb(embed_dim, win_size, num_heads)
|
91 |
+
|
92 |
+
self.blocks = nn.ModuleList([
|
93 |
+
Block(
|
94 |
+
embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
|
95 |
+
drop_path=drop_path
|
96 |
+
) for _ in range(n_blocks)])
|
97 |
+
|
98 |
+
self.norm = norm_layer(embed_dim)
|
99 |
+
|
100 |
+
self.initialize_weights()
|
101 |
+
|
102 |
+
def initialize_weights(self):
|
103 |
+
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
104 |
+
torch.nn.init.normal_(self.cls_token, std=.02)
|
105 |
+
|
106 |
+
# initialize nn.Linear and nn.LayerNorm
|
107 |
+
self.apply(self._init_weights)
|
108 |
+
|
109 |
+
def _init_weights(self, m):
|
110 |
+
if isinstance(m, nn.Linear):
|
111 |
+
# we use xavier_uniform following official JAX ViT:
|
112 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
113 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
114 |
+
nn.init.constant_(m.bias, 0)
|
115 |
+
elif isinstance(m, nn.LayerNorm):
|
116 |
+
nn.init.constant_(m.bias, 0)
|
117 |
+
nn.init.constant_(m.weight, 1.0)
|
118 |
+
|
119 |
+
def forward(self, coord_obj, mask_obj):
|
120 |
+
|
121 |
+
# [B, H/ws*W/ws, C]
|
122 |
+
coord_embedding = self.coord_embed(coord_obj, mask_obj)
|
123 |
+
|
124 |
+
# append cls token
|
125 |
+
# [1, 1, C]
|
126 |
+
cls_token = self.cls_token
|
127 |
+
# [B, 1, C]
|
128 |
+
cls_tokens = cls_token.expand(coord_embedding.shape[0], -1, -1)
|
129 |
+
|
130 |
+
# [B, H/ws*W/ws+1, C]
|
131 |
+
coord_embedding = torch.cat((cls_tokens, coord_embedding), dim=1)
|
132 |
+
|
133 |
+
# apply Transformer blocks
|
134 |
+
for blk in self.blocks:
|
135 |
+
coord_embedding = blk(coord_embedding)
|
136 |
+
coord_embedding = self.norm(coord_embedding)
|
137 |
+
|
138 |
+
# [B, H/ws*W/ws+1, C]
|
139 |
+
return coord_embedding
|
140 |
+
|
141 |
+
class CoordEncRes(nn.Module):
|
142 |
+
"""
|
143 |
+
Seen surface encoder based on resnet.
|
144 |
+
"""
|
145 |
+
def __init__(self, opt):
|
146 |
+
super().__init__()
|
147 |
+
|
148 |
+
self.encoder = torchvision.models.resnet50(pretrained=True)
|
149 |
+
self.encoder.fc = nn.Sequential(
|
150 |
+
Bottleneck_Conv(2048),
|
151 |
+
Bottleneck_Conv(2048),
|
152 |
+
nn.Linear(2048, opt.arch.latent_dim)
|
153 |
+
)
|
154 |
+
|
155 |
+
# define hooks
|
156 |
+
self.seen_feature = None
|
157 |
+
def feature_hook(model, input, output):
|
158 |
+
self.seen_feature = output
|
159 |
+
|
160 |
+
# attach hooks
|
161 |
+
assert opt.arch.depth.dsp == 1
|
162 |
+
if (opt.arch.win_size) == 16:
|
163 |
+
self.encoder.layer3.register_forward_hook(feature_hook)
|
164 |
+
self.depth_feat_proj = nn.Sequential(
|
165 |
+
Bottleneck_Conv(1024),
|
166 |
+
Bottleneck_Conv(1024),
|
167 |
+
nn.Conv2d(1024, opt.arch.latent_dim, 1)
|
168 |
+
)
|
169 |
+
elif (opt.arch.win_size) == 32:
|
170 |
+
self.encoder.layer4.register_forward_hook(feature_hook)
|
171 |
+
self.depth_feat_proj = nn.Sequential(
|
172 |
+
Bottleneck_Conv(2048),
|
173 |
+
Bottleneck_Conv(2048),
|
174 |
+
nn.Conv2d(2048, opt.arch.latent_dim, 1)
|
175 |
+
)
|
176 |
+
else:
|
177 |
+
print('Make sure win_size is 16 or 32 when using resnet backbone!')
|
178 |
+
raise NotImplementedError
|
179 |
+
|
180 |
+
def forward(self, coord_obj, mask_obj):
|
181 |
+
batch_size = coord_obj.shape[0]
|
182 |
+
assert len(coord_obj.shape) == len(mask_obj.shape) == 4
|
183 |
+
mask_obj = mask_obj.float()
|
184 |
+
coord_obj = coord_obj * mask_obj
|
185 |
+
|
186 |
+
# [B, 1, C]
|
187 |
+
global_feat = self.encoder(coord_obj).unsqueeze(1)
|
188 |
+
# [B, C, H/ws*W/ws]
|
189 |
+
local_feat = self.depth_feat_proj(self.seen_feature).view(batch_size, global_feat.shape[-1], -1)
|
190 |
+
# [B, H/ws*W/ws, C]
|
191 |
+
local_feat = local_feat.permute(0, 2, 1).contiguous()
|
192 |
+
# [B, 1+H/ws*W/ws, C]
|
193 |
+
seen_embedding = torch.cat([global_feat, local_feat], dim=1)
|
194 |
+
|
195 |
+
return seen_embedding
|
model/shape_engine.py
ADDED
@@ -0,0 +1,598 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os, time, datetime
|
3 |
+
import torch
|
4 |
+
import torch.utils.tensorboard
|
5 |
+
import torch.profiler
|
6 |
+
import importlib
|
7 |
+
import shutil
|
8 |
+
import utils.util as util
|
9 |
+
import utils.util_vis as util_vis
|
10 |
+
import utils.eval_3D as eval_3D
|
11 |
+
|
12 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
13 |
+
from utils.util import print_eval, setup, cleanup
|
14 |
+
from utils.util import EasyDict as edict
|
15 |
+
from copy import deepcopy
|
16 |
+
from model.compute_graph import graph_shape
|
17 |
+
|
18 |
+
# ============================ main engine for training and evaluation ============================
|
19 |
+
|
20 |
+
class Runner():
|
21 |
+
|
22 |
+
def __init__(self, opt):
|
23 |
+
super().__init__()
|
24 |
+
if os.path.isdir(opt.output_path) and opt.resume == False and opt.device == 0:
|
25 |
+
for filename in os.listdir(opt.output_path):
|
26 |
+
if "tfevents" in filename: os.remove(os.path.join(opt.output_path, filename))
|
27 |
+
if "html" in filename: os.remove(os.path.join(opt.output_path, filename))
|
28 |
+
if "vis" in filename: shutil.rmtree(os.path.join(opt.output_path, filename))
|
29 |
+
if "embedding" in filename: shutil.rmtree(os.path.join(opt.output_path, filename))
|
30 |
+
if opt.device == 0:
|
31 |
+
os.makedirs(opt.output_path,exist_ok=True)
|
32 |
+
setup(opt.device, opt.world_size, opt.port)
|
33 |
+
opt.batch_size = opt.batch_size // opt.world_size
|
34 |
+
|
35 |
+
def get_viz_data(self, opt):
|
36 |
+
# get data for visualization
|
37 |
+
viz_data_list = []
|
38 |
+
sample_range = len(self.viz_loader)
|
39 |
+
viz_interval = sample_range // opt.eval.n_vis
|
40 |
+
for i in range(sample_range):
|
41 |
+
current_batch = next(self.viz_loader_iter)
|
42 |
+
if i % viz_interval != 0: continue
|
43 |
+
viz_data_list.append(current_batch)
|
44 |
+
return viz_data_list
|
45 |
+
|
46 |
+
def load_dataset(self, opt, eval_split="test"):
|
47 |
+
data_train = importlib.import_module('data.{}'.format(opt.data.dataset_train))
|
48 |
+
data_test = importlib.import_module('data.{}'.format(opt.data.dataset_test))
|
49 |
+
if opt.device == 0: print("loading training data...")
|
50 |
+
self.train_data = data_train.Dataset(opt, split="train")
|
51 |
+
self.train_loader = self.train_data.setup_loader(opt, shuffle=True, use_ddp=True, drop_last=True)
|
52 |
+
self.num_batches = len(self.train_loader)
|
53 |
+
if opt.device == 0: print("loading test data...")
|
54 |
+
self.test_data = data_test.Dataset(opt, split=eval_split)
|
55 |
+
self.test_loader = self.test_data.setup_loader(opt, shuffle=False, use_ddp=True, drop_last=True, batch_size=opt.eval.batch_size)
|
56 |
+
self.num_batches_test = len(self.test_loader)
|
57 |
+
if len(self.test_loader.sampler) * opt.world_size < len(self.test_data):
|
58 |
+
self.aux_test_dataset = torch.utils.data.Subset(self.test_data,
|
59 |
+
range(len(self.test_loader.sampler) * opt.world_size, len(self.test_data)))
|
60 |
+
self.aux_test_loader = torch.utils.data.DataLoader(
|
61 |
+
self.aux_test_dataset, batch_size=opt.eval.batch_size, shuffle=False, drop_last=False,
|
62 |
+
num_workers=opt.data.num_workers)
|
63 |
+
if opt.device == 0:
|
64 |
+
print("creating data for visualization...")
|
65 |
+
self.viz_loader = self.test_data.setup_loader(opt, shuffle=False, use_ddp=False, drop_last=False, batch_size=1)
|
66 |
+
self.viz_loader_iter = iter(self.viz_loader)
|
67 |
+
self.viz_data = self.get_viz_data(opt)
|
68 |
+
|
69 |
+
def build_networks(self, opt):
|
70 |
+
if opt.device == 0: print("building networks...")
|
71 |
+
self.graph = DDP(graph_shape.Graph(opt).to(opt.device), device_ids=[opt.device], find_unused_parameters=(not opt.optim.fix_dpt or not opt.optim.fix_clip))
|
72 |
+
|
73 |
+
# =================================================== set up training =========================================================
|
74 |
+
|
75 |
+
def setup_optimizer(self, opt):
|
76 |
+
if opt.device == 0: print("setting up optimizers...")
|
77 |
+
if opt.optim.fix_dpt:
|
78 |
+
# when we do not need to train the dpt depth, every param will start from scratch
|
79 |
+
scratch_param_decay = []
|
80 |
+
scratch_param_nodecay = []
|
81 |
+
# loop over all params
|
82 |
+
for name, param in self.graph.named_parameters():
|
83 |
+
# skip and fixed params
|
84 |
+
if not param.requires_grad or 'dpt_depth' in name or 'intr_' in name:
|
85 |
+
continue
|
86 |
+
# do not add wd on bias or low-dim params
|
87 |
+
if param.ndim <= 1 or name.endswith(".bias"):
|
88 |
+
scratch_param_nodecay.append(param)
|
89 |
+
# print("{} -> scratch_param_nodecay".format(name))
|
90 |
+
else:
|
91 |
+
scratch_param_decay.append(param)
|
92 |
+
# print("{} -> scratch_param_decay".format(name))
|
93 |
+
# create the optim dictionary
|
94 |
+
optim_dict = [
|
95 |
+
{'params': scratch_param_nodecay, 'lr': opt.optim.lr, 'weight_decay': 0.},
|
96 |
+
{'params': scratch_param_decay, 'lr': opt.optim.lr, 'weight_decay': opt.optim.weight_decay}
|
97 |
+
]
|
98 |
+
else:
|
99 |
+
# when we need to train dpt as well, related params should go to finetune list
|
100 |
+
finetune_param_nodecay = []
|
101 |
+
scratch_param_nodecay = []
|
102 |
+
finetune_param_decay = []
|
103 |
+
scratch_param_decay = []
|
104 |
+
for name, param in self.graph.named_parameters():
|
105 |
+
# skip and fixed params
|
106 |
+
if not param.requires_grad:
|
107 |
+
continue
|
108 |
+
# put dpt params into finetune list
|
109 |
+
if 'dpt_depth' in name or 'intr_' in name:
|
110 |
+
if param.ndim <= 1 or name.endswith(".bias"):
|
111 |
+
# print("{} -> finetune_param_nodecay".format(name))
|
112 |
+
finetune_param_nodecay.append(param)
|
113 |
+
else:
|
114 |
+
finetune_param_decay.append(param)
|
115 |
+
# print("{} -> finetune_param_decay".format(name))
|
116 |
+
# all other params go to scratch list
|
117 |
+
else:
|
118 |
+
if param.ndim <= 1 or name.endswith(".bias"):
|
119 |
+
scratch_param_nodecay.append(param)
|
120 |
+
# print("{} -> scratch_param_nodecay".format(name))
|
121 |
+
else:
|
122 |
+
scratch_param_decay.append(param)
|
123 |
+
# print("{} -> scratch_param_decay".format(name))
|
124 |
+
# create the optim dictionary
|
125 |
+
optim_dict = [
|
126 |
+
{'params': finetune_param_nodecay, 'lr': opt.optim.lr_ft, 'weight_decay': 0.},
|
127 |
+
{'params': finetune_param_decay, 'lr': opt.optim.lr_ft, 'weight_decay': opt.optim.weight_decay},
|
128 |
+
{'params': scratch_param_nodecay, 'lr': opt.optim.lr, 'weight_decay': 0.},
|
129 |
+
{'params': scratch_param_decay, 'lr': opt.optim.lr, 'weight_decay': opt.optim.weight_decay}
|
130 |
+
]
|
131 |
+
|
132 |
+
self.optim = torch.optim.AdamW(optim_dict, betas=(0.9, 0.95))
|
133 |
+
if opt.optim.sched:
|
134 |
+
self.sched = torch.optim.lr_scheduler.CosineAnnealingLR(self.optim, opt.max_epoch)
|
135 |
+
if opt.optim.amp:
|
136 |
+
self.scaler = torch.cuda.amp.GradScaler()
|
137 |
+
|
138 |
+
def restore_checkpoint(self, opt, best=False, evaluate=False):
|
139 |
+
epoch_start, iter_start = None, None
|
140 |
+
if opt.resume:
|
141 |
+
if opt.device == 0: print("resuming from previous checkpoint...")
|
142 |
+
epoch_start, iter_start, best_val, best_ep = util.restore_checkpoint(opt, self, resume=opt.resume, best=best, evaluate=evaluate)
|
143 |
+
self.best_val = best_val
|
144 |
+
self.best_ep = best_ep
|
145 |
+
elif opt.load is not None:
|
146 |
+
if opt.device == 0: print("loading weights from checkpoint {}...".format(opt.load))
|
147 |
+
epoch_start, iter_start, best_val, best_ep = util.restore_checkpoint(opt, self, load_name=opt.load)
|
148 |
+
else:
|
149 |
+
if opt.device == 0: print("initializing weights from scratch...")
|
150 |
+
self.epoch_start = epoch_start or 0
|
151 |
+
self.iter_start = iter_start or 0
|
152 |
+
|
153 |
+
def setup_visualizer(self, opt, test=False):
|
154 |
+
if opt.device == 0:
|
155 |
+
print("setting up visualizers...")
|
156 |
+
if opt.tb:
|
157 |
+
if test == False:
|
158 |
+
self.tb = torch.utils.tensorboard.SummaryWriter(log_dir=opt.output_path, flush_secs=10)
|
159 |
+
else:
|
160 |
+
embedding_folder = os.path.join(opt.output_path, 'embedding')
|
161 |
+
os.makedirs(embedding_folder, exist_ok=True)
|
162 |
+
self.tb = torch.utils.tensorboard.SummaryWriter(log_dir=embedding_folder, flush_secs=10)
|
163 |
+
|
164 |
+
def train(self, opt):
|
165 |
+
# before training
|
166 |
+
torch.cuda.set_device(opt.device)
|
167 |
+
torch.cuda.empty_cache()
|
168 |
+
if opt.device == 0: print("TRAINING START")
|
169 |
+
self.train_metric_logger = util.MetricLogger(delimiter=" ")
|
170 |
+
self.train_metric_logger.add_meter('lr', util.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
171 |
+
self.iter_skip = self.iter_start % len(self.train_loader)
|
172 |
+
self.it = self.iter_start
|
173 |
+
self.skip_dis = False
|
174 |
+
if not opt.resume:
|
175 |
+
self.best_val = np.inf
|
176 |
+
self.best_ep = 1
|
177 |
+
# training
|
178 |
+
if self.iter_start == 0 and not opt.debug: self.evaluate(opt, ep=0, training=True)
|
179 |
+
for self.ep in range(self.epoch_start, opt.max_epoch):
|
180 |
+
self.train_epoch(opt)
|
181 |
+
# after training
|
182 |
+
if opt.device == 0: self.save_checkpoint(opt, ep=self.ep, it=self.it, best_val=self.best_val, best_ep=self.best_ep)
|
183 |
+
if opt.tb and opt.device == 0:
|
184 |
+
self.tb.flush()
|
185 |
+
self.tb.close()
|
186 |
+
if opt.device == 0:
|
187 |
+
print("TRAINING DONE")
|
188 |
+
print("Best CD: %.4f @ epoch %d" % (self.best_val, self.best_ep))
|
189 |
+
cleanup()
|
190 |
+
|
191 |
+
def train_epoch(self, opt):
|
192 |
+
# before train epoch
|
193 |
+
self.train_loader.sampler.set_epoch(self.ep)
|
194 |
+
if opt.device == 0:
|
195 |
+
print("training epoch {}".format(self.ep+1))
|
196 |
+
batch_progress = range(self.num_batches)
|
197 |
+
self.graph.train()
|
198 |
+
# train epoch
|
199 |
+
loader = iter(self.train_loader)
|
200 |
+
|
201 |
+
if opt.debug and opt.profile:
|
202 |
+
with torch.profiler.profile(
|
203 |
+
schedule=torch.profiler.schedule(wait=3, warmup=3, active=5, repeat=2),
|
204 |
+
on_trace_ready=torch.profiler.tensorboard_trace_handler('debug/profiler_log'),
|
205 |
+
record_shapes=True,
|
206 |
+
profile_memory=True,
|
207 |
+
with_stack=False
|
208 |
+
) as prof:
|
209 |
+
for batch_id in batch_progress:
|
210 |
+
if batch_id >= (1 + 1 + 3) * 2:
|
211 |
+
# exit the program after 2 iterations of the warmup, active, and repeat steps
|
212 |
+
exit()
|
213 |
+
|
214 |
+
# if resuming from previous checkpoint, skip until the last iteration number is reached
|
215 |
+
if self.iter_skip>0:
|
216 |
+
self.iter_skip -= 1
|
217 |
+
continue
|
218 |
+
batch = next(loader)
|
219 |
+
# train iteration
|
220 |
+
var = edict(batch)
|
221 |
+
opt.H, opt.W = opt.image_size
|
222 |
+
var = util.move_to_device(var, opt.device)
|
223 |
+
loss = self.train_iteration(opt, var, batch_progress)
|
224 |
+
prof.step()
|
225 |
+
else:
|
226 |
+
for batch_id in batch_progress:
|
227 |
+
# if resuming from previous checkpoint, skip until the last iteration number is reached
|
228 |
+
if self.iter_skip>0:
|
229 |
+
self.iter_skip -= 1
|
230 |
+
continue
|
231 |
+
batch = next(loader)
|
232 |
+
# train iteration
|
233 |
+
var = edict(batch)
|
234 |
+
opt.H, opt.W = opt.image_size
|
235 |
+
var = util.move_to_device(var, opt.device)
|
236 |
+
loss = self.train_iteration(opt, var, batch_progress)
|
237 |
+
|
238 |
+
# after train epoch
|
239 |
+
if opt.optim.sched: self.sched.step()
|
240 |
+
if (self.ep + 1) % opt.freq.eval == 0:
|
241 |
+
if opt.device == 0: print("validating epoch {}".format(self.ep+1))
|
242 |
+
current_val = self.evaluate(opt, ep=self.ep+1, training=True)
|
243 |
+
if current_val < self.best_val and opt.device == 0:
|
244 |
+
self.best_val = current_val
|
245 |
+
self.best_ep = self.ep + 1
|
246 |
+
self.save_checkpoint(opt, ep=self.ep, it=self.it, best_val=self.best_val, best_ep=self.best_ep, best=True, latest=True)
|
247 |
+
|
248 |
+
def train_iteration(self, opt, var, batch_progress):
|
249 |
+
# before train iteration
|
250 |
+
torch.distributed.barrier()
|
251 |
+
# train iteration
|
252 |
+
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=opt.optim.amp):
|
253 |
+
var, loss = self.graph.forward(opt, var, training=True, get_loss=True)
|
254 |
+
loss = self.summarize_loss(opt, var, loss)
|
255 |
+
loss_scaled = loss.all / opt.optim.accum
|
256 |
+
|
257 |
+
# backward
|
258 |
+
if opt.optim.amp:
|
259 |
+
self.scaler.scale(loss_scaled).backward()
|
260 |
+
# skip update if accumulating gradient
|
261 |
+
if (self.it + 1) % opt.optim.accum == 0:
|
262 |
+
self.scaler.unscale_(self.optim)
|
263 |
+
# gradient clipping
|
264 |
+
if opt.optim.clip_norm:
|
265 |
+
norm = torch.nn.utils.clip_grad_norm_(self.graph.parameters(), opt.optim.clip_norm)
|
266 |
+
if opt.debug: print("Grad norm: {}".format(norm))
|
267 |
+
self.scaler.step(self.optim)
|
268 |
+
self.scaler.update()
|
269 |
+
self.optim.zero_grad()
|
270 |
+
else:
|
271 |
+
loss_scaled.backward()
|
272 |
+
if (self.it + 1) % opt.optim.accum == 0:
|
273 |
+
if opt.optim.clip_norm:
|
274 |
+
norm = torch.nn.utils.clip_grad_norm_(self.graph.parameters(), opt.optim.clip_norm)
|
275 |
+
if opt.debug: print("Grad norm: {}".format(norm))
|
276 |
+
self.optim.step()
|
277 |
+
self.optim.zero_grad()
|
278 |
+
|
279 |
+
# after train iteration
|
280 |
+
lr = self.sched.get_last_lr()[0] if opt.optim.sched else opt.optim.lr
|
281 |
+
self.train_metric_logger.update(lr=lr)
|
282 |
+
self.train_metric_logger.update(loss=loss.all)
|
283 |
+
if opt.device == 0:
|
284 |
+
self.graph.eval()
|
285 |
+
# if (self.it) % opt.freq.vis == 0: self.visualize(opt, var, step=self.it, split="train")
|
286 |
+
if (self.it) % opt.freq.ckpt_latest == 0 and not opt.debug:
|
287 |
+
self.save_checkpoint(opt, ep=self.ep, it=self.it, best_val=self.best_val, best_ep=self.best_ep, latest=True)
|
288 |
+
if (self.it) % opt.freq.scalar == 0 and not opt.debug:
|
289 |
+
self.log_scalars(opt, var, loss, step=self.it, split="train")
|
290 |
+
if (self.it) % (opt.freq.save_vis * (self.it//10000*10+1)) == 0 and not opt.debug:
|
291 |
+
self.vis_train_iter(opt)
|
292 |
+
if (self.it) % opt.freq.print == 0:
|
293 |
+
print('[{}] '.format(datetime.datetime.now().time()), end='')
|
294 |
+
print(f'Train Iter {self.it}/{self.num_batches*opt.max_epoch}: {self.train_metric_logger}')
|
295 |
+
self.graph.train()
|
296 |
+
self.it += 1
|
297 |
+
return loss
|
298 |
+
|
299 |
+
@torch.no_grad()
|
300 |
+
def vis_train_iter(self, opt):
|
301 |
+
for i in range(len(self.viz_data)):
|
302 |
+
var_viz = edict(deepcopy(self.viz_data[i]))
|
303 |
+
var_viz = util.move_to_device(var_viz, opt.device)
|
304 |
+
var_viz = self.graph.module(opt, var_viz, training=False, get_loss=False)
|
305 |
+
eval_3D.eval_metrics(opt, var_viz, self.graph.module.impl_network, vis_only=True)
|
306 |
+
vis_folder = "vis_log/iter_{}".format(self.it)
|
307 |
+
os.makedirs("{}/{}".format(opt.output_path, vis_folder), exist_ok=True)
|
308 |
+
util_vis.dump_images(opt, var_viz.idx, "image_input", var_viz.rgb_input_map, masks=None, from_range=(0, 1), folder=vis_folder)
|
309 |
+
util_vis.dump_images(opt, var_viz.idx, "mask_input", var_viz.mask_input_map, folder=vis_folder)
|
310 |
+
util_vis.dump_meshes_viz(opt, var_viz.idx, "mesh_viz", var_viz.mesh_pred, folder=vis_folder)
|
311 |
+
if 'depth_pred' in var_viz:
|
312 |
+
util_vis.dump_depths(opt, var_viz.idx, "depth_est", var_viz.depth_pred, var_viz.mask_input_map, rescale=True, folder=vis_folder)
|
313 |
+
if 'depth_input_map' in var_viz:
|
314 |
+
util_vis.dump_depths(opt, var_viz.idx, "depth_input", var_viz.depth_input_map, var_viz.mask_input_map, rescale=True, folder=vis_folder)
|
315 |
+
if 'attn_vis' in var_viz:
|
316 |
+
util_vis.dump_attentions(opt, var_viz.idx, "attn", var_viz.attn_vis, folder=vis_folder)
|
317 |
+
if 'gt_surf_points' in var_viz and 'seen_points' in var_viz:
|
318 |
+
util_vis.dump_pointclouds_compare(opt, var_viz.idx, "seen_surface", var_viz.seen_points, var_viz.gt_surf_points, folder=vis_folder)
|
319 |
+
|
320 |
+
def summarize_loss(self, opt, var, loss, non_act_loss_key=[]):
|
321 |
+
loss_all = 0.
|
322 |
+
assert("all" not in loss)
|
323 |
+
# weigh losses
|
324 |
+
for key in loss:
|
325 |
+
assert(key in opt.loss_weight)
|
326 |
+
if opt.loss_weight[key] is not None:
|
327 |
+
assert not torch.isinf(loss[key].mean()), "loss {} is Inf".format(key)
|
328 |
+
assert not torch.isnan(loss[key].mean()), "loss {} is NaN".format(key)
|
329 |
+
loss_all += float(opt.loss_weight[key])*loss[key].mean() if key not in non_act_loss_key else 0.0*loss[key].mean()
|
330 |
+
loss.update(all=loss_all)
|
331 |
+
return loss
|
332 |
+
|
333 |
+
# =================================================== set up evaluation =========================================================
|
334 |
+
|
335 |
+
@torch.no_grad()
|
336 |
+
def evaluate(self, opt, ep, training=False):
|
337 |
+
self.graph.eval()
|
338 |
+
|
339 |
+
# lists for metrics
|
340 |
+
cd_accs = []
|
341 |
+
cd_comps = []
|
342 |
+
f_scores = []
|
343 |
+
cat_indices = []
|
344 |
+
loss_eval = edict()
|
345 |
+
metric_eval = dict(dist_acc=0., dist_cov=0.)
|
346 |
+
eval_metric_logger = util.MetricLogger(delimiter=" ")
|
347 |
+
|
348 |
+
# result file on the fly
|
349 |
+
if not training:
|
350 |
+
assert opt.device == 0
|
351 |
+
full_results_file = open(os.path.join(opt.output_path, '{}_full_results.txt'.format(opt.data.dataset_test)), 'w')
|
352 |
+
full_results_file.write("IND, CD, ACC, COMP, ")
|
353 |
+
full_results_file.write(", ".join(["F-score@{:.2f}".format(opt.eval.f_thresholds[i]*100) for i in range(len(opt.eval.f_thresholds))]))
|
354 |
+
|
355 |
+
# dataloader on the test set
|
356 |
+
with torch.cuda.device(opt.device):
|
357 |
+
for it, batch in enumerate(self.test_loader):
|
358 |
+
|
359 |
+
# inference the model
|
360 |
+
var = edict(batch)
|
361 |
+
var = self.evaluate_batch(opt, var, ep, it, single_gpu=False)
|
362 |
+
|
363 |
+
# record CD for evaluation
|
364 |
+
dist_acc, dist_cov = eval_3D.eval_metrics(opt, var, self.graph.module.impl_network)
|
365 |
+
|
366 |
+
# accumulate the scores
|
367 |
+
cd_accs.append(var.cd_acc)
|
368 |
+
cd_comps.append(var.cd_comp)
|
369 |
+
f_scores.append(var.f_score)
|
370 |
+
cat_indices.append(var.category_label)
|
371 |
+
eval_metric_logger.update(ACC=dist_acc)
|
372 |
+
eval_metric_logger.update(COMP=dist_cov)
|
373 |
+
eval_metric_logger.update(CD=(dist_acc+dist_cov) / 2)
|
374 |
+
|
375 |
+
if opt.device == 0 and it % opt.freq.print_eval == 0:
|
376 |
+
print('[{}] '.format(datetime.datetime.now().time()), end='')
|
377 |
+
print(f'Eval Iter {it}/{len(self.test_loader)} @ EP {ep}: {eval_metric_logger}')
|
378 |
+
|
379 |
+
# write to file
|
380 |
+
if not training:
|
381 |
+
full_results_file.write("\n")
|
382 |
+
full_results_file.write("{:d}".format(var.idx.item()))
|
383 |
+
full_results_file.write("\t{:.4f}".format((var.cd_acc.item() + var.cd_comp.item()) / 2))
|
384 |
+
full_results_file.write("\t{:.4f}".format(var.cd_acc.item()))
|
385 |
+
full_results_file.write("\t{:.4f}".format(var.cd_comp.item()))
|
386 |
+
full_results_file.write("\t" + "\t".join(["{:.4f}".format(var.f_score[0][i].item()) for i in range(len(opt.eval.f_thresholds))]))
|
387 |
+
full_results_file.flush()
|
388 |
+
|
389 |
+
# dump the result if in eval mode
|
390 |
+
if not training:
|
391 |
+
self.dump_results(opt, var, ep, write_new=(it == 0))
|
392 |
+
|
393 |
+
# save the predicted mesh for vis data if in train mode
|
394 |
+
if it == 0 and training and opt.device == 0:
|
395 |
+
print("visualizing and saving results...")
|
396 |
+
for i in range(len(self.viz_data)):
|
397 |
+
var_viz = edict(deepcopy(self.viz_data[i]))
|
398 |
+
var_viz = self.evaluate_batch(opt, var_viz, ep, it, single_gpu=True)
|
399 |
+
eval_3D.eval_metrics(opt, var_viz, self.graph.module.impl_network, vis_only=True)
|
400 |
+
# self.visualize(opt, var_viz, step=ep, split="eval")
|
401 |
+
self.dump_results(opt, var_viz, ep, train=True)
|
402 |
+
# write html that organizes the results
|
403 |
+
util_vis.create_gif_html(os.path.join(opt.output_path, "vis_{}".format(ep)),
|
404 |
+
os.path.join(opt.output_path, "results_ep{}.html".format(ep)),
|
405 |
+
skip_every=1)
|
406 |
+
|
407 |
+
# collect the eval results into tensors
|
408 |
+
cd_accs = torch.cat(cd_accs, dim=0)
|
409 |
+
cd_comps = torch.cat(cd_comps, dim=0)
|
410 |
+
f_scores = torch.cat(f_scores, dim=0)
|
411 |
+
cat_indices = torch.cat(cat_indices, dim=0)
|
412 |
+
|
413 |
+
if opt.world_size > 1:
|
414 |
+
# empty tensors for gathering
|
415 |
+
cd_accs_all = [torch.zeros_like(cd_accs).to(opt.device) for _ in range(opt.world_size)]
|
416 |
+
cd_comps_all = [torch.zeros_like(cd_comps).to(opt.device) for _ in range(opt.world_size)]
|
417 |
+
f_scores_all = [torch.zeros_like(f_scores).to(opt.device) for _ in range(opt.world_size)]
|
418 |
+
cat_indices_all = [torch.zeros_like(cat_indices).long().to(opt.device) for _ in range(opt.world_size)]
|
419 |
+
|
420 |
+
# gather the metrics
|
421 |
+
torch.distributed.barrier()
|
422 |
+
torch.distributed.all_gather(cd_accs_all, cd_accs)
|
423 |
+
torch.distributed.all_gather(cd_comps_all, cd_comps)
|
424 |
+
torch.distributed.all_gather(f_scores_all, f_scores)
|
425 |
+
torch.distributed.all_gather(cat_indices_all, cat_indices)
|
426 |
+
cd_accs_all = torch.cat(cd_accs_all, dim=0)
|
427 |
+
cd_comps_all = torch.cat(cd_comps_all, dim=0)
|
428 |
+
f_scores_all = torch.cat(f_scores_all, dim=0)
|
429 |
+
cat_indices_all = torch.cat(cat_indices_all, dim=0)
|
430 |
+
else:
|
431 |
+
cd_accs_all = cd_accs
|
432 |
+
cd_comps_all = cd_comps
|
433 |
+
f_scores_all = f_scores
|
434 |
+
cat_indices_all = cat_indices
|
435 |
+
# handle last batch, if any
|
436 |
+
if len(self.test_loader.sampler) * opt.world_size < len(self.test_data):
|
437 |
+
cd_accs_all = [cd_accs_all]
|
438 |
+
cd_comps_all = [cd_comps_all]
|
439 |
+
f_scores_all = [f_scores_all]
|
440 |
+
cat_indices_all = [cat_indices_all]
|
441 |
+
for batch in self.aux_test_loader:
|
442 |
+
# inference the model
|
443 |
+
var = edict(batch)
|
444 |
+
var = self.evaluate_batch(opt, var, ep, it, single_gpu=False)
|
445 |
+
|
446 |
+
# record CD for evaluation
|
447 |
+
dist_acc, dist_cov = eval_3D.eval_metrics(opt, var, self.graph.module.impl_network)
|
448 |
+
# accumulate the scores
|
449 |
+
cd_accs_all.append(var.cd_acc)
|
450 |
+
cd_comps_all.append(var.cd_comp)
|
451 |
+
f_scores_all.append(var.f_score)
|
452 |
+
cat_indices_all.append(var.category_label)
|
453 |
+
|
454 |
+
# dump the result if in eval mode
|
455 |
+
if not training and opt.device == 0:
|
456 |
+
self.dump_results(opt, var, ep, write_new=(it == 0))
|
457 |
+
|
458 |
+
cd_accs_all = torch.cat(cd_accs_all, dim=0)
|
459 |
+
cd_comps_all = torch.cat(cd_comps_all, dim=0)
|
460 |
+
f_scores_all = torch.cat(f_scores_all, dim=0)
|
461 |
+
cat_indices_all = torch.cat(cat_indices_all, dim=0)
|
462 |
+
|
463 |
+
assert cd_accs_all.shape[0] == len(self.test_data)
|
464 |
+
if not training:
|
465 |
+
full_results_file.close()
|
466 |
+
# printout and save the metrics
|
467 |
+
if opt.device == 0:
|
468 |
+
metric_eval["dist_acc"] = cd_accs_all.mean()
|
469 |
+
metric_eval["dist_cov"] = cd_comps_all.mean()
|
470 |
+
|
471 |
+
# print eval info
|
472 |
+
print_eval(opt, loss=None, chamfer=(metric_eval["dist_acc"],
|
473 |
+
metric_eval["dist_cov"]))
|
474 |
+
val_metric = (metric_eval["dist_acc"] + metric_eval["dist_cov"]) / 2
|
475 |
+
|
476 |
+
if training:
|
477 |
+
# log/visualize results to tb/vis
|
478 |
+
self.log_scalars(opt, var, loss_eval, metric=metric_eval, step=ep, split="eval")
|
479 |
+
|
480 |
+
if not training:
|
481 |
+
# save the per-cat evaluation metrics if on shapenet
|
482 |
+
per_cat_cd_file = os.path.join(opt.output_path, 'cd_cat.txt')
|
483 |
+
with open(per_cat_cd_file, "w") as outfile:
|
484 |
+
outfile.write("CD Acc Comp Count Cat\n")
|
485 |
+
for i in range(opt.data.num_classes_test):
|
486 |
+
if (cat_indices_all==i).sum() == 0:
|
487 |
+
continue
|
488 |
+
acc_i = cd_accs_all[cat_indices_all==i].mean().item()
|
489 |
+
comp_i = cd_comps_all[cat_indices_all==i].mean().item()
|
490 |
+
counts_cat = torch.sum(cat_indices_all==i)
|
491 |
+
cd_i = (acc_i + comp_i) / 2
|
492 |
+
outfile.write("%.4f %.4f %.4f %5d %s\n" % (cd_i, acc_i, comp_i, counts_cat, self.test_data.label2cat[i]))
|
493 |
+
|
494 |
+
# print f_scores
|
495 |
+
f_scores_avg = f_scores_all.mean(dim=0)
|
496 |
+
print('##############################')
|
497 |
+
for i in range(len(opt.eval.f_thresholds)):
|
498 |
+
print('F-score @ %.2f: %.4f' % (opt.eval.f_thresholds[i]*100, f_scores_avg[i].item()))
|
499 |
+
print('##############################')
|
500 |
+
|
501 |
+
# write to file
|
502 |
+
result_file = os.path.join(opt.output_path, 'quantitative_{}.txt'.format(opt.data.dataset_test))
|
503 |
+
with open(result_file, "w") as outfile:
|
504 |
+
outfile.write('CD Acc Comp \n')
|
505 |
+
outfile.write('%.4f %.4f %.4f\n' % (val_metric, metric_eval["dist_acc"], metric_eval["dist_cov"]))
|
506 |
+
for i in range(len(opt.eval.f_thresholds)):
|
507 |
+
outfile.write('F-score @ %.2f: %.4f\n' % (opt.eval.f_thresholds[i]*100, f_scores_avg[i].item()))
|
508 |
+
|
509 |
+
# write html that organizes the results
|
510 |
+
util_vis.create_gif_html(os.path.join(opt.output_path, "dump_{}".format(opt.data.dataset_test)),
|
511 |
+
os.path.join(opt.output_path, "results_test.html"), skip_every=10)
|
512 |
+
|
513 |
+
# torch.cuda.empty_cache()
|
514 |
+
return val_metric.item()
|
515 |
+
return float('inf')
|
516 |
+
|
517 |
+
def evaluate_batch(self, opt, var, ep=None, it=None, single_gpu=False):
|
518 |
+
var = util.move_to_device(var, opt.device)
|
519 |
+
if single_gpu:
|
520 |
+
var = self.graph.module(opt, var, training=False, get_loss=False)
|
521 |
+
else:
|
522 |
+
var = self.graph(opt, var, training=False, get_loss=False)
|
523 |
+
return var
|
524 |
+
|
525 |
+
@torch.no_grad()
|
526 |
+
def log_scalars(self, opt, var, loss, metric=None, step=0, split="train"):
|
527 |
+
if split=="train":
|
528 |
+
dist_acc, dist_cov = eval_3D.eval_metrics(opt, var, self.graph.module.impl_network)
|
529 |
+
metric = dict(dist_acc=dist_acc, dist_cov=dist_cov)
|
530 |
+
for key, value in loss.items():
|
531 |
+
if key=="all": continue
|
532 |
+
self.tb.add_scalar("{0}/loss_{1}".format(split, key), value.mean(), step)
|
533 |
+
if metric is not None:
|
534 |
+
for key, value in metric.items():
|
535 |
+
self.tb.add_scalar("{0}/{1}".format(split, key), value, step)
|
536 |
+
# log the attention average values
|
537 |
+
if 'attn_geo_avg' in var:
|
538 |
+
self.tb.add_scalar("{0}/attn_geo_avg".format(split), var.attn_geo_avg, step)
|
539 |
+
if 'attn_geo_seen' in var:
|
540 |
+
self.tb.add_scalar("{0}/attn_geo_seen".format(split), var.attn_geo_seen, step)
|
541 |
+
if 'attn_geo_occl' in var:
|
542 |
+
self.tb.add_scalar("{0}/attn_geo_occl".format(split), var.attn_geo_occl, step)
|
543 |
+
if 'attn_geo_bg' in var:
|
544 |
+
self.tb.add_scalar("{0}/attn_geo_bg".format(split), var.attn_geo_bg, step)
|
545 |
+
|
546 |
+
@torch.no_grad()
|
547 |
+
def visualize(self, opt, var, step=0, split="train"):
|
548 |
+
if 'pose_input' in var:
|
549 |
+
pose_input = var.pose_input
|
550 |
+
elif 'pose_gt' in var:
|
551 |
+
pose_input = var.pose_gt
|
552 |
+
else:
|
553 |
+
pose_input = None
|
554 |
+
util_vis.tb_image(opt, self.tb, step, split, "image_input_map", var.rgb_input_map, masks=None, from_range=(0, 1), poses=pose_input)
|
555 |
+
util_vis.tb_image(opt, self.tb, step, split, "image_input_map_est", var.rgb_input_map, masks=None, from_range=(0, 1),
|
556 |
+
poses=var.pose_pred if 'pose_pred' in var else var.pose)
|
557 |
+
util_vis.tb_image(opt, self.tb, step, split, "mask_input_map", var.mask_input_map)
|
558 |
+
if 'depth_pred' in var:
|
559 |
+
util_vis.tb_image(opt, self.tb, step, split, "depth_est_map", var.depth_pred)
|
560 |
+
if 'depth_input_map' in var:
|
561 |
+
util_vis.tb_image(opt, self.tb, step, split, "depth_input_map", var.depth_input_map)
|
562 |
+
|
563 |
+
@torch.no_grad()
|
564 |
+
def dump_results(self, opt, var, ep, write_new=False, train=False):
|
565 |
+
# create the dir
|
566 |
+
current_folder = "dump_{}".format(opt.data.dataset_test) if train == False else "vis_{}".format(ep)
|
567 |
+
os.makedirs("{}/{}/".format(opt.output_path, current_folder), exist_ok=True)
|
568 |
+
|
569 |
+
# save the results
|
570 |
+
if 'pose_input' in var:
|
571 |
+
pose_input = var.pose_input
|
572 |
+
elif 'pose_gt' in var:
|
573 |
+
pose_input = var.pose_gt
|
574 |
+
else:
|
575 |
+
pose_input = None
|
576 |
+
util_vis.dump_images(opt, var.idx, "image_input", var.rgb_input_map, masks=None, from_range=(0, 1), poses=pose_input, folder=current_folder)
|
577 |
+
util_vis.dump_images(opt, var.idx, "mask_input", var.mask_input_map, folder=current_folder)
|
578 |
+
util_vis.dump_meshes(opt, var.idx, "mesh", var.mesh_pred, folder=current_folder)
|
579 |
+
util_vis.dump_meshes_viz(opt, var.idx, "mesh_viz", var.mesh_pred, folder=current_folder) # image frames + gifs
|
580 |
+
if 'depth_pred' in var:
|
581 |
+
util_vis.dump_depths(opt, var.idx, "depth_est", var.depth_pred, var.mask_input_map, rescale=True, folder=current_folder)
|
582 |
+
if 'depth_input_map' in var:
|
583 |
+
util_vis.dump_depths(opt, var.idx, "depth_input", var.depth_input_map, var.mask_input_map, rescale=True, folder=current_folder)
|
584 |
+
if 'gt_surf_points' in var and 'seen_points' in var:
|
585 |
+
util_vis.dump_pointclouds_compare(opt, var.idx, "seen_surface", var.seen_points, var.gt_surf_points, folder=current_folder)
|
586 |
+
if 'attn_vis' in var:
|
587 |
+
util_vis.dump_attentions(opt, var.idx, "attn", var.attn_vis, folder=current_folder)
|
588 |
+
if 'attn_pc' in var:
|
589 |
+
util_vis.dump_pointclouds(opt, var.idx, "attn_pc", var.attn_pc["points"], var.attn_pc["colors"], folder=current_folder)
|
590 |
+
if 'dpc' in var:
|
591 |
+
util_vis.dump_pointclouds_compare(opt, var.idx, "pointclouds_comp", var.dpc_pred, var.dpc.points, folder=current_folder)
|
592 |
+
|
593 |
+
def save_checkpoint(self, opt, ep=0, it=0, best_val=np.inf, best_ep=1, latest=False, best=False):
|
594 |
+
util.save_checkpoint(opt, self, ep=ep, it=it, best_val=best_val, best_ep=best_ep, latest=latest, best=best)
|
595 |
+
if not latest:
|
596 |
+
print("checkpoint saved: ({0}) {1}, epoch {2} (iteration {3})".format(opt.group, opt.name, ep, it))
|
597 |
+
if best:
|
598 |
+
print("Saving the current model as the best...")
|
options/depth.yaml
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
group: depth
|
2 |
+
name: depth_est
|
3 |
+
load:
|
4 |
+
|
5 |
+
batch_size: 44
|
6 |
+
debug: false
|
7 |
+
image_size: [224,224]
|
8 |
+
gpu: 0
|
9 |
+
max_epoch: 15
|
10 |
+
output_root: output
|
11 |
+
resume: false
|
12 |
+
seed: 0
|
13 |
+
yaml:
|
14 |
+
|
15 |
+
arch:
|
16 |
+
depth:
|
17 |
+
pretrained: model/depth/pretrained_weights/omnidata_dpt_depth_v2.ckpt
|
18 |
+
|
19 |
+
eval:
|
20 |
+
batch_size: 44
|
21 |
+
n_vis: 50
|
22 |
+
depth_cap:
|
23 |
+
d_thresholds: [1.02,1.05,1.1,1.2]
|
24 |
+
|
25 |
+
data:
|
26 |
+
num_classes_test: 15
|
27 |
+
max_img_cat:
|
28 |
+
dataset_train: synthetic
|
29 |
+
dataset_test: synthetic
|
30 |
+
num_workers: 6
|
31 |
+
bgcolor: 1
|
32 |
+
pix3d:
|
33 |
+
cat:
|
34 |
+
ocrtoc:
|
35 |
+
cat:
|
36 |
+
erode_mask: 10
|
37 |
+
synthetic:
|
38 |
+
subset: objaverse_LVIS,ShapeNet55
|
39 |
+
percentage: 1
|
40 |
+
train_sub:
|
41 |
+
val_sub:
|
42 |
+
|
43 |
+
training:
|
44 |
+
n_sdf_points: 4096
|
45 |
+
depth_loss:
|
46 |
+
grad_reg: 0.1
|
47 |
+
depth_inv: true
|
48 |
+
mask_shrink: false
|
49 |
+
|
50 |
+
loss_weight:
|
51 |
+
depth: 1
|
52 |
+
intr: 10
|
53 |
+
|
54 |
+
optim:
|
55 |
+
lr: 3.e-5
|
56 |
+
weight_decay: 0.05
|
57 |
+
clip_norm:
|
58 |
+
amp: false
|
59 |
+
accum: 1
|
60 |
+
sched: false
|
61 |
+
|
62 |
+
tb:
|
63 |
+
num_images: [4,8]
|
64 |
+
|
65 |
+
freq:
|
66 |
+
print: 200
|
67 |
+
print_eval: 100
|
68 |
+
scalar: 1000 # iterations
|
69 |
+
vis: 1000 # iterations
|
70 |
+
save_vis: 1000
|
71 |
+
ckpt_latest: 1000 # iterations
|
72 |
+
eval: 1
|
options/shape.yaml
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
group: shape
|
2 |
+
name: shape_recon
|
3 |
+
load:
|
4 |
+
|
5 |
+
batch_size: 28
|
6 |
+
debug: false
|
7 |
+
profile: false
|
8 |
+
image_size: [224,224]
|
9 |
+
gpu: 0
|
10 |
+
max_epoch: 15
|
11 |
+
output_root: output
|
12 |
+
resume: false
|
13 |
+
seed: 0
|
14 |
+
yaml:
|
15 |
+
|
16 |
+
pretrain:
|
17 |
+
depth: weights/depth.ckpt
|
18 |
+
|
19 |
+
arch:
|
20 |
+
# general
|
21 |
+
num_heads: 8
|
22 |
+
latent_dim: 256
|
23 |
+
win_size: 16
|
24 |
+
# depth
|
25 |
+
depth:
|
26 |
+
encoder: resnet
|
27 |
+
n_blocks: 12
|
28 |
+
dsp: 2
|
29 |
+
pretrained: model/depth/pretrained_weights/omnidata_dpt_depth_v2.ckpt
|
30 |
+
# rgb
|
31 |
+
rgb:
|
32 |
+
encoder:
|
33 |
+
n_blocks: 12
|
34 |
+
# implicit
|
35 |
+
impl:
|
36 |
+
n_channels: 256
|
37 |
+
# attention-related
|
38 |
+
att_blocks: 2
|
39 |
+
mlp_ratio: 4.
|
40 |
+
posenc_perlayer: false
|
41 |
+
# mlp-related
|
42 |
+
mlp_layers: 8
|
43 |
+
posenc_3D: 0
|
44 |
+
skip_in: [2,4,6]
|
45 |
+
|
46 |
+
eval:
|
47 |
+
batch_size: 2
|
48 |
+
brute_force: false
|
49 |
+
n_vis: 50
|
50 |
+
vox_res: 64
|
51 |
+
num_points: 10000
|
52 |
+
range: [-1.5,1.5]
|
53 |
+
icp: false
|
54 |
+
f_thresholds: [0.005, 0.01, 0.02, 0.05, 0.1, 0.2]
|
55 |
+
|
56 |
+
data:
|
57 |
+
num_classes_test: 15
|
58 |
+
max_img_cat:
|
59 |
+
dataset_train: synthetic
|
60 |
+
dataset_test: synthetic
|
61 |
+
num_workers: 6
|
62 |
+
bgcolor: 1
|
63 |
+
pix3d:
|
64 |
+
cat:
|
65 |
+
ocrtoc:
|
66 |
+
cat:
|
67 |
+
erode_mask:
|
68 |
+
synthetic:
|
69 |
+
subset: objaverse_LVIS,ShapeNet55
|
70 |
+
percentage: 1
|
71 |
+
train_sub:
|
72 |
+
val_sub:
|
73 |
+
|
74 |
+
training:
|
75 |
+
n_sdf_points: 4096
|
76 |
+
shape_loss:
|
77 |
+
impt_weight: 1
|
78 |
+
impt_thres: 0.01
|
79 |
+
depth_loss:
|
80 |
+
grad_reg: 0.1
|
81 |
+
depth_inv: true
|
82 |
+
mask_shrink: false
|
83 |
+
|
84 |
+
loss_weight:
|
85 |
+
shape: 1
|
86 |
+
depth:
|
87 |
+
intr:
|
88 |
+
|
89 |
+
optim:
|
90 |
+
lr: 3.e-5
|
91 |
+
lr_ft: 1.e-5
|
92 |
+
weight_decay: 0.05
|
93 |
+
fix_dpt: false
|
94 |
+
fix_clip: true
|
95 |
+
clip_norm:
|
96 |
+
amp: false
|
97 |
+
accum: 1
|
98 |
+
sched: false
|
99 |
+
|
100 |
+
tb:
|
101 |
+
num_images: [4,8]
|
102 |
+
|
103 |
+
freq:
|
104 |
+
print: 200
|
105 |
+
print_eval: 100
|
106 |
+
scalar: 1000 # iterations
|
107 |
+
vis: 1000 # iterations
|
108 |
+
save_vis: 1000
|
109 |
+
ckpt_latest: 1000 # iterations
|
110 |
+
eval: 1
|
requirements.txt
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==2.0.0
|
2 |
+
anyio==4.2.0
|
3 |
+
attrs==23.2.0
|
4 |
+
cachetools==5.3.2
|
5 |
+
certifi==2023.11.17
|
6 |
+
chardet==5.2.0
|
7 |
+
charset-normalizer==3.3.2
|
8 |
+
colorlog==6.8.0
|
9 |
+
contourpy==1.2.0
|
10 |
+
cycler==0.12.1
|
11 |
+
docopt==0.6.2
|
12 |
+
embreex==2.17.7.post4
|
13 |
+
exceptiongroup==1.2.0
|
14 |
+
filelock==3.13.1
|
15 |
+
fonttools==4.47.2
|
16 |
+
freetype-py==2.4.0
|
17 |
+
fsspec==2023.12.2
|
18 |
+
google-auth==2.26.2
|
19 |
+
google-auth-oauthlib==1.2.0
|
20 |
+
grpcio==1.60.0
|
21 |
+
h11==0.14.0
|
22 |
+
httpcore==1.0.2
|
23 |
+
httpx==0.26.0
|
24 |
+
huggingface-hub==0.20.2
|
25 |
+
idna==3.6
|
26 |
+
imageio==2.33.1
|
27 |
+
Jinja2==3.1.3
|
28 |
+
jsonschema==4.20.0
|
29 |
+
jsonschema-specifications==2023.12.1
|
30 |
+
kiwisolver==1.4.5
|
31 |
+
lxml==5.1.0
|
32 |
+
mapbox-earcut==1.0.1
|
33 |
+
Markdown==3.5.2
|
34 |
+
MarkupSafe==2.1.3
|
35 |
+
matplotlib==3.8.2
|
36 |
+
mpmath==1.3.0
|
37 |
+
networkx==3.2.1
|
38 |
+
ninja==1.11.1.1
|
39 |
+
numpy==1.26.3
|
40 |
+
nvidia-cublas-cu12==12.1.3.1
|
41 |
+
nvidia-cuda-cupti-cu12==12.1.105
|
42 |
+
nvidia-cuda-nvrtc-cu12==12.1.105
|
43 |
+
nvidia-cuda-runtime-cu12==12.1.105
|
44 |
+
nvidia-cudnn-cu12==8.9.2.26
|
45 |
+
nvidia-cufft-cu12==11.0.2.54
|
46 |
+
nvidia-curand-cu12==10.3.2.106
|
47 |
+
nvidia-cusolver-cu12==11.4.5.107
|
48 |
+
nvidia-cusparse-cu12==12.1.0.106
|
49 |
+
nvidia-nccl-cu12==2.18.1
|
50 |
+
nvidia-nvjitlink-cu12==12.3.101
|
51 |
+
nvidia-nvtx-cu12==12.1.105
|
52 |
+
oauthlib==3.2.2
|
53 |
+
opencv-python==4.9.0.80
|
54 |
+
packaging==23.2
|
55 |
+
pillow==10.2.0
|
56 |
+
pipreqs==0.4.13
|
57 |
+
protobuf==4.23.4
|
58 |
+
pyasn1==0.5.1
|
59 |
+
pyasn1-modules==0.3.0
|
60 |
+
pycollada==0.8
|
61 |
+
pyglet==2.0.10
|
62 |
+
PyMCubes==0.1.4
|
63 |
+
PyOpenGL==3.1.0
|
64 |
+
pyparsing==3.1.1
|
65 |
+
pyrender==0.1.45
|
66 |
+
python-dateutil==2.8.2
|
67 |
+
PyYAML==6.0.1
|
68 |
+
referencing==0.32.1
|
69 |
+
requests==2.31.0
|
70 |
+
requests-oauthlib==1.3.1
|
71 |
+
rpds-py==0.16.2
|
72 |
+
rsa==4.9
|
73 |
+
Rtree==1.1.0
|
74 |
+
safetensors==0.4.1
|
75 |
+
scipy==1.11.4
|
76 |
+
shapely==2.0.2
|
77 |
+
six==1.16.0
|
78 |
+
sniffio==1.3.0
|
79 |
+
svg.path==6.3
|
80 |
+
sympy==1.12
|
81 |
+
tensorboard==2.15.1
|
82 |
+
tensorboard-data-server==0.7.2
|
83 |
+
timm==0.9.12
|
84 |
+
torch==2.1.2
|
85 |
+
torchvision==0.16.2
|
86 |
+
tqdm==4.66.1
|
87 |
+
trimesh==4.0.9
|
88 |
+
triton==2.1.0
|
89 |
+
typing_extensions==4.9.0
|
90 |
+
urllib3==2.1.0
|
91 |
+
vhacdx==0.0.5
|
92 |
+
Werkzeug==3.0.1
|
93 |
+
xxhash==3.4.1
|
94 |
+
yarg==0.1.9
|
95 |
+
rembg
|
utils/camera.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# partially from https://github.com/chenhsuanlin/signed-distance-SRN
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
class Pose():
|
7 |
+
# a pose class with util methods
|
8 |
+
def __call__(self, R=None, t=None):
|
9 |
+
assert(R is not None or t is not None)
|
10 |
+
if R is None:
|
11 |
+
if not isinstance(t, torch.Tensor): t = torch.tensor(t)
|
12 |
+
R = torch.eye(3, device=t.device).repeat(*t.shape[:-1], 1, 1)
|
13 |
+
elif t is None:
|
14 |
+
if not isinstance(R, torch.Tensor): R = torch.tensor(R)
|
15 |
+
t = torch.zeros(R.shape[:-1], device=R.device)
|
16 |
+
else:
|
17 |
+
if not isinstance(R, torch.Tensor): R = torch.tensor(R)
|
18 |
+
if not isinstance(t, torch.Tensor): t = torch.tensor(t)
|
19 |
+
assert(R.shape[:-1]==t.shape and R.shape[-2:]==(3, 3))
|
20 |
+
R = R.float()
|
21 |
+
t = t.float()
|
22 |
+
pose = torch.cat([R, t[..., None]], dim=-1) # [..., 3, 4]
|
23 |
+
assert(pose.shape[-2:]==(3, 4))
|
24 |
+
return pose
|
25 |
+
|
26 |
+
def invert(self, pose, use_inverse=False):
|
27 |
+
R, t = pose[..., :3], pose[..., 3:]
|
28 |
+
R_inv = R.inverse() if use_inverse else R.transpose(-1, -2)
|
29 |
+
t_inv = (-R_inv@t)[..., 0]
|
30 |
+
pose_inv = self(R=R_inv, t=t_inv)
|
31 |
+
return pose_inv
|
32 |
+
|
33 |
+
def compose(self, pose_list):
|
34 |
+
# pose_new(x) = poseN(...(pose2(pose1(x)))...)
|
35 |
+
pose_new = pose_list[0]
|
36 |
+
for pose in pose_list[1:]:
|
37 |
+
pose_new = self.compose_pair(pose_new, pose)
|
38 |
+
return pose_new
|
39 |
+
|
40 |
+
def compose_pair(self, pose_a, pose_b):
|
41 |
+
# pose_new(x) = pose_b(pose_a(x))
|
42 |
+
R_a, t_a = pose_a[..., :3], pose_a[..., 3:]
|
43 |
+
R_b, t_b = pose_b[..., :3], pose_b[..., 3:]
|
44 |
+
R_new = R_b@R_a
|
45 |
+
t_new = (R_b@t_a+t_b)[..., 0]
|
46 |
+
pose_new = self(R=R_new, t=t_new)
|
47 |
+
return pose_new
|
48 |
+
|
49 |
+
pose = Pose()
|
50 |
+
|
51 |
+
# unit sphere normalization
|
52 |
+
def valid_norm_fac(seen_points, mask):
|
53 |
+
'''
|
54 |
+
seen_points: [B, H*W, 3]
|
55 |
+
mask: [B, 1, H, W], boolean
|
56 |
+
'''
|
57 |
+
# get valid points
|
58 |
+
batch_size = seen_points.shape[0]
|
59 |
+
# [B, H*W]
|
60 |
+
mask = mask.view(batch_size, seen_points.shape[1])
|
61 |
+
|
62 |
+
# get mean and variance by sample
|
63 |
+
means, max_dists = [], []
|
64 |
+
for b in range(batch_size):
|
65 |
+
# [N_valid, 3]
|
66 |
+
seen_points_valid = seen_points[b][mask[b]]
|
67 |
+
# [3]
|
68 |
+
xyz_mean = torch.mean(seen_points_valid, dim=0)
|
69 |
+
seen_points_valid_zmean = seen_points_valid - xyz_mean
|
70 |
+
# scalar
|
71 |
+
max_dist = torch.max(seen_points_valid_zmean.norm(dim=1))
|
72 |
+
means.append(xyz_mean)
|
73 |
+
max_dists.append(max_dist)
|
74 |
+
# [B, 3]
|
75 |
+
means = torch.stack(means, dim=0)
|
76 |
+
# [B]
|
77 |
+
max_dists = torch.stack(max_dists, dim=0)
|
78 |
+
return means, max_dists
|
79 |
+
|
80 |
+
def get_pixel_grid(opt, H, W):
|
81 |
+
y_range = torch.arange(H, dtype=torch.float32).to(opt.device)
|
82 |
+
x_range = torch.arange(W, dtype=torch.float32).to(opt.device)
|
83 |
+
Y, X = torch.meshgrid(y_range, x_range, indexing='ij')
|
84 |
+
Z = torch.ones_like(Y)
|
85 |
+
xyz_grid = torch.stack([X, Y, Z],dim=-1).view(-1,3)
|
86 |
+
return xyz_grid
|
87 |
+
|
88 |
+
def unproj_depth(opt, depth, intr):
|
89 |
+
'''
|
90 |
+
depth: [B, 1, H, W]
|
91 |
+
intr: [B, 3, 3]
|
92 |
+
'''
|
93 |
+
batch_size, _, H, W = depth.shape
|
94 |
+
assert opt.H == H == W
|
95 |
+
depth = depth.squeeze(1)
|
96 |
+
|
97 |
+
# [B, 3, 3]
|
98 |
+
K_inv = torch.linalg.inv(intr).float()
|
99 |
+
# [1, H*W,3]
|
100 |
+
pixel_grid = get_pixel_grid(opt, H, W).unsqueeze(0)
|
101 |
+
# [B, H*W,3]
|
102 |
+
pixel_grid = pixel_grid.repeat(batch_size, 1, 1)
|
103 |
+
# [B, 3, H*W]
|
104 |
+
ray_dirs = K_inv @ pixel_grid.permute(0, 2, 1).contiguous()
|
105 |
+
# [B, H*W, 3], in camera coordinates
|
106 |
+
seen_points = ray_dirs.permute(0, 2, 1).contiguous() * depth.view(batch_size, H*W, 1)
|
107 |
+
|
108 |
+
return seen_points
|
109 |
+
|
110 |
+
def to_hom(X):
|
111 |
+
'''
|
112 |
+
X: [B, N, 3]
|
113 |
+
Returns:
|
114 |
+
X_hom: [B, N, 4]
|
115 |
+
'''
|
116 |
+
X_hom = torch.cat([X, torch.ones_like(X[..., :1])], dim=-1)
|
117 |
+
return X_hom
|
118 |
+
|
119 |
+
def world2cam(X_world, pose):
|
120 |
+
'''
|
121 |
+
X_world: [B, N, 3]
|
122 |
+
pose: [B, 3, 4]
|
123 |
+
Returns:
|
124 |
+
X_cam: [B, N, 3]
|
125 |
+
'''
|
126 |
+
X_hom = to_hom(X_world)
|
127 |
+
X_cam = X_hom @ pose.transpose(-1, -2)
|
128 |
+
return X_cam
|
129 |
+
|
130 |
+
def cam2img(X_cam, cam_intr):
|
131 |
+
'''
|
132 |
+
X_cam: [B, N, 3]
|
133 |
+
cam_intr: [B, 3, 3]
|
134 |
+
Returns:
|
135 |
+
X_img: [B, N, 3]
|
136 |
+
'''
|
137 |
+
X_img = X_cam @ cam_intr.transpose(-1, -2)
|
138 |
+
return X_img
|
139 |
+
|
140 |
+
def proj_points(opt, points, intr, pose):
|
141 |
+
'''
|
142 |
+
points: [B, N, 3]
|
143 |
+
intr: [B, 3, 3]
|
144 |
+
pose: [B, 3, 4]
|
145 |
+
'''
|
146 |
+
# [B, N, 3]
|
147 |
+
points_cam = world2cam(points, pose)
|
148 |
+
# [B, N]
|
149 |
+
depth = points_cam[..., 2]
|
150 |
+
# [B, N, 3]
|
151 |
+
points_img = cam2img(points_cam, intr)
|
152 |
+
# [B, N, 2]
|
153 |
+
points_2D = points_img[..., :2] / points_img[..., 2:]
|
154 |
+
return points_2D, depth
|
155 |
+
|
156 |
+
def azim_to_rotation_matrix(azim, representation='angle'):
|
157 |
+
"""Azim is angle with vector +X, rotated in XZ plane"""
|
158 |
+
if representation == 'rad':
|
159 |
+
# [B, ]
|
160 |
+
cos, sin = torch.cos(azim), torch.sin(azim)
|
161 |
+
elif representation == 'angle':
|
162 |
+
# [B, ]
|
163 |
+
azim = azim * np.pi / 180
|
164 |
+
cos, sin = torch.cos(azim), torch.sin(azim)
|
165 |
+
elif representation == 'trig':
|
166 |
+
# [B, 2]
|
167 |
+
cos, sin = azim[:, 0], azim[:, 1]
|
168 |
+
R = torch.eye(3, device=azim.device)[None].repeat(len(azim), 1, 1)
|
169 |
+
zeros = torch.zeros(len(azim), device=azim.device)
|
170 |
+
R[:, 0, :] = torch.stack([cos, zeros, sin], dim=-1)
|
171 |
+
R[:, 2, :] = torch.stack([-sin, zeros, cos], dim=-1)
|
172 |
+
return R
|
173 |
+
|
174 |
+
def elev_to_rotation_matrix(elev, representation='angle'):
|
175 |
+
"""Angle with vector +Z in YZ plane"""
|
176 |
+
if representation == 'rad':
|
177 |
+
# [B, ]
|
178 |
+
cos, sin = torch.cos(elev), torch.sin(elev)
|
179 |
+
elif representation == 'angle':
|
180 |
+
# [B, ]
|
181 |
+
elev = elev * np.pi / 180
|
182 |
+
cos, sin = torch.cos(elev), torch.sin(elev)
|
183 |
+
elif representation == 'trig':
|
184 |
+
# [B, 2]
|
185 |
+
cos, sin = elev[:, 0], elev[:, 1]
|
186 |
+
R = torch.eye(3, device=elev.device)[None].repeat(len(elev), 1, 1)
|
187 |
+
R[:, 1, 1:] = torch.stack([cos, -sin], dim=-1)
|
188 |
+
R[:, 2, 1:] = torch.stack([sin, cos], dim=-1)
|
189 |
+
return R
|
190 |
+
|
191 |
+
def roll_to_rotation_matrix(roll, representation='angle'):
|
192 |
+
"""Angle with vector +X in XY plane"""
|
193 |
+
if representation == 'rad':
|
194 |
+
# [B, ]
|
195 |
+
cos, sin = torch.cos(roll), torch.sin(roll)
|
196 |
+
elif representation == 'angle':
|
197 |
+
# [B, ]
|
198 |
+
roll = roll * np.pi / 180
|
199 |
+
cos, sin = torch.cos(roll), torch.sin(roll)
|
200 |
+
elif representation == 'trig':
|
201 |
+
# [B, 2]
|
202 |
+
cos, sin = roll[:, 0], roll[:, 1]
|
203 |
+
R = torch.eye(3, device=roll.device)[None].repeat(len(roll), 1, 1)
|
204 |
+
R[:, 0, :2] = torch.stack([cos, sin], dim=-1)
|
205 |
+
R[:, 1, :2] = torch.stack([-sin, cos], dim=-1)
|
206 |
+
return R
|
207 |
+
|
208 |
+
def get_rotation_sphere(azim_sample=4, elev_sample=4, roll_sample=4, scales=[1.0], device='cuda'):
|
209 |
+
rotations = []
|
210 |
+
azim_range = [0, 360]
|
211 |
+
elev_range = [0, 360]
|
212 |
+
roll_range = [0, 360]
|
213 |
+
azims = np.linspace(azim_range[0], azim_range[1], num=azim_sample, endpoint=False)
|
214 |
+
elevs = np.linspace(elev_range[0], elev_range[1], num=elev_sample, endpoint=False)
|
215 |
+
rolls = np.linspace(roll_range[0], roll_range[1], num=roll_sample, endpoint=False)
|
216 |
+
for scale in scales:
|
217 |
+
for azim in azims:
|
218 |
+
for elev in elevs:
|
219 |
+
for roll in rolls:
|
220 |
+
Ry = azim_to_rotation_matrix(torch.tensor([azim]))
|
221 |
+
Rx = elev_to_rotation_matrix(torch.tensor([elev]))
|
222 |
+
Rz = roll_to_rotation_matrix(torch.tensor([roll]))
|
223 |
+
R_permute = torch.tensor([
|
224 |
+
[-1, 0, 0],
|
225 |
+
[0, 0, -1],
|
226 |
+
[0, -1, 0]
|
227 |
+
]).float().to(Ry.device).unsqueeze(0).expand_as(Ry)
|
228 |
+
R = scale * Rz@Rx@Ry@R_permute
|
229 |
+
rotations.append(R.to(device).float())
|
230 |
+
return torch.cat(rotations, dim=0)
|
utils/eval_3D.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import threading
|
4 |
+
import mcubes
|
5 |
+
import trimesh
|
6 |
+
from utils.util_vis import show_att_on_image
|
7 |
+
from utils.camera import get_rotation_sphere
|
8 |
+
|
9 |
+
@torch.no_grad()
|
10 |
+
def get_dense_3D_grid(opt, var, N=None):
|
11 |
+
batch_size = len(var.idx)
|
12 |
+
N = N or opt.eval.vox_res
|
13 |
+
# -0.6, 0.6
|
14 |
+
range_min, range_max = opt.eval.range
|
15 |
+
grid = torch.linspace(range_min, range_max, N+1, device=opt.device)
|
16 |
+
points_3D = torch.stack(torch.meshgrid(grid, grid, grid, indexing='ij'), dim=-1) # [N, N, N, 3]
|
17 |
+
# actually N+1 instead of N
|
18 |
+
points_3D = points_3D.repeat(batch_size, 1, 1, 1, 1) # [B, N, N, N, 3]
|
19 |
+
return points_3D
|
20 |
+
|
21 |
+
@torch.no_grad()
|
22 |
+
def compute_level_grid(opt, impl_network, latent_depth, latent_semantic, points_3D, images, vis_attn=False):
|
23 |
+
# needed for amp
|
24 |
+
latent_depth = latent_depth.to(torch.float32) if latent_depth is not None else None
|
25 |
+
latent_semantic = latent_semantic.to(torch.float32) if latent_semantic is not None else None
|
26 |
+
|
27 |
+
# process points in sliced way
|
28 |
+
batch_size = points_3D.shape[0]
|
29 |
+
N = points_3D.shape[1]
|
30 |
+
assert N == points_3D.shape[2] == points_3D.shape[3]
|
31 |
+
assert points_3D.shape[4] == 3
|
32 |
+
|
33 |
+
points_3D = points_3D.view(batch_size, N, N*N, 3)
|
34 |
+
occ = []
|
35 |
+
attn = []
|
36 |
+
for i in range(N):
|
37 |
+
# [B, N*N, 3]
|
38 |
+
points_slice = points_3D[:, i]
|
39 |
+
# [B, N*N, 3] -> [B, N*N], [B, N*N, 1+feat_res**2]
|
40 |
+
occ_slice, attn_slice = impl_network(latent_depth, latent_semantic, points_slice)
|
41 |
+
occ.append(occ_slice)
|
42 |
+
attn.append(attn_slice.detach())
|
43 |
+
# [B, N, N*N] -> [B, N, N, N]
|
44 |
+
occ = torch.stack(occ, dim=1).view(batch_size, N, N, N)
|
45 |
+
occ = torch.sigmoid(occ)
|
46 |
+
if vis_attn:
|
47 |
+
N_global = 1
|
48 |
+
feat_res = opt.H // opt.arch.win_size
|
49 |
+
attn = torch.stack(attn, dim=1).view(batch_size, N, N, N, N_global+feat_res**2)
|
50 |
+
# average along Z, [B, N, N, N_global+feat_res**2]
|
51 |
+
attn = torch.mean(attn, dim=3)
|
52 |
+
# [B, N, N, N_global] -> [B, N, N, 1]
|
53 |
+
attn_global = attn[:, :, :, :N_global].sum(dim=-1, keepdim=True)
|
54 |
+
# [B, N, N, feat_res, feat_res]
|
55 |
+
attn_local = attn[:, :, :, N_global:].view(batch_size, N, N, feat_res, feat_res)
|
56 |
+
# [B, N, N, feat_res, feat_res]
|
57 |
+
attn_vis = attn_global.unsqueeze(-1) + attn_local
|
58 |
+
# list of frame lists
|
59 |
+
images_vis = []
|
60 |
+
for b in range(batch_size):
|
61 |
+
images_vis_sample = []
|
62 |
+
for row in range(0, N, 8):
|
63 |
+
if row % 16 == 0:
|
64 |
+
col_range = range(0, N//8*8+1, 8)
|
65 |
+
else:
|
66 |
+
col_range = range(N//8*8, -1, -8)
|
67 |
+
for col in col_range:
|
68 |
+
# [feat_res, feat_res], x is col
|
69 |
+
attn_curr = attn_vis[b, col, row]
|
70 |
+
attn_curr = torch.nn.functional.interpolate(
|
71 |
+
attn_curr.unsqueeze(0).unsqueeze(0), size=(opt.H, opt.W),
|
72 |
+
mode='bilinear', align_corners=False
|
73 |
+
).squeeze(0).squeeze(0).cpu().numpy()
|
74 |
+
attn_curr /= attn_curr.max()
|
75 |
+
# [feat_res, feat_res, 3]
|
76 |
+
image_curr = images[b].permute(1, 2, 0).cpu().numpy()
|
77 |
+
# merge the image and the attention
|
78 |
+
images_vis_sample.append(show_att_on_image(image_curr, attn_curr))
|
79 |
+
images_vis.append(images_vis_sample)
|
80 |
+
return occ, images_vis if vis_attn else None
|
81 |
+
|
82 |
+
@torch.no_grad()
|
83 |
+
def standardize_pc(pc):
|
84 |
+
assert len(pc.shape) == 3
|
85 |
+
pc_mean = pc.mean(dim=1, keepdim=True)
|
86 |
+
pc_zmean = pc - pc_mean
|
87 |
+
origin_distance = (pc_zmean**2).sum(dim=2, keepdim=True).sqrt()
|
88 |
+
scale = torch.sqrt(torch.sum(origin_distance**2, dim=1, keepdim=True) / pc.shape[1])
|
89 |
+
pc_standardized = pc_zmean / (scale * 2)
|
90 |
+
return pc_standardized
|
91 |
+
|
92 |
+
@torch.no_grad()
|
93 |
+
def normalize_pc(pc):
|
94 |
+
assert len(pc.shape) == 3
|
95 |
+
pc_mean = pc.mean(dim=1, keepdim=True)
|
96 |
+
pc_zmean = pc - pc_mean
|
97 |
+
length_x = pc_zmean[:, :, 0].max(dim=-1)[0] - pc_zmean[:, :, 0].min(dim=-1)[0]
|
98 |
+
length_y = pc_zmean[:, :, 1].max(dim=-1)[0] - pc_zmean[:, :, 1].min(dim=-1)[0]
|
99 |
+
length_max = torch.stack([length_x, length_y], dim=-1).max(dim=-1)[0].unsqueeze(-1).unsqueeze(-1)
|
100 |
+
pc_normalized = pc_zmean / (length_max + 1.e-7)
|
101 |
+
return pc_normalized
|
102 |
+
|
103 |
+
def convert_to_explicit(opt, level_grids, isoval=0., to_pointcloud=False):
|
104 |
+
N = len(level_grids)
|
105 |
+
meshes = [None]*N
|
106 |
+
pointclouds = [None]*N if to_pointcloud else None
|
107 |
+
threads = [threading.Thread(target=convert_to_explicit_worker,
|
108 |
+
args=(opt, i, level_grids[i], isoval, meshes),
|
109 |
+
kwargs=dict(pointclouds=pointclouds),
|
110 |
+
daemon=False) for i in range(N)]
|
111 |
+
for t in threads: t.start()
|
112 |
+
for t in threads: t.join()
|
113 |
+
if to_pointcloud:
|
114 |
+
pointclouds = np.stack(pointclouds, axis=0)
|
115 |
+
return meshes, pointclouds
|
116 |
+
else: return meshes
|
117 |
+
|
118 |
+
def convert_to_explicit_worker(opt, i, level_vox_i, isoval, meshes, pointclouds=None):
|
119 |
+
# use marching cubes to convert implicit surface to mesh
|
120 |
+
vertices, faces = mcubes.marching_cubes(level_vox_i, isovalue=isoval)
|
121 |
+
assert(level_vox_i.shape[0]==level_vox_i.shape[1]==level_vox_i.shape[2])
|
122 |
+
S = level_vox_i.shape[0]
|
123 |
+
range_min, range_max = opt.eval.range
|
124 |
+
# marching cubes treat every cube as unit length
|
125 |
+
vertices = vertices/S*(range_max-range_min)+range_min
|
126 |
+
mesh = trimesh.Trimesh(vertices, faces)
|
127 |
+
meshes[i] = mesh
|
128 |
+
if pointclouds is not None:
|
129 |
+
# randomly sample on mesh to get uniform dense point cloud
|
130 |
+
if len(mesh.triangles)!=0:
|
131 |
+
points = mesh.sample(opt.eval.num_points)
|
132 |
+
else: points = np.zeros([opt.eval.num_points, 3])
|
133 |
+
pointclouds[i] = points
|
utils/eval_depth.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# based on https://gist.github.com/ranftlr/45f4c7ddeb1bbb88d606bc600cab6c8d
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
class DepthMetric:
|
6 |
+
def __init__(self, thresholds=[1.25, 1.25**2, 1.25**3], depth_cap=None, prediction_type='depth'):
|
7 |
+
self.thresholds = thresholds
|
8 |
+
self.depth_cap = depth_cap
|
9 |
+
self.metric_keys = self.get_metric_keys()
|
10 |
+
self.prediction_type = prediction_type
|
11 |
+
|
12 |
+
def compute_scale_and_shift(self, prediction, target, mask):
|
13 |
+
# system matrix: A = [[a_00, a_01], [a_10, a_11]]
|
14 |
+
a_00 = torch.sum(mask * prediction * prediction, (1, 2))
|
15 |
+
a_01 = torch.sum(mask * prediction, (1, 2))
|
16 |
+
a_11 = torch.sum(mask, (1, 2))
|
17 |
+
|
18 |
+
# right hand side: b = [b_0, b_1]
|
19 |
+
b_0 = torch.sum(mask * prediction * target, (1, 2))
|
20 |
+
b_1 = torch.sum(mask * target, (1, 2))
|
21 |
+
|
22 |
+
# solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b
|
23 |
+
x_0 = torch.zeros_like(b_0)
|
24 |
+
x_1 = torch.zeros_like(b_1)
|
25 |
+
|
26 |
+
det = a_00 * a_11 - a_01 * a_01
|
27 |
+
# A needs to be a positive definite matrix.
|
28 |
+
valid = det > 0
|
29 |
+
|
30 |
+
x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid]
|
31 |
+
x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid]
|
32 |
+
|
33 |
+
return x_0, x_1
|
34 |
+
|
35 |
+
def get_metric_keys(self):
|
36 |
+
metric_keys = []
|
37 |
+
for threshold in self.thresholds:
|
38 |
+
metric_keys.append('d>{}'.format(threshold))
|
39 |
+
metric_keys.append('rmse')
|
40 |
+
metric_keys.append('l1_err')
|
41 |
+
metric_keys.append('abs_rel')
|
42 |
+
return metric_keys
|
43 |
+
|
44 |
+
def compute_metrics(self, prediction, target, mask):
|
45 |
+
# check inputs
|
46 |
+
prediction = prediction.float()
|
47 |
+
target = target.float()
|
48 |
+
mask = mask.float()
|
49 |
+
assert prediction.shape == target.shape == mask.shape
|
50 |
+
assert len(prediction.shape) == 4
|
51 |
+
assert prediction.shape[1] == 1
|
52 |
+
assert prediction.dtype == target.dtype == mask.dtype == torch.float32
|
53 |
+
|
54 |
+
# process inputs
|
55 |
+
prediction = prediction.squeeze(1)
|
56 |
+
target = target.squeeze(1)
|
57 |
+
mask = (mask.squeeze(1) > 0.5).long()
|
58 |
+
|
59 |
+
# output dict
|
60 |
+
metrics = {}
|
61 |
+
|
62 |
+
# get the predicted disparity
|
63 |
+
prediction_disparity = torch.zeros_like(prediction)
|
64 |
+
if self.prediction_type == 'depth':
|
65 |
+
prediction_disparity[mask == 1] = 1.0 / (prediction[mask == 1] + 1.e-6)
|
66 |
+
elif self.prediction_type == 'disparity':
|
67 |
+
prediction_disparity[mask == 1] = prediction[mask == 1]
|
68 |
+
else:
|
69 |
+
raise ValueError('Unknown prediction type: {}'.format(self.prediction_type))
|
70 |
+
|
71 |
+
# transform predicted disparity to align with depth
|
72 |
+
target_disparity = torch.zeros_like(target)
|
73 |
+
target_disparity[mask == 1] = 1.0 / target[mask == 1]
|
74 |
+
scale, shift = self.compute_scale_and_shift(prediction_disparity, target_disparity, mask)
|
75 |
+
prediction_aligned = scale.view(-1, 1, 1) * prediction_disparity + shift.view(-1, 1, 1)
|
76 |
+
|
77 |
+
if self.depth_cap is not None:
|
78 |
+
disparity_cap = 1.0 / self.depth_cap
|
79 |
+
prediction_aligned[prediction_aligned < disparity_cap] = disparity_cap
|
80 |
+
|
81 |
+
prediciton_depth = 1.0 / prediction_aligned
|
82 |
+
|
83 |
+
# delta > threshold, [batch_size, ]
|
84 |
+
for threshold in self.thresholds:
|
85 |
+
err = torch.zeros_like(prediciton_depth, dtype=torch.float)
|
86 |
+
err[mask == 1] = torch.max(
|
87 |
+
prediciton_depth[mask == 1] / target[mask == 1],
|
88 |
+
target[mask == 1] / prediciton_depth[mask == 1],
|
89 |
+
)
|
90 |
+
err[mask == 1] = (err[mask == 1] > threshold).float()
|
91 |
+
metrics['d>{}'.format(threshold)] = torch.sum(err, (1, 2)) / torch.sum(mask, (1, 2))
|
92 |
+
|
93 |
+
# rmse, [batch_size, ]
|
94 |
+
rmse = torch.zeros_like(prediciton_depth, dtype=torch.float)
|
95 |
+
rmse[mask == 1] = (prediciton_depth[mask == 1] - target[mask == 1]) ** 2
|
96 |
+
rmse = torch.sum(rmse, (1, 2)) / torch.sum(mask, (1, 2))
|
97 |
+
metrics['rmse'] = torch.sqrt(rmse)
|
98 |
+
|
99 |
+
# l1 error, [batch_size, ]
|
100 |
+
l1_err = torch.zeros_like(prediciton_depth, dtype=torch.float)
|
101 |
+
l1_err[mask == 1] = torch.abs(prediciton_depth[mask == 1] - target[mask == 1])
|
102 |
+
metrics['l1_err'] = torch.sum(l1_err, (1, 2)) / torch.sum(mask, (1, 2))
|
103 |
+
|
104 |
+
# abs_rel, [batch_size, ]
|
105 |
+
abs_rel = torch.zeros_like(prediciton_depth, dtype=torch.float)
|
106 |
+
abs_rel[mask == 1] = torch.abs(prediciton_depth[mask == 1] - target[mask == 1]) / target[mask == 1]
|
107 |
+
metrics['abs_rel'] = torch.sum(abs_rel, (1, 2)) / torch.sum(mask, (1, 2))
|
108 |
+
|
109 |
+
return metrics, prediciton_depth.unsqueeze(1)
|
110 |
+
|
utils/layers.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from functools import partial
|
5 |
+
from timm.models.vision_transformer import Block
|
6 |
+
|
7 |
+
# 3D positional encoding, from https://github.com/bmild/nerf.
|
8 |
+
class Embedder:
|
9 |
+
def __init__(self, **kwargs):
|
10 |
+
self.kwargs = kwargs
|
11 |
+
self.create_embedding_fn()
|
12 |
+
|
13 |
+
def create_embedding_fn(self):
|
14 |
+
embed_fns = []
|
15 |
+
d = self.kwargs['input_dims']
|
16 |
+
out_dim = 0
|
17 |
+
if self.kwargs['include_input']:
|
18 |
+
embed_fns.append(lambda x: x)
|
19 |
+
out_dim += d
|
20 |
+
|
21 |
+
max_freq = self.kwargs['max_freq_log2']
|
22 |
+
N_freqs = self.kwargs['num_freqs']
|
23 |
+
|
24 |
+
if self.kwargs['log_sampling']:
|
25 |
+
freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)
|
26 |
+
else:
|
27 |
+
freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs)
|
28 |
+
|
29 |
+
for freq in freq_bands:
|
30 |
+
for p_fn in self.kwargs['periodic_fns']:
|
31 |
+
embed_fns.append(lambda x, p_fn=p_fn,
|
32 |
+
freq=freq: p_fn(x * freq))
|
33 |
+
out_dim += d
|
34 |
+
|
35 |
+
self.embed_fns = embed_fns
|
36 |
+
self.out_dim = out_dim
|
37 |
+
|
38 |
+
def embed(self, inputs):
|
39 |
+
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
|
40 |
+
|
41 |
+
def get_embedder(posenc_res, input_dims=3):
|
42 |
+
embed_kwargs = {
|
43 |
+
'include_input': True,
|
44 |
+
'input_dims': input_dims,
|
45 |
+
'max_freq_log2': posenc_res-1,
|
46 |
+
'num_freqs': posenc_res,
|
47 |
+
'log_sampling': True,
|
48 |
+
'periodic_fns': [torch.sin, torch.cos],
|
49 |
+
}
|
50 |
+
|
51 |
+
embedder_obj = Embedder(**embed_kwargs)
|
52 |
+
def embed(x, eo=embedder_obj): return eo.embed(x)
|
53 |
+
return embed, embedder_obj.out_dim
|
54 |
+
|
55 |
+
class LayerScale(nn.Module):
|
56 |
+
def __init__(self, dim, init_values=1e-5, inplace=False):
|
57 |
+
super().__init__()
|
58 |
+
self.inplace = inplace
|
59 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
63 |
+
|
64 |
+
class Bottleneck_Linear(nn.Module):
|
65 |
+
def __init__(self, n_channels):
|
66 |
+
super().__init__()
|
67 |
+
self.linear1 = nn.Linear(n_channels, n_channels)
|
68 |
+
self.norm = nn.LayerNorm(n_channels)
|
69 |
+
self.linear2 = nn.Linear(n_channels, n_channels)
|
70 |
+
self.gelu = nn.GELU()
|
71 |
+
|
72 |
+
def forward(self, x):
|
73 |
+
x = x + self.linear2(self.gelu(self.linear1(self.norm(x))))
|
74 |
+
return x
|
75 |
+
|
76 |
+
class Bottleneck_Conv(nn.Module):
|
77 |
+
def __init__(self, n_channels, kernel_size=1):
|
78 |
+
super().__init__()
|
79 |
+
self.linear1 = nn.Conv2d(n_channels, n_channels, kernel_size=kernel_size, padding=kernel_size//2, bias=False)
|
80 |
+
self.bn1 = nn.BatchNorm2d(n_channels)
|
81 |
+
self.linear2 = nn.Conv2d(n_channels, n_channels, kernel_size=kernel_size, padding=kernel_size//2, bias=False)
|
82 |
+
self.bn2 = nn.BatchNorm2d(n_channels)
|
83 |
+
self.relu = nn.ReLU(inplace=True)
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
assert len(x.shape) in [2, 4]
|
87 |
+
input_dims = len(x.shape)
|
88 |
+
if input_dims == 2:
|
89 |
+
x = x.unsqueeze(-1).unsqueeze(-1)
|
90 |
+
residual = x
|
91 |
+
out = self.linear1(x)
|
92 |
+
out = self.bn1(out)
|
93 |
+
out = self.relu(out)
|
94 |
+
out = self.linear2(out)
|
95 |
+
out = self.bn2(out)
|
96 |
+
out += residual
|
97 |
+
out = self.relu(out)
|
98 |
+
if input_dims == 2:
|
99 |
+
out = out.squeeze(-1).squeeze(-1)
|
100 |
+
return out
|
101 |
+
|
102 |
+
class CLIPFusionBlock_Concat(nn.Module):
|
103 |
+
"""
|
104 |
+
Fuse clip and rgb embeddings via concat-proj
|
105 |
+
"""
|
106 |
+
def __init__(self, n_channels=512, n_layers=1, act=True):
|
107 |
+
super().__init__()
|
108 |
+
proj = [Bottleneck_Linear(2 * n_channels) for _ in range(n_layers)]
|
109 |
+
proj.append(nn.Linear(2 * n_channels, n_channels))
|
110 |
+
if act: proj.append(nn.GELU())
|
111 |
+
self.proj = nn.Sequential(*proj)
|
112 |
+
|
113 |
+
def forward(self, sem_latent, clip_latent):
|
114 |
+
"""
|
115 |
+
sem_latent: [B, N, C]
|
116 |
+
clip_latent: [B, C]
|
117 |
+
"""
|
118 |
+
# [B, N, 2C]
|
119 |
+
latent_concat = torch.cat([sem_latent, clip_latent.unsqueeze(1).expand_as(sem_latent)], dim=-1)
|
120 |
+
# [B, N, C]
|
121 |
+
latent = self.proj(latent_concat)
|
122 |
+
return latent
|
123 |
+
|
124 |
+
class CLIPFusionBlock_Attn(nn.Module):
|
125 |
+
"""
|
126 |
+
Fuse geometric and semantic embeddings via multi-layer MHA blocks
|
127 |
+
"""
|
128 |
+
def __init__(self, n_channels=512, n_layers=1, act=True):
|
129 |
+
super().__init__()
|
130 |
+
self.attn_blocks = nn.ModuleList(
|
131 |
+
[Block(
|
132 |
+
n_channels, 8, 4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), drop_path=0.1
|
133 |
+
) for _ in range(n_layers)]
|
134 |
+
)
|
135 |
+
if act: self.attn_blocks.append(nn.GELU())
|
136 |
+
|
137 |
+
def forward(self, sem_latent, clip_latent):
|
138 |
+
"""
|
139 |
+
sem_latent: [B, N, C]
|
140 |
+
clip_latent: [B, C]
|
141 |
+
"""
|
142 |
+
# [B, 1+N, C], clip first
|
143 |
+
latent = torch.cat([clip_latent.unsqueeze(1), sem_latent], dim=1)
|
144 |
+
for attn_block in self.attn_blocks:
|
145 |
+
latent = attn_block(latent)
|
146 |
+
# [B, N, C]
|
147 |
+
return latent[:, 1:, :]
|
utils/loss.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as torch_F
|
4 |
+
|
5 |
+
from copy import deepcopy
|
6 |
+
from model.depth.midas_loss import MidasLoss
|
7 |
+
|
8 |
+
class Loss(nn.Module):
|
9 |
+
|
10 |
+
def __init__(self, opt):
|
11 |
+
super().__init__()
|
12 |
+
self.opt = deepcopy(opt)
|
13 |
+
self.occ_loss = nn.BCEWithLogitsLoss(reduction='none')
|
14 |
+
self.midas_loss = MidasLoss(alpha=opt.training.depth_loss.grad_reg,
|
15 |
+
inverse_depth=opt.training.depth_loss.depth_inv,
|
16 |
+
shrink_mask=opt.training.depth_loss.mask_shrink)
|
17 |
+
|
18 |
+
def shape_loss(self, pred_occ_raw, gt_sdf):
|
19 |
+
assert len(pred_occ_raw.shape) == 2
|
20 |
+
assert len(gt_sdf.shape) == 2
|
21 |
+
# [B, N]
|
22 |
+
gt_occ = (gt_sdf < 0).float()
|
23 |
+
loss = self.occ_loss(pred_occ_raw, gt_occ)
|
24 |
+
weight_mask = torch.ones_like(loss)
|
25 |
+
thres = self.opt.training.shape_loss.impt_thres
|
26 |
+
weight_mask[torch.abs(gt_sdf) < thres] = weight_mask[torch.abs(gt_sdf) < thres] * self.opt.training.shape_loss.impt_weight
|
27 |
+
loss = loss * weight_mask
|
28 |
+
return loss.mean()
|
29 |
+
|
30 |
+
def depth_loss(self, pred_depth, gt_depth, mask):
|
31 |
+
assert len(pred_depth.shape) == len(gt_depth.shape) == len(mask.shape) == 4
|
32 |
+
assert pred_depth.shape[1] == gt_depth.shape[1] == mask.shape[1] == 1
|
33 |
+
loss = self.midas_loss(pred_depth, gt_depth, mask)
|
34 |
+
return loss
|
35 |
+
|
36 |
+
def intr_loss(self, seen_pred, seen_gt, mask):
|
37 |
+
assert len(seen_pred.shape) == len(seen_gt.shape) == 3
|
38 |
+
assert len(mask.shape) == 2
|
39 |
+
# [B, HW]
|
40 |
+
distance = torch.sum((seen_pred - seen_gt)**2, dim=-1)
|
41 |
+
loss = (distance * mask).sum() / (mask.sum() + 1.e-8)
|
42 |
+
return loss
|
utils/options.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os, sys, time
|
3 |
+
import torch
|
4 |
+
import random
|
5 |
+
import string
|
6 |
+
import yaml
|
7 |
+
import utils.util as util
|
8 |
+
import time
|
9 |
+
|
10 |
+
from utils.util import EasyDict as edict
|
11 |
+
|
12 |
+
# torch.backends.cudnn.enabled = False
|
13 |
+
# torch.backends.cudnn.benchmark = False
|
14 |
+
# torch.backends.cudnn.deterministic = True
|
15 |
+
|
16 |
+
def parse_arguments(args):
|
17 |
+
# parse from command line (syntax: --key1.key2.key3=value)
|
18 |
+
opt_cmd = {}
|
19 |
+
for arg in args:
|
20 |
+
assert(arg.startswith("--"))
|
21 |
+
if "=" not in arg[2:]: # --key means key=True, --key! means key=False
|
22 |
+
key_str, value = (arg[2:-1], "false") if arg[-1]=="!" else (arg[2:], "true")
|
23 |
+
else:
|
24 |
+
key_str, value = arg[2:].split("=")
|
25 |
+
keys_sub = key_str.split(".")
|
26 |
+
opt_sub = opt_cmd
|
27 |
+
for k in keys_sub[:-1]:
|
28 |
+
if k not in opt_sub: opt_sub[k] = {}
|
29 |
+
opt_sub = opt_sub[k]
|
30 |
+
# if opt_cmd['key1']['key2']['key3'] already exist for key1.key2.key3, print key3 as error msg
|
31 |
+
assert keys_sub[-1] not in opt_sub, keys_sub[-1]
|
32 |
+
opt_sub[keys_sub[-1]] = yaml.safe_load(value)
|
33 |
+
opt_cmd = edict(opt_cmd)
|
34 |
+
return opt_cmd
|
35 |
+
|
36 |
+
def set(opt_cmd={}, verbose=True, safe_check=True):
|
37 |
+
print("setting configurations...")
|
38 |
+
fname = opt_cmd.yaml # load from yaml file
|
39 |
+
opt_base = load_options(fname)
|
40 |
+
# override with command line arguments
|
41 |
+
opt = override_options(opt_base, opt_cmd, key_stack=[], safe_check=safe_check)
|
42 |
+
process_options(opt)
|
43 |
+
if verbose:
|
44 |
+
def print_options(opt, level=0):
|
45 |
+
for key, value in sorted(opt.items()):
|
46 |
+
if isinstance(value, (dict, edict)):
|
47 |
+
print(" "*level+"* "+key+":")
|
48 |
+
print_options(value, level+1)
|
49 |
+
else:
|
50 |
+
print(" "*level+"* "+key+":", value)
|
51 |
+
print_options(opt)
|
52 |
+
return opt
|
53 |
+
|
54 |
+
def load_options(fname):
|
55 |
+
with open(fname) as file:
|
56 |
+
opt = edict(yaml.safe_load(file))
|
57 |
+
if "_parent_" in opt:
|
58 |
+
# load parent yaml file(s) as base options
|
59 |
+
parent_fnames = opt.pop("_parent_")
|
60 |
+
if type(parent_fnames) is str:
|
61 |
+
parent_fnames = [parent_fnames]
|
62 |
+
for parent_fname in parent_fnames:
|
63 |
+
opt_parent = load_options(parent_fname)
|
64 |
+
opt_parent = override_options(opt_parent, opt, key_stack=[])
|
65 |
+
opt = opt_parent
|
66 |
+
print("loading {}...".format(fname))
|
67 |
+
return opt
|
68 |
+
|
69 |
+
def override_options(opt, opt_over, key_stack=None, safe_check=False):
|
70 |
+
for key, value in opt_over.items():
|
71 |
+
if isinstance(value, dict):
|
72 |
+
# parse child options (until leaf nodes are reached)
|
73 |
+
opt[key] = override_options(opt.get(key, dict()), value, key_stack=key_stack+[key], safe_check=safe_check)
|
74 |
+
else:
|
75 |
+
# ensure command line argument to override is also in yaml file
|
76 |
+
if safe_check and key not in opt:
|
77 |
+
add_new = None
|
78 |
+
while add_new not in ["y", "n"]:
|
79 |
+
key_str = ".".join(key_stack+[key])
|
80 |
+
add_new = input("\"{}\" not found in original opt, add? (y/n) ".format(key_str))
|
81 |
+
if add_new=="n":
|
82 |
+
print("safe exiting...")
|
83 |
+
exit()
|
84 |
+
opt[key] = value
|
85 |
+
return opt
|
86 |
+
|
87 |
+
def process_options(opt):
|
88 |
+
# set seed
|
89 |
+
if opt.seed is not None:
|
90 |
+
random.seed(opt.seed)
|
91 |
+
np.random.seed(opt.seed)
|
92 |
+
torch.manual_seed(opt.seed)
|
93 |
+
torch.cuda.manual_seed_all(opt.seed)
|
94 |
+
else:
|
95 |
+
# create random string as run ID
|
96 |
+
randkey = "".join(random.choice(string.ascii_uppercase) for _ in range(4))
|
97 |
+
opt.name += "_{}".format(randkey)
|
98 |
+
# other default options
|
99 |
+
opt.output_path = "{0}/{1}/{2}".format(opt.output_root, opt.group, opt.name)
|
100 |
+
os.makedirs(opt.output_path, exist_ok=True)
|
101 |
+
opt.H, opt.W = opt.image_size
|
102 |
+
if opt.freq.eval is None:
|
103 |
+
opt.freq.eval = max(opt.max_epoch // 20, 1)
|
104 |
+
if 'loss_weight' in opt:
|
105 |
+
opt.get_depth = False
|
106 |
+
opt.get_normal = False
|
107 |
+
|
108 |
+
def save_options_file(opt):
|
109 |
+
opt_fname = "{}/options.yaml".format(opt.output_path)
|
110 |
+
if os.path.isfile(opt_fname):
|
111 |
+
with open(opt_fname) as file:
|
112 |
+
opt_old = yaml.safe_load(file)
|
113 |
+
if opt!=opt_old:
|
114 |
+
# prompt if options are not identical
|
115 |
+
opt_new_fname = "{}/options_temp.yaml".format(opt.output_path)
|
116 |
+
with open(opt_new_fname, "w") as file:
|
117 |
+
yaml.safe_dump(util.to_dict(opt), file, default_flow_style=False, indent=4)
|
118 |
+
print("existing options file found (different from current one)...")
|
119 |
+
os.system("diff {} {}".format(opt_fname, opt_new_fname))
|
120 |
+
os.system("rm {}".format(opt_new_fname))
|
121 |
+
if not opt.debug:
|
122 |
+
print("please cancel within 10 seconds if you do not want to override...")
|
123 |
+
time.sleep(10)
|
124 |
+
else: print("existing options file found (identical)")
|
125 |
+
else: print("(creating new options file...)")
|
126 |
+
with open(opt_fname, "w") as file:
|
127 |
+
yaml.safe_dump(util.to_dict(opt), file, default_flow_style=False, indent=4)
|
utils/pos_embed.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is copied from https://github.com/facebookresearch/MCC
|
5 |
+
# The original code base is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
# --------------------------------------------------------
|
8 |
+
# Position embedding utils
|
9 |
+
# --------------------------------------------------------
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
import torch
|
14 |
+
|
15 |
+
# --------------------------------------------------------
|
16 |
+
# 2D sine-cosine position embedding
|
17 |
+
# References:
|
18 |
+
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
|
19 |
+
# MoCo v3: https://github.com/facebookresearch/moco-v3
|
20 |
+
# --------------------------------------------------------
|
21 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
22 |
+
"""
|
23 |
+
grid_size: int of the grid height and width
|
24 |
+
return:
|
25 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
26 |
+
"""
|
27 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
28 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
29 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
30 |
+
grid = np.stack(grid, axis=0)
|
31 |
+
|
32 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
33 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
34 |
+
if cls_token:
|
35 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
36 |
+
return pos_embed
|
37 |
+
|
38 |
+
|
39 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
40 |
+
assert embed_dim % 2 == 0
|
41 |
+
|
42 |
+
# use half of dimensions to encode grid_h
|
43 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
44 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
45 |
+
|
46 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
47 |
+
return emb
|
48 |
+
|
49 |
+
|
50 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
51 |
+
"""
|
52 |
+
embed_dim: output dimension for each position
|
53 |
+
pos: a list of positions to be encoded: size (M,)
|
54 |
+
out: (M, D)
|
55 |
+
"""
|
56 |
+
assert embed_dim % 2 == 0
|
57 |
+
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
58 |
+
omega /= embed_dim / 2.
|
59 |
+
omega = 1. / 10000**omega # (D/2,)
|
60 |
+
|
61 |
+
pos = pos.reshape(-1) # (M,)
|
62 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
63 |
+
|
64 |
+
emb_sin = np.sin(out) # (M, D/2)
|
65 |
+
emb_cos = np.cos(out) # (M, D/2)
|
66 |
+
|
67 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
68 |
+
return emb
|
69 |
+
|
70 |
+
def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos):
|
71 |
+
"""
|
72 |
+
embed_dim: output dimension for each position
|
73 |
+
pos: a list of positions to be encoded: size (M,)
|
74 |
+
out: (M, D)
|
75 |
+
"""
|
76 |
+
assert embed_dim % 2 == 0
|
77 |
+
omega = torch.arange(embed_dim // 2, device=pos.device).float()
|
78 |
+
omega /= embed_dim / 2.
|
79 |
+
omega = 1. / 10000**omega # (D/2,)
|
80 |
+
|
81 |
+
pos = pos.reshape(-1) # (M,)
|
82 |
+
out = torch.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
83 |
+
|
84 |
+
emb_sin = torch.sin(out) # (M, D/2)
|
85 |
+
emb_cos = torch.cos(out) # (M, D/2)
|
86 |
+
|
87 |
+
emb = torch.cat([emb_sin, emb_cos], axis=1) # (M, D)
|
88 |
+
return emb
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
# --------------------------------------------------------
|
93 |
+
# Interpolate position embeddings for high-resolution
|
94 |
+
# References:
|
95 |
+
# DeiT: https://github.com/facebookresearch/deit
|
96 |
+
# --------------------------------------------------------
|
97 |
+
def interpolate_pos_embed(model, checkpoint_model):
|
98 |
+
if 'pos_embed' in checkpoint_model:
|
99 |
+
pos_embed_checkpoint = checkpoint_model['pos_embed']
|
100 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
101 |
+
num_patches = model.patch_embed.num_patches
|
102 |
+
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
103 |
+
# height (== width) for the checkpoint position embedding
|
104 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
105 |
+
# height (== width) for the new position embedding
|
106 |
+
new_size = int(num_patches ** 0.5)
|
107 |
+
# class_token and dist_token are kept unchanged
|
108 |
+
if orig_size != new_size:
|
109 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
110 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
111 |
+
# only the position tokens are interpolated
|
112 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
113 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
114 |
+
pos_tokens = torch.nn.functional.interpolate(
|
115 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
116 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
117 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
118 |
+
checkpoint_model['pos_embed'] = new_pos_embed
|
utils/util.py
ADDED
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys, time
|
2 |
+
import shutil
|
3 |
+
import datetime
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as torch_F
|
6 |
+
import socket
|
7 |
+
import contextlib
|
8 |
+
import socket
|
9 |
+
import torch.distributed as dist
|
10 |
+
from collections import defaultdict, deque
|
11 |
+
|
12 |
+
class SmoothedValue(object):
|
13 |
+
"""Track a series of values and provide access to smoothed values over a
|
14 |
+
window or the global series average.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, window_size=20, fmt=None):
|
18 |
+
if fmt is None:
|
19 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
20 |
+
self.deque = deque(maxlen=window_size)
|
21 |
+
self.total = 0.0
|
22 |
+
self.count = 0
|
23 |
+
self.fmt = fmt
|
24 |
+
|
25 |
+
def update(self, value, n=1):
|
26 |
+
self.deque.append(value)
|
27 |
+
self.count += n
|
28 |
+
self.total += value * n
|
29 |
+
|
30 |
+
@property
|
31 |
+
def median(self):
|
32 |
+
d = torch.tensor(list(self.deque))
|
33 |
+
return d.median().item()
|
34 |
+
|
35 |
+
@property
|
36 |
+
def avg(self):
|
37 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
38 |
+
return d.mean().item()
|
39 |
+
|
40 |
+
@property
|
41 |
+
def global_avg(self):
|
42 |
+
return self.total / self.count
|
43 |
+
|
44 |
+
@property
|
45 |
+
def max(self):
|
46 |
+
return max(self.deque)
|
47 |
+
|
48 |
+
@property
|
49 |
+
def value(self):
|
50 |
+
return self.deque[-1]
|
51 |
+
|
52 |
+
def __str__(self):
|
53 |
+
return self.fmt.format(
|
54 |
+
median=self.median,
|
55 |
+
avg=self.avg,
|
56 |
+
global_avg=self.global_avg,
|
57 |
+
max=self.max,
|
58 |
+
value=self.value)
|
59 |
+
|
60 |
+
|
61 |
+
class MetricLogger(object):
|
62 |
+
def __init__(self, delimiter="\t"):
|
63 |
+
self.meters = defaultdict(SmoothedValue)
|
64 |
+
self.delimiter = delimiter
|
65 |
+
|
66 |
+
def update(self, **kwargs):
|
67 |
+
for k, v in kwargs.items():
|
68 |
+
if v is None:
|
69 |
+
continue
|
70 |
+
if isinstance(v, torch.Tensor):
|
71 |
+
v = v.item()
|
72 |
+
assert isinstance(v, (float, int))
|
73 |
+
self.meters[k].update(v)
|
74 |
+
|
75 |
+
def __getattr__(self, attr):
|
76 |
+
if attr in self.meters:
|
77 |
+
return self.meters[attr]
|
78 |
+
if attr in self.__dict__:
|
79 |
+
return self.__dict__[attr]
|
80 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
81 |
+
type(self).__name__, attr))
|
82 |
+
|
83 |
+
def __str__(self):
|
84 |
+
loss_str = []
|
85 |
+
for name, meter in self.meters.items():
|
86 |
+
loss_str.append(
|
87 |
+
"{}: {}".format(name, str(meter))
|
88 |
+
)
|
89 |
+
return self.delimiter.join(loss_str)
|
90 |
+
|
91 |
+
def add_meter(self, name, meter):
|
92 |
+
self.meters[name] = meter
|
93 |
+
|
94 |
+
def log_every(self, iterable, print_freq, header=None):
|
95 |
+
i = 0
|
96 |
+
if not header:
|
97 |
+
header = ''
|
98 |
+
start_time = time.time()
|
99 |
+
end = time.time()
|
100 |
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
101 |
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
102 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
103 |
+
log_msg = [
|
104 |
+
header,
|
105 |
+
'[{0' + space_fmt + '}/{1}]',
|
106 |
+
'eta: {eta}',
|
107 |
+
'{meters}',
|
108 |
+
'time: {time}',
|
109 |
+
'data: {data}'
|
110 |
+
]
|
111 |
+
if torch.cuda.is_available():
|
112 |
+
log_msg.append('max mem: {memory:.0f}')
|
113 |
+
log_msg = self.delimiter.join(log_msg)
|
114 |
+
MB = 1024.0 * 1024.0
|
115 |
+
for obj in iterable:
|
116 |
+
data_time.update(time.time() - end)
|
117 |
+
yield obj
|
118 |
+
iter_time.update(time.time() - end)
|
119 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
120 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
121 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
122 |
+
if torch.cuda.is_available():
|
123 |
+
print(log_msg.format(
|
124 |
+
i, len(iterable), eta=eta_string,
|
125 |
+
meters=str(self),
|
126 |
+
time=str(iter_time), data=str(data_time),
|
127 |
+
memory=torch.cuda.max_memory_allocated() / MB))
|
128 |
+
else:
|
129 |
+
print(log_msg.format(
|
130 |
+
i, len(iterable), eta=eta_string,
|
131 |
+
meters=str(self),
|
132 |
+
time=str(iter_time), data=str(data_time)))
|
133 |
+
i += 1
|
134 |
+
end = time.time()
|
135 |
+
total_time = time.time() - start_time
|
136 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
137 |
+
print('{} Total time: {} ({:.4f} s / it)'.format(
|
138 |
+
header, total_time_str, total_time / len(iterable)))
|
139 |
+
|
140 |
+
def print_eval(opt, loss=None, chamfer=None, depth_metrics=None):
|
141 |
+
message = "[eval] "
|
142 |
+
if loss is not None: message += "loss:{}".format("{:.3e}".format(loss.all))
|
143 |
+
if chamfer is not None:
|
144 |
+
message += " chamfer:{}|{}|{}".format("{:.4f}".format(chamfer[0]),
|
145 |
+
"{:.4f}".format(chamfer[1]),
|
146 |
+
"{:.4f}".format((chamfer[0]+chamfer[1])/2))
|
147 |
+
if depth_metrics is not None:
|
148 |
+
for k, v in depth_metrics.items():
|
149 |
+
message += "{}:{}, ".format(k, "{:.4f}".format(v))
|
150 |
+
message = message[:-2]
|
151 |
+
print(message)
|
152 |
+
|
153 |
+
def update_timer(opt, timer, ep, it_per_ep):
|
154 |
+
momentum = 0.99
|
155 |
+
timer.elapsed = time.time()-timer.start
|
156 |
+
timer.it = timer.it_end-timer.it_start
|
157 |
+
# compute speed with moving average
|
158 |
+
timer.it_mean = timer.it_mean*momentum+timer.it*(1-momentum) if timer.it_mean is not None else timer.it
|
159 |
+
timer.arrival = timer.it_mean*it_per_ep*(opt.max_epoch-ep)
|
160 |
+
|
161 |
+
# move tensors to device in-place
|
162 |
+
def move_to_device(X, device):
|
163 |
+
if isinstance(X, dict):
|
164 |
+
for k, v in X.items():
|
165 |
+
X[k] = move_to_device(v, device)
|
166 |
+
elif isinstance(X, list):
|
167 |
+
for i, e in enumerate(X):
|
168 |
+
X[i] = move_to_device(e, device)
|
169 |
+
elif isinstance(X, tuple) and hasattr(X, "_fields"): # collections.namedtuple
|
170 |
+
dd = X._asdict()
|
171 |
+
dd = move_to_device(dd, device)
|
172 |
+
return type(X)(**dd)
|
173 |
+
elif isinstance(X, torch.Tensor):
|
174 |
+
return X.to(device=device, non_blocking=True)
|
175 |
+
return X
|
176 |
+
|
177 |
+
# detach tensors
|
178 |
+
def detach_tensors(X):
|
179 |
+
if isinstance(X, dict):
|
180 |
+
for k, v in X.items():
|
181 |
+
X[k] = detach_tensors(v)
|
182 |
+
elif isinstance(X, list):
|
183 |
+
for i, e in enumerate(X):
|
184 |
+
X[i] = detach_tensors(e)
|
185 |
+
elif isinstance(X, tuple) and hasattr(X, "_fields"): # collections.namedtuple
|
186 |
+
dd = X._asdict()
|
187 |
+
dd = detach_tensors(dd)
|
188 |
+
return type(X)(**dd)
|
189 |
+
elif isinstance(X, torch.Tensor):
|
190 |
+
return X.detach()
|
191 |
+
return X
|
192 |
+
|
193 |
+
# this recursion seems to only work for the outer loop when dict_type is not dict
|
194 |
+
def to_dict(D, dict_type=dict):
|
195 |
+
D = dict_type(D)
|
196 |
+
for k, v in D.items():
|
197 |
+
if isinstance(v, dict):
|
198 |
+
D[k] = to_dict(v, dict_type)
|
199 |
+
return D
|
200 |
+
|
201 |
+
def get_child_state_dict(state_dict, key):
|
202 |
+
out_dict = {}
|
203 |
+
for k, v in state_dict.items():
|
204 |
+
if k.startswith("module."):
|
205 |
+
param_name = k[7:]
|
206 |
+
else:
|
207 |
+
param_name = k
|
208 |
+
if param_name.startswith("{}.".format(key)):
|
209 |
+
out_dict[".".join(param_name.split(".")[1:])] = v
|
210 |
+
return out_dict
|
211 |
+
|
212 |
+
def resume_checkpoint(opt, model, best):
|
213 |
+
load_name = "{0}/best.ckpt".format(opt.output_path) if best else "{0}/latest.ckpt".format(opt.output_path)
|
214 |
+
checkpoint = torch.load(load_name, map_location=torch.device(opt.device))
|
215 |
+
model.graph.module.load_state_dict(checkpoint["graph"], strict=True)
|
216 |
+
# load the training stats
|
217 |
+
for key in model.__dict__:
|
218 |
+
if key.split("_")[0] in ["optim", "sched", "scaler"] and key in checkpoint:
|
219 |
+
if opt.device == 0: print("restoring {}...".format(key))
|
220 |
+
getattr(model, key).load_state_dict(checkpoint[key])
|
221 |
+
# also need to record ep, it, best_val if we are returning
|
222 |
+
ep, it = checkpoint["epoch"], checkpoint["iter"]
|
223 |
+
best_val, best_ep = checkpoint["best_val"], checkpoint["best_ep"] if "best_ep" in checkpoint else 0
|
224 |
+
print("resuming from epoch {0} (iteration {1})".format(ep, it))
|
225 |
+
|
226 |
+
return ep, it, best_val, best_ep
|
227 |
+
|
228 |
+
def load_checkpoint(opt, model, load_name):
|
229 |
+
# load_name as to be given
|
230 |
+
checkpoint = torch.load(load_name, map_location=torch.device(opt.device))
|
231 |
+
# load individual (possibly partial) children modules
|
232 |
+
for name, child in model.graph.module.named_children():
|
233 |
+
child_state_dict = get_child_state_dict(checkpoint["graph"], name)
|
234 |
+
if child_state_dict:
|
235 |
+
if opt.device == 0: print("restoring {}...".format(name))
|
236 |
+
child.load_state_dict(child_state_dict, strict=True)
|
237 |
+
else:
|
238 |
+
if opt.device == 0: print("skipping {}...".format(name))
|
239 |
+
return None, None, None, None
|
240 |
+
|
241 |
+
def restore_checkpoint(opt, model, load_name=None, resume=False, best=False, evaluate=False):
|
242 |
+
# we cannot load and resume at the same time
|
243 |
+
assert not (load_name is not None and resume)
|
244 |
+
# when resuming we want everything to be the same
|
245 |
+
if resume:
|
246 |
+
ep, it, best_val, best_ep = resume_checkpoint(opt, model, best)
|
247 |
+
# loading is more flexible, as we can only load parts of the model
|
248 |
+
else:
|
249 |
+
ep, it, best_val, best_ep = load_checkpoint(opt, model, load_name)
|
250 |
+
return ep, it, best_val, best_ep
|
251 |
+
|
252 |
+
def save_checkpoint(opt, model, ep, it, best_val, best_ep, latest=False, best=False, children=None):
|
253 |
+
os.makedirs("{0}/checkpoint".format(opt.output_path), exist_ok=True)
|
254 |
+
if isinstance(model.graph, torch.nn.DataParallel) or isinstance(model.graph, torch.nn.parallel.DistributedDataParallel):
|
255 |
+
graph = model.graph.module
|
256 |
+
else:
|
257 |
+
graph = model.graph
|
258 |
+
if children is not None:
|
259 |
+
graph_state_dict = { k: v for k, v in graph.state_dict().items() if k.startswith(children) }
|
260 |
+
else: graph_state_dict = graph.state_dict()
|
261 |
+
checkpoint = dict(
|
262 |
+
epoch=ep,
|
263 |
+
iter=it,
|
264 |
+
best_val=best_val,
|
265 |
+
best_ep=best_ep,
|
266 |
+
graph=graph_state_dict,
|
267 |
+
)
|
268 |
+
for key in model.__dict__:
|
269 |
+
if key.split("_")[0] in ["optim", "sched", "scaler"]:
|
270 |
+
checkpoint.update({key: getattr(model, key).state_dict()})
|
271 |
+
torch.save(checkpoint, "{0}/latest.ckpt".format(opt.output_path))
|
272 |
+
if best:
|
273 |
+
shutil.copy("{0}/latest.ckpt".format(opt.output_path),
|
274 |
+
"{0}/best.ckpt".format(opt.output_path))
|
275 |
+
if not latest:
|
276 |
+
shutil.copy("{0}/latest.ckpt".format(opt.output_path),
|
277 |
+
"{0}/checkpoint/ep{1}.ckpt".format(opt.output_path, ep))
|
278 |
+
|
279 |
+
def check_socket_open(hostname, port):
|
280 |
+
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
281 |
+
is_open = False
|
282 |
+
try:
|
283 |
+
s.bind((hostname, port))
|
284 |
+
except socket.error:
|
285 |
+
is_open = True
|
286 |
+
finally:
|
287 |
+
s.close()
|
288 |
+
return is_open
|
289 |
+
|
290 |
+
def get_layer_dims(layers):
|
291 |
+
# return a list of tuples (k_in, k_out)
|
292 |
+
return list(zip(layers[:-1], layers[1:]))
|
293 |
+
|
294 |
+
@contextlib.contextmanager
|
295 |
+
def suppress(stdout=False, stderr=False):
|
296 |
+
with open(os.devnull, "w") as devnull:
|
297 |
+
if stdout: old_stdout, sys.stdout = sys.stdout, devnull
|
298 |
+
if stderr: old_stderr, sys.stderr = sys.stderr, devnull
|
299 |
+
try: yield
|
300 |
+
finally:
|
301 |
+
if stdout: sys.stdout = old_stdout
|
302 |
+
if stderr: sys.stderr = old_stderr
|
303 |
+
|
304 |
+
def toggle_grad(model, requires_grad):
|
305 |
+
for p in model.parameters():
|
306 |
+
p.requires_grad_(requires_grad)
|
307 |
+
|
308 |
+
def compute_grad2(d_outs, x_in):
|
309 |
+
d_outs = [d_outs] if not isinstance(d_outs, list) else d_outs
|
310 |
+
reg = 0
|
311 |
+
for d_out in d_outs:
|
312 |
+
batch_size = x_in.size(0)
|
313 |
+
grad_dout = torch.autograd.grad(
|
314 |
+
outputs=d_out.sum(), inputs=x_in,
|
315 |
+
create_graph=True, retain_graph=True, only_inputs=True
|
316 |
+
)[0]
|
317 |
+
grad_dout2 = grad_dout.pow(2)
|
318 |
+
assert(grad_dout2.size() == x_in.size())
|
319 |
+
reg += grad_dout2.view(batch_size, -1).sum(1)
|
320 |
+
return reg / len(d_outs)
|
321 |
+
|
322 |
+
# import matplotlib.pyplot as plt
|
323 |
+
def interpolate_depth(depth_input, mask_input, size, bg_depth=20):
|
324 |
+
assert len(depth_input.shape) == len(mask_input.shape) == 4
|
325 |
+
mask = (mask_input > 0.5).float()
|
326 |
+
depth_valid = depth_input * mask
|
327 |
+
depth_valid = torch_F.interpolate(depth_valid, size, mode='bilinear', align_corners=False)
|
328 |
+
mask = torch_F.interpolate(mask, size, mode='bilinear', align_corners=False)
|
329 |
+
depth_out = depth_valid / (mask + 1.e-6)
|
330 |
+
mask_binary = (mask > 0.5).float()
|
331 |
+
depth_out = depth_out * mask_binary + bg_depth * (1 - mask_binary)
|
332 |
+
return depth_out, mask_binary
|
333 |
+
|
334 |
+
# import matplotlib.pyplot as plt
|
335 |
+
# import torchvision
|
336 |
+
def interpolate_coordmap(coord_map, mask_input, size, bg_coord=0):
|
337 |
+
assert len(coord_map.shape) == len(mask_input.shape) == 4
|
338 |
+
mask = (mask_input > 0.5).float()
|
339 |
+
coord_valid = coord_map * mask
|
340 |
+
coord_valid = torch_F.interpolate(coord_valid, size, mode='bilinear', align_corners=False)
|
341 |
+
mask = torch_F.interpolate(mask, size, mode='bilinear', align_corners=False)
|
342 |
+
coord_out = coord_valid / (mask + 1.e-6)
|
343 |
+
mask_binary = (mask > 0.5).float()
|
344 |
+
coord_out = coord_out * mask_binary + bg_coord * (1 - mask_binary)
|
345 |
+
return coord_out, mask_binary
|
346 |
+
|
347 |
+
def cleanup():
|
348 |
+
dist.destroy_process_group()
|
349 |
+
|
350 |
+
def is_port_in_use(port):
|
351 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
352 |
+
return s.connect_ex(('localhost', port)) == 0
|
353 |
+
|
354 |
+
def setup(rank, world_size, port_no):
|
355 |
+
full_address = 'tcp://127.0.0.1:' + str(port_no)
|
356 |
+
dist.init_process_group("nccl", init_method=full_address, rank=rank, world_size=world_size)
|
357 |
+
|
358 |
+
def print_grad(grad, prefix=''):
|
359 |
+
print("{} --- Grad Abs Mean, Grad Max, Grad Min: {:.5f} | {:.5f} | {:.5f}".format(prefix, grad.abs().mean().item(), grad.max().item(), grad.min().item()))
|
360 |
+
|
361 |
+
class AverageMeter(object):
|
362 |
+
"""Computes and stores the average and current value"""
|
363 |
+
def __init__(self):
|
364 |
+
self.reset()
|
365 |
+
|
366 |
+
def reset(self):
|
367 |
+
self.val = 0
|
368 |
+
self.avg = 0
|
369 |
+
self.sum = 0
|
370 |
+
self.count = 0
|
371 |
+
|
372 |
+
def update(self, val, n=1):
|
373 |
+
self.val = val
|
374 |
+
self.sum += val * n
|
375 |
+
self.count += n
|
376 |
+
self.avg = self.sum / self.count
|
377 |
+
|
378 |
+
class EasyDict(dict):
|
379 |
+
def __init__(self, d=None, **kwargs):
|
380 |
+
if d is None:
|
381 |
+
d = {}
|
382 |
+
else:
|
383 |
+
d = dict(d)
|
384 |
+
if kwargs:
|
385 |
+
d.update(**kwargs)
|
386 |
+
for k, v in d.items():
|
387 |
+
setattr(self, k, v)
|
388 |
+
# Class attributes
|
389 |
+
for k in self.__class__.__dict__.keys():
|
390 |
+
if not (k.startswith('__') and k.endswith('__')) and not k in ('update', 'pop'):
|
391 |
+
setattr(self, k, getattr(self, k))
|
392 |
+
|
393 |
+
def __setattr__(self, name, value):
|
394 |
+
if isinstance(value, (list, tuple)):
|
395 |
+
value = [self.__class__(x)
|
396 |
+
if isinstance(x, dict) else x for x in value]
|
397 |
+
elif isinstance(value, dict) and not isinstance(value, self.__class__):
|
398 |
+
value = self.__class__(value)
|
399 |
+
super(EasyDict, self).__setattr__(name, value)
|
400 |
+
super(EasyDict, self).__setitem__(name, value)
|
401 |
+
|
402 |
+
__setitem__ = __setattr__
|
403 |
+
|
404 |
+
def update(self, e=None, **f):
|
405 |
+
d = e or dict()
|
406 |
+
d.update(f)
|
407 |
+
for k in d:
|
408 |
+
setattr(self, k, d[k])
|
409 |
+
|
410 |
+
def pop(self, k, d=None):
|
411 |
+
delattr(self, k)
|
412 |
+
return super(EasyDict, self).pop(k, d)
|
413 |
+
|
utils/util_vis.py
ADDED
@@ -0,0 +1,511 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import torchvision
|
5 |
+
import torchvision.transforms.functional as torchvision_F
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import PIL
|
8 |
+
import PIL.ImageDraw
|
9 |
+
from PIL import Image, ImageFont
|
10 |
+
import trimesh
|
11 |
+
import pyrender
|
12 |
+
import cv2
|
13 |
+
import copy
|
14 |
+
import base64
|
15 |
+
import io
|
16 |
+
import imageio
|
17 |
+
|
18 |
+
os.environ['PYOPENGL_PLATFORM'] = 'egl'
|
19 |
+
@torch.no_grad()
|
20 |
+
def tb_image(opt, tb, step, group, name, images, masks=None, num_vis=None, from_range=(0, 1), poses=None, cmap="gray", depth=False):
|
21 |
+
if not depth:
|
22 |
+
images = preprocess_vis_image(opt, images, masks=masks, from_range=from_range, cmap=cmap) # [B, 3, H, W]
|
23 |
+
else:
|
24 |
+
masks = (masks > 0.5).float()
|
25 |
+
images = images * masks + (1 - masks) * ((images * masks).max())
|
26 |
+
images = (1 - images).detach().cpu()
|
27 |
+
num_H, num_W = num_vis or opt.tb.num_images
|
28 |
+
images = images[:num_H*num_W]
|
29 |
+
if poses is not None:
|
30 |
+
# poses: [B, 3, 4]
|
31 |
+
# rots: [max(B, num_images), 3, 3]
|
32 |
+
rots = poses[:num_H*num_W, ..., :3]
|
33 |
+
images = torch.stack([draw_pose(opt, image, rot, size=20, width=2) for image, rot in zip(images, rots)], dim=0)
|
34 |
+
image_grid = torchvision.utils.make_grid(images[:, :3], nrow=num_W, pad_value=1.)
|
35 |
+
if images.shape[1]==4:
|
36 |
+
mask_grid = torchvision.utils.make_grid(images[:, 3:], nrow=num_W, pad_value=1.)[:1]
|
37 |
+
image_grid = torch.cat([image_grid, mask_grid], dim=0)
|
38 |
+
tag = "{0}/{1}".format(group, name)
|
39 |
+
tb.add_image(tag, image_grid, step)
|
40 |
+
|
41 |
+
def preprocess_vis_image(opt, images, masks=None, from_range=(0, 1), cmap="gray"):
|
42 |
+
min, max = from_range
|
43 |
+
images = (images-min)/(max-min)
|
44 |
+
if masks is not None:
|
45 |
+
# then the mask is directly the transparency channel of png
|
46 |
+
images = torch.cat([images, masks], dim=1)
|
47 |
+
images = images.clamp(min=0, max=1).cpu()
|
48 |
+
if images.shape[1]==1:
|
49 |
+
images = get_heatmap(opt, images[:, 0].cpu(), cmap=cmap)
|
50 |
+
return images
|
51 |
+
|
52 |
+
def preprocess_depth_image(opt, depth, mask=None, max_depth=1000):
|
53 |
+
if mask is not None: depth = depth * mask + (1 - mask) * max_depth # min of this will leads to minimum of masked regions
|
54 |
+
depth = depth - depth.min()
|
55 |
+
|
56 |
+
if mask is not None: depth = depth * mask # max of this will leads to maximum of masked regions
|
57 |
+
depth = depth / depth.max()
|
58 |
+
return depth
|
59 |
+
|
60 |
+
def dump_images(opt, idx, name, images, masks=None, from_range=(0, 1), poses=None, metrics=None, cmap="gray", folder='dump'):
|
61 |
+
images = preprocess_vis_image(opt, images, masks=masks, from_range=from_range, cmap=cmap) # [B, 3, H, W]
|
62 |
+
if poses is not None:
|
63 |
+
rots = poses[..., :3]
|
64 |
+
images = torch.stack([draw_pose(opt, image, rot, size=20, width=2) for image, rot in zip(images, rots)], dim=0)
|
65 |
+
if metrics is not None:
|
66 |
+
images = torch.stack([draw_metric(opt, image, metric.item()) for image, metric in zip(images, metrics)], dim=0)
|
67 |
+
images = images.cpu().permute(0, 2, 3, 1).contiguous().numpy() # [B, H, W, 3]
|
68 |
+
for i, img in zip(idx, images):
|
69 |
+
fname = "{}/{}/{}_{}.png".format(opt.output_path, folder, i, name)
|
70 |
+
img = Image.fromarray((img*255).astype(np.uint8))
|
71 |
+
img.save(fname)
|
72 |
+
|
73 |
+
def dump_depths(opt, idx, name, depths, masks=None, rescale=False, folder='dump'):
|
74 |
+
if rescale:
|
75 |
+
masks = (masks > 0.5).float()
|
76 |
+
depths = depths * masks + (1 - masks) * ((depths * masks).max())
|
77 |
+
depths = (1 - depths).detach().cpu()
|
78 |
+
for i, depth in zip(idx, depths):
|
79 |
+
fname = "{}/{}/{}_{}.png".format(opt.output_path, folder, i, name)
|
80 |
+
plt.imsave(fname, depth.squeeze(), cmap='viridis')
|
81 |
+
|
82 |
+
# img_list is a list of length n_views, where each view is a image tensor of [B, 3, H, W]
|
83 |
+
def dump_gifs(opt, idx, name, imgs_list, from_range=(0, 1), folder='dump', cmap="gray"):
|
84 |
+
for i in range(len(imgs_list)):
|
85 |
+
imgs_list[i] = preprocess_vis_image(opt, imgs_list[i], from_range=from_range, cmap=cmap)
|
86 |
+
for i in range(len(idx)):
|
87 |
+
img_list_np = [imgs[i].cpu().permute(1, 2, 0).contiguous().numpy() for imgs in imgs_list] # list of [H, W, 3], each item is a view of ith sample
|
88 |
+
img_list_pil = [Image.fromarray((img*255).astype(np.uint8)).convert('RGB') for img in img_list_np]
|
89 |
+
fname = "{}/{}/{}_{}.gif".format(opt.output_path, folder, idx[i], name)
|
90 |
+
img_list_pil[0].save(fname, format='GIF', append_images=img_list_pil[1:], save_all=True, duration=100, loop=0)
|
91 |
+
|
92 |
+
# img_list is a list of length n_views, where each view is a image tensor of [B, 3, H, W]
|
93 |
+
def dump_attentions(opt, idx, name, attn_vis, folder='dump'):
|
94 |
+
for i in range(len(idx)):
|
95 |
+
img_list_pil = [Image.fromarray((img*255).astype(np.uint8)).convert('RGB') for img in attn_vis[i]]
|
96 |
+
fname = "{}/{}/{}_{}.gif".format(opt.output_path, folder, idx[i], name)
|
97 |
+
img_list_pil[0].save(fname, format='GIF', append_images=img_list_pil[1:], save_all=True, duration=50, loop=0)
|
98 |
+
|
99 |
+
def get_heatmap(opt, gray, cmap): # [N, H, W]
|
100 |
+
color = plt.get_cmap(cmap)(gray.numpy())
|
101 |
+
color = torch.from_numpy(color[..., :3]).permute(0, 3, 1, 2).contiguous().float() # [N, 3, H, W]
|
102 |
+
return color
|
103 |
+
|
104 |
+
def dump_meshes(opt, idx, name, meshes, folder='dump'):
|
105 |
+
for i, mesh in zip(idx, meshes):
|
106 |
+
fname = "{}/{}/{}_{}.ply".format(opt.output_path, folder, i, name)
|
107 |
+
try:
|
108 |
+
mesh.export(fname)
|
109 |
+
except:
|
110 |
+
print('Mesh is empty!')
|
111 |
+
|
112 |
+
def dump_meshes_viz(opt, idx, name, meshes, save_frames=True, folder='dump'):
|
113 |
+
for i, mesh in zip(idx, meshes):
|
114 |
+
mesh = copy.deepcopy(mesh)
|
115 |
+
R = trimesh.transformations.rotation_matrix(np.radians(180), [0,0,1])
|
116 |
+
mesh.apply_transform(R)
|
117 |
+
R = trimesh.transformations.rotation_matrix(np.radians(180), [0,1,0])
|
118 |
+
mesh.apply_transform(R)
|
119 |
+
# our marching cubes outputs inverted normals for some reason so this is necessary
|
120 |
+
trimesh.repair.fix_inversion(mesh)
|
121 |
+
|
122 |
+
fname = "{}/{}/{}_{}".format(opt.output_path, folder, i, name)
|
123 |
+
try:
|
124 |
+
mesh = scale_to_unit_cube(mesh)
|
125 |
+
visualize_mesh(mesh, fname, write_frames=save_frames)
|
126 |
+
except:
|
127 |
+
pass
|
128 |
+
|
129 |
+
def dump_seen_surface(opt, idx, obj_name, img_name, seen_projs, folder='dump'):
|
130 |
+
# seen_proj: [B, H, W, 3]
|
131 |
+
for i, seen_proj in zip(idx, seen_projs):
|
132 |
+
out_folder = "{}/{}".format(opt.output_path, folder)
|
133 |
+
img_fname = "{}_{}.png".format(i, img_name)
|
134 |
+
create_seen_surface(i, img_fname, seen_proj, out_folder, obj_name)
|
135 |
+
|
136 |
+
# https://github.com/princeton-vl/oasis/blob/master/utils/vis_mesh.py
|
137 |
+
def create_seen_surface(sample_ID, img_path, XYZ, output_folder, obj_name, connect_thres=0.005):
|
138 |
+
height, width = XYZ.shape[:2]
|
139 |
+
XYZ_to_idx = {}
|
140 |
+
idx = 1
|
141 |
+
with open("{}/{}_{}.mtl".format(output_folder, sample_ID, obj_name), "w") as f:
|
142 |
+
f.write("newmtl material_0\n")
|
143 |
+
f.write("Ka 0.200000 0.200000 0.200000\n")
|
144 |
+
f.write("Kd 0.752941 0.752941 0.752941\n")
|
145 |
+
f.write("Ks 1.000000 1.000000 1.000000\n")
|
146 |
+
f.write("Tr 1.000000\n")
|
147 |
+
f.write("illum 2\n")
|
148 |
+
f.write("Ns 0.000000\n")
|
149 |
+
f.write("map_Ka %s\n" % img_path)
|
150 |
+
f.write("map_Kd %s\n" % img_path)
|
151 |
+
|
152 |
+
with open("{}/{}_{}.obj".format(output_folder, sample_ID, obj_name), "w") as f:
|
153 |
+
f.write("mtllib {}_{}.mtl\n".format(sample_ID, obj_name))
|
154 |
+
for y in range(height):
|
155 |
+
for x in range(width):
|
156 |
+
if XYZ[y][x][2] > 0:
|
157 |
+
XYZ_to_idx[(y, x)] = idx
|
158 |
+
idx += 1
|
159 |
+
f.write("v %.4f %.4f %.4f\n" % (XYZ[y][x][0], XYZ[y][x][1], XYZ[y][x][2]))
|
160 |
+
f.write("vt %.8f %.8f\n" % ( float(x) / float(width), 1.0 - float(y) / float(height)))
|
161 |
+
f.write("usemtl material_0\n")
|
162 |
+
for y in range(height-1):
|
163 |
+
for x in range(width-1):
|
164 |
+
if XYZ[y][x][2] > 0 and XYZ[y][x+1][2] > 0 and XYZ[y+1][x][2] > 0:
|
165 |
+
# if close enough, connect vertices to form a face
|
166 |
+
if torch.norm(XYZ[y][x] - XYZ[y][x+1]).item() < connect_thres and torch.norm(XYZ[y][x] - XYZ[y+1][x]).item() < connect_thres:
|
167 |
+
f.write("f %d/%d %d/%d %d/%d\n" % (XYZ_to_idx[(y, x)], XYZ_to_idx[(y, x)], XYZ_to_idx[(y, x+1)], XYZ_to_idx[(y, x+1)], XYZ_to_idx[(y+1, x)], XYZ_to_idx[(y+1, x)]))
|
168 |
+
if XYZ[y][x+1][2] > 0 and XYZ[y+1][x+1][2] > 0 and XYZ[y+1][x][2] > 0:
|
169 |
+
if torch.norm(XYZ[y][x+1] - XYZ[y+1][x+1]).item() < connect_thres and torch.norm(XYZ[y][x+1] - XYZ[y+1][x]).item() < connect_thres:
|
170 |
+
f.write("f %d/%d %d/%d %d/%d\n" % (XYZ_to_idx[(y, x+1)], XYZ_to_idx[(y, x+1)], XYZ_to_idx[(y+1, x+1)], XYZ_to_idx[(y+1, x+1)], XYZ_to_idx[(y+1, x)], XYZ_to_idx[(y+1, x)]))
|
171 |
+
|
172 |
+
def dump_pointclouds_compare(opt, idx, name, preds, gts, folder='dump'):
|
173 |
+
for i in range(len(idx)):
|
174 |
+
pred = preds[i].cpu().numpy() # [N1, 3]
|
175 |
+
gt = gts[i].cpu().numpy() # [N2, 3]
|
176 |
+
color_pred = np.zeros(pred.shape).astype(np.uint8)
|
177 |
+
color_pred[:, 0] = 255
|
178 |
+
color_gt = np.zeros(gt.shape).astype(np.uint8)
|
179 |
+
color_gt[:, 1] = 255
|
180 |
+
pc_vertices = np.vstack([pred, gt])
|
181 |
+
colors = np.vstack([color_pred, color_gt])
|
182 |
+
pc_color = trimesh.points.PointCloud(vertices=pc_vertices, colors=colors)
|
183 |
+
fname = "{}/{}/{}_{}.ply".format(opt.output_path, folder, idx[i], name)
|
184 |
+
pc_color.export(fname)
|
185 |
+
|
186 |
+
def dump_pointclouds(opt, idx, name, pcs, colors, folder='dump', colormap='jet'):
|
187 |
+
for i, pc, color in zip(idx, pcs, colors):
|
188 |
+
pc = pc.cpu().numpy() # [N, 3]
|
189 |
+
color = color.cpu().numpy() # [N, 3] or [N, 1]
|
190 |
+
# convert scalar color to rgb with colormap
|
191 |
+
if color.shape[1] == 1:
|
192 |
+
# single channel color in numpy between [0, 1] to rgb
|
193 |
+
color = plt.get_cmap(colormap)(color[:, 0])
|
194 |
+
color = (color * 255).astype(np.uint8)
|
195 |
+
pc_color = trimesh.points.PointCloud(vertices=pc, colors=color)
|
196 |
+
fname = "{}/{}/{}_{}.ply".format(opt.output_path, folder, i, name)
|
197 |
+
pc_color.export(fname)
|
198 |
+
|
199 |
+
@torch.no_grad()
|
200 |
+
def vis_pointcloud(opt, vis, step, split, pred, GT=None):
|
201 |
+
win_name = "{0}/{1}".format(opt.group, opt.name)
|
202 |
+
pred, GT = pred.cpu().numpy(), GT.cpu().numpy()
|
203 |
+
for i in range(opt.visdom.num_samples):
|
204 |
+
# prediction
|
205 |
+
data = [dict(
|
206 |
+
type="scatter3d",
|
207 |
+
x=[float(n) for n in points[i, :opt.visdom.num_points, 0]],
|
208 |
+
y=[float(n) for n in points[i, :opt.visdom.num_points, 1]],
|
209 |
+
z=[float(n) for n in points[i, :opt.visdom.num_points, 2]],
|
210 |
+
mode="markers",
|
211 |
+
marker=dict(
|
212 |
+
color=color,
|
213 |
+
size=1,
|
214 |
+
),
|
215 |
+
) for points, color in zip([pred, GT], ["blue", "magenta"])]
|
216 |
+
vis._send(dict(
|
217 |
+
data=data,
|
218 |
+
win="{0} #{1}".format(split, i),
|
219 |
+
eid="{0}/{1}".format(opt.group, opt.name),
|
220 |
+
layout=dict(
|
221 |
+
title="{0} #{1} ({2})".format(split, i, step),
|
222 |
+
autosize=True,
|
223 |
+
margin=dict(l=30, r=30, b=30, t=30, ),
|
224 |
+
showlegend=False,
|
225 |
+
yaxis=dict(
|
226 |
+
scaleanchor="x",
|
227 |
+
scaleratio=1,
|
228 |
+
)
|
229 |
+
),
|
230 |
+
opts=dict(title="{0} #{1} ({2})".format(win_name, i, step), ),
|
231 |
+
))
|
232 |
+
|
233 |
+
@torch.no_grad()
|
234 |
+
def draw_pose(opt, image, rot_mtrx, size=15, width=1):
|
235 |
+
# rot_mtrx: [3, 4]
|
236 |
+
mode = "RGBA" if image.shape[0]==4 else "RGB"
|
237 |
+
image_pil = torchvision_F.to_pil_image(image.cpu()).convert("RGBA")
|
238 |
+
draw_pil = PIL.Image.new("RGBA", image_pil.size, (0, 0, 0, 0))
|
239 |
+
draw = PIL.ImageDraw.Draw(draw_pil)
|
240 |
+
center = (size, size)
|
241 |
+
# first column of rotation matrix is the rotated vector of [1, 0, 0]'
|
242 |
+
# second column of rotation matrix is the rotated vector of [0, 1, 0]'
|
243 |
+
# third column of rotation matrix is the rotated vector of [0, 0, 1]'
|
244 |
+
# then always take the first two element of each column is a projection to the 2D plane for visualization
|
245 |
+
endpoint = [(size+size*p[0], size+size*p[1]) for p in rot_mtrx.t()]
|
246 |
+
draw.line([center, endpoint[0]], fill=(255, 0, 0), width=width)
|
247 |
+
draw.line([center, endpoint[1]], fill=(0, 255, 0), width=width)
|
248 |
+
draw.line([center, endpoint[2]], fill=(0, 0, 255), width=width)
|
249 |
+
image_pil.alpha_composite(draw_pil)
|
250 |
+
image_drawn = torchvision_F.to_tensor(image_pil.convert(mode))
|
251 |
+
return image_drawn
|
252 |
+
|
253 |
+
@torch.no_grad()
|
254 |
+
def draw_metric(opt, image, metric):
|
255 |
+
mode = "RGBA" if image.shape[0]==4 else "RGB"
|
256 |
+
image_pil = torchvision_F.to_pil_image(image.cpu()).convert("RGBA")
|
257 |
+
draw_pil = PIL.Image.new("RGBA", image_pil.size, (0, 0, 0, 0))
|
258 |
+
draw = PIL.ImageDraw.Draw(draw_pil)
|
259 |
+
font = ImageFont.truetype("DejaVuSans.ttf", 24)
|
260 |
+
position = (image_pil.size[0] - 80, image_pil.size[1] - 35)
|
261 |
+
draw.text(position, '{:.3f}'.format(metric), fill="red", font=font)
|
262 |
+
image_pil.alpha_composite(draw_pil)
|
263 |
+
image_drawn = torchvision_F.to_tensor(image_pil.convert(mode))
|
264 |
+
return image_drawn
|
265 |
+
|
266 |
+
@torch.no_grad()
|
267 |
+
def show_att_on_image(img, mask):
|
268 |
+
"""
|
269 |
+
Convert the grayscale attention into heatmap on the image.
|
270 |
+
Parameters
|
271 |
+
----------
|
272 |
+
img: np.array, [H, W, 3]
|
273 |
+
Original colored image in [0, 1].
|
274 |
+
mask: np.array, [H, W]
|
275 |
+
Attention map in [0, 1].
|
276 |
+
Returns
|
277 |
+
----------
|
278 |
+
np image with attention applied.
|
279 |
+
"""
|
280 |
+
# check the validity
|
281 |
+
assert np.max(img) <= 1
|
282 |
+
assert np.max(mask) <= 1
|
283 |
+
|
284 |
+
# generate heatmap and normalize into [0, 1]
|
285 |
+
heatmap = cv2.cvtColor(cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB)
|
286 |
+
heatmap = np.float32(heatmap) / 255
|
287 |
+
|
288 |
+
# add heatmap onto the image
|
289 |
+
merged = heatmap + np.float32(img)
|
290 |
+
|
291 |
+
# re-scale the image
|
292 |
+
merged = merged / np.max(merged)
|
293 |
+
return merged
|
294 |
+
|
295 |
+
def look_at(camera_position, camera_target, up_vector):
|
296 |
+
vector = camera_position - camera_target
|
297 |
+
vector = vector / np.linalg.norm(vector)
|
298 |
+
|
299 |
+
vector2 = np.cross(up_vector, vector)
|
300 |
+
vector2 = vector2 / np.linalg.norm(vector2)
|
301 |
+
|
302 |
+
vector3 = np.cross(vector, vector2)
|
303 |
+
return np.array([
|
304 |
+
[vector2[0], vector3[0], vector[0], 0.0],
|
305 |
+
[vector2[1], vector3[1], vector[1], 0.0],
|
306 |
+
[vector2[2], vector3[2], vector[2], 0.0],
|
307 |
+
[-np.dot(vector2, camera_position), -np.dot(vector3, camera_position), np.dot(vector, camera_position), 1.0]
|
308 |
+
])
|
309 |
+
|
310 |
+
def scale_to_unit_cube(mesh):
|
311 |
+
if isinstance(mesh, trimesh.Scene):
|
312 |
+
mesh = mesh.dump().sum()
|
313 |
+
|
314 |
+
vertices = mesh.vertices - mesh.bounding_box.centroid
|
315 |
+
vertices *= 2 / np.max(mesh.bounding_box.extents)
|
316 |
+
vertices *= 0.5
|
317 |
+
|
318 |
+
return trimesh.Trimesh(vertices=vertices, faces=mesh.faces)
|
319 |
+
|
320 |
+
def get_positions_and_rotations(n_frames=180, r=1.5):
|
321 |
+
'''
|
322 |
+
n_frames: how many frames
|
323 |
+
r: how far should the camera be
|
324 |
+
'''
|
325 |
+
# test case 1
|
326 |
+
n_frame_full_circ = n_frames // 3 # frames for a full circle
|
327 |
+
n_frame_half_circ = n_frames // 6 # frames for a half circle
|
328 |
+
|
329 |
+
# full circle in horizontal axes going from 1 to -1 height axis
|
330 |
+
pos1 = [np.array([r*np.cos(theta), elev, r*np.sin(theta)])
|
331 |
+
for theta, elev in zip(np.linspace(0.5*np.pi,2.5*np.pi, n_frame_full_circ), np.linspace(1,-1,n_frame_full_circ))]
|
332 |
+
# half circle in horizontal axes at fixed -1 height
|
333 |
+
pos2 = [np.array([r*np.cos(theta), -1, r*np.sin(theta)])
|
334 |
+
for theta in np.linspace(2.5*np.pi,3.5*np.pi, n_frame_half_circ)]
|
335 |
+
# full circle in horizontal axes going from -1 to 1 height axis
|
336 |
+
pos3 = [np.array([r*np.cos(theta), elev, r*np.sin(theta)])
|
337 |
+
for theta, elev in zip(np.linspace(3.5*np.pi,5.5*np.pi, n_frame_full_circ), np.linspace(-1,1,n_frame_full_circ))]
|
338 |
+
# half circle in horizontal axes at fixed 1 height
|
339 |
+
pos4 = [np.array([r*np.cos(theta), 1, r*np.sin(theta)])
|
340 |
+
for theta in np.linspace(3.5*np.pi,4.5*np.pi, n_frame_half_circ)]
|
341 |
+
|
342 |
+
pos = pos1 + pos2 + pos3 + pos4
|
343 |
+
target = np.array([0.0, 0.0, 0.0])
|
344 |
+
up = np.array([0.0, 1.0, 0.0])
|
345 |
+
rot = [look_at(x, target, up) for x in pos]
|
346 |
+
return pos, rot
|
347 |
+
|
348 |
+
def visualize_mesh(mesh, output_path, resolution=(200,200), write_gif=True, write_frames=True, time_per_frame=80, n_frames=180):
|
349 |
+
'''
|
350 |
+
mesh: Trimesh mesh object
|
351 |
+
output_path: absolute path, ".gif" will get added if write_gif, and this will be used as dirname if write_frames is true
|
352 |
+
time_per_frame: how many milliseconds to wait for each frame
|
353 |
+
n_frames: how many frames in total
|
354 |
+
'''
|
355 |
+
|
356 |
+
# set material
|
357 |
+
mat = pyrender.MetallicRoughnessMaterial(
|
358 |
+
metallicFactor=0.8,
|
359 |
+
roughnessFactor=1.0,
|
360 |
+
alphaMode='OPAQUE',
|
361 |
+
baseColorFactor=(0.5, 0.5, 0.8, 1.0),
|
362 |
+
)
|
363 |
+
# define and add scene elements
|
364 |
+
mesh = pyrender.Mesh.from_trimesh(mesh, material=mat)
|
365 |
+
camera = pyrender.PerspectiveCamera(yfov=np.pi / 3.0, aspectRatio=1.0)
|
366 |
+
light = pyrender.SpotLight(color=np.ones(3), intensity=15.0,
|
367 |
+
innerConeAngle=np.pi/4.0,
|
368 |
+
outerConeAngle=np.pi/4.0)
|
369 |
+
|
370 |
+
scene = pyrender.Scene()
|
371 |
+
obj = scene.add(mesh)
|
372 |
+
cam = scene.add(camera)
|
373 |
+
light = scene.add(light)
|
374 |
+
|
375 |
+
positions, rotations = get_positions_and_rotations(n_frames=n_frames)
|
376 |
+
|
377 |
+
r = pyrender.OffscreenRenderer(*resolution)
|
378 |
+
|
379 |
+
# move the camera and generate images
|
380 |
+
count = 0
|
381 |
+
image_list = []
|
382 |
+
for pos, rot in zip(positions, rotations):
|
383 |
+
|
384 |
+
pose = np.eye(4)
|
385 |
+
pose[:3, 3] = pos
|
386 |
+
pose[:3,:3] = rot[:3,:3]
|
387 |
+
|
388 |
+
scene.set_pose(cam, pose)
|
389 |
+
scene.set_pose(light, pose)
|
390 |
+
|
391 |
+
color, depth = r.render(scene)
|
392 |
+
|
393 |
+
img = Image.fromarray(color, mode="RGB")
|
394 |
+
image_list.append(img)
|
395 |
+
|
396 |
+
# save to file
|
397 |
+
if write_gif:
|
398 |
+
image_list[0].save(f"{output_path}.gif", format='GIF', append_images=image_list[1:], save_all=True, duration=80, loop=0)
|
399 |
+
|
400 |
+
if write_frames:
|
401 |
+
if not os.path.exists(output_path):
|
402 |
+
os.makedirs(output_path)
|
403 |
+
|
404 |
+
for i, img in enumerate(image_list):
|
405 |
+
img.save(os.path.join(output_path, f"{i:04d}.jpg"))
|
406 |
+
|
407 |
+
def get_base64_encoded_image(image_path):
|
408 |
+
"""
|
409 |
+
Returns the base64-encoded image at the given path.
|
410 |
+
|
411 |
+
Args:
|
412 |
+
image_path (str): The path to the image file.
|
413 |
+
|
414 |
+
Returns:
|
415 |
+
str: The base64-encoded image.
|
416 |
+
"""
|
417 |
+
with open(image_path, "rb") as f:
|
418 |
+
img = Image.open(f)
|
419 |
+
if img.mode == 'RGBA':
|
420 |
+
img = img.convert('RGB')
|
421 |
+
# Resize the image to reduce its file size
|
422 |
+
img.thumbnail((200, 200))
|
423 |
+
buffer = io.BytesIO()
|
424 |
+
# Convert the image to JPEG format to reduce its file size
|
425 |
+
img.save(buffer, format="JPEG", quality=80)
|
426 |
+
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
427 |
+
|
428 |
+
def get_base64_encoded_gif(gif_path):
|
429 |
+
"""
|
430 |
+
Returns the base64-encoded GIF at the given path.
|
431 |
+
|
432 |
+
Args:
|
433 |
+
gif_path (str): The path to the GIF file.
|
434 |
+
|
435 |
+
Returns:
|
436 |
+
str: The base64-encoded GIF.
|
437 |
+
"""
|
438 |
+
with open(gif_path, "rb") as f:
|
439 |
+
frames = imageio.mimread(f)
|
440 |
+
# Reduce the number of frames to reduce the file size
|
441 |
+
frames = frames[::4]
|
442 |
+
buffer = io.BytesIO()
|
443 |
+
# compress each image frame to reduce the file size
|
444 |
+
frames = [frame[::2, ::2] for frame in frames]
|
445 |
+
# Convert the GIF to a subrectangle format to reduce the file size
|
446 |
+
imageio.mimsave(buffer, frames, format="GIF", fps=10, subrectangles=True)
|
447 |
+
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
448 |
+
|
449 |
+
def create_gif_html(folder_path, html_file, skip_every=10):
|
450 |
+
"""
|
451 |
+
Creates an HTML file with a grid of sample visualizations.
|
452 |
+
|
453 |
+
Args:
|
454 |
+
folder_path (str): The path to the folder containing the sample visualizations.
|
455 |
+
html_file (str): The name of the HTML file to create.
|
456 |
+
"""
|
457 |
+
# convert path to absolute path
|
458 |
+
folder_path = os.path.abspath(folder_path)
|
459 |
+
|
460 |
+
# Get a list of all the sample IDs
|
461 |
+
ids = []
|
462 |
+
count = 0
|
463 |
+
all_files = sorted(os.listdir(folder_path), key=lambda x: int(x.split("_")[0]))
|
464 |
+
for filename in all_files:
|
465 |
+
if filename.endswith("_image_input.png"):
|
466 |
+
if count % skip_every == 0:
|
467 |
+
ids.append(filename.split("_")[0])
|
468 |
+
count += 1
|
469 |
+
|
470 |
+
# Write the HTML file
|
471 |
+
with open(html_file, "w") as f:
|
472 |
+
# Write the HTML header and CSS style
|
473 |
+
f.write("<html>\n")
|
474 |
+
f.write("<head>\n")
|
475 |
+
f.write("<style>\n")
|
476 |
+
f.write(".sample-container {\n")
|
477 |
+
f.write(" display: inline-block;\n")
|
478 |
+
f.write(" margin: 10px;\n")
|
479 |
+
f.write(" width: 350px;\n")
|
480 |
+
f.write(" height: 150px;\n")
|
481 |
+
f.write(" text-align: center;\n")
|
482 |
+
f.write("}\n")
|
483 |
+
f.write(".sample-container:nth-child(6n+1) {\n")
|
484 |
+
f.write(" clear: left;\n")
|
485 |
+
f.write("}\n")
|
486 |
+
f.write(".image-container, .gif-container {\n")
|
487 |
+
f.write(" display: inline-block;\n")
|
488 |
+
f.write(" margin: 10px;\n")
|
489 |
+
f.write(" width: 90px;\n")
|
490 |
+
f.write(" height: 90px;\n")
|
491 |
+
f.write(" object-fit: cover;\n")
|
492 |
+
f.write("}\n")
|
493 |
+
f.write("</style>\n")
|
494 |
+
f.write("</head>\n")
|
495 |
+
f.write("<body>\n")
|
496 |
+
|
497 |
+
# Write the sample visualizations to the HTML file
|
498 |
+
for sample_id in ids:
|
499 |
+
try:
|
500 |
+
f.write("<div class=\"sample-container\">\n")
|
501 |
+
f.write(f"<div class=\"sample-id\"><p>{sample_id}</p></div>\n")
|
502 |
+
f.write(f"<div class=\"image-container\"><img src=\"data:image/png;base64,{get_base64_encoded_image(os.path.join(folder_path, sample_id + '_image_input.png'))}\" width=\"90\" height=\"90\"></div>\n")
|
503 |
+
f.write(f"<div class=\"image-container\"><img src=\"data:image/png;base64,{get_base64_encoded_image(os.path.join(folder_path, sample_id + '_depth_est.png'))}\" width=\"90\" height=\"90\"></div>\n")
|
504 |
+
f.write(f"<div class=\"gif-container\"><img src=\"data:image/gif;base64,{get_base64_encoded_gif(os.path.join(folder_path, sample_id + '_mesh_viz.gif'))}\" width=\"90\" height=\"90\"></div>\n")
|
505 |
+
f.write("</div>\n")
|
506 |
+
except:
|
507 |
+
pass
|
508 |
+
|
509 |
+
# Write the HTML footer
|
510 |
+
f.write("</body>\n")
|
511 |
+
f.write("</html>\n")
|
weights/.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ignore everything in this directory
|
2 |
+
*
|
3 |
+
# Except this file
|
4 |
+
!.gitignore
|