Spaces:
Build error
Build error
Upload 6 files
Browse files- lib/common/render.py +392 -0
- lib/common/render_utils.py +221 -0
- lib/common/seg3d_lossless.py +604 -0
- lib/common/seg3d_utils.py +392 -0
- lib/common/smpl_vert_segmentation.json +0 -0
- lib/common/train_util.py +599 -0
lib/common/render.py
ADDED
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
4 |
+
# holder of all proprietary rights on this computer program.
|
5 |
+
# You can only use this computer program if you have closed
|
6 |
+
# a license agreement with MPG or you get the right to use the computer
|
7 |
+
# program from someone who is authorized to grant you that right.
|
8 |
+
# Any use of the computer program without a valid license is prohibited and
|
9 |
+
# liable to prosecution.
|
10 |
+
#
|
11 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
12 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
13 |
+
# for Intelligent Systems. All rights reserved.
|
14 |
+
#
|
15 |
+
# Contact: ps-license@tuebingen.mpg.de
|
16 |
+
|
17 |
+
from pytorch3d.renderer import (
|
18 |
+
BlendParams,
|
19 |
+
blending,
|
20 |
+
look_at_view_transform,
|
21 |
+
FoVOrthographicCameras,
|
22 |
+
PointLights,
|
23 |
+
RasterizationSettings,
|
24 |
+
PointsRasterizationSettings,
|
25 |
+
PointsRenderer,
|
26 |
+
AlphaCompositor,
|
27 |
+
PointsRasterizer,
|
28 |
+
MeshRenderer,
|
29 |
+
MeshRasterizer,
|
30 |
+
SoftPhongShader,
|
31 |
+
SoftSilhouetteShader,
|
32 |
+
TexturesVertex,
|
33 |
+
)
|
34 |
+
from pytorch3d.renderer.mesh import TexturesVertex
|
35 |
+
from pytorch3d.structures import Meshes
|
36 |
+
|
37 |
+
import os, subprocess
|
38 |
+
|
39 |
+
from lib.dataset.mesh_util import SMPLX, get_visibility
|
40 |
+
import lib.common.render_utils as util
|
41 |
+
import torch
|
42 |
+
import numpy as np
|
43 |
+
from PIL import Image
|
44 |
+
from tqdm import tqdm
|
45 |
+
import cv2
|
46 |
+
import math
|
47 |
+
from termcolor import colored
|
48 |
+
|
49 |
+
|
50 |
+
def image2vid(images, vid_path):
|
51 |
+
|
52 |
+
w, h = images[0].size
|
53 |
+
videodims = (w, h)
|
54 |
+
fourcc = cv2.VideoWriter_fourcc(*'XVID')
|
55 |
+
video = cv2.VideoWriter(vid_path, fourcc, 30, videodims)
|
56 |
+
for image in images:
|
57 |
+
video.write(cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR))
|
58 |
+
video.release()
|
59 |
+
|
60 |
+
|
61 |
+
def query_color(verts, faces, image, device):
|
62 |
+
"""query colors from points and image
|
63 |
+
|
64 |
+
Args:
|
65 |
+
verts ([B, 3]): [query verts]
|
66 |
+
faces ([M, 3]): [query faces]
|
67 |
+
image ([B, 3, H, W]): [full image]
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
[np.float]: [return colors]
|
71 |
+
"""
|
72 |
+
|
73 |
+
verts = verts.float().to(device)
|
74 |
+
faces = faces.long().to(device)
|
75 |
+
|
76 |
+
(xy, z) = verts.split([2, 1], dim=1)
|
77 |
+
visibility = get_visibility(xy, z, faces[:, [0, 2, 1]]).flatten()
|
78 |
+
uv = xy.unsqueeze(0).unsqueeze(2) # [B, N, 2]
|
79 |
+
uv = uv * torch.tensor([1.0, -1.0]).type_as(uv)
|
80 |
+
colors = (torch.nn.functional.grid_sample(image, uv, align_corners=True)[
|
81 |
+
0, :, :, 0].permute(1, 0) + 1.0) * 0.5 * 255.0
|
82 |
+
colors[visibility == 0.0] = ((Meshes(verts.unsqueeze(0), faces.unsqueeze(
|
83 |
+
0)).verts_normals_padded().squeeze(0) + 1.0) * 0.5 * 255.0)[visibility == 0.0]
|
84 |
+
|
85 |
+
return colors.detach().cpu()
|
86 |
+
|
87 |
+
|
88 |
+
class cleanShader(torch.nn.Module):
|
89 |
+
def __init__(self, device="cpu", cameras=None, blend_params=None):
|
90 |
+
super().__init__()
|
91 |
+
self.cameras = cameras
|
92 |
+
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
93 |
+
|
94 |
+
def forward(self, fragments, meshes, **kwargs):
|
95 |
+
cameras = kwargs.get("cameras", self.cameras)
|
96 |
+
if cameras is None:
|
97 |
+
msg = "Cameras must be specified either at initialization \
|
98 |
+
or in the forward pass of TexturedSoftPhongShader"
|
99 |
+
|
100 |
+
raise ValueError(msg)
|
101 |
+
|
102 |
+
# get renderer output
|
103 |
+
blend_params = kwargs.get("blend_params", self.blend_params)
|
104 |
+
texels = meshes.sample_textures(fragments)
|
105 |
+
images = blending.softmax_rgb_blend(
|
106 |
+
texels, fragments, blend_params, znear=-256, zfar=256
|
107 |
+
)
|
108 |
+
|
109 |
+
return images
|
110 |
+
|
111 |
+
|
112 |
+
class Render:
|
113 |
+
def __init__(self, size=512, device=torch.device("cuda:0")):
|
114 |
+
self.device = device
|
115 |
+
self.mesh_y_center = 100.0
|
116 |
+
self.dis = 100.0
|
117 |
+
self.scale = 1.0
|
118 |
+
self.size = size
|
119 |
+
self.cam_pos = [(0, 100, 100)]
|
120 |
+
|
121 |
+
self.mesh = None
|
122 |
+
self.deform_mesh = None
|
123 |
+
self.pcd = None
|
124 |
+
self.renderer = None
|
125 |
+
self.meshRas = None
|
126 |
+
self.type = None
|
127 |
+
self.knn = None
|
128 |
+
self.knn_inverse = None
|
129 |
+
|
130 |
+
self.smpl_seg = None
|
131 |
+
self.smpl_cmap = None
|
132 |
+
|
133 |
+
self.smplx = SMPLX()
|
134 |
+
|
135 |
+
self.uv_rasterizer = util.Pytorch3dRasterizer(self.size)
|
136 |
+
|
137 |
+
def get_camera(self, cam_id):
|
138 |
+
|
139 |
+
R, T = look_at_view_transform(
|
140 |
+
eye=[self.cam_pos[cam_id]],
|
141 |
+
at=((0, self.mesh_y_center, 0),),
|
142 |
+
up=((0, 1, 0),),
|
143 |
+
)
|
144 |
+
|
145 |
+
camera = FoVOrthographicCameras(
|
146 |
+
device=self.device,
|
147 |
+
R=R,
|
148 |
+
T=T,
|
149 |
+
znear=100.0,
|
150 |
+
zfar=-100.0,
|
151 |
+
max_y=100.0,
|
152 |
+
min_y=-100.0,
|
153 |
+
max_x=100.0,
|
154 |
+
min_x=-100.0,
|
155 |
+
scale_xyz=(self.scale * np.ones(3),),
|
156 |
+
)
|
157 |
+
|
158 |
+
return camera
|
159 |
+
|
160 |
+
def init_renderer(self, camera, type="clean_mesh", bg="gray"):
|
161 |
+
|
162 |
+
if "mesh" in type:
|
163 |
+
|
164 |
+
# rasterizer
|
165 |
+
self.raster_settings_mesh = RasterizationSettings(
|
166 |
+
image_size=self.size,
|
167 |
+
blur_radius=np.log(1.0 / 1e-4) * 1e-7,
|
168 |
+
faces_per_pixel=30,
|
169 |
+
)
|
170 |
+
self.meshRas = MeshRasterizer(
|
171 |
+
cameras=camera, raster_settings=self.raster_settings_mesh
|
172 |
+
)
|
173 |
+
|
174 |
+
if bg == "black":
|
175 |
+
blendparam = BlendParams(1e-4, 1e-4, (0.0, 0.0, 0.0))
|
176 |
+
elif bg == "white":
|
177 |
+
blendparam = BlendParams(1e-4, 1e-8, (1.0, 1.0, 1.0))
|
178 |
+
elif bg == "gray":
|
179 |
+
blendparam = BlendParams(1e-4, 1e-8, (0.5, 0.5, 0.5))
|
180 |
+
|
181 |
+
if type == "ori_mesh":
|
182 |
+
|
183 |
+
lights = PointLights(
|
184 |
+
device=self.device,
|
185 |
+
ambient_color=((0.8, 0.8, 0.8),),
|
186 |
+
diffuse_color=((0.2, 0.2, 0.2),),
|
187 |
+
specular_color=((0.0, 0.0, 0.0),),
|
188 |
+
location=[[0.0, 200.0, 0.0]],
|
189 |
+
)
|
190 |
+
|
191 |
+
self.renderer = MeshRenderer(
|
192 |
+
rasterizer=self.meshRas,
|
193 |
+
shader=SoftPhongShader(
|
194 |
+
device=self.device,
|
195 |
+
cameras=camera,
|
196 |
+
lights=lights,
|
197 |
+
blend_params=blendparam,
|
198 |
+
),
|
199 |
+
)
|
200 |
+
|
201 |
+
if type == "silhouette":
|
202 |
+
self.raster_settings_silhouette = RasterizationSettings(
|
203 |
+
image_size=self.size,
|
204 |
+
blur_radius=np.log(1.0 / 1e-4 - 1.0) * 5e-5,
|
205 |
+
faces_per_pixel=50,
|
206 |
+
cull_backfaces=True,
|
207 |
+
)
|
208 |
+
|
209 |
+
self.silhouetteRas = MeshRasterizer(
|
210 |
+
cameras=camera, raster_settings=self.raster_settings_silhouette
|
211 |
+
)
|
212 |
+
self.renderer = MeshRenderer(
|
213 |
+
rasterizer=self.silhouetteRas, shader=SoftSilhouetteShader()
|
214 |
+
)
|
215 |
+
|
216 |
+
if type == "pointcloud":
|
217 |
+
self.raster_settings_pcd = PointsRasterizationSettings(
|
218 |
+
image_size=self.size, radius=0.006, points_per_pixel=10
|
219 |
+
)
|
220 |
+
|
221 |
+
self.pcdRas = PointsRasterizer(
|
222 |
+
cameras=camera, raster_settings=self.raster_settings_pcd
|
223 |
+
)
|
224 |
+
self.renderer = PointsRenderer(
|
225 |
+
rasterizer=self.pcdRas,
|
226 |
+
compositor=AlphaCompositor(background_color=(0, 0, 0)),
|
227 |
+
)
|
228 |
+
|
229 |
+
if type == "clean_mesh":
|
230 |
+
|
231 |
+
self.renderer = MeshRenderer(
|
232 |
+
rasterizer=self.meshRas,
|
233 |
+
shader=cleanShader(
|
234 |
+
device=self.device, cameras=camera, blend_params=blendparam
|
235 |
+
),
|
236 |
+
)
|
237 |
+
|
238 |
+
def VF2Mesh(self, verts, faces):
|
239 |
+
|
240 |
+
if not torch.is_tensor(verts):
|
241 |
+
verts = torch.tensor(verts)
|
242 |
+
if not torch.is_tensor(faces):
|
243 |
+
faces = torch.tensor(faces)
|
244 |
+
|
245 |
+
if verts.ndimension() == 2:
|
246 |
+
verts = verts.unsqueeze(0).float()
|
247 |
+
if faces.ndimension() == 2:
|
248 |
+
faces = faces.unsqueeze(0).long()
|
249 |
+
|
250 |
+
verts = verts.to(self.device)
|
251 |
+
faces = faces.to(self.device)
|
252 |
+
|
253 |
+
mesh = Meshes(verts, faces).to(self.device)
|
254 |
+
|
255 |
+
mesh.textures = TexturesVertex(
|
256 |
+
verts_features=(mesh.verts_normals_padded() + 1.0) * 0.5
|
257 |
+
)
|
258 |
+
|
259 |
+
return mesh
|
260 |
+
|
261 |
+
def load_meshes(self, verts, faces):
|
262 |
+
"""load mesh into the pytorch3d renderer
|
263 |
+
|
264 |
+
Args:
|
265 |
+
verts ([N,3]): verts
|
266 |
+
faces ([N,3]): faces
|
267 |
+
offset ([N,3]): offset
|
268 |
+
"""
|
269 |
+
|
270 |
+
# camera setting
|
271 |
+
self.scale = 100.0
|
272 |
+
self.mesh_y_center = 0.0
|
273 |
+
|
274 |
+
self.cam_pos = [
|
275 |
+
(0, self.mesh_y_center, 100.0),
|
276 |
+
(100.0, self.mesh_y_center, 0),
|
277 |
+
(0, self.mesh_y_center, -100.0),
|
278 |
+
(-100.0, self.mesh_y_center, 0),
|
279 |
+
]
|
280 |
+
|
281 |
+
self.type = "color"
|
282 |
+
|
283 |
+
if isinstance(verts, list):
|
284 |
+
self.meshes = []
|
285 |
+
for V, F in zip(verts, faces):
|
286 |
+
self.meshes.append(self.VF2Mesh(V, F))
|
287 |
+
else:
|
288 |
+
self.meshes = [self.VF2Mesh(verts, faces)]
|
289 |
+
|
290 |
+
def get_depth_map(self, cam_ids=[0, 2]):
|
291 |
+
|
292 |
+
depth_maps = []
|
293 |
+
for cam_id in cam_ids:
|
294 |
+
self.init_renderer(self.get_camera(cam_id), "clean_mesh", "gray")
|
295 |
+
fragments = self.meshRas(self.meshes[0])
|
296 |
+
depth_map = fragments.zbuf[..., 0].squeeze(0)
|
297 |
+
if cam_id == 2:
|
298 |
+
depth_map = torch.fliplr(depth_map)
|
299 |
+
depth_maps.append(depth_map)
|
300 |
+
|
301 |
+
return depth_maps
|
302 |
+
|
303 |
+
def get_rgb_image(self, cam_ids=[0, 2]):
|
304 |
+
|
305 |
+
images = []
|
306 |
+
for cam_id in range(len(self.cam_pos)):
|
307 |
+
if cam_id in cam_ids:
|
308 |
+
self.init_renderer(self.get_camera(
|
309 |
+
cam_id), "clean_mesh", "gray")
|
310 |
+
if len(cam_ids) == 4:
|
311 |
+
rendered_img = (
|
312 |
+
self.renderer(self.meshes[0])[
|
313 |
+
0:1, :, :, :3].permute(0, 3, 1, 2)
|
314 |
+
- 0.5
|
315 |
+
) * 2.0
|
316 |
+
else:
|
317 |
+
rendered_img = (
|
318 |
+
self.renderer(self.meshes[0])[
|
319 |
+
0:1, :, :, :3].permute(0, 3, 1, 2)
|
320 |
+
- 0.5
|
321 |
+
) * 2.0
|
322 |
+
if cam_id == 2 and len(cam_ids) == 2:
|
323 |
+
rendered_img = torch.flip(rendered_img, dims=[3])
|
324 |
+
images.append(rendered_img)
|
325 |
+
|
326 |
+
return images
|
327 |
+
|
328 |
+
def get_rendered_video(self, images, save_path):
|
329 |
+
|
330 |
+
tmp_path = save_path.replace('cloth', 'tmp')
|
331 |
+
|
332 |
+
self.cam_pos = []
|
333 |
+
for angle in range(0, 360, 3):
|
334 |
+
self.cam_pos.append(
|
335 |
+
(
|
336 |
+
100.0 * math.cos(np.pi / 180 * angle),
|
337 |
+
self.mesh_y_center,
|
338 |
+
100.0 * math.sin(np.pi / 180 * angle),
|
339 |
+
)
|
340 |
+
)
|
341 |
+
|
342 |
+
old_shape = np.array(images[0].shape[:2])
|
343 |
+
new_shape = np.around(
|
344 |
+
(self.size / old_shape[0]) * old_shape).astype(np.int)
|
345 |
+
|
346 |
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
347 |
+
video = cv2.VideoWriter(
|
348 |
+
tmp_path, fourcc, 30, (self.size * len(self.meshes) +
|
349 |
+
new_shape[1] * len(images), self.size)
|
350 |
+
)
|
351 |
+
|
352 |
+
pbar = tqdm(range(len(self.cam_pos)))
|
353 |
+
pbar.set_description(colored(f"exporting video {os.path.basename(save_path)}...", "blue"))
|
354 |
+
for cam_id in pbar:
|
355 |
+
self.init_renderer(self.get_camera(cam_id), "clean_mesh", "gray")
|
356 |
+
|
357 |
+
img_lst = [
|
358 |
+
np.array(Image.fromarray(img).resize(new_shape[::-1])).astype(np.uint8)[
|
359 |
+
:, :, [2, 1, 0]
|
360 |
+
]
|
361 |
+
for img in images
|
362 |
+
]
|
363 |
+
|
364 |
+
for mesh in self.meshes:
|
365 |
+
rendered_img = (
|
366 |
+
(self.renderer(mesh)[0, :, :, :3] * 255.0)
|
367 |
+
.detach()
|
368 |
+
.cpu()
|
369 |
+
.numpy()
|
370 |
+
.astype(np.uint8)
|
371 |
+
)
|
372 |
+
|
373 |
+
img_lst.append(rendered_img)
|
374 |
+
final_img = np.concatenate(img_lst, axis=1)
|
375 |
+
video.write(final_img)
|
376 |
+
|
377 |
+
video.release()
|
378 |
+
|
379 |
+
os.system(f'ffmpeg -y -loglevel quiet -stats -i {tmp_path} -c:v libx264 {save_path}')
|
380 |
+
|
381 |
+
def get_silhouette_image(self, cam_ids=[0, 2]):
|
382 |
+
|
383 |
+
images = []
|
384 |
+
for cam_id in range(len(self.cam_pos)):
|
385 |
+
if cam_id in cam_ids:
|
386 |
+
self.init_renderer(self.get_camera(cam_id), "silhouette")
|
387 |
+
rendered_img = self.renderer(self.meshes[0])[0:1, :, :, 3]
|
388 |
+
if cam_id == 2 and len(cam_ids) == 2:
|
389 |
+
rendered_img = torch.flip(rendered_img, dims=[2])
|
390 |
+
images.append(rendered_img)
|
391 |
+
|
392 |
+
return images
|
lib/common/render_utils.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
5 |
+
# holder of all proprietary rights on this computer program.
|
6 |
+
# You can only use this computer program if you have closed
|
7 |
+
# a license agreement with MPG or you get the right to use the computer
|
8 |
+
# program from someone who is authorized to grant you that right.
|
9 |
+
# Any use of the computer program without a valid license is prohibited and
|
10 |
+
# liable to prosecution.
|
11 |
+
#
|
12 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
13 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
14 |
+
# for Intelligent Systems. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: ps-license@tuebingen.mpg.de
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from torch import nn
|
20 |
+
import trimesh
|
21 |
+
import math
|
22 |
+
from typing import NewType
|
23 |
+
from pytorch3d.structures import Meshes
|
24 |
+
from pytorch3d.renderer.mesh import rasterize_meshes
|
25 |
+
|
26 |
+
Tensor = NewType('Tensor', torch.Tensor)
|
27 |
+
|
28 |
+
|
29 |
+
def solid_angles(points: Tensor,
|
30 |
+
triangles: Tensor,
|
31 |
+
thresh: float = 1e-8) -> Tensor:
|
32 |
+
''' Compute solid angle between the input points and triangles
|
33 |
+
Follows the method described in:
|
34 |
+
The Solid Angle of a Plane Triangle
|
35 |
+
A. VAN OOSTEROM AND J. STRACKEE
|
36 |
+
IEEE TRANSACTIONS ON BIOMEDICAL ENGINEERING,
|
37 |
+
VOL. BME-30, NO. 2, FEBRUARY 1983
|
38 |
+
Parameters
|
39 |
+
-----------
|
40 |
+
points: BxQx3
|
41 |
+
Tensor of input query points
|
42 |
+
triangles: BxFx3x3
|
43 |
+
Target triangles
|
44 |
+
thresh: float
|
45 |
+
float threshold
|
46 |
+
Returns
|
47 |
+
-------
|
48 |
+
solid_angles: BxQxF
|
49 |
+
A tensor containing the solid angle between all query points
|
50 |
+
and input triangles
|
51 |
+
'''
|
52 |
+
# Center the triangles on the query points. Size should be BxQxFx3x3
|
53 |
+
centered_tris = triangles[:, None] - points[:, :, None, None]
|
54 |
+
|
55 |
+
# BxQxFx3
|
56 |
+
norms = torch.norm(centered_tris, dim=-1)
|
57 |
+
|
58 |
+
# Should be BxQxFx3
|
59 |
+
cross_prod = torch.cross(centered_tris[:, :, :, 1],
|
60 |
+
centered_tris[:, :, :, 2],
|
61 |
+
dim=-1)
|
62 |
+
# Should be BxQxF
|
63 |
+
numerator = (centered_tris[:, :, :, 0] * cross_prod).sum(dim=-1)
|
64 |
+
del cross_prod
|
65 |
+
|
66 |
+
dot01 = (centered_tris[:, :, :, 0] * centered_tris[:, :, :, 1]).sum(dim=-1)
|
67 |
+
dot12 = (centered_tris[:, :, :, 1] * centered_tris[:, :, :, 2]).sum(dim=-1)
|
68 |
+
dot02 = (centered_tris[:, :, :, 0] * centered_tris[:, :, :, 2]).sum(dim=-1)
|
69 |
+
del centered_tris
|
70 |
+
|
71 |
+
denominator = (norms.prod(dim=-1) + dot01 * norms[:, :, :, 2] +
|
72 |
+
dot02 * norms[:, :, :, 1] + dot12 * norms[:, :, :, 0])
|
73 |
+
del dot01, dot12, dot02, norms
|
74 |
+
|
75 |
+
# Should be BxQ
|
76 |
+
solid_angle = torch.atan2(numerator, denominator)
|
77 |
+
del numerator, denominator
|
78 |
+
|
79 |
+
torch.cuda.empty_cache()
|
80 |
+
|
81 |
+
return 2 * solid_angle
|
82 |
+
|
83 |
+
|
84 |
+
def winding_numbers(points: Tensor,
|
85 |
+
triangles: Tensor,
|
86 |
+
thresh: float = 1e-8) -> Tensor:
|
87 |
+
''' Uses winding_numbers to compute inside/outside
|
88 |
+
Robust inside-outside segmentation using generalized winding numbers
|
89 |
+
Alec Jacobson,
|
90 |
+
Ladislav Kavan,
|
91 |
+
Olga Sorkine-Hornung
|
92 |
+
Fast Winding Numbers for Soups and Clouds SIGGRAPH 2018
|
93 |
+
Gavin Barill
|
94 |
+
NEIL G. Dickson
|
95 |
+
Ryan Schmidt
|
96 |
+
David I.W. Levin
|
97 |
+
and Alec Jacobson
|
98 |
+
Parameters
|
99 |
+
-----------
|
100 |
+
points: BxQx3
|
101 |
+
Tensor of input query points
|
102 |
+
triangles: BxFx3x3
|
103 |
+
Target triangles
|
104 |
+
thresh: float
|
105 |
+
float threshold
|
106 |
+
Returns
|
107 |
+
-------
|
108 |
+
winding_numbers: BxQ
|
109 |
+
A tensor containing the Generalized winding numbers
|
110 |
+
'''
|
111 |
+
# The generalized winding number is the sum of solid angles of the point
|
112 |
+
# with respect to all triangles.
|
113 |
+
return 1 / (4 * math.pi) * solid_angles(points, triangles,
|
114 |
+
thresh=thresh).sum(dim=-1)
|
115 |
+
|
116 |
+
|
117 |
+
def batch_contains(verts, faces, points):
|
118 |
+
|
119 |
+
B = verts.shape[0]
|
120 |
+
N = points.shape[1]
|
121 |
+
|
122 |
+
verts = verts.detach().cpu()
|
123 |
+
faces = faces.detach().cpu()
|
124 |
+
points = points.detach().cpu()
|
125 |
+
contains = torch.zeros(B, N)
|
126 |
+
|
127 |
+
for i in range(B):
|
128 |
+
contains[i] = torch.as_tensor(
|
129 |
+
trimesh.Trimesh(verts[i], faces[i]).contains(points[i]))
|
130 |
+
|
131 |
+
return 2.0 * (contains - 0.5)
|
132 |
+
|
133 |
+
|
134 |
+
def dict2obj(d):
|
135 |
+
# if isinstance(d, list):
|
136 |
+
# d = [dict2obj(x) for x in d]
|
137 |
+
if not isinstance(d, dict):
|
138 |
+
return d
|
139 |
+
|
140 |
+
class C(object):
|
141 |
+
pass
|
142 |
+
|
143 |
+
o = C()
|
144 |
+
for k in d:
|
145 |
+
o.__dict__[k] = dict2obj(d[k])
|
146 |
+
return o
|
147 |
+
|
148 |
+
|
149 |
+
def face_vertices(vertices, faces):
|
150 |
+
"""
|
151 |
+
:param vertices: [batch size, number of vertices, 3]
|
152 |
+
:param faces: [batch size, number of faces, 3]
|
153 |
+
:return: [batch size, number of faces, 3, 3]
|
154 |
+
"""
|
155 |
+
|
156 |
+
bs, nv = vertices.shape[:2]
|
157 |
+
bs, nf = faces.shape[:2]
|
158 |
+
device = vertices.device
|
159 |
+
faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) *
|
160 |
+
nv)[:, None, None]
|
161 |
+
vertices = vertices.reshape((bs * nv, vertices.shape[-1]))
|
162 |
+
|
163 |
+
return vertices[faces.long()]
|
164 |
+
|
165 |
+
|
166 |
+
class Pytorch3dRasterizer(nn.Module):
|
167 |
+
""" Borrowed from https://github.com/facebookresearch/pytorch3d
|
168 |
+
Notice:
|
169 |
+
x,y,z are in image space, normalized
|
170 |
+
can only render squared image now
|
171 |
+
"""
|
172 |
+
|
173 |
+
def __init__(self, image_size=224):
|
174 |
+
"""
|
175 |
+
use fixed raster_settings for rendering faces
|
176 |
+
"""
|
177 |
+
super().__init__()
|
178 |
+
raster_settings = {
|
179 |
+
'image_size': image_size,
|
180 |
+
'blur_radius': 0.0,
|
181 |
+
'faces_per_pixel': 1,
|
182 |
+
'bin_size': None,
|
183 |
+
'max_faces_per_bin': None,
|
184 |
+
'perspective_correct': True,
|
185 |
+
'cull_backfaces': True,
|
186 |
+
}
|
187 |
+
raster_settings = dict2obj(raster_settings)
|
188 |
+
self.raster_settings = raster_settings
|
189 |
+
|
190 |
+
def forward(self, vertices, faces, attributes=None):
|
191 |
+
fixed_vertices = vertices.clone()
|
192 |
+
fixed_vertices[..., :2] = -fixed_vertices[..., :2]
|
193 |
+
meshes_screen = Meshes(verts=fixed_vertices.float(),
|
194 |
+
faces=faces.long())
|
195 |
+
raster_settings = self.raster_settings
|
196 |
+
pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
|
197 |
+
meshes_screen,
|
198 |
+
image_size=raster_settings.image_size,
|
199 |
+
blur_radius=raster_settings.blur_radius,
|
200 |
+
faces_per_pixel=raster_settings.faces_per_pixel,
|
201 |
+
bin_size=raster_settings.bin_size,
|
202 |
+
max_faces_per_bin=raster_settings.max_faces_per_bin,
|
203 |
+
perspective_correct=raster_settings.perspective_correct,
|
204 |
+
)
|
205 |
+
vismask = (pix_to_face > -1).float()
|
206 |
+
D = attributes.shape[-1]
|
207 |
+
attributes = attributes.clone()
|
208 |
+
attributes = attributes.view(attributes.shape[0] * attributes.shape[1],
|
209 |
+
3, attributes.shape[-1])
|
210 |
+
N, H, W, K, _ = bary_coords.shape
|
211 |
+
mask = pix_to_face == -1
|
212 |
+
pix_to_face = pix_to_face.clone()
|
213 |
+
pix_to_face[mask] = 0
|
214 |
+
idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D)
|
215 |
+
pixel_face_vals = attributes.gather(0, idx).view(N, H, W, K, 3, D)
|
216 |
+
pixel_vals = (bary_coords[..., None] * pixel_face_vals).sum(dim=-2)
|
217 |
+
pixel_vals[mask] = 0 # Replace masked values in output.
|
218 |
+
pixel_vals = pixel_vals[:, :, :, 0].permute(0, 3, 1, 2)
|
219 |
+
pixel_vals = torch.cat(
|
220 |
+
[pixel_vals, vismask[:, :, :, 0][:, None, :, :]], dim=1)
|
221 |
+
return pixel_vals
|
lib/common/seg3d_lossless.py
ADDED
@@ -0,0 +1,604 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
5 |
+
# holder of all proprietary rights on this computer program.
|
6 |
+
# You can only use this computer program if you have closed
|
7 |
+
# a license agreement with MPG or you get the right to use the computer
|
8 |
+
# program from someone who is authorized to grant you that right.
|
9 |
+
# Any use of the computer program without a valid license is prohibited and
|
10 |
+
# liable to prosecution.
|
11 |
+
#
|
12 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
13 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
14 |
+
# for Intelligent Systems. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: ps-license@tuebingen.mpg.de
|
17 |
+
|
18 |
+
|
19 |
+
from .seg3d_utils import (
|
20 |
+
create_grid3D,
|
21 |
+
plot_mask3D,
|
22 |
+
SmoothConv3D,
|
23 |
+
)
|
24 |
+
|
25 |
+
import torch
|
26 |
+
import torch.nn as nn
|
27 |
+
import numpy as np
|
28 |
+
import torch.nn.functional as F
|
29 |
+
import mcubes
|
30 |
+
from kaolin.ops.conversions import voxelgrids_to_trianglemeshes
|
31 |
+
import logging
|
32 |
+
|
33 |
+
logging.getLogger("lightning").setLevel(logging.ERROR)
|
34 |
+
|
35 |
+
|
36 |
+
class Seg3dLossless(nn.Module):
|
37 |
+
def __init__(self,
|
38 |
+
query_func,
|
39 |
+
b_min,
|
40 |
+
b_max,
|
41 |
+
resolutions,
|
42 |
+
channels=1,
|
43 |
+
balance_value=0.5,
|
44 |
+
align_corners=False,
|
45 |
+
visualize=False,
|
46 |
+
debug=False,
|
47 |
+
use_cuda_impl=False,
|
48 |
+
faster=False,
|
49 |
+
use_shadow=False,
|
50 |
+
**kwargs):
|
51 |
+
"""
|
52 |
+
align_corners: same with how you process gt. (grid_sample / interpolate)
|
53 |
+
"""
|
54 |
+
super().__init__()
|
55 |
+
self.query_func = query_func
|
56 |
+
self.register_buffer(
|
57 |
+
'b_min',
|
58 |
+
torch.tensor(b_min).float().unsqueeze(1)) # [bz, 1, 3]
|
59 |
+
self.register_buffer(
|
60 |
+
'b_max',
|
61 |
+
torch.tensor(b_max).float().unsqueeze(1)) # [bz, 1, 3]
|
62 |
+
|
63 |
+
# ti.init(arch=ti.cuda)
|
64 |
+
# self.mciso_taichi = MCISO(dim=3, N=resolutions[-1]-1)
|
65 |
+
|
66 |
+
if type(resolutions[0]) is int:
|
67 |
+
resolutions = torch.tensor([(res, res, res)
|
68 |
+
for res in resolutions])
|
69 |
+
else:
|
70 |
+
resolutions = torch.tensor(resolutions)
|
71 |
+
self.register_buffer('resolutions', resolutions)
|
72 |
+
self.batchsize = self.b_min.size(0)
|
73 |
+
assert self.batchsize == 1
|
74 |
+
self.balance_value = balance_value
|
75 |
+
self.channels = channels
|
76 |
+
assert self.channels == 1
|
77 |
+
self.align_corners = align_corners
|
78 |
+
self.visualize = visualize
|
79 |
+
self.debug = debug
|
80 |
+
self.use_cuda_impl = use_cuda_impl
|
81 |
+
self.faster = faster
|
82 |
+
self.use_shadow = use_shadow
|
83 |
+
|
84 |
+
for resolution in resolutions:
|
85 |
+
assert resolution[0] % 2 == 1 and resolution[1] % 2 == 1, \
|
86 |
+
f"resolution {resolution} need to be odd becuase of align_corner."
|
87 |
+
|
88 |
+
# init first resolution
|
89 |
+
init_coords = create_grid3D(0,
|
90 |
+
resolutions[-1] - 1,
|
91 |
+
steps=resolutions[0]) # [N, 3]
|
92 |
+
init_coords = init_coords.unsqueeze(0).repeat(self.batchsize, 1,
|
93 |
+
1) # [bz, N, 3]
|
94 |
+
self.register_buffer('init_coords', init_coords)
|
95 |
+
|
96 |
+
# some useful tensors
|
97 |
+
calculated = torch.zeros(
|
98 |
+
(self.resolutions[-1][2], self.resolutions[-1][1],
|
99 |
+
self.resolutions[-1][0]),
|
100 |
+
dtype=torch.bool)
|
101 |
+
self.register_buffer('calculated', calculated)
|
102 |
+
|
103 |
+
gird8_offsets = torch.stack(
|
104 |
+
torch.meshgrid([
|
105 |
+
torch.tensor([-1, 0, 1]),
|
106 |
+
torch.tensor([-1, 0, 1]),
|
107 |
+
torch.tensor([-1, 0, 1])
|
108 |
+
])).int().view(3, -1).t() # [27, 3]
|
109 |
+
self.register_buffer('gird8_offsets', gird8_offsets)
|
110 |
+
|
111 |
+
# smooth convs
|
112 |
+
self.smooth_conv3x3 = SmoothConv3D(in_channels=1,
|
113 |
+
out_channels=1,
|
114 |
+
kernel_size=3)
|
115 |
+
self.smooth_conv5x5 = SmoothConv3D(in_channels=1,
|
116 |
+
out_channels=1,
|
117 |
+
kernel_size=5)
|
118 |
+
self.smooth_conv7x7 = SmoothConv3D(in_channels=1,
|
119 |
+
out_channels=1,
|
120 |
+
kernel_size=7)
|
121 |
+
self.smooth_conv9x9 = SmoothConv3D(in_channels=1,
|
122 |
+
out_channels=1,
|
123 |
+
kernel_size=9)
|
124 |
+
|
125 |
+
def batch_eval(self, coords, **kwargs):
|
126 |
+
"""
|
127 |
+
coords: in the coordinates of last resolution
|
128 |
+
**kwargs: for query_func
|
129 |
+
"""
|
130 |
+
coords = coords.detach()
|
131 |
+
# normalize coords to fit in [b_min, b_max]
|
132 |
+
if self.align_corners:
|
133 |
+
coords2D = coords.float() / (self.resolutions[-1] - 1)
|
134 |
+
else:
|
135 |
+
step = 1.0 / self.resolutions[-1].float()
|
136 |
+
coords2D = coords.float() / self.resolutions[-1] + step / 2
|
137 |
+
coords2D = coords2D * (self.b_max - self.b_min) + self.b_min
|
138 |
+
# query function
|
139 |
+
occupancys = self.query_func(**kwargs, points=coords2D)
|
140 |
+
if type(occupancys) is list:
|
141 |
+
occupancys = torch.stack(occupancys) # [bz, C, N]
|
142 |
+
assert len(occupancys.size()) == 3, \
|
143 |
+
"query_func should return a occupancy with shape of [bz, C, N]"
|
144 |
+
return occupancys
|
145 |
+
|
146 |
+
def forward(self, **kwargs):
|
147 |
+
if self.faster:
|
148 |
+
return self._forward_faster(**kwargs)
|
149 |
+
else:
|
150 |
+
return self._forward(**kwargs)
|
151 |
+
|
152 |
+
def _forward_faster(self, **kwargs):
|
153 |
+
"""
|
154 |
+
In faster mode, we make following changes to exchange accuracy for speed:
|
155 |
+
1. no conflict checking: 4.88 fps -> 6.56 fps
|
156 |
+
2. smooth_conv9x9 ~ smooth_conv3x3 for different resolution
|
157 |
+
3. last step no examine
|
158 |
+
"""
|
159 |
+
final_W = self.resolutions[-1][0]
|
160 |
+
final_H = self.resolutions[-1][1]
|
161 |
+
final_D = self.resolutions[-1][2]
|
162 |
+
|
163 |
+
for resolution in self.resolutions:
|
164 |
+
W, H, D = resolution
|
165 |
+
stride = (self.resolutions[-1] - 1) / (resolution - 1)
|
166 |
+
|
167 |
+
# first step
|
168 |
+
if torch.equal(resolution, self.resolutions[0]):
|
169 |
+
coords = self.init_coords.clone() # torch.long
|
170 |
+
occupancys = self.batch_eval(coords, **kwargs)
|
171 |
+
occupancys = occupancys.view(self.batchsize, self.channels, D,
|
172 |
+
H, W)
|
173 |
+
if (occupancys > 0.5).sum() == 0:
|
174 |
+
# return F.interpolate(
|
175 |
+
# occupancys, size=(final_D, final_H, final_W),
|
176 |
+
# mode="linear", align_corners=True)
|
177 |
+
return None
|
178 |
+
|
179 |
+
if self.visualize:
|
180 |
+
self.plot(occupancys, coords, final_D, final_H, final_W)
|
181 |
+
|
182 |
+
with torch.no_grad():
|
183 |
+
coords_accum = coords / stride
|
184 |
+
|
185 |
+
# last step
|
186 |
+
elif torch.equal(resolution, self.resolutions[-1]):
|
187 |
+
|
188 |
+
with torch.no_grad():
|
189 |
+
# here true is correct!
|
190 |
+
valid = F.interpolate(
|
191 |
+
(occupancys > self.balance_value).float(),
|
192 |
+
size=(D, H, W),
|
193 |
+
mode="trilinear",
|
194 |
+
align_corners=True)
|
195 |
+
|
196 |
+
# here true is correct!
|
197 |
+
occupancys = F.interpolate(occupancys.float(),
|
198 |
+
size=(D, H, W),
|
199 |
+
mode="trilinear",
|
200 |
+
align_corners=True)
|
201 |
+
|
202 |
+
# is_boundary = (valid > 0.0) & (valid < 1.0)
|
203 |
+
is_boundary = valid == 0.5
|
204 |
+
|
205 |
+
# next steps
|
206 |
+
else:
|
207 |
+
coords_accum *= 2
|
208 |
+
|
209 |
+
with torch.no_grad():
|
210 |
+
# here true is correct!
|
211 |
+
valid = F.interpolate(
|
212 |
+
(occupancys > self.balance_value).float(),
|
213 |
+
size=(D, H, W),
|
214 |
+
mode="trilinear",
|
215 |
+
align_corners=True)
|
216 |
+
|
217 |
+
# here true is correct!
|
218 |
+
occupancys = F.interpolate(occupancys.float(),
|
219 |
+
size=(D, H, W),
|
220 |
+
mode="trilinear",
|
221 |
+
align_corners=True)
|
222 |
+
|
223 |
+
is_boundary = (valid > 0.0) & (valid < 1.0)
|
224 |
+
|
225 |
+
with torch.no_grad():
|
226 |
+
if torch.equal(resolution, self.resolutions[1]):
|
227 |
+
is_boundary = (self.smooth_conv9x9(is_boundary.float())
|
228 |
+
> 0)[0, 0]
|
229 |
+
elif torch.equal(resolution, self.resolutions[2]):
|
230 |
+
is_boundary = (self.smooth_conv7x7(is_boundary.float())
|
231 |
+
> 0)[0, 0]
|
232 |
+
else:
|
233 |
+
is_boundary = (self.smooth_conv3x3(is_boundary.float())
|
234 |
+
> 0)[0, 0]
|
235 |
+
|
236 |
+
coords_accum = coords_accum.long()
|
237 |
+
is_boundary[coords_accum[0, :, 2], coords_accum[0, :, 1],
|
238 |
+
coords_accum[0, :, 0]] = False
|
239 |
+
point_coords = is_boundary.permute(
|
240 |
+
2, 1, 0).nonzero(as_tuple=False).unsqueeze(0)
|
241 |
+
point_indices = (point_coords[:, :, 2] * H * W +
|
242 |
+
point_coords[:, :, 1] * W +
|
243 |
+
point_coords[:, :, 0])
|
244 |
+
|
245 |
+
R, C, D, H, W = occupancys.shape
|
246 |
+
|
247 |
+
# inferred value
|
248 |
+
coords = point_coords * stride
|
249 |
+
|
250 |
+
if coords.size(1) == 0:
|
251 |
+
continue
|
252 |
+
occupancys_topk = self.batch_eval(coords, **kwargs)
|
253 |
+
|
254 |
+
# put mask point predictions to the right places on the upsampled grid.
|
255 |
+
R, C, D, H, W = occupancys.shape
|
256 |
+
point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
|
257 |
+
occupancys = (occupancys.reshape(R, C, D * H * W).scatter_(
|
258 |
+
2, point_indices, occupancys_topk).view(R, C, D, H, W))
|
259 |
+
|
260 |
+
with torch.no_grad():
|
261 |
+
voxels = coords / stride
|
262 |
+
coords_accum = torch.cat([voxels, coords_accum],
|
263 |
+
dim=1).unique(dim=1)
|
264 |
+
|
265 |
+
return occupancys[0, 0]
|
266 |
+
|
267 |
+
def _forward(self, **kwargs):
|
268 |
+
"""
|
269 |
+
output occupancy field would be:
|
270 |
+
(bz, C, res, res)
|
271 |
+
"""
|
272 |
+
final_W = self.resolutions[-1][0]
|
273 |
+
final_H = self.resolutions[-1][1]
|
274 |
+
final_D = self.resolutions[-1][2]
|
275 |
+
|
276 |
+
calculated = self.calculated.clone()
|
277 |
+
|
278 |
+
for resolution in self.resolutions:
|
279 |
+
W, H, D = resolution
|
280 |
+
stride = (self.resolutions[-1] - 1) / (resolution - 1)
|
281 |
+
|
282 |
+
if self.visualize:
|
283 |
+
this_stage_coords = []
|
284 |
+
|
285 |
+
# first step
|
286 |
+
if torch.equal(resolution, self.resolutions[0]):
|
287 |
+
coords = self.init_coords.clone() # torch.long
|
288 |
+
occupancys = self.batch_eval(coords, **kwargs)
|
289 |
+
occupancys = occupancys.view(self.batchsize, self.channels, D,
|
290 |
+
H, W)
|
291 |
+
|
292 |
+
if self.visualize:
|
293 |
+
self.plot(occupancys, coords, final_D, final_H, final_W)
|
294 |
+
|
295 |
+
with torch.no_grad():
|
296 |
+
coords_accum = coords / stride
|
297 |
+
calculated[coords[0, :, 2], coords[0, :, 1],
|
298 |
+
coords[0, :, 0]] = True
|
299 |
+
|
300 |
+
# next steps
|
301 |
+
else:
|
302 |
+
coords_accum *= 2
|
303 |
+
|
304 |
+
with torch.no_grad():
|
305 |
+
# here true is correct!
|
306 |
+
valid = F.interpolate(
|
307 |
+
(occupancys > self.balance_value).float(),
|
308 |
+
size=(D, H, W),
|
309 |
+
mode="trilinear",
|
310 |
+
align_corners=True)
|
311 |
+
|
312 |
+
# here true is correct!
|
313 |
+
occupancys = F.interpolate(occupancys.float(),
|
314 |
+
size=(D, H, W),
|
315 |
+
mode="trilinear",
|
316 |
+
align_corners=True)
|
317 |
+
|
318 |
+
is_boundary = (valid > 0.0) & (valid < 1.0)
|
319 |
+
|
320 |
+
with torch.no_grad():
|
321 |
+
# TODO
|
322 |
+
if self.use_shadow and torch.equal(resolution,
|
323 |
+
self.resolutions[-1]):
|
324 |
+
# larger z means smaller depth here
|
325 |
+
depth_res = resolution[2].item()
|
326 |
+
depth_index = torch.linspace(0,
|
327 |
+
depth_res - 1,
|
328 |
+
steps=depth_res).type_as(
|
329 |
+
occupancys.device)
|
330 |
+
depth_index_max = torch.max(
|
331 |
+
(occupancys > self.balance_value) *
|
332 |
+
(depth_index + 1),
|
333 |
+
dim=-1,
|
334 |
+
keepdim=True)[0] - 1
|
335 |
+
shadow = depth_index < depth_index_max
|
336 |
+
is_boundary[shadow] = False
|
337 |
+
is_boundary = is_boundary[0, 0]
|
338 |
+
else:
|
339 |
+
is_boundary = (self.smooth_conv3x3(is_boundary.float())
|
340 |
+
> 0)[0, 0]
|
341 |
+
# is_boundary = is_boundary[0, 0]
|
342 |
+
|
343 |
+
is_boundary[coords_accum[0, :, 2], coords_accum[0, :, 1],
|
344 |
+
coords_accum[0, :, 0]] = False
|
345 |
+
point_coords = is_boundary.permute(
|
346 |
+
2, 1, 0).nonzero(as_tuple=False).unsqueeze(0)
|
347 |
+
point_indices = (point_coords[:, :, 2] * H * W +
|
348 |
+
point_coords[:, :, 1] * W +
|
349 |
+
point_coords[:, :, 0])
|
350 |
+
|
351 |
+
R, C, D, H, W = occupancys.shape
|
352 |
+
# interpolated value
|
353 |
+
occupancys_interp = torch.gather(
|
354 |
+
occupancys.reshape(R, C, D * H * W), 2,
|
355 |
+
point_indices.unsqueeze(1))
|
356 |
+
|
357 |
+
# inferred value
|
358 |
+
coords = point_coords * stride
|
359 |
+
|
360 |
+
if coords.size(1) == 0:
|
361 |
+
continue
|
362 |
+
occupancys_topk = self.batch_eval(coords, **kwargs)
|
363 |
+
if self.visualize:
|
364 |
+
this_stage_coords.append(coords)
|
365 |
+
|
366 |
+
# put mask point predictions to the right places on the upsampled grid.
|
367 |
+
R, C, D, H, W = occupancys.shape
|
368 |
+
point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
|
369 |
+
occupancys = (occupancys.reshape(R, C, D * H * W).scatter_(
|
370 |
+
2, point_indices, occupancys_topk).view(R, C, D, H, W))
|
371 |
+
|
372 |
+
with torch.no_grad():
|
373 |
+
# conflicts
|
374 |
+
conflicts = ((occupancys_interp - self.balance_value) *
|
375 |
+
(occupancys_topk - self.balance_value) < 0)[0,
|
376 |
+
0]
|
377 |
+
|
378 |
+
if self.visualize:
|
379 |
+
self.plot(occupancys, coords, final_D, final_H,
|
380 |
+
final_W)
|
381 |
+
|
382 |
+
voxels = coords / stride
|
383 |
+
coords_accum = torch.cat([voxels, coords_accum],
|
384 |
+
dim=1).unique(dim=1)
|
385 |
+
calculated[coords[0, :, 2], coords[0, :, 1],
|
386 |
+
coords[0, :, 0]] = True
|
387 |
+
|
388 |
+
while conflicts.sum() > 0:
|
389 |
+
if self.use_shadow and torch.equal(resolution,
|
390 |
+
self.resolutions[-1]):
|
391 |
+
break
|
392 |
+
|
393 |
+
with torch.no_grad():
|
394 |
+
conflicts_coords = coords[0, conflicts, :]
|
395 |
+
|
396 |
+
if self.debug:
|
397 |
+
self.plot(occupancys,
|
398 |
+
conflicts_coords.unsqueeze(0),
|
399 |
+
final_D,
|
400 |
+
final_H,
|
401 |
+
final_W,
|
402 |
+
title='conflicts')
|
403 |
+
|
404 |
+
conflicts_boundary = (conflicts_coords.int() +
|
405 |
+
self.gird8_offsets.unsqueeze(1) *
|
406 |
+
stride.int()).reshape(
|
407 |
+
-1, 3).long().unique(dim=0)
|
408 |
+
conflicts_boundary[:, 0] = (
|
409 |
+
conflicts_boundary[:, 0].clamp(
|
410 |
+
0,
|
411 |
+
calculated.size(2) - 1))
|
412 |
+
conflicts_boundary[:, 1] = (
|
413 |
+
conflicts_boundary[:, 1].clamp(
|
414 |
+
0,
|
415 |
+
calculated.size(1) - 1))
|
416 |
+
conflicts_boundary[:, 2] = (
|
417 |
+
conflicts_boundary[:, 2].clamp(
|
418 |
+
0,
|
419 |
+
calculated.size(0) - 1))
|
420 |
+
|
421 |
+
coords = conflicts_boundary[calculated[
|
422 |
+
conflicts_boundary[:, 2], conflicts_boundary[:, 1],
|
423 |
+
conflicts_boundary[:, 0]] == False]
|
424 |
+
|
425 |
+
if self.debug:
|
426 |
+
self.plot(occupancys,
|
427 |
+
coords.unsqueeze(0),
|
428 |
+
final_D,
|
429 |
+
final_H,
|
430 |
+
final_W,
|
431 |
+
title='coords')
|
432 |
+
|
433 |
+
coords = coords.unsqueeze(0)
|
434 |
+
point_coords = coords / stride
|
435 |
+
point_indices = (point_coords[:, :, 2] * H * W +
|
436 |
+
point_coords[:, :, 1] * W +
|
437 |
+
point_coords[:, :, 0])
|
438 |
+
|
439 |
+
R, C, D, H, W = occupancys.shape
|
440 |
+
# interpolated value
|
441 |
+
occupancys_interp = torch.gather(
|
442 |
+
occupancys.reshape(R, C, D * H * W), 2,
|
443 |
+
point_indices.unsqueeze(1))
|
444 |
+
|
445 |
+
# inferred value
|
446 |
+
coords = point_coords * stride
|
447 |
+
|
448 |
+
if coords.size(1) == 0:
|
449 |
+
break
|
450 |
+
occupancys_topk = self.batch_eval(coords, **kwargs)
|
451 |
+
if self.visualize:
|
452 |
+
this_stage_coords.append(coords)
|
453 |
+
|
454 |
+
with torch.no_grad():
|
455 |
+
# conflicts
|
456 |
+
conflicts = ((occupancys_interp - self.balance_value) *
|
457 |
+
(occupancys_topk - self.balance_value) <
|
458 |
+
0)[0, 0]
|
459 |
+
|
460 |
+
# put mask point predictions to the right places on the upsampled grid.
|
461 |
+
point_indices = point_indices.unsqueeze(1).expand(
|
462 |
+
-1, C, -1)
|
463 |
+
occupancys = (occupancys.reshape(R, C, D * H * W).scatter_(
|
464 |
+
2, point_indices, occupancys_topk).view(R, C, D, H, W))
|
465 |
+
|
466 |
+
with torch.no_grad():
|
467 |
+
voxels = coords / stride
|
468 |
+
coords_accum = torch.cat([voxels, coords_accum],
|
469 |
+
dim=1).unique(dim=1)
|
470 |
+
calculated[coords[0, :, 2], coords[0, :, 1],
|
471 |
+
coords[0, :, 0]] = True
|
472 |
+
|
473 |
+
if self.visualize:
|
474 |
+
this_stage_coords = torch.cat(this_stage_coords, dim=1)
|
475 |
+
self.plot(occupancys, this_stage_coords, final_D, final_H,
|
476 |
+
final_W)
|
477 |
+
|
478 |
+
return occupancys[0, 0]
|
479 |
+
|
480 |
+
def plot(self,
|
481 |
+
occupancys,
|
482 |
+
coords,
|
483 |
+
final_D,
|
484 |
+
final_H,
|
485 |
+
final_W,
|
486 |
+
title='',
|
487 |
+
**kwargs):
|
488 |
+
final = F.interpolate(occupancys.float(),
|
489 |
+
size=(final_D, final_H, final_W),
|
490 |
+
mode="trilinear",
|
491 |
+
align_corners=True) # here true is correct!
|
492 |
+
x = coords[0, :, 0].to("cpu")
|
493 |
+
y = coords[0, :, 1].to("cpu")
|
494 |
+
z = coords[0, :, 2].to("cpu")
|
495 |
+
|
496 |
+
plot_mask3D(final[0, 0].to("cpu"), title, (x, y, z), **kwargs)
|
497 |
+
|
498 |
+
def find_vertices(self, sdf, direction="front"):
|
499 |
+
'''
|
500 |
+
- direction: "front" | "back" | "left" | "right"
|
501 |
+
'''
|
502 |
+
resolution = sdf.size(2)
|
503 |
+
if direction == "front":
|
504 |
+
pass
|
505 |
+
elif direction == "left":
|
506 |
+
sdf = sdf.permute(2, 1, 0)
|
507 |
+
elif direction == "back":
|
508 |
+
inv_idx = torch.arange(sdf.size(2) - 1, -1, -1).long()
|
509 |
+
sdf = sdf[inv_idx, :, :]
|
510 |
+
elif direction == "right":
|
511 |
+
inv_idx = torch.arange(sdf.size(2) - 1, -1, -1).long()
|
512 |
+
sdf = sdf[:, :, inv_idx]
|
513 |
+
sdf = sdf.permute(2, 1, 0)
|
514 |
+
|
515 |
+
inv_idx = torch.arange(sdf.size(2) - 1, -1, -1).long()
|
516 |
+
sdf = sdf[inv_idx, :, :]
|
517 |
+
sdf_all = sdf.permute(2, 1, 0)
|
518 |
+
|
519 |
+
# shadow
|
520 |
+
grad_v = (sdf_all > 0.5) * torch.linspace(
|
521 |
+
resolution, 1, steps=resolution).to(sdf.device)
|
522 |
+
grad_c = torch.ones_like(sdf_all) * torch.linspace(
|
523 |
+
0, resolution - 1, steps=resolution).to(sdf.device)
|
524 |
+
max_v, max_c = grad_v.max(dim=2)
|
525 |
+
shadow = grad_c > max_c.view(resolution, resolution, 1)
|
526 |
+
keep = (sdf_all > 0.5) & (~shadow)
|
527 |
+
|
528 |
+
p1 = keep.nonzero(as_tuple=False).t() # [3, N]
|
529 |
+
p2 = p1.clone() # z
|
530 |
+
p2[2, :] = (p2[2, :] - 2).clamp(0, resolution)
|
531 |
+
p3 = p1.clone() # y
|
532 |
+
p3[1, :] = (p3[1, :] - 2).clamp(0, resolution)
|
533 |
+
p4 = p1.clone() # x
|
534 |
+
p4[0, :] = (p4[0, :] - 2).clamp(0, resolution)
|
535 |
+
|
536 |
+
v1 = sdf_all[p1[0, :], p1[1, :], p1[2, :]]
|
537 |
+
v2 = sdf_all[p2[0, :], p2[1, :], p2[2, :]]
|
538 |
+
v3 = sdf_all[p3[0, :], p3[1, :], p3[2, :]]
|
539 |
+
v4 = sdf_all[p4[0, :], p4[1, :], p4[2, :]]
|
540 |
+
|
541 |
+
X = p1[0, :].long() # [N,]
|
542 |
+
Y = p1[1, :].long() # [N,]
|
543 |
+
Z = p2[2, :].float() * (0.5 - v1) / (v2 - v1) + \
|
544 |
+
p1[2, :].float() * (v2 - 0.5) / (v2 - v1) # [N,]
|
545 |
+
Z = Z.clamp(0, resolution)
|
546 |
+
|
547 |
+
# normal
|
548 |
+
norm_z = v2 - v1
|
549 |
+
norm_y = v3 - v1
|
550 |
+
norm_x = v4 - v1
|
551 |
+
# print (v2.min(dim=0)[0], v2.max(dim=0)[0], v3.min(dim=0)[0], v3.max(dim=0)[0])
|
552 |
+
|
553 |
+
norm = torch.stack([norm_x, norm_y, norm_z], dim=1)
|
554 |
+
norm = norm / torch.norm(norm, p=2, dim=1, keepdim=True)
|
555 |
+
|
556 |
+
return X, Y, Z, norm
|
557 |
+
|
558 |
+
def render_normal(self, resolution, X, Y, Z, norm):
|
559 |
+
image = torch.ones((1, 3, resolution, resolution),
|
560 |
+
dtype=torch.float32).to(norm.device)
|
561 |
+
color = (norm + 1) / 2.0
|
562 |
+
color = color.clamp(0, 1)
|
563 |
+
image[0, :, Y, X] = color.t()
|
564 |
+
return image
|
565 |
+
|
566 |
+
def display(self, sdf):
|
567 |
+
|
568 |
+
# render
|
569 |
+
X, Y, Z, norm = self.find_vertices(sdf, direction="front")
|
570 |
+
image1 = self.render_normal(self.resolutions[-1, -1], X, Y, Z, norm)
|
571 |
+
X, Y, Z, norm = self.find_vertices(sdf, direction="left")
|
572 |
+
image2 = self.render_normal(self.resolutions[-1, -1], X, Y, Z, norm)
|
573 |
+
X, Y, Z, norm = self.find_vertices(sdf, direction="right")
|
574 |
+
image3 = self.render_normal(self.resolutions[-1, -1], X, Y, Z, norm)
|
575 |
+
X, Y, Z, norm = self.find_vertices(sdf, direction="back")
|
576 |
+
image4 = self.render_normal(self.resolutions[-1, -1], X, Y, Z, norm)
|
577 |
+
|
578 |
+
image = torch.cat([image1, image2, image3, image4], axis=3)
|
579 |
+
image = image.detach().cpu().numpy()[0].transpose(1, 2, 0) * 255.0
|
580 |
+
|
581 |
+
return np.uint8(image)
|
582 |
+
|
583 |
+
def export_mesh(self, occupancys):
|
584 |
+
|
585 |
+
final = occupancys[1:, 1:, 1:].contiguous()
|
586 |
+
|
587 |
+
if final.shape[0] > 256:
|
588 |
+
# for voxelgrid larger than 256^3, the required GPU memory will be > 9GB
|
589 |
+
# thus we use CPU marching_cube to avoid "CUDA out of memory"
|
590 |
+
occu_arr = final.detach().cpu().numpy() # non-smooth surface
|
591 |
+
# occu_arr = mcubes.smooth(final.detach().cpu().numpy()) # smooth surface
|
592 |
+
vertices, triangles = mcubes.marching_cubes(
|
593 |
+
occu_arr, self.balance_value)
|
594 |
+
verts = torch.as_tensor(vertices[:, [2, 1, 0]])
|
595 |
+
faces = torch.as_tensor(triangles.astype(
|
596 |
+
np.long), dtype=torch.long)[:, [0, 2, 1]]
|
597 |
+
else:
|
598 |
+
torch.cuda.empty_cache()
|
599 |
+
vertices, triangles = voxelgrids_to_trianglemeshes(
|
600 |
+
final.unsqueeze(0))
|
601 |
+
verts = vertices[0][:, [2, 1, 0]].cpu()
|
602 |
+
faces = triangles[0][:, [0, 2, 1]].cpu()
|
603 |
+
|
604 |
+
return verts, faces
|
lib/common/seg3d_utils.py
ADDED
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
5 |
+
# holder of all proprietary rights on this computer program.
|
6 |
+
# You can only use this computer program if you have closed
|
7 |
+
# a license agreement with MPG or you get the right to use the computer
|
8 |
+
# program from someone who is authorized to grant you that right.
|
9 |
+
# Any use of the computer program without a valid license is prohibited and
|
10 |
+
# liable to prosecution.
|
11 |
+
#
|
12 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
13 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
14 |
+
# for Intelligent Systems. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: ps-license@tuebingen.mpg.de
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
import matplotlib.pyplot as plt
|
22 |
+
|
23 |
+
|
24 |
+
def plot_mask2D(mask,
|
25 |
+
title="",
|
26 |
+
point_coords=None,
|
27 |
+
figsize=10,
|
28 |
+
point_marker_size=5):
|
29 |
+
'''
|
30 |
+
Simple plotting tool to show intermediate mask predictions and points
|
31 |
+
where PointRend is applied.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
mask (Tensor): mask prediction of shape HxW
|
35 |
+
title (str): title for the plot
|
36 |
+
point_coords ((Tensor, Tensor)): x and y point coordinates
|
37 |
+
figsize (int): size of the figure to plot
|
38 |
+
point_marker_size (int): marker size for points
|
39 |
+
'''
|
40 |
+
|
41 |
+
H, W = mask.shape
|
42 |
+
plt.figure(figsize=(figsize, figsize))
|
43 |
+
if title:
|
44 |
+
title += ", "
|
45 |
+
plt.title("{}resolution {}x{}".format(title, H, W), fontsize=30)
|
46 |
+
plt.ylabel(H, fontsize=30)
|
47 |
+
plt.xlabel(W, fontsize=30)
|
48 |
+
plt.xticks([], [])
|
49 |
+
plt.yticks([], [])
|
50 |
+
plt.imshow(mask.detach(),
|
51 |
+
interpolation="nearest",
|
52 |
+
cmap=plt.get_cmap('gray'))
|
53 |
+
if point_coords is not None:
|
54 |
+
plt.scatter(x=point_coords[0],
|
55 |
+
y=point_coords[1],
|
56 |
+
color="red",
|
57 |
+
s=point_marker_size,
|
58 |
+
clip_on=True)
|
59 |
+
plt.xlim(-0.5, W - 0.5)
|
60 |
+
plt.ylim(H - 0.5, -0.5)
|
61 |
+
plt.show()
|
62 |
+
|
63 |
+
|
64 |
+
def plot_mask3D(mask=None,
|
65 |
+
title="",
|
66 |
+
point_coords=None,
|
67 |
+
figsize=1500,
|
68 |
+
point_marker_size=8,
|
69 |
+
interactive=True):
|
70 |
+
'''
|
71 |
+
Simple plotting tool to show intermediate mask predictions and points
|
72 |
+
where PointRend is applied.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
mask (Tensor): mask prediction of shape DxHxW
|
76 |
+
title (str): title for the plot
|
77 |
+
point_coords ((Tensor, Tensor, Tensor)): x and y and z point coordinates
|
78 |
+
figsize (int): size of the figure to plot
|
79 |
+
point_marker_size (int): marker size for points
|
80 |
+
'''
|
81 |
+
import trimesh
|
82 |
+
import vtkplotter
|
83 |
+
from skimage import measure
|
84 |
+
|
85 |
+
vp = vtkplotter.Plotter(title=title, size=(figsize, figsize))
|
86 |
+
vis_list = []
|
87 |
+
|
88 |
+
if mask is not None:
|
89 |
+
mask = mask.detach().to("cpu").numpy()
|
90 |
+
mask = mask.transpose(2, 1, 0)
|
91 |
+
|
92 |
+
# marching cube to find surface
|
93 |
+
verts, faces, normals, values = measure.marching_cubes_lewiner(
|
94 |
+
mask, 0.5, gradient_direction='ascent')
|
95 |
+
|
96 |
+
# create a mesh
|
97 |
+
mesh = trimesh.Trimesh(verts, faces)
|
98 |
+
mesh.visual.face_colors = [200, 200, 250, 100]
|
99 |
+
vis_list.append(mesh)
|
100 |
+
|
101 |
+
if point_coords is not None:
|
102 |
+
point_coords = torch.stack(point_coords, 1).to("cpu").numpy()
|
103 |
+
|
104 |
+
# import numpy as np
|
105 |
+
# select_x = np.logical_and(point_coords[:, 0] >= 16, point_coords[:, 0] <= 112)
|
106 |
+
# select_y = np.logical_and(point_coords[:, 1] >= 48, point_coords[:, 1] <= 272)
|
107 |
+
# select_z = np.logical_and(point_coords[:, 2] >= 16, point_coords[:, 2] <= 112)
|
108 |
+
# select = np.logical_and(np.logical_and(select_x, select_y), select_z)
|
109 |
+
# point_coords = point_coords[select, :]
|
110 |
+
|
111 |
+
pc = vtkplotter.Points(point_coords, r=point_marker_size, c='red')
|
112 |
+
vis_list.append(pc)
|
113 |
+
|
114 |
+
vp.show(*vis_list,
|
115 |
+
bg="white",
|
116 |
+
axes=1,
|
117 |
+
interactive=interactive,
|
118 |
+
azimuth=30,
|
119 |
+
elevation=30)
|
120 |
+
|
121 |
+
|
122 |
+
def create_grid3D(min, max, steps):
|
123 |
+
if type(min) is int:
|
124 |
+
min = (min, min, min) # (x, y, z)
|
125 |
+
if type(max) is int:
|
126 |
+
max = (max, max, max) # (x, y)
|
127 |
+
if type(steps) is int:
|
128 |
+
steps = (steps, steps, steps) # (x, y, z)
|
129 |
+
arrangeX = torch.linspace(min[0], max[0], steps[0]).long()
|
130 |
+
arrangeY = torch.linspace(min[1], max[1], steps[1]).long()
|
131 |
+
arrangeZ = torch.linspace(min[2], max[2], steps[2]).long()
|
132 |
+
gridD, girdH, gridW = torch.meshgrid([arrangeZ, arrangeY, arrangeX])
|
133 |
+
coords = torch.stack([gridW, girdH,
|
134 |
+
gridD]) # [2, steps[0], steps[1], steps[2]]
|
135 |
+
coords = coords.view(3, -1).t() # [N, 3]
|
136 |
+
return coords
|
137 |
+
|
138 |
+
|
139 |
+
def create_grid2D(min, max, steps):
|
140 |
+
if type(min) is int:
|
141 |
+
min = (min, min) # (x, y)
|
142 |
+
if type(max) is int:
|
143 |
+
max = (max, max) # (x, y)
|
144 |
+
if type(steps) is int:
|
145 |
+
steps = (steps, steps) # (x, y)
|
146 |
+
arrangeX = torch.linspace(min[0], max[0], steps[0]).long()
|
147 |
+
arrangeY = torch.linspace(min[1], max[1], steps[1]).long()
|
148 |
+
girdH, gridW = torch.meshgrid([arrangeY, arrangeX])
|
149 |
+
coords = torch.stack([gridW, girdH]) # [2, steps[0], steps[1]]
|
150 |
+
coords = coords.view(2, -1).t() # [N, 2]
|
151 |
+
return coords
|
152 |
+
|
153 |
+
|
154 |
+
class SmoothConv2D(nn.Module):
|
155 |
+
def __init__(self, in_channels, out_channels, kernel_size=3):
|
156 |
+
super().__init__()
|
157 |
+
assert kernel_size % 2 == 1, "kernel_size for smooth_conv must be odd: {3, 5, ...}"
|
158 |
+
self.padding = (kernel_size - 1) // 2
|
159 |
+
|
160 |
+
weight = torch.ones(
|
161 |
+
(in_channels, out_channels, kernel_size, kernel_size),
|
162 |
+
dtype=torch.float32) / (kernel_size**2)
|
163 |
+
self.register_buffer('weight', weight)
|
164 |
+
|
165 |
+
def forward(self, input):
|
166 |
+
return F.conv2d(input, self.weight, padding=self.padding)
|
167 |
+
|
168 |
+
|
169 |
+
class SmoothConv3D(nn.Module):
|
170 |
+
def __init__(self, in_channels, out_channels, kernel_size=3):
|
171 |
+
super().__init__()
|
172 |
+
assert kernel_size % 2 == 1, "kernel_size for smooth_conv must be odd: {3, 5, ...}"
|
173 |
+
self.padding = (kernel_size - 1) // 2
|
174 |
+
|
175 |
+
weight = torch.ones(
|
176 |
+
(in_channels, out_channels, kernel_size, kernel_size, kernel_size),
|
177 |
+
dtype=torch.float32) / (kernel_size**3)
|
178 |
+
self.register_buffer('weight', weight)
|
179 |
+
|
180 |
+
def forward(self, input):
|
181 |
+
return F.conv3d(input, self.weight, padding=self.padding)
|
182 |
+
|
183 |
+
|
184 |
+
def build_smooth_conv3D(in_channels=1,
|
185 |
+
out_channels=1,
|
186 |
+
kernel_size=3,
|
187 |
+
padding=1):
|
188 |
+
smooth_conv = torch.nn.Conv3d(in_channels=in_channels,
|
189 |
+
out_channels=out_channels,
|
190 |
+
kernel_size=kernel_size,
|
191 |
+
padding=padding)
|
192 |
+
smooth_conv.weight.data = torch.ones(
|
193 |
+
(in_channels, out_channels, kernel_size, kernel_size, kernel_size),
|
194 |
+
dtype=torch.float32) / (kernel_size**3)
|
195 |
+
smooth_conv.bias.data = torch.zeros(out_channels)
|
196 |
+
return smooth_conv
|
197 |
+
|
198 |
+
|
199 |
+
def build_smooth_conv2D(in_channels=1,
|
200 |
+
out_channels=1,
|
201 |
+
kernel_size=3,
|
202 |
+
padding=1):
|
203 |
+
smooth_conv = torch.nn.Conv2d(in_channels=in_channels,
|
204 |
+
out_channels=out_channels,
|
205 |
+
kernel_size=kernel_size,
|
206 |
+
padding=padding)
|
207 |
+
smooth_conv.weight.data = torch.ones(
|
208 |
+
(in_channels, out_channels, kernel_size, kernel_size),
|
209 |
+
dtype=torch.float32) / (kernel_size**2)
|
210 |
+
smooth_conv.bias.data = torch.zeros(out_channels)
|
211 |
+
return smooth_conv
|
212 |
+
|
213 |
+
|
214 |
+
def get_uncertain_point_coords_on_grid3D(uncertainty_map, num_points,
|
215 |
+
**kwargs):
|
216 |
+
"""
|
217 |
+
Find `num_points` most uncertain points from `uncertainty_map` grid.
|
218 |
+
Args:
|
219 |
+
uncertainty_map (Tensor): A tensor of shape (N, 1, H, W, D) that contains uncertainty
|
220 |
+
values for a set of points on a regular H x W x D grid.
|
221 |
+
num_points (int): The number of points P to select.
|
222 |
+
Returns:
|
223 |
+
point_indices (Tensor): A tensor of shape (N, P) that contains indices from
|
224 |
+
[0, H x W x D) of the most uncertain points.
|
225 |
+
point_coords (Tensor): A tensor of shape (N, P, 3) that contains [0, 1] x [0, 1] normalized
|
226 |
+
coordinates of the most uncertain points from the H x W x D grid.
|
227 |
+
"""
|
228 |
+
R, _, D, H, W = uncertainty_map.shape
|
229 |
+
# h_step = 1.0 / float(H)
|
230 |
+
# w_step = 1.0 / float(W)
|
231 |
+
# d_step = 1.0 / float(D)
|
232 |
+
|
233 |
+
num_points = min(D * H * W, num_points)
|
234 |
+
point_scores, point_indices = torch.topk(uncertainty_map.view(
|
235 |
+
R, D * H * W),
|
236 |
+
k=num_points,
|
237 |
+
dim=1)
|
238 |
+
point_coords = torch.zeros(R,
|
239 |
+
num_points,
|
240 |
+
3,
|
241 |
+
dtype=torch.float,
|
242 |
+
device=uncertainty_map.device)
|
243 |
+
# point_coords[:, :, 0] = h_step / 2.0 + (point_indices // (W * D)).to(torch.float) * h_step
|
244 |
+
# point_coords[:, :, 1] = w_step / 2.0 + (point_indices % (W * D) // D).to(torch.float) * w_step
|
245 |
+
# point_coords[:, :, 2] = d_step / 2.0 + (point_indices % D).to(torch.float) * d_step
|
246 |
+
point_coords[:, :, 0] = (point_indices % W).to(torch.float) # x
|
247 |
+
point_coords[:, :, 1] = (point_indices % (H * W) // W).to(torch.float) # y
|
248 |
+
point_coords[:, :, 2] = (point_indices // (H * W)).to(torch.float) # z
|
249 |
+
print(f"resolution {D} x {H} x {W}", point_scores.min(),
|
250 |
+
point_scores.max())
|
251 |
+
return point_indices, point_coords
|
252 |
+
|
253 |
+
|
254 |
+
def get_uncertain_point_coords_on_grid3D_faster(uncertainty_map, num_points,
|
255 |
+
clip_min):
|
256 |
+
"""
|
257 |
+
Find `num_points` most uncertain points from `uncertainty_map` grid.
|
258 |
+
Args:
|
259 |
+
uncertainty_map (Tensor): A tensor of shape (N, 1, H, W, D) that contains uncertainty
|
260 |
+
values for a set of points on a regular H x W x D grid.
|
261 |
+
num_points (int): The number of points P to select.
|
262 |
+
Returns:
|
263 |
+
point_indices (Tensor): A tensor of shape (N, P) that contains indices from
|
264 |
+
[0, H x W x D) of the most uncertain points.
|
265 |
+
point_coords (Tensor): A tensor of shape (N, P, 3) that contains [0, 1] x [0, 1] normalized
|
266 |
+
coordinates of the most uncertain points from the H x W x D grid.
|
267 |
+
"""
|
268 |
+
R, _, D, H, W = uncertainty_map.shape
|
269 |
+
# h_step = 1.0 / float(H)
|
270 |
+
# w_step = 1.0 / float(W)
|
271 |
+
# d_step = 1.0 / float(D)
|
272 |
+
|
273 |
+
assert R == 1, "batchsize > 1 is not implemented!"
|
274 |
+
uncertainty_map = uncertainty_map.view(D * H * W)
|
275 |
+
indices = (uncertainty_map >= clip_min).nonzero().squeeze(1)
|
276 |
+
num_points = min(num_points, indices.size(0))
|
277 |
+
point_scores, point_indices = torch.topk(uncertainty_map[indices],
|
278 |
+
k=num_points,
|
279 |
+
dim=0)
|
280 |
+
point_indices = indices[point_indices].unsqueeze(0)
|
281 |
+
|
282 |
+
point_coords = torch.zeros(R,
|
283 |
+
num_points,
|
284 |
+
3,
|
285 |
+
dtype=torch.float,
|
286 |
+
device=uncertainty_map.device)
|
287 |
+
# point_coords[:, :, 0] = h_step / 2.0 + (point_indices // (W * D)).to(torch.float) * h_step
|
288 |
+
# point_coords[:, :, 1] = w_step / 2.0 + (point_indices % (W * D) // D).to(torch.float) * w_step
|
289 |
+
# point_coords[:, :, 2] = d_step / 2.0 + (point_indices % D).to(torch.float) * d_step
|
290 |
+
point_coords[:, :, 0] = (point_indices % W).to(torch.float) # x
|
291 |
+
point_coords[:, :, 1] = (point_indices % (H * W) // W).to(torch.float) # y
|
292 |
+
point_coords[:, :, 2] = (point_indices // (H * W)).to(torch.float) # z
|
293 |
+
# print (f"resolution {D} x {H} x {W}", point_scores.min(), point_scores.max())
|
294 |
+
return point_indices, point_coords
|
295 |
+
|
296 |
+
|
297 |
+
def get_uncertain_point_coords_on_grid2D(uncertainty_map, num_points,
|
298 |
+
**kwargs):
|
299 |
+
"""
|
300 |
+
Find `num_points` most uncertain points from `uncertainty_map` grid.
|
301 |
+
Args:
|
302 |
+
uncertainty_map (Tensor): A tensor of shape (N, 1, H, W) that contains uncertainty
|
303 |
+
values for a set of points on a regular H x W grid.
|
304 |
+
num_points (int): The number of points P to select.
|
305 |
+
Returns:
|
306 |
+
point_indices (Tensor): A tensor of shape (N, P) that contains indices from
|
307 |
+
[0, H x W) of the most uncertain points.
|
308 |
+
point_coords (Tensor): A tensor of shape (N, P, 2) that contains [0, 1] x [0, 1] normalized
|
309 |
+
coordinates of the most uncertain points from the H x W grid.
|
310 |
+
"""
|
311 |
+
R, _, H, W = uncertainty_map.shape
|
312 |
+
# h_step = 1.0 / float(H)
|
313 |
+
# w_step = 1.0 / float(W)
|
314 |
+
|
315 |
+
num_points = min(H * W, num_points)
|
316 |
+
point_scores, point_indices = torch.topk(uncertainty_map.view(R, H * W),
|
317 |
+
k=num_points,
|
318 |
+
dim=1)
|
319 |
+
point_coords = torch.zeros(R,
|
320 |
+
num_points,
|
321 |
+
2,
|
322 |
+
dtype=torch.long,
|
323 |
+
device=uncertainty_map.device)
|
324 |
+
# point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step
|
325 |
+
# point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step
|
326 |
+
point_coords[:, :, 0] = (point_indices % W).to(torch.long)
|
327 |
+
point_coords[:, :, 1] = (point_indices // W).to(torch.long)
|
328 |
+
# print (point_scores.min(), point_scores.max())
|
329 |
+
return point_indices, point_coords
|
330 |
+
|
331 |
+
|
332 |
+
def get_uncertain_point_coords_on_grid2D_faster(uncertainty_map, num_points,
|
333 |
+
clip_min):
|
334 |
+
"""
|
335 |
+
Find `num_points` most uncertain points from `uncertainty_map` grid.
|
336 |
+
Args:
|
337 |
+
uncertainty_map (Tensor): A tensor of shape (N, 1, H, W) that contains uncertainty
|
338 |
+
values for a set of points on a regular H x W grid.
|
339 |
+
num_points (int): The number of points P to select.
|
340 |
+
Returns:
|
341 |
+
point_indices (Tensor): A tensor of shape (N, P) that contains indices from
|
342 |
+
[0, H x W) of the most uncertain points.
|
343 |
+
point_coords (Tensor): A tensor of shape (N, P, 2) that contains [0, 1] x [0, 1] normalized
|
344 |
+
coordinates of the most uncertain points from the H x W grid.
|
345 |
+
"""
|
346 |
+
R, _, H, W = uncertainty_map.shape
|
347 |
+
# h_step = 1.0 / float(H)
|
348 |
+
# w_step = 1.0 / float(W)
|
349 |
+
|
350 |
+
assert R == 1, "batchsize > 1 is not implemented!"
|
351 |
+
uncertainty_map = uncertainty_map.view(H * W)
|
352 |
+
indices = (uncertainty_map >= clip_min).nonzero().squeeze(1)
|
353 |
+
num_points = min(num_points, indices.size(0))
|
354 |
+
point_scores, point_indices = torch.topk(uncertainty_map[indices],
|
355 |
+
k=num_points,
|
356 |
+
dim=0)
|
357 |
+
point_indices = indices[point_indices].unsqueeze(0)
|
358 |
+
|
359 |
+
point_coords = torch.zeros(R,
|
360 |
+
num_points,
|
361 |
+
2,
|
362 |
+
dtype=torch.long,
|
363 |
+
device=uncertainty_map.device)
|
364 |
+
# point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step
|
365 |
+
# point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step
|
366 |
+
point_coords[:, :, 0] = (point_indices % W).to(torch.long)
|
367 |
+
point_coords[:, :, 1] = (point_indices // W).to(torch.long)
|
368 |
+
# print (point_scores.min(), point_scores.max())
|
369 |
+
return point_indices, point_coords
|
370 |
+
|
371 |
+
|
372 |
+
def calculate_uncertainty(logits, classes=None, balance_value=0.5):
|
373 |
+
"""
|
374 |
+
We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the
|
375 |
+
foreground class in `classes`.
|
376 |
+
Args:
|
377 |
+
logits (Tensor): A tensor of shape (R, C, ...) or (R, 1, ...) for class-specific or
|
378 |
+
class-agnostic, where R is the total number of predicted masks in all images and C is
|
379 |
+
the number of foreground classes. The values are logits.
|
380 |
+
classes (list): A list of length R that contains either predicted of ground truth class
|
381 |
+
for eash predicted mask.
|
382 |
+
Returns:
|
383 |
+
scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with
|
384 |
+
the most uncertain locations having the highest uncertainty score.
|
385 |
+
"""
|
386 |
+
if logits.shape[1] == 1:
|
387 |
+
gt_class_logits = logits
|
388 |
+
else:
|
389 |
+
gt_class_logits = logits[
|
390 |
+
torch.arange(logits.shape[0], device=logits.device),
|
391 |
+
classes].unsqueeze(1)
|
392 |
+
return -torch.abs(gt_class_logits - balance_value)
|
lib/common/smpl_vert_segmentation.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
lib/common/train_util.py
ADDED
@@ -0,0 +1,599 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
5 |
+
# holder of all proprietary rights on this computer program.
|
6 |
+
# You can only use this computer program if you have closed
|
7 |
+
# a license agreement with MPG or you get the right to use the computer
|
8 |
+
# program from someone who is authorized to grant you that right.
|
9 |
+
# Any use of the computer program without a valid license is prohibited and
|
10 |
+
# liable to prosecution.
|
11 |
+
#
|
12 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
13 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
14 |
+
# for Intelligent Systems. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: ps-license@tuebingen.mpg.de
|
17 |
+
|
18 |
+
import yaml
|
19 |
+
import os.path as osp
|
20 |
+
import torch
|
21 |
+
import numpy as np
|
22 |
+
import torch.nn.functional as F
|
23 |
+
from ..dataset.mesh_util import *
|
24 |
+
from ..net.geometry import orthogonal
|
25 |
+
from pytorch3d.renderer.mesh import rasterize_meshes
|
26 |
+
from .render_utils import Pytorch3dRasterizer
|
27 |
+
from pytorch3d.structures import Meshes
|
28 |
+
import cv2
|
29 |
+
from PIL import Image
|
30 |
+
from tqdm import tqdm
|
31 |
+
import os
|
32 |
+
from termcolor import colored
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
def reshape_sample_tensor(sample_tensor, num_views):
|
38 |
+
if num_views == 1:
|
39 |
+
return sample_tensor
|
40 |
+
# Need to repeat sample_tensor along the batch dim num_views times
|
41 |
+
sample_tensor = sample_tensor.unsqueeze(dim=1)
|
42 |
+
sample_tensor = sample_tensor.repeat(1, num_views, 1, 1)
|
43 |
+
sample_tensor = sample_tensor.view(
|
44 |
+
sample_tensor.shape[0] * sample_tensor.shape[1],
|
45 |
+
sample_tensor.shape[2], sample_tensor.shape[3])
|
46 |
+
return sample_tensor
|
47 |
+
|
48 |
+
|
49 |
+
def gen_mesh_eval(opt, net, cuda, data, resolution=None):
|
50 |
+
resolution = opt.resolution if resolution is None else resolution
|
51 |
+
image_tensor = data['img'].to(device=cuda)
|
52 |
+
calib_tensor = data['calib'].to(device=cuda)
|
53 |
+
|
54 |
+
net.filter(image_tensor)
|
55 |
+
|
56 |
+
b_min = data['b_min']
|
57 |
+
b_max = data['b_max']
|
58 |
+
try:
|
59 |
+
verts, faces, _, _ = reconstruction_faster(net,
|
60 |
+
cuda,
|
61 |
+
calib_tensor,
|
62 |
+
resolution,
|
63 |
+
b_min,
|
64 |
+
b_max,
|
65 |
+
use_octree=False)
|
66 |
+
|
67 |
+
except Exception as e:
|
68 |
+
print(e)
|
69 |
+
print('Can not create marching cubes at this time.')
|
70 |
+
verts, faces = None, None
|
71 |
+
return verts, faces
|
72 |
+
|
73 |
+
|
74 |
+
def gen_mesh(opt, net, cuda, data, save_path, resolution=None):
|
75 |
+
resolution = opt.resolution if resolution is None else resolution
|
76 |
+
image_tensor = data['img'].to(device=cuda)
|
77 |
+
calib_tensor = data['calib'].to(device=cuda)
|
78 |
+
|
79 |
+
net.filter(image_tensor)
|
80 |
+
|
81 |
+
b_min = data['b_min']
|
82 |
+
b_max = data['b_max']
|
83 |
+
try:
|
84 |
+
save_img_path = save_path[:-4] + '.png'
|
85 |
+
save_img_list = []
|
86 |
+
for v in range(image_tensor.shape[0]):
|
87 |
+
save_img = (np.transpose(image_tensor[v].detach().cpu().numpy(),
|
88 |
+
(1, 2, 0)) * 0.5 +
|
89 |
+
0.5)[:, :, ::-1] * 255.0
|
90 |
+
save_img_list.append(save_img)
|
91 |
+
save_img = np.concatenate(save_img_list, axis=1)
|
92 |
+
Image.fromarray(np.uint8(save_img[:, :, ::-1])).save(save_img_path)
|
93 |
+
|
94 |
+
verts, faces, _, _ = reconstruction_faster(net, cuda, calib_tensor,
|
95 |
+
resolution, b_min, b_max)
|
96 |
+
verts_tensor = torch.from_numpy(
|
97 |
+
verts.T).unsqueeze(0).to(device=cuda).float()
|
98 |
+
xyz_tensor = net.projection(verts_tensor, calib_tensor[:1])
|
99 |
+
uv = xyz_tensor[:, :2, :]
|
100 |
+
color = netG.index(image_tensor[:1], uv).detach().cpu().numpy()[0].T
|
101 |
+
color = color * 0.5 + 0.5
|
102 |
+
save_obj_mesh_with_color(save_path, verts, faces, color)
|
103 |
+
except Exception as e:
|
104 |
+
print(e)
|
105 |
+
print('Can not create marching cubes at this time.')
|
106 |
+
verts, faces, color = None, None, None
|
107 |
+
return verts, faces, color
|
108 |
+
|
109 |
+
|
110 |
+
def gen_mesh_color(opt, netG, netC, cuda, data, save_path, use_octree=True):
|
111 |
+
image_tensor = data['img'].to(device=cuda)
|
112 |
+
calib_tensor = data['calib'].to(device=cuda)
|
113 |
+
|
114 |
+
netG.filter(image_tensor)
|
115 |
+
netC.filter(image_tensor)
|
116 |
+
netC.attach(netG.get_im_feat())
|
117 |
+
|
118 |
+
b_min = data['b_min']
|
119 |
+
b_max = data['b_max']
|
120 |
+
try:
|
121 |
+
save_img_path = save_path[:-4] + '.png'
|
122 |
+
save_img_list = []
|
123 |
+
for v in range(image_tensor.shape[0]):
|
124 |
+
save_img = (np.transpose(image_tensor[v].detach().cpu().numpy(),
|
125 |
+
(1, 2, 0)) * 0.5 +
|
126 |
+
0.5)[:, :, ::-1] * 255.0
|
127 |
+
save_img_list.append(save_img)
|
128 |
+
save_img = np.concatenate(save_img_list, axis=1)
|
129 |
+
Image.fromarray(np.uint8(save_img[:, :, ::-1])).save(save_img_path)
|
130 |
+
|
131 |
+
verts, faces, _, _ = reconstruction_faster(netG,
|
132 |
+
cuda,
|
133 |
+
calib_tensor,
|
134 |
+
opt.resolution,
|
135 |
+
b_min,
|
136 |
+
b_max,
|
137 |
+
use_octree=use_octree)
|
138 |
+
|
139 |
+
# Now Getting colors
|
140 |
+
verts_tensor = torch.from_numpy(
|
141 |
+
verts.T).unsqueeze(0).to(device=cuda).float()
|
142 |
+
verts_tensor = reshape_sample_tensor(verts_tensor, opt.num_views)
|
143 |
+
color = np.zeros(verts.shape)
|
144 |
+
interval = 10000
|
145 |
+
for i in range(len(color) // interval):
|
146 |
+
left = i * interval
|
147 |
+
right = i * interval + interval
|
148 |
+
if i == len(color) // interval - 1:
|
149 |
+
right = -1
|
150 |
+
netC.query(verts_tensor[:, :, left:right], calib_tensor)
|
151 |
+
rgb = netC.get_preds()[0].detach().cpu().numpy() * 0.5 + 0.5
|
152 |
+
color[left:right] = rgb.T
|
153 |
+
|
154 |
+
save_obj_mesh_with_color(save_path, verts, faces, color)
|
155 |
+
except Exception as e:
|
156 |
+
print(e)
|
157 |
+
print('Can not create marching cubes at this time.')
|
158 |
+
verts, faces, color = None, None, None
|
159 |
+
return verts, faces, color
|
160 |
+
|
161 |
+
|
162 |
+
def adjust_learning_rate(optimizer, epoch, lr, schedule, gamma):
|
163 |
+
"""Sets the learning rate to the initial LR decayed by schedule"""
|
164 |
+
if epoch in schedule:
|
165 |
+
lr *= gamma
|
166 |
+
for param_group in optimizer.param_groups:
|
167 |
+
param_group['lr'] = lr
|
168 |
+
return lr
|
169 |
+
|
170 |
+
|
171 |
+
def compute_acc(pred, gt, thresh=0.5):
|
172 |
+
'''
|
173 |
+
return:
|
174 |
+
IOU, precision, and recall
|
175 |
+
'''
|
176 |
+
with torch.no_grad():
|
177 |
+
vol_pred = pred > thresh
|
178 |
+
vol_gt = gt > thresh
|
179 |
+
|
180 |
+
union = vol_pred | vol_gt
|
181 |
+
inter = vol_pred & vol_gt
|
182 |
+
|
183 |
+
true_pos = inter.sum().float()
|
184 |
+
|
185 |
+
union = union.sum().float()
|
186 |
+
if union == 0:
|
187 |
+
union = 1
|
188 |
+
vol_pred = vol_pred.sum().float()
|
189 |
+
if vol_pred == 0:
|
190 |
+
vol_pred = 1
|
191 |
+
vol_gt = vol_gt.sum().float()
|
192 |
+
if vol_gt == 0:
|
193 |
+
vol_gt = 1
|
194 |
+
return true_pos / union, true_pos / vol_pred, true_pos / vol_gt
|
195 |
+
|
196 |
+
|
197 |
+
# def calc_metrics(opt, net, cuda, dataset, num_tests,
|
198 |
+
# resolution=128, sampled_points=1000, use_kaolin=True):
|
199 |
+
# if num_tests > len(dataset):
|
200 |
+
# num_tests = len(dataset)
|
201 |
+
# with torch.no_grad():
|
202 |
+
# chamfer_arr, p2s_arr = [], []
|
203 |
+
# for idx in tqdm(range(num_tests)):
|
204 |
+
# data = dataset[idx * len(dataset) // num_tests]
|
205 |
+
|
206 |
+
# verts, faces = gen_mesh_eval(opt, net, cuda, data, resolution)
|
207 |
+
# if verts is None:
|
208 |
+
# continue
|
209 |
+
|
210 |
+
# mesh_gt = trimesh.load(data['mesh_path'])
|
211 |
+
# mesh_gt = mesh_gt.split(only_watertight=False)
|
212 |
+
# comp_num = [mesh.vertices.shape[0] for mesh in mesh_gt]
|
213 |
+
# mesh_gt = mesh_gt[comp_num.index(max(comp_num))]
|
214 |
+
|
215 |
+
# mesh_pred = trimesh.Trimesh(verts, faces)
|
216 |
+
|
217 |
+
# gt_surface_pts, _ = trimesh.sample.sample_surface_even(
|
218 |
+
# mesh_gt, sampled_points)
|
219 |
+
# pred_surface_pts, _ = trimesh.sample.sample_surface_even(
|
220 |
+
# mesh_pred, sampled_points)
|
221 |
+
|
222 |
+
# if use_kaolin and has_kaolin:
|
223 |
+
# kal_mesh_gt = kal.rep.TriangleMesh.from_tensors(
|
224 |
+
# torch.tensor(mesh_gt.vertices).float().to(device=cuda),
|
225 |
+
# torch.tensor(mesh_gt.faces).long().to(device=cuda))
|
226 |
+
# kal_mesh_pred = kal.rep.TriangleMesh.from_tensors(
|
227 |
+
# torch.tensor(mesh_pred.vertices).float().to(device=cuda),
|
228 |
+
# torch.tensor(mesh_pred.faces).long().to(device=cuda))
|
229 |
+
|
230 |
+
# kal_distance_0 = kal.metrics.mesh.point_to_surface(
|
231 |
+
# torch.tensor(pred_surface_pts).float().to(device=cuda), kal_mesh_gt)
|
232 |
+
# kal_distance_1 = kal.metrics.mesh.point_to_surface(
|
233 |
+
# torch.tensor(gt_surface_pts).float().to(device=cuda), kal_mesh_pred)
|
234 |
+
|
235 |
+
# dist_gt_pred = torch.sqrt(kal_distance_0).cpu().numpy()
|
236 |
+
# dist_pred_gt = torch.sqrt(kal_distance_1).cpu().numpy()
|
237 |
+
# else:
|
238 |
+
# try:
|
239 |
+
# _, dist_pred_gt, _ = trimesh.proximity.closest_point(mesh_pred, gt_surface_pts)
|
240 |
+
# _, dist_gt_pred, _ = trimesh.proximity.closest_point(mesh_gt, pred_surface_pts)
|
241 |
+
# except Exception as e:
|
242 |
+
# print (e)
|
243 |
+
# continue
|
244 |
+
|
245 |
+
# chamfer_dist = 0.5 * (dist_pred_gt.mean() + dist_gt_pred.mean())
|
246 |
+
# p2s_dist = dist_pred_gt.mean()
|
247 |
+
|
248 |
+
# chamfer_arr.append(chamfer_dist)
|
249 |
+
# p2s_arr.append(p2s_dist)
|
250 |
+
|
251 |
+
# return np.average(chamfer_arr), np.average(p2s_arr)
|
252 |
+
|
253 |
+
|
254 |
+
def calc_error(opt, net, cuda, dataset, num_tests):
|
255 |
+
if num_tests > len(dataset):
|
256 |
+
num_tests = len(dataset)
|
257 |
+
with torch.no_grad():
|
258 |
+
erorr_arr, IOU_arr, prec_arr, recall_arr = [], [], [], []
|
259 |
+
for idx in tqdm(range(num_tests)):
|
260 |
+
data = dataset[idx * len(dataset) // num_tests]
|
261 |
+
# retrieve the data
|
262 |
+
image_tensor = data['img'].to(device=cuda)
|
263 |
+
calib_tensor = data['calib'].to(device=cuda)
|
264 |
+
sample_tensor = data['samples'].to(device=cuda).unsqueeze(0)
|
265 |
+
if opt.num_views > 1:
|
266 |
+
sample_tensor = reshape_sample_tensor(sample_tensor,
|
267 |
+
opt.num_views)
|
268 |
+
label_tensor = data['labels'].to(device=cuda).unsqueeze(0)
|
269 |
+
|
270 |
+
res, error = net.forward(image_tensor,
|
271 |
+
sample_tensor,
|
272 |
+
calib_tensor,
|
273 |
+
labels=label_tensor)
|
274 |
+
|
275 |
+
IOU, prec, recall = compute_acc(res, label_tensor)
|
276 |
+
|
277 |
+
# print(
|
278 |
+
# '{0}/{1} | Error: {2:06f} IOU: {3:06f} prec: {4:06f} recall: {5:06f}'
|
279 |
+
# .format(idx, num_tests, error.item(), IOU.item(), prec.item(), recall.item()))
|
280 |
+
erorr_arr.append(error.item())
|
281 |
+
IOU_arr.append(IOU.item())
|
282 |
+
prec_arr.append(prec.item())
|
283 |
+
recall_arr.append(recall.item())
|
284 |
+
|
285 |
+
return np.average(erorr_arr), np.average(IOU_arr), np.average(
|
286 |
+
prec_arr), np.average(recall_arr)
|
287 |
+
|
288 |
+
|
289 |
+
def calc_error_color(opt, netG, netC, cuda, dataset, num_tests):
|
290 |
+
if num_tests > len(dataset):
|
291 |
+
num_tests = len(dataset)
|
292 |
+
with torch.no_grad():
|
293 |
+
error_color_arr = []
|
294 |
+
|
295 |
+
for idx in tqdm(range(num_tests)):
|
296 |
+
data = dataset[idx * len(dataset) // num_tests]
|
297 |
+
# retrieve the data
|
298 |
+
image_tensor = data['img'].to(device=cuda)
|
299 |
+
calib_tensor = data['calib'].to(device=cuda)
|
300 |
+
color_sample_tensor = data['color_samples'].to(
|
301 |
+
device=cuda).unsqueeze(0)
|
302 |
+
|
303 |
+
if opt.num_views > 1:
|
304 |
+
color_sample_tensor = reshape_sample_tensor(
|
305 |
+
color_sample_tensor, opt.num_views)
|
306 |
+
|
307 |
+
rgb_tensor = data['rgbs'].to(device=cuda).unsqueeze(0)
|
308 |
+
|
309 |
+
netG.filter(image_tensor)
|
310 |
+
_, errorC = netC.forward(image_tensor,
|
311 |
+
netG.get_im_feat(),
|
312 |
+
color_sample_tensor,
|
313 |
+
calib_tensor,
|
314 |
+
labels=rgb_tensor)
|
315 |
+
|
316 |
+
# print('{0}/{1} | Error inout: {2:06f} | Error color: {3:06f}'
|
317 |
+
# .format(idx, num_tests, errorG.item(), errorC.item()))
|
318 |
+
error_color_arr.append(errorC.item())
|
319 |
+
|
320 |
+
return np.average(error_color_arr)
|
321 |
+
|
322 |
+
|
323 |
+
# pytorch lightning training related fucntions
|
324 |
+
|
325 |
+
|
326 |
+
def query_func(opt, netG, features, points, proj_matrix=None):
|
327 |
+
'''
|
328 |
+
- points: size of (bz, N, 3)
|
329 |
+
- proj_matrix: size of (bz, 4, 4)
|
330 |
+
return: size of (bz, 1, N)
|
331 |
+
'''
|
332 |
+
assert len(points) == 1
|
333 |
+
samples = points.repeat(opt.num_views, 1, 1)
|
334 |
+
samples = samples.permute(0, 2, 1) # [bz, 3, N]
|
335 |
+
|
336 |
+
# view specific query
|
337 |
+
if proj_matrix is not None:
|
338 |
+
samples = orthogonal(samples, proj_matrix)
|
339 |
+
|
340 |
+
calib_tensor = torch.stack([torch.eye(4).float()], dim=0).type_as(samples)
|
341 |
+
|
342 |
+
preds = netG.query(features=features,
|
343 |
+
points=samples,
|
344 |
+
calibs=calib_tensor,
|
345 |
+
regressor=netG.if_regressor)
|
346 |
+
|
347 |
+
if type(preds) is list:
|
348 |
+
preds = preds[0]
|
349 |
+
|
350 |
+
return preds
|
351 |
+
|
352 |
+
|
353 |
+
def isin(ar1, ar2):
|
354 |
+
return (ar1[..., None] == ar2).any(-1)
|
355 |
+
|
356 |
+
|
357 |
+
def in1d(ar1, ar2):
|
358 |
+
mask = ar2.new_zeros((max(ar1.max(), ar2.max()) + 1, ), dtype=torch.bool)
|
359 |
+
mask[ar2.unique()] = True
|
360 |
+
return mask[ar1]
|
361 |
+
|
362 |
+
|
363 |
+
def get_visibility(xy, z, faces):
|
364 |
+
"""get the visibility of vertices
|
365 |
+
|
366 |
+
Args:
|
367 |
+
xy (torch.tensor): [N,2]
|
368 |
+
z (torch.tensor): [N,1]
|
369 |
+
faces (torch.tensor): [N,3]
|
370 |
+
size (int): resolution of rendered image
|
371 |
+
"""
|
372 |
+
|
373 |
+
xyz = torch.cat((xy, -z), dim=1)
|
374 |
+
xyz = (xyz + 1.0) / 2.0
|
375 |
+
faces = faces.long()
|
376 |
+
|
377 |
+
rasterizer = Pytorch3dRasterizer(image_size=2**12)
|
378 |
+
meshes_screen = Meshes(verts=xyz[None, ...], faces=faces[None, ...])
|
379 |
+
raster_settings = rasterizer.raster_settings
|
380 |
+
|
381 |
+
pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
|
382 |
+
meshes_screen,
|
383 |
+
image_size=raster_settings.image_size,
|
384 |
+
blur_radius=raster_settings.blur_radius,
|
385 |
+
faces_per_pixel=raster_settings.faces_per_pixel,
|
386 |
+
bin_size=raster_settings.bin_size,
|
387 |
+
max_faces_per_bin=raster_settings.max_faces_per_bin,
|
388 |
+
perspective_correct=raster_settings.perspective_correct,
|
389 |
+
cull_backfaces=raster_settings.cull_backfaces,
|
390 |
+
)
|
391 |
+
|
392 |
+
vis_vertices_id = torch.unique(faces[torch.unique(pix_to_face), :])
|
393 |
+
vis_mask = torch.zeros(size=(z.shape[0], 1))
|
394 |
+
vis_mask[vis_vertices_id] = 1.0
|
395 |
+
|
396 |
+
# print("------------------------\n")
|
397 |
+
# print(f"keep points : {vis_mask.sum()/len(vis_mask)}")
|
398 |
+
|
399 |
+
return vis_mask
|
400 |
+
|
401 |
+
|
402 |
+
def batch_mean(res, key):
|
403 |
+
# recursive mean for multilevel dicts
|
404 |
+
return torch.stack([
|
405 |
+
x[key] if isinstance(x, dict) else batch_mean(x, key) for x in res
|
406 |
+
]).mean()
|
407 |
+
|
408 |
+
|
409 |
+
def tf_log_convert(log_dict):
|
410 |
+
new_log_dict = log_dict.copy()
|
411 |
+
for k, v in log_dict.items():
|
412 |
+
new_log_dict[k.replace("_", "/")] = v
|
413 |
+
del new_log_dict[k]
|
414 |
+
|
415 |
+
return new_log_dict
|
416 |
+
|
417 |
+
|
418 |
+
def bar_log_convert(log_dict, name=None, rot=None):
|
419 |
+
from decimal import Decimal
|
420 |
+
|
421 |
+
new_log_dict = {}
|
422 |
+
|
423 |
+
if name is not None:
|
424 |
+
new_log_dict['name'] = name[0]
|
425 |
+
if rot is not None:
|
426 |
+
new_log_dict['rot'] = rot[0]
|
427 |
+
|
428 |
+
for k, v in log_dict.items():
|
429 |
+
color = "yellow"
|
430 |
+
if 'loss' in k:
|
431 |
+
color = "red"
|
432 |
+
k = k.replace("loss", "L")
|
433 |
+
elif 'acc' in k:
|
434 |
+
color = "green"
|
435 |
+
k = k.replace("acc", "A")
|
436 |
+
elif 'iou' in k:
|
437 |
+
color = "green"
|
438 |
+
k = k.replace("iou", "I")
|
439 |
+
elif 'prec' in k:
|
440 |
+
color = "green"
|
441 |
+
k = k.replace("prec", "P")
|
442 |
+
elif 'recall' in k:
|
443 |
+
color = "green"
|
444 |
+
k = k.replace("recall", "R")
|
445 |
+
|
446 |
+
if 'lr' not in k:
|
447 |
+
new_log_dict[colored(k.split("_")[1],
|
448 |
+
color)] = colored(f"{v:.3f}", color)
|
449 |
+
else:
|
450 |
+
new_log_dict[colored(k.split("_")[1],
|
451 |
+
color)] = colored(f"{Decimal(str(v)):.1E}",
|
452 |
+
color)
|
453 |
+
|
454 |
+
if 'loss' in new_log_dict.keys():
|
455 |
+
del new_log_dict['loss']
|
456 |
+
|
457 |
+
return new_log_dict
|
458 |
+
|
459 |
+
|
460 |
+
def accumulate(outputs, rot_num, split):
|
461 |
+
|
462 |
+
hparam_log_dict = {}
|
463 |
+
|
464 |
+
metrics = outputs[0].keys()
|
465 |
+
datasets = split.keys()
|
466 |
+
|
467 |
+
for dataset in datasets:
|
468 |
+
for metric in metrics:
|
469 |
+
keyword = f"hparam/{dataset}-{metric}"
|
470 |
+
if keyword not in hparam_log_dict.keys():
|
471 |
+
hparam_log_dict[keyword] = 0
|
472 |
+
for idx in range(split[dataset][0] * rot_num,
|
473 |
+
split[dataset][1] * rot_num):
|
474 |
+
hparam_log_dict[keyword] += outputs[idx][metric]
|
475 |
+
hparam_log_dict[keyword] /= (split[dataset][1] -
|
476 |
+
split[dataset][0]) * rot_num
|
477 |
+
|
478 |
+
print(colored(hparam_log_dict, "green"))
|
479 |
+
|
480 |
+
return hparam_log_dict
|
481 |
+
|
482 |
+
|
483 |
+
def calc_error_N(outputs, targets):
|
484 |
+
"""calculate the error of normal (IGR)
|
485 |
+
|
486 |
+
Args:
|
487 |
+
outputs (torch.tensor): [B, 3, N]
|
488 |
+
target (torch.tensor): [B, N, 3]
|
489 |
+
|
490 |
+
# manifold loss and grad_loss in IGR paper
|
491 |
+
grad_loss = ((nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean()
|
492 |
+
normals_loss = ((mnfld_grad - normals).abs()).norm(2, dim=1).mean()
|
493 |
+
|
494 |
+
Returns:
|
495 |
+
torch.tensor: error of valid normals on the surface
|
496 |
+
"""
|
497 |
+
# outputs = torch.tanh(-outputs.permute(0,2,1).reshape(-1,3))
|
498 |
+
outputs = -outputs.permute(0, 2, 1).reshape(-1, 1)
|
499 |
+
targets = targets.reshape(-1, 3)[:, 2:3]
|
500 |
+
with_normals = targets.sum(dim=1).abs() > 0.0
|
501 |
+
|
502 |
+
# eikonal loss
|
503 |
+
grad_loss = ((outputs[with_normals].norm(2, dim=-1) - 1)**2).mean()
|
504 |
+
# normals loss
|
505 |
+
normal_loss = (outputs - targets)[with_normals].abs().norm(2, dim=1).mean()
|
506 |
+
|
507 |
+
return grad_loss * 0.0 + normal_loss
|
508 |
+
|
509 |
+
|
510 |
+
def calc_knn_acc(preds, carn_verts, labels, pick_num):
|
511 |
+
"""calculate knn accuracy
|
512 |
+
|
513 |
+
Args:
|
514 |
+
preds (torch.tensor): [B, 3, N]
|
515 |
+
carn_verts (torch.tensor): [SMPLX_V_num, 3]
|
516 |
+
labels (torch.tensor): [B, N_knn, N]
|
517 |
+
"""
|
518 |
+
N_knn_full = labels.shape[1]
|
519 |
+
preds = preds.permute(0, 2, 1).reshape(-1, 3)
|
520 |
+
labels = labels.permute(0, 2, 1).reshape(-1, N_knn_full) # [BxN, num_knn]
|
521 |
+
labels = labels[:, :pick_num]
|
522 |
+
|
523 |
+
dist = torch.cdist(preds, carn_verts, p=2) # [BxN, SMPL_V_num]
|
524 |
+
knn = dist.topk(k=pick_num, dim=1, largest=False)[1] # [BxN, num_knn]
|
525 |
+
cat_mat = torch.sort(torch.cat((knn, labels), dim=1))[0]
|
526 |
+
bool_col = torch.zeros_like(cat_mat)[:, 0]
|
527 |
+
for i in range(pick_num * 2 - 1):
|
528 |
+
bool_col += cat_mat[:, i] == cat_mat[:, i + 1]
|
529 |
+
acc = (bool_col > 0).sum() / len(bool_col)
|
530 |
+
|
531 |
+
return acc
|
532 |
+
|
533 |
+
|
534 |
+
def calc_acc_seg(output, target, num_multiseg):
|
535 |
+
from pytorch_lightning.metrics import Accuracy
|
536 |
+
return Accuracy()(output.reshape(-1, num_multiseg).cpu(),
|
537 |
+
target.flatten().cpu())
|
538 |
+
|
539 |
+
|
540 |
+
def add_watermark(imgs, titles):
|
541 |
+
|
542 |
+
# Write some Text
|
543 |
+
|
544 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
545 |
+
bottomLeftCornerOfText = (350, 50)
|
546 |
+
bottomRightCornerOfText = (800, 50)
|
547 |
+
fontScale = 1
|
548 |
+
fontColor = (1.0, 1.0, 1.0)
|
549 |
+
lineType = 2
|
550 |
+
|
551 |
+
for i in range(len(imgs)):
|
552 |
+
|
553 |
+
title = titles[i + 1]
|
554 |
+
cv2.putText(imgs[i], title, bottomLeftCornerOfText, font, fontScale,
|
555 |
+
fontColor, lineType)
|
556 |
+
|
557 |
+
if i == 0:
|
558 |
+
cv2.putText(imgs[i], str(titles[i][0]), bottomRightCornerOfText,
|
559 |
+
font, fontScale, fontColor, lineType)
|
560 |
+
|
561 |
+
result = np.concatenate(imgs, axis=0).transpose(2, 0, 1)
|
562 |
+
|
563 |
+
return result
|
564 |
+
|
565 |
+
|
566 |
+
def make_test_gif(img_dir):
|
567 |
+
|
568 |
+
if img_dir is not None and len(os.listdir(img_dir)) > 0:
|
569 |
+
for dataset in os.listdir(img_dir):
|
570 |
+
for subject in sorted(os.listdir(osp.join(img_dir, dataset))):
|
571 |
+
img_lst = []
|
572 |
+
im1 = None
|
573 |
+
for file in sorted(
|
574 |
+
os.listdir(osp.join(img_dir, dataset, subject))):
|
575 |
+
if file[-3:] not in ['obj', 'gif']:
|
576 |
+
img_path = os.path.join(img_dir, dataset, subject,
|
577 |
+
file)
|
578 |
+
if im1 == None:
|
579 |
+
im1 = Image.open(img_path)
|
580 |
+
else:
|
581 |
+
img_lst.append(Image.open(img_path))
|
582 |
+
|
583 |
+
print(os.path.join(img_dir, dataset, subject, "out.gif"))
|
584 |
+
im1.save(os.path.join(img_dir, dataset, subject, "out.gif"),
|
585 |
+
save_all=True,
|
586 |
+
append_images=img_lst,
|
587 |
+
duration=500,
|
588 |
+
loop=0)
|
589 |
+
|
590 |
+
|
591 |
+
def export_cfg(logger, cfg):
|
592 |
+
|
593 |
+
cfg_export_file = osp.join(logger.save_dir, logger.name,
|
594 |
+
f"version_{logger.version}", "cfg.yaml")
|
595 |
+
|
596 |
+
if not osp.exists(cfg_export_file):
|
597 |
+
os.makedirs(osp.dirname(cfg_export_file), exist_ok=True)
|
598 |
+
with open(cfg_export_file, "w+") as file:
|
599 |
+
_ = yaml.dump(cfg, file)
|