YoonaAI commited on
Commit
208e8a7
·
1 Parent(s): 7853fb6

Upload 6 files

Browse files
lib/common/render.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems. All rights reserved.
14
+ #
15
+ # Contact: ps-license@tuebingen.mpg.de
16
+
17
+ from pytorch3d.renderer import (
18
+ BlendParams,
19
+ blending,
20
+ look_at_view_transform,
21
+ FoVOrthographicCameras,
22
+ PointLights,
23
+ RasterizationSettings,
24
+ PointsRasterizationSettings,
25
+ PointsRenderer,
26
+ AlphaCompositor,
27
+ PointsRasterizer,
28
+ MeshRenderer,
29
+ MeshRasterizer,
30
+ SoftPhongShader,
31
+ SoftSilhouetteShader,
32
+ TexturesVertex,
33
+ )
34
+ from pytorch3d.renderer.mesh import TexturesVertex
35
+ from pytorch3d.structures import Meshes
36
+
37
+ import os, subprocess
38
+
39
+ from lib.dataset.mesh_util import SMPLX, get_visibility
40
+ import lib.common.render_utils as util
41
+ import torch
42
+ import numpy as np
43
+ from PIL import Image
44
+ from tqdm import tqdm
45
+ import cv2
46
+ import math
47
+ from termcolor import colored
48
+
49
+
50
+ def image2vid(images, vid_path):
51
+
52
+ w, h = images[0].size
53
+ videodims = (w, h)
54
+ fourcc = cv2.VideoWriter_fourcc(*'XVID')
55
+ video = cv2.VideoWriter(vid_path, fourcc, 30, videodims)
56
+ for image in images:
57
+ video.write(cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR))
58
+ video.release()
59
+
60
+
61
+ def query_color(verts, faces, image, device):
62
+ """query colors from points and image
63
+
64
+ Args:
65
+ verts ([B, 3]): [query verts]
66
+ faces ([M, 3]): [query faces]
67
+ image ([B, 3, H, W]): [full image]
68
+
69
+ Returns:
70
+ [np.float]: [return colors]
71
+ """
72
+
73
+ verts = verts.float().to(device)
74
+ faces = faces.long().to(device)
75
+
76
+ (xy, z) = verts.split([2, 1], dim=1)
77
+ visibility = get_visibility(xy, z, faces[:, [0, 2, 1]]).flatten()
78
+ uv = xy.unsqueeze(0).unsqueeze(2) # [B, N, 2]
79
+ uv = uv * torch.tensor([1.0, -1.0]).type_as(uv)
80
+ colors = (torch.nn.functional.grid_sample(image, uv, align_corners=True)[
81
+ 0, :, :, 0].permute(1, 0) + 1.0) * 0.5 * 255.0
82
+ colors[visibility == 0.0] = ((Meshes(verts.unsqueeze(0), faces.unsqueeze(
83
+ 0)).verts_normals_padded().squeeze(0) + 1.0) * 0.5 * 255.0)[visibility == 0.0]
84
+
85
+ return colors.detach().cpu()
86
+
87
+
88
+ class cleanShader(torch.nn.Module):
89
+ def __init__(self, device="cpu", cameras=None, blend_params=None):
90
+ super().__init__()
91
+ self.cameras = cameras
92
+ self.blend_params = blend_params if blend_params is not None else BlendParams()
93
+
94
+ def forward(self, fragments, meshes, **kwargs):
95
+ cameras = kwargs.get("cameras", self.cameras)
96
+ if cameras is None:
97
+ msg = "Cameras must be specified either at initialization \
98
+ or in the forward pass of TexturedSoftPhongShader"
99
+
100
+ raise ValueError(msg)
101
+
102
+ # get renderer output
103
+ blend_params = kwargs.get("blend_params", self.blend_params)
104
+ texels = meshes.sample_textures(fragments)
105
+ images = blending.softmax_rgb_blend(
106
+ texels, fragments, blend_params, znear=-256, zfar=256
107
+ )
108
+
109
+ return images
110
+
111
+
112
+ class Render:
113
+ def __init__(self, size=512, device=torch.device("cuda:0")):
114
+ self.device = device
115
+ self.mesh_y_center = 100.0
116
+ self.dis = 100.0
117
+ self.scale = 1.0
118
+ self.size = size
119
+ self.cam_pos = [(0, 100, 100)]
120
+
121
+ self.mesh = None
122
+ self.deform_mesh = None
123
+ self.pcd = None
124
+ self.renderer = None
125
+ self.meshRas = None
126
+ self.type = None
127
+ self.knn = None
128
+ self.knn_inverse = None
129
+
130
+ self.smpl_seg = None
131
+ self.smpl_cmap = None
132
+
133
+ self.smplx = SMPLX()
134
+
135
+ self.uv_rasterizer = util.Pytorch3dRasterizer(self.size)
136
+
137
+ def get_camera(self, cam_id):
138
+
139
+ R, T = look_at_view_transform(
140
+ eye=[self.cam_pos[cam_id]],
141
+ at=((0, self.mesh_y_center, 0),),
142
+ up=((0, 1, 0),),
143
+ )
144
+
145
+ camera = FoVOrthographicCameras(
146
+ device=self.device,
147
+ R=R,
148
+ T=T,
149
+ znear=100.0,
150
+ zfar=-100.0,
151
+ max_y=100.0,
152
+ min_y=-100.0,
153
+ max_x=100.0,
154
+ min_x=-100.0,
155
+ scale_xyz=(self.scale * np.ones(3),),
156
+ )
157
+
158
+ return camera
159
+
160
+ def init_renderer(self, camera, type="clean_mesh", bg="gray"):
161
+
162
+ if "mesh" in type:
163
+
164
+ # rasterizer
165
+ self.raster_settings_mesh = RasterizationSettings(
166
+ image_size=self.size,
167
+ blur_radius=np.log(1.0 / 1e-4) * 1e-7,
168
+ faces_per_pixel=30,
169
+ )
170
+ self.meshRas = MeshRasterizer(
171
+ cameras=camera, raster_settings=self.raster_settings_mesh
172
+ )
173
+
174
+ if bg == "black":
175
+ blendparam = BlendParams(1e-4, 1e-4, (0.0, 0.0, 0.0))
176
+ elif bg == "white":
177
+ blendparam = BlendParams(1e-4, 1e-8, (1.0, 1.0, 1.0))
178
+ elif bg == "gray":
179
+ blendparam = BlendParams(1e-4, 1e-8, (0.5, 0.5, 0.5))
180
+
181
+ if type == "ori_mesh":
182
+
183
+ lights = PointLights(
184
+ device=self.device,
185
+ ambient_color=((0.8, 0.8, 0.8),),
186
+ diffuse_color=((0.2, 0.2, 0.2),),
187
+ specular_color=((0.0, 0.0, 0.0),),
188
+ location=[[0.0, 200.0, 0.0]],
189
+ )
190
+
191
+ self.renderer = MeshRenderer(
192
+ rasterizer=self.meshRas,
193
+ shader=SoftPhongShader(
194
+ device=self.device,
195
+ cameras=camera,
196
+ lights=lights,
197
+ blend_params=blendparam,
198
+ ),
199
+ )
200
+
201
+ if type == "silhouette":
202
+ self.raster_settings_silhouette = RasterizationSettings(
203
+ image_size=self.size,
204
+ blur_radius=np.log(1.0 / 1e-4 - 1.0) * 5e-5,
205
+ faces_per_pixel=50,
206
+ cull_backfaces=True,
207
+ )
208
+
209
+ self.silhouetteRas = MeshRasterizer(
210
+ cameras=camera, raster_settings=self.raster_settings_silhouette
211
+ )
212
+ self.renderer = MeshRenderer(
213
+ rasterizer=self.silhouetteRas, shader=SoftSilhouetteShader()
214
+ )
215
+
216
+ if type == "pointcloud":
217
+ self.raster_settings_pcd = PointsRasterizationSettings(
218
+ image_size=self.size, radius=0.006, points_per_pixel=10
219
+ )
220
+
221
+ self.pcdRas = PointsRasterizer(
222
+ cameras=camera, raster_settings=self.raster_settings_pcd
223
+ )
224
+ self.renderer = PointsRenderer(
225
+ rasterizer=self.pcdRas,
226
+ compositor=AlphaCompositor(background_color=(0, 0, 0)),
227
+ )
228
+
229
+ if type == "clean_mesh":
230
+
231
+ self.renderer = MeshRenderer(
232
+ rasterizer=self.meshRas,
233
+ shader=cleanShader(
234
+ device=self.device, cameras=camera, blend_params=blendparam
235
+ ),
236
+ )
237
+
238
+ def VF2Mesh(self, verts, faces):
239
+
240
+ if not torch.is_tensor(verts):
241
+ verts = torch.tensor(verts)
242
+ if not torch.is_tensor(faces):
243
+ faces = torch.tensor(faces)
244
+
245
+ if verts.ndimension() == 2:
246
+ verts = verts.unsqueeze(0).float()
247
+ if faces.ndimension() == 2:
248
+ faces = faces.unsqueeze(0).long()
249
+
250
+ verts = verts.to(self.device)
251
+ faces = faces.to(self.device)
252
+
253
+ mesh = Meshes(verts, faces).to(self.device)
254
+
255
+ mesh.textures = TexturesVertex(
256
+ verts_features=(mesh.verts_normals_padded() + 1.0) * 0.5
257
+ )
258
+
259
+ return mesh
260
+
261
+ def load_meshes(self, verts, faces):
262
+ """load mesh into the pytorch3d renderer
263
+
264
+ Args:
265
+ verts ([N,3]): verts
266
+ faces ([N,3]): faces
267
+ offset ([N,3]): offset
268
+ """
269
+
270
+ # camera setting
271
+ self.scale = 100.0
272
+ self.mesh_y_center = 0.0
273
+
274
+ self.cam_pos = [
275
+ (0, self.mesh_y_center, 100.0),
276
+ (100.0, self.mesh_y_center, 0),
277
+ (0, self.mesh_y_center, -100.0),
278
+ (-100.0, self.mesh_y_center, 0),
279
+ ]
280
+
281
+ self.type = "color"
282
+
283
+ if isinstance(verts, list):
284
+ self.meshes = []
285
+ for V, F in zip(verts, faces):
286
+ self.meshes.append(self.VF2Mesh(V, F))
287
+ else:
288
+ self.meshes = [self.VF2Mesh(verts, faces)]
289
+
290
+ def get_depth_map(self, cam_ids=[0, 2]):
291
+
292
+ depth_maps = []
293
+ for cam_id in cam_ids:
294
+ self.init_renderer(self.get_camera(cam_id), "clean_mesh", "gray")
295
+ fragments = self.meshRas(self.meshes[0])
296
+ depth_map = fragments.zbuf[..., 0].squeeze(0)
297
+ if cam_id == 2:
298
+ depth_map = torch.fliplr(depth_map)
299
+ depth_maps.append(depth_map)
300
+
301
+ return depth_maps
302
+
303
+ def get_rgb_image(self, cam_ids=[0, 2]):
304
+
305
+ images = []
306
+ for cam_id in range(len(self.cam_pos)):
307
+ if cam_id in cam_ids:
308
+ self.init_renderer(self.get_camera(
309
+ cam_id), "clean_mesh", "gray")
310
+ if len(cam_ids) == 4:
311
+ rendered_img = (
312
+ self.renderer(self.meshes[0])[
313
+ 0:1, :, :, :3].permute(0, 3, 1, 2)
314
+ - 0.5
315
+ ) * 2.0
316
+ else:
317
+ rendered_img = (
318
+ self.renderer(self.meshes[0])[
319
+ 0:1, :, :, :3].permute(0, 3, 1, 2)
320
+ - 0.5
321
+ ) * 2.0
322
+ if cam_id == 2 and len(cam_ids) == 2:
323
+ rendered_img = torch.flip(rendered_img, dims=[3])
324
+ images.append(rendered_img)
325
+
326
+ return images
327
+
328
+ def get_rendered_video(self, images, save_path):
329
+
330
+ tmp_path = save_path.replace('cloth', 'tmp')
331
+
332
+ self.cam_pos = []
333
+ for angle in range(0, 360, 3):
334
+ self.cam_pos.append(
335
+ (
336
+ 100.0 * math.cos(np.pi / 180 * angle),
337
+ self.mesh_y_center,
338
+ 100.0 * math.sin(np.pi / 180 * angle),
339
+ )
340
+ )
341
+
342
+ old_shape = np.array(images[0].shape[:2])
343
+ new_shape = np.around(
344
+ (self.size / old_shape[0]) * old_shape).astype(np.int)
345
+
346
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
347
+ video = cv2.VideoWriter(
348
+ tmp_path, fourcc, 30, (self.size * len(self.meshes) +
349
+ new_shape[1] * len(images), self.size)
350
+ )
351
+
352
+ pbar = tqdm(range(len(self.cam_pos)))
353
+ pbar.set_description(colored(f"exporting video {os.path.basename(save_path)}...", "blue"))
354
+ for cam_id in pbar:
355
+ self.init_renderer(self.get_camera(cam_id), "clean_mesh", "gray")
356
+
357
+ img_lst = [
358
+ np.array(Image.fromarray(img).resize(new_shape[::-1])).astype(np.uint8)[
359
+ :, :, [2, 1, 0]
360
+ ]
361
+ for img in images
362
+ ]
363
+
364
+ for mesh in self.meshes:
365
+ rendered_img = (
366
+ (self.renderer(mesh)[0, :, :, :3] * 255.0)
367
+ .detach()
368
+ .cpu()
369
+ .numpy()
370
+ .astype(np.uint8)
371
+ )
372
+
373
+ img_lst.append(rendered_img)
374
+ final_img = np.concatenate(img_lst, axis=1)
375
+ video.write(final_img)
376
+
377
+ video.release()
378
+
379
+ os.system(f'ffmpeg -y -loglevel quiet -stats -i {tmp_path} -c:v libx264 {save_path}')
380
+
381
+ def get_silhouette_image(self, cam_ids=[0, 2]):
382
+
383
+ images = []
384
+ for cam_id in range(len(self.cam_pos)):
385
+ if cam_id in cam_ids:
386
+ self.init_renderer(self.get_camera(cam_id), "silhouette")
387
+ rendered_img = self.renderer(self.meshes[0])[0:1, :, :, 3]
388
+ if cam_id == 2 and len(cam_ids) == 2:
389
+ rendered_img = torch.flip(rendered_img, dims=[2])
390
+ images.append(rendered_img)
391
+
392
+ return images
lib/common/render_utils.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: ps-license@tuebingen.mpg.de
17
+
18
+ import torch
19
+ from torch import nn
20
+ import trimesh
21
+ import math
22
+ from typing import NewType
23
+ from pytorch3d.structures import Meshes
24
+ from pytorch3d.renderer.mesh import rasterize_meshes
25
+
26
+ Tensor = NewType('Tensor', torch.Tensor)
27
+
28
+
29
+ def solid_angles(points: Tensor,
30
+ triangles: Tensor,
31
+ thresh: float = 1e-8) -> Tensor:
32
+ ''' Compute solid angle between the input points and triangles
33
+ Follows the method described in:
34
+ The Solid Angle of a Plane Triangle
35
+ A. VAN OOSTEROM AND J. STRACKEE
36
+ IEEE TRANSACTIONS ON BIOMEDICAL ENGINEERING,
37
+ VOL. BME-30, NO. 2, FEBRUARY 1983
38
+ Parameters
39
+ -----------
40
+ points: BxQx3
41
+ Tensor of input query points
42
+ triangles: BxFx3x3
43
+ Target triangles
44
+ thresh: float
45
+ float threshold
46
+ Returns
47
+ -------
48
+ solid_angles: BxQxF
49
+ A tensor containing the solid angle between all query points
50
+ and input triangles
51
+ '''
52
+ # Center the triangles on the query points. Size should be BxQxFx3x3
53
+ centered_tris = triangles[:, None] - points[:, :, None, None]
54
+
55
+ # BxQxFx3
56
+ norms = torch.norm(centered_tris, dim=-1)
57
+
58
+ # Should be BxQxFx3
59
+ cross_prod = torch.cross(centered_tris[:, :, :, 1],
60
+ centered_tris[:, :, :, 2],
61
+ dim=-1)
62
+ # Should be BxQxF
63
+ numerator = (centered_tris[:, :, :, 0] * cross_prod).sum(dim=-1)
64
+ del cross_prod
65
+
66
+ dot01 = (centered_tris[:, :, :, 0] * centered_tris[:, :, :, 1]).sum(dim=-1)
67
+ dot12 = (centered_tris[:, :, :, 1] * centered_tris[:, :, :, 2]).sum(dim=-1)
68
+ dot02 = (centered_tris[:, :, :, 0] * centered_tris[:, :, :, 2]).sum(dim=-1)
69
+ del centered_tris
70
+
71
+ denominator = (norms.prod(dim=-1) + dot01 * norms[:, :, :, 2] +
72
+ dot02 * norms[:, :, :, 1] + dot12 * norms[:, :, :, 0])
73
+ del dot01, dot12, dot02, norms
74
+
75
+ # Should be BxQ
76
+ solid_angle = torch.atan2(numerator, denominator)
77
+ del numerator, denominator
78
+
79
+ torch.cuda.empty_cache()
80
+
81
+ return 2 * solid_angle
82
+
83
+
84
+ def winding_numbers(points: Tensor,
85
+ triangles: Tensor,
86
+ thresh: float = 1e-8) -> Tensor:
87
+ ''' Uses winding_numbers to compute inside/outside
88
+ Robust inside-outside segmentation using generalized winding numbers
89
+ Alec Jacobson,
90
+ Ladislav Kavan,
91
+ Olga Sorkine-Hornung
92
+ Fast Winding Numbers for Soups and Clouds SIGGRAPH 2018
93
+ Gavin Barill
94
+ NEIL G. Dickson
95
+ Ryan Schmidt
96
+ David I.W. Levin
97
+ and Alec Jacobson
98
+ Parameters
99
+ -----------
100
+ points: BxQx3
101
+ Tensor of input query points
102
+ triangles: BxFx3x3
103
+ Target triangles
104
+ thresh: float
105
+ float threshold
106
+ Returns
107
+ -------
108
+ winding_numbers: BxQ
109
+ A tensor containing the Generalized winding numbers
110
+ '''
111
+ # The generalized winding number is the sum of solid angles of the point
112
+ # with respect to all triangles.
113
+ return 1 / (4 * math.pi) * solid_angles(points, triangles,
114
+ thresh=thresh).sum(dim=-1)
115
+
116
+
117
+ def batch_contains(verts, faces, points):
118
+
119
+ B = verts.shape[0]
120
+ N = points.shape[1]
121
+
122
+ verts = verts.detach().cpu()
123
+ faces = faces.detach().cpu()
124
+ points = points.detach().cpu()
125
+ contains = torch.zeros(B, N)
126
+
127
+ for i in range(B):
128
+ contains[i] = torch.as_tensor(
129
+ trimesh.Trimesh(verts[i], faces[i]).contains(points[i]))
130
+
131
+ return 2.0 * (contains - 0.5)
132
+
133
+
134
+ def dict2obj(d):
135
+ # if isinstance(d, list):
136
+ # d = [dict2obj(x) for x in d]
137
+ if not isinstance(d, dict):
138
+ return d
139
+
140
+ class C(object):
141
+ pass
142
+
143
+ o = C()
144
+ for k in d:
145
+ o.__dict__[k] = dict2obj(d[k])
146
+ return o
147
+
148
+
149
+ def face_vertices(vertices, faces):
150
+ """
151
+ :param vertices: [batch size, number of vertices, 3]
152
+ :param faces: [batch size, number of faces, 3]
153
+ :return: [batch size, number of faces, 3, 3]
154
+ """
155
+
156
+ bs, nv = vertices.shape[:2]
157
+ bs, nf = faces.shape[:2]
158
+ device = vertices.device
159
+ faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) *
160
+ nv)[:, None, None]
161
+ vertices = vertices.reshape((bs * nv, vertices.shape[-1]))
162
+
163
+ return vertices[faces.long()]
164
+
165
+
166
+ class Pytorch3dRasterizer(nn.Module):
167
+ """ Borrowed from https://github.com/facebookresearch/pytorch3d
168
+ Notice:
169
+ x,y,z are in image space, normalized
170
+ can only render squared image now
171
+ """
172
+
173
+ def __init__(self, image_size=224):
174
+ """
175
+ use fixed raster_settings for rendering faces
176
+ """
177
+ super().__init__()
178
+ raster_settings = {
179
+ 'image_size': image_size,
180
+ 'blur_radius': 0.0,
181
+ 'faces_per_pixel': 1,
182
+ 'bin_size': None,
183
+ 'max_faces_per_bin': None,
184
+ 'perspective_correct': True,
185
+ 'cull_backfaces': True,
186
+ }
187
+ raster_settings = dict2obj(raster_settings)
188
+ self.raster_settings = raster_settings
189
+
190
+ def forward(self, vertices, faces, attributes=None):
191
+ fixed_vertices = vertices.clone()
192
+ fixed_vertices[..., :2] = -fixed_vertices[..., :2]
193
+ meshes_screen = Meshes(verts=fixed_vertices.float(),
194
+ faces=faces.long())
195
+ raster_settings = self.raster_settings
196
+ pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
197
+ meshes_screen,
198
+ image_size=raster_settings.image_size,
199
+ blur_radius=raster_settings.blur_radius,
200
+ faces_per_pixel=raster_settings.faces_per_pixel,
201
+ bin_size=raster_settings.bin_size,
202
+ max_faces_per_bin=raster_settings.max_faces_per_bin,
203
+ perspective_correct=raster_settings.perspective_correct,
204
+ )
205
+ vismask = (pix_to_face > -1).float()
206
+ D = attributes.shape[-1]
207
+ attributes = attributes.clone()
208
+ attributes = attributes.view(attributes.shape[0] * attributes.shape[1],
209
+ 3, attributes.shape[-1])
210
+ N, H, W, K, _ = bary_coords.shape
211
+ mask = pix_to_face == -1
212
+ pix_to_face = pix_to_face.clone()
213
+ pix_to_face[mask] = 0
214
+ idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D)
215
+ pixel_face_vals = attributes.gather(0, idx).view(N, H, W, K, 3, D)
216
+ pixel_vals = (bary_coords[..., None] * pixel_face_vals).sum(dim=-2)
217
+ pixel_vals[mask] = 0 # Replace masked values in output.
218
+ pixel_vals = pixel_vals[:, :, :, 0].permute(0, 3, 1, 2)
219
+ pixel_vals = torch.cat(
220
+ [pixel_vals, vismask[:, :, :, 0][:, None, :, :]], dim=1)
221
+ return pixel_vals
lib/common/seg3d_lossless.py ADDED
@@ -0,0 +1,604 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: ps-license@tuebingen.mpg.de
17
+
18
+
19
+ from .seg3d_utils import (
20
+ create_grid3D,
21
+ plot_mask3D,
22
+ SmoothConv3D,
23
+ )
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import numpy as np
28
+ import torch.nn.functional as F
29
+ import mcubes
30
+ from kaolin.ops.conversions import voxelgrids_to_trianglemeshes
31
+ import logging
32
+
33
+ logging.getLogger("lightning").setLevel(logging.ERROR)
34
+
35
+
36
+ class Seg3dLossless(nn.Module):
37
+ def __init__(self,
38
+ query_func,
39
+ b_min,
40
+ b_max,
41
+ resolutions,
42
+ channels=1,
43
+ balance_value=0.5,
44
+ align_corners=False,
45
+ visualize=False,
46
+ debug=False,
47
+ use_cuda_impl=False,
48
+ faster=False,
49
+ use_shadow=False,
50
+ **kwargs):
51
+ """
52
+ align_corners: same with how you process gt. (grid_sample / interpolate)
53
+ """
54
+ super().__init__()
55
+ self.query_func = query_func
56
+ self.register_buffer(
57
+ 'b_min',
58
+ torch.tensor(b_min).float().unsqueeze(1)) # [bz, 1, 3]
59
+ self.register_buffer(
60
+ 'b_max',
61
+ torch.tensor(b_max).float().unsqueeze(1)) # [bz, 1, 3]
62
+
63
+ # ti.init(arch=ti.cuda)
64
+ # self.mciso_taichi = MCISO(dim=3, N=resolutions[-1]-1)
65
+
66
+ if type(resolutions[0]) is int:
67
+ resolutions = torch.tensor([(res, res, res)
68
+ for res in resolutions])
69
+ else:
70
+ resolutions = torch.tensor(resolutions)
71
+ self.register_buffer('resolutions', resolutions)
72
+ self.batchsize = self.b_min.size(0)
73
+ assert self.batchsize == 1
74
+ self.balance_value = balance_value
75
+ self.channels = channels
76
+ assert self.channels == 1
77
+ self.align_corners = align_corners
78
+ self.visualize = visualize
79
+ self.debug = debug
80
+ self.use_cuda_impl = use_cuda_impl
81
+ self.faster = faster
82
+ self.use_shadow = use_shadow
83
+
84
+ for resolution in resolutions:
85
+ assert resolution[0] % 2 == 1 and resolution[1] % 2 == 1, \
86
+ f"resolution {resolution} need to be odd becuase of align_corner."
87
+
88
+ # init first resolution
89
+ init_coords = create_grid3D(0,
90
+ resolutions[-1] - 1,
91
+ steps=resolutions[0]) # [N, 3]
92
+ init_coords = init_coords.unsqueeze(0).repeat(self.batchsize, 1,
93
+ 1) # [bz, N, 3]
94
+ self.register_buffer('init_coords', init_coords)
95
+
96
+ # some useful tensors
97
+ calculated = torch.zeros(
98
+ (self.resolutions[-1][2], self.resolutions[-1][1],
99
+ self.resolutions[-1][0]),
100
+ dtype=torch.bool)
101
+ self.register_buffer('calculated', calculated)
102
+
103
+ gird8_offsets = torch.stack(
104
+ torch.meshgrid([
105
+ torch.tensor([-1, 0, 1]),
106
+ torch.tensor([-1, 0, 1]),
107
+ torch.tensor([-1, 0, 1])
108
+ ])).int().view(3, -1).t() # [27, 3]
109
+ self.register_buffer('gird8_offsets', gird8_offsets)
110
+
111
+ # smooth convs
112
+ self.smooth_conv3x3 = SmoothConv3D(in_channels=1,
113
+ out_channels=1,
114
+ kernel_size=3)
115
+ self.smooth_conv5x5 = SmoothConv3D(in_channels=1,
116
+ out_channels=1,
117
+ kernel_size=5)
118
+ self.smooth_conv7x7 = SmoothConv3D(in_channels=1,
119
+ out_channels=1,
120
+ kernel_size=7)
121
+ self.smooth_conv9x9 = SmoothConv3D(in_channels=1,
122
+ out_channels=1,
123
+ kernel_size=9)
124
+
125
+ def batch_eval(self, coords, **kwargs):
126
+ """
127
+ coords: in the coordinates of last resolution
128
+ **kwargs: for query_func
129
+ """
130
+ coords = coords.detach()
131
+ # normalize coords to fit in [b_min, b_max]
132
+ if self.align_corners:
133
+ coords2D = coords.float() / (self.resolutions[-1] - 1)
134
+ else:
135
+ step = 1.0 / self.resolutions[-1].float()
136
+ coords2D = coords.float() / self.resolutions[-1] + step / 2
137
+ coords2D = coords2D * (self.b_max - self.b_min) + self.b_min
138
+ # query function
139
+ occupancys = self.query_func(**kwargs, points=coords2D)
140
+ if type(occupancys) is list:
141
+ occupancys = torch.stack(occupancys) # [bz, C, N]
142
+ assert len(occupancys.size()) == 3, \
143
+ "query_func should return a occupancy with shape of [bz, C, N]"
144
+ return occupancys
145
+
146
+ def forward(self, **kwargs):
147
+ if self.faster:
148
+ return self._forward_faster(**kwargs)
149
+ else:
150
+ return self._forward(**kwargs)
151
+
152
+ def _forward_faster(self, **kwargs):
153
+ """
154
+ In faster mode, we make following changes to exchange accuracy for speed:
155
+ 1. no conflict checking: 4.88 fps -> 6.56 fps
156
+ 2. smooth_conv9x9 ~ smooth_conv3x3 for different resolution
157
+ 3. last step no examine
158
+ """
159
+ final_W = self.resolutions[-1][0]
160
+ final_H = self.resolutions[-1][1]
161
+ final_D = self.resolutions[-1][2]
162
+
163
+ for resolution in self.resolutions:
164
+ W, H, D = resolution
165
+ stride = (self.resolutions[-1] - 1) / (resolution - 1)
166
+
167
+ # first step
168
+ if torch.equal(resolution, self.resolutions[0]):
169
+ coords = self.init_coords.clone() # torch.long
170
+ occupancys = self.batch_eval(coords, **kwargs)
171
+ occupancys = occupancys.view(self.batchsize, self.channels, D,
172
+ H, W)
173
+ if (occupancys > 0.5).sum() == 0:
174
+ # return F.interpolate(
175
+ # occupancys, size=(final_D, final_H, final_W),
176
+ # mode="linear", align_corners=True)
177
+ return None
178
+
179
+ if self.visualize:
180
+ self.plot(occupancys, coords, final_D, final_H, final_W)
181
+
182
+ with torch.no_grad():
183
+ coords_accum = coords / stride
184
+
185
+ # last step
186
+ elif torch.equal(resolution, self.resolutions[-1]):
187
+
188
+ with torch.no_grad():
189
+ # here true is correct!
190
+ valid = F.interpolate(
191
+ (occupancys > self.balance_value).float(),
192
+ size=(D, H, W),
193
+ mode="trilinear",
194
+ align_corners=True)
195
+
196
+ # here true is correct!
197
+ occupancys = F.interpolate(occupancys.float(),
198
+ size=(D, H, W),
199
+ mode="trilinear",
200
+ align_corners=True)
201
+
202
+ # is_boundary = (valid > 0.0) & (valid < 1.0)
203
+ is_boundary = valid == 0.5
204
+
205
+ # next steps
206
+ else:
207
+ coords_accum *= 2
208
+
209
+ with torch.no_grad():
210
+ # here true is correct!
211
+ valid = F.interpolate(
212
+ (occupancys > self.balance_value).float(),
213
+ size=(D, H, W),
214
+ mode="trilinear",
215
+ align_corners=True)
216
+
217
+ # here true is correct!
218
+ occupancys = F.interpolate(occupancys.float(),
219
+ size=(D, H, W),
220
+ mode="trilinear",
221
+ align_corners=True)
222
+
223
+ is_boundary = (valid > 0.0) & (valid < 1.0)
224
+
225
+ with torch.no_grad():
226
+ if torch.equal(resolution, self.resolutions[1]):
227
+ is_boundary = (self.smooth_conv9x9(is_boundary.float())
228
+ > 0)[0, 0]
229
+ elif torch.equal(resolution, self.resolutions[2]):
230
+ is_boundary = (self.smooth_conv7x7(is_boundary.float())
231
+ > 0)[0, 0]
232
+ else:
233
+ is_boundary = (self.smooth_conv3x3(is_boundary.float())
234
+ > 0)[0, 0]
235
+
236
+ coords_accum = coords_accum.long()
237
+ is_boundary[coords_accum[0, :, 2], coords_accum[0, :, 1],
238
+ coords_accum[0, :, 0]] = False
239
+ point_coords = is_boundary.permute(
240
+ 2, 1, 0).nonzero(as_tuple=False).unsqueeze(0)
241
+ point_indices = (point_coords[:, :, 2] * H * W +
242
+ point_coords[:, :, 1] * W +
243
+ point_coords[:, :, 0])
244
+
245
+ R, C, D, H, W = occupancys.shape
246
+
247
+ # inferred value
248
+ coords = point_coords * stride
249
+
250
+ if coords.size(1) == 0:
251
+ continue
252
+ occupancys_topk = self.batch_eval(coords, **kwargs)
253
+
254
+ # put mask point predictions to the right places on the upsampled grid.
255
+ R, C, D, H, W = occupancys.shape
256
+ point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
257
+ occupancys = (occupancys.reshape(R, C, D * H * W).scatter_(
258
+ 2, point_indices, occupancys_topk).view(R, C, D, H, W))
259
+
260
+ with torch.no_grad():
261
+ voxels = coords / stride
262
+ coords_accum = torch.cat([voxels, coords_accum],
263
+ dim=1).unique(dim=1)
264
+
265
+ return occupancys[0, 0]
266
+
267
+ def _forward(self, **kwargs):
268
+ """
269
+ output occupancy field would be:
270
+ (bz, C, res, res)
271
+ """
272
+ final_W = self.resolutions[-1][0]
273
+ final_H = self.resolutions[-1][1]
274
+ final_D = self.resolutions[-1][2]
275
+
276
+ calculated = self.calculated.clone()
277
+
278
+ for resolution in self.resolutions:
279
+ W, H, D = resolution
280
+ stride = (self.resolutions[-1] - 1) / (resolution - 1)
281
+
282
+ if self.visualize:
283
+ this_stage_coords = []
284
+
285
+ # first step
286
+ if torch.equal(resolution, self.resolutions[0]):
287
+ coords = self.init_coords.clone() # torch.long
288
+ occupancys = self.batch_eval(coords, **kwargs)
289
+ occupancys = occupancys.view(self.batchsize, self.channels, D,
290
+ H, W)
291
+
292
+ if self.visualize:
293
+ self.plot(occupancys, coords, final_D, final_H, final_W)
294
+
295
+ with torch.no_grad():
296
+ coords_accum = coords / stride
297
+ calculated[coords[0, :, 2], coords[0, :, 1],
298
+ coords[0, :, 0]] = True
299
+
300
+ # next steps
301
+ else:
302
+ coords_accum *= 2
303
+
304
+ with torch.no_grad():
305
+ # here true is correct!
306
+ valid = F.interpolate(
307
+ (occupancys > self.balance_value).float(),
308
+ size=(D, H, W),
309
+ mode="trilinear",
310
+ align_corners=True)
311
+
312
+ # here true is correct!
313
+ occupancys = F.interpolate(occupancys.float(),
314
+ size=(D, H, W),
315
+ mode="trilinear",
316
+ align_corners=True)
317
+
318
+ is_boundary = (valid > 0.0) & (valid < 1.0)
319
+
320
+ with torch.no_grad():
321
+ # TODO
322
+ if self.use_shadow and torch.equal(resolution,
323
+ self.resolutions[-1]):
324
+ # larger z means smaller depth here
325
+ depth_res = resolution[2].item()
326
+ depth_index = torch.linspace(0,
327
+ depth_res - 1,
328
+ steps=depth_res).type_as(
329
+ occupancys.device)
330
+ depth_index_max = torch.max(
331
+ (occupancys > self.balance_value) *
332
+ (depth_index + 1),
333
+ dim=-1,
334
+ keepdim=True)[0] - 1
335
+ shadow = depth_index < depth_index_max
336
+ is_boundary[shadow] = False
337
+ is_boundary = is_boundary[0, 0]
338
+ else:
339
+ is_boundary = (self.smooth_conv3x3(is_boundary.float())
340
+ > 0)[0, 0]
341
+ # is_boundary = is_boundary[0, 0]
342
+
343
+ is_boundary[coords_accum[0, :, 2], coords_accum[0, :, 1],
344
+ coords_accum[0, :, 0]] = False
345
+ point_coords = is_boundary.permute(
346
+ 2, 1, 0).nonzero(as_tuple=False).unsqueeze(0)
347
+ point_indices = (point_coords[:, :, 2] * H * W +
348
+ point_coords[:, :, 1] * W +
349
+ point_coords[:, :, 0])
350
+
351
+ R, C, D, H, W = occupancys.shape
352
+ # interpolated value
353
+ occupancys_interp = torch.gather(
354
+ occupancys.reshape(R, C, D * H * W), 2,
355
+ point_indices.unsqueeze(1))
356
+
357
+ # inferred value
358
+ coords = point_coords * stride
359
+
360
+ if coords.size(1) == 0:
361
+ continue
362
+ occupancys_topk = self.batch_eval(coords, **kwargs)
363
+ if self.visualize:
364
+ this_stage_coords.append(coords)
365
+
366
+ # put mask point predictions to the right places on the upsampled grid.
367
+ R, C, D, H, W = occupancys.shape
368
+ point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
369
+ occupancys = (occupancys.reshape(R, C, D * H * W).scatter_(
370
+ 2, point_indices, occupancys_topk).view(R, C, D, H, W))
371
+
372
+ with torch.no_grad():
373
+ # conflicts
374
+ conflicts = ((occupancys_interp - self.balance_value) *
375
+ (occupancys_topk - self.balance_value) < 0)[0,
376
+ 0]
377
+
378
+ if self.visualize:
379
+ self.plot(occupancys, coords, final_D, final_H,
380
+ final_W)
381
+
382
+ voxels = coords / stride
383
+ coords_accum = torch.cat([voxels, coords_accum],
384
+ dim=1).unique(dim=1)
385
+ calculated[coords[0, :, 2], coords[0, :, 1],
386
+ coords[0, :, 0]] = True
387
+
388
+ while conflicts.sum() > 0:
389
+ if self.use_shadow and torch.equal(resolution,
390
+ self.resolutions[-1]):
391
+ break
392
+
393
+ with torch.no_grad():
394
+ conflicts_coords = coords[0, conflicts, :]
395
+
396
+ if self.debug:
397
+ self.plot(occupancys,
398
+ conflicts_coords.unsqueeze(0),
399
+ final_D,
400
+ final_H,
401
+ final_W,
402
+ title='conflicts')
403
+
404
+ conflicts_boundary = (conflicts_coords.int() +
405
+ self.gird8_offsets.unsqueeze(1) *
406
+ stride.int()).reshape(
407
+ -1, 3).long().unique(dim=0)
408
+ conflicts_boundary[:, 0] = (
409
+ conflicts_boundary[:, 0].clamp(
410
+ 0,
411
+ calculated.size(2) - 1))
412
+ conflicts_boundary[:, 1] = (
413
+ conflicts_boundary[:, 1].clamp(
414
+ 0,
415
+ calculated.size(1) - 1))
416
+ conflicts_boundary[:, 2] = (
417
+ conflicts_boundary[:, 2].clamp(
418
+ 0,
419
+ calculated.size(0) - 1))
420
+
421
+ coords = conflicts_boundary[calculated[
422
+ conflicts_boundary[:, 2], conflicts_boundary[:, 1],
423
+ conflicts_boundary[:, 0]] == False]
424
+
425
+ if self.debug:
426
+ self.plot(occupancys,
427
+ coords.unsqueeze(0),
428
+ final_D,
429
+ final_H,
430
+ final_W,
431
+ title='coords')
432
+
433
+ coords = coords.unsqueeze(0)
434
+ point_coords = coords / stride
435
+ point_indices = (point_coords[:, :, 2] * H * W +
436
+ point_coords[:, :, 1] * W +
437
+ point_coords[:, :, 0])
438
+
439
+ R, C, D, H, W = occupancys.shape
440
+ # interpolated value
441
+ occupancys_interp = torch.gather(
442
+ occupancys.reshape(R, C, D * H * W), 2,
443
+ point_indices.unsqueeze(1))
444
+
445
+ # inferred value
446
+ coords = point_coords * stride
447
+
448
+ if coords.size(1) == 0:
449
+ break
450
+ occupancys_topk = self.batch_eval(coords, **kwargs)
451
+ if self.visualize:
452
+ this_stage_coords.append(coords)
453
+
454
+ with torch.no_grad():
455
+ # conflicts
456
+ conflicts = ((occupancys_interp - self.balance_value) *
457
+ (occupancys_topk - self.balance_value) <
458
+ 0)[0, 0]
459
+
460
+ # put mask point predictions to the right places on the upsampled grid.
461
+ point_indices = point_indices.unsqueeze(1).expand(
462
+ -1, C, -1)
463
+ occupancys = (occupancys.reshape(R, C, D * H * W).scatter_(
464
+ 2, point_indices, occupancys_topk).view(R, C, D, H, W))
465
+
466
+ with torch.no_grad():
467
+ voxels = coords / stride
468
+ coords_accum = torch.cat([voxels, coords_accum],
469
+ dim=1).unique(dim=1)
470
+ calculated[coords[0, :, 2], coords[0, :, 1],
471
+ coords[0, :, 0]] = True
472
+
473
+ if self.visualize:
474
+ this_stage_coords = torch.cat(this_stage_coords, dim=1)
475
+ self.plot(occupancys, this_stage_coords, final_D, final_H,
476
+ final_W)
477
+
478
+ return occupancys[0, 0]
479
+
480
+ def plot(self,
481
+ occupancys,
482
+ coords,
483
+ final_D,
484
+ final_H,
485
+ final_W,
486
+ title='',
487
+ **kwargs):
488
+ final = F.interpolate(occupancys.float(),
489
+ size=(final_D, final_H, final_W),
490
+ mode="trilinear",
491
+ align_corners=True) # here true is correct!
492
+ x = coords[0, :, 0].to("cpu")
493
+ y = coords[0, :, 1].to("cpu")
494
+ z = coords[0, :, 2].to("cpu")
495
+
496
+ plot_mask3D(final[0, 0].to("cpu"), title, (x, y, z), **kwargs)
497
+
498
+ def find_vertices(self, sdf, direction="front"):
499
+ '''
500
+ - direction: "front" | "back" | "left" | "right"
501
+ '''
502
+ resolution = sdf.size(2)
503
+ if direction == "front":
504
+ pass
505
+ elif direction == "left":
506
+ sdf = sdf.permute(2, 1, 0)
507
+ elif direction == "back":
508
+ inv_idx = torch.arange(sdf.size(2) - 1, -1, -1).long()
509
+ sdf = sdf[inv_idx, :, :]
510
+ elif direction == "right":
511
+ inv_idx = torch.arange(sdf.size(2) - 1, -1, -1).long()
512
+ sdf = sdf[:, :, inv_idx]
513
+ sdf = sdf.permute(2, 1, 0)
514
+
515
+ inv_idx = torch.arange(sdf.size(2) - 1, -1, -1).long()
516
+ sdf = sdf[inv_idx, :, :]
517
+ sdf_all = sdf.permute(2, 1, 0)
518
+
519
+ # shadow
520
+ grad_v = (sdf_all > 0.5) * torch.linspace(
521
+ resolution, 1, steps=resolution).to(sdf.device)
522
+ grad_c = torch.ones_like(sdf_all) * torch.linspace(
523
+ 0, resolution - 1, steps=resolution).to(sdf.device)
524
+ max_v, max_c = grad_v.max(dim=2)
525
+ shadow = grad_c > max_c.view(resolution, resolution, 1)
526
+ keep = (sdf_all > 0.5) & (~shadow)
527
+
528
+ p1 = keep.nonzero(as_tuple=False).t() # [3, N]
529
+ p2 = p1.clone() # z
530
+ p2[2, :] = (p2[2, :] - 2).clamp(0, resolution)
531
+ p3 = p1.clone() # y
532
+ p3[1, :] = (p3[1, :] - 2).clamp(0, resolution)
533
+ p4 = p1.clone() # x
534
+ p4[0, :] = (p4[0, :] - 2).clamp(0, resolution)
535
+
536
+ v1 = sdf_all[p1[0, :], p1[1, :], p1[2, :]]
537
+ v2 = sdf_all[p2[0, :], p2[1, :], p2[2, :]]
538
+ v3 = sdf_all[p3[0, :], p3[1, :], p3[2, :]]
539
+ v4 = sdf_all[p4[0, :], p4[1, :], p4[2, :]]
540
+
541
+ X = p1[0, :].long() # [N,]
542
+ Y = p1[1, :].long() # [N,]
543
+ Z = p2[2, :].float() * (0.5 - v1) / (v2 - v1) + \
544
+ p1[2, :].float() * (v2 - 0.5) / (v2 - v1) # [N,]
545
+ Z = Z.clamp(0, resolution)
546
+
547
+ # normal
548
+ norm_z = v2 - v1
549
+ norm_y = v3 - v1
550
+ norm_x = v4 - v1
551
+ # print (v2.min(dim=0)[0], v2.max(dim=0)[0], v3.min(dim=0)[0], v3.max(dim=0)[0])
552
+
553
+ norm = torch.stack([norm_x, norm_y, norm_z], dim=1)
554
+ norm = norm / torch.norm(norm, p=2, dim=1, keepdim=True)
555
+
556
+ return X, Y, Z, norm
557
+
558
+ def render_normal(self, resolution, X, Y, Z, norm):
559
+ image = torch.ones((1, 3, resolution, resolution),
560
+ dtype=torch.float32).to(norm.device)
561
+ color = (norm + 1) / 2.0
562
+ color = color.clamp(0, 1)
563
+ image[0, :, Y, X] = color.t()
564
+ return image
565
+
566
+ def display(self, sdf):
567
+
568
+ # render
569
+ X, Y, Z, norm = self.find_vertices(sdf, direction="front")
570
+ image1 = self.render_normal(self.resolutions[-1, -1], X, Y, Z, norm)
571
+ X, Y, Z, norm = self.find_vertices(sdf, direction="left")
572
+ image2 = self.render_normal(self.resolutions[-1, -1], X, Y, Z, norm)
573
+ X, Y, Z, norm = self.find_vertices(sdf, direction="right")
574
+ image3 = self.render_normal(self.resolutions[-1, -1], X, Y, Z, norm)
575
+ X, Y, Z, norm = self.find_vertices(sdf, direction="back")
576
+ image4 = self.render_normal(self.resolutions[-1, -1], X, Y, Z, norm)
577
+
578
+ image = torch.cat([image1, image2, image3, image4], axis=3)
579
+ image = image.detach().cpu().numpy()[0].transpose(1, 2, 0) * 255.0
580
+
581
+ return np.uint8(image)
582
+
583
+ def export_mesh(self, occupancys):
584
+
585
+ final = occupancys[1:, 1:, 1:].contiguous()
586
+
587
+ if final.shape[0] > 256:
588
+ # for voxelgrid larger than 256^3, the required GPU memory will be > 9GB
589
+ # thus we use CPU marching_cube to avoid "CUDA out of memory"
590
+ occu_arr = final.detach().cpu().numpy() # non-smooth surface
591
+ # occu_arr = mcubes.smooth(final.detach().cpu().numpy()) # smooth surface
592
+ vertices, triangles = mcubes.marching_cubes(
593
+ occu_arr, self.balance_value)
594
+ verts = torch.as_tensor(vertices[:, [2, 1, 0]])
595
+ faces = torch.as_tensor(triangles.astype(
596
+ np.long), dtype=torch.long)[:, [0, 2, 1]]
597
+ else:
598
+ torch.cuda.empty_cache()
599
+ vertices, triangles = voxelgrids_to_trianglemeshes(
600
+ final.unsqueeze(0))
601
+ verts = vertices[0][:, [2, 1, 0]].cpu()
602
+ faces = triangles[0][:, [0, 2, 1]].cpu()
603
+
604
+ return verts, faces
lib/common/seg3d_utils.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: ps-license@tuebingen.mpg.de
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import matplotlib.pyplot as plt
22
+
23
+
24
+ def plot_mask2D(mask,
25
+ title="",
26
+ point_coords=None,
27
+ figsize=10,
28
+ point_marker_size=5):
29
+ '''
30
+ Simple plotting tool to show intermediate mask predictions and points
31
+ where PointRend is applied.
32
+
33
+ Args:
34
+ mask (Tensor): mask prediction of shape HxW
35
+ title (str): title for the plot
36
+ point_coords ((Tensor, Tensor)): x and y point coordinates
37
+ figsize (int): size of the figure to plot
38
+ point_marker_size (int): marker size for points
39
+ '''
40
+
41
+ H, W = mask.shape
42
+ plt.figure(figsize=(figsize, figsize))
43
+ if title:
44
+ title += ", "
45
+ plt.title("{}resolution {}x{}".format(title, H, W), fontsize=30)
46
+ plt.ylabel(H, fontsize=30)
47
+ plt.xlabel(W, fontsize=30)
48
+ plt.xticks([], [])
49
+ plt.yticks([], [])
50
+ plt.imshow(mask.detach(),
51
+ interpolation="nearest",
52
+ cmap=plt.get_cmap('gray'))
53
+ if point_coords is not None:
54
+ plt.scatter(x=point_coords[0],
55
+ y=point_coords[1],
56
+ color="red",
57
+ s=point_marker_size,
58
+ clip_on=True)
59
+ plt.xlim(-0.5, W - 0.5)
60
+ plt.ylim(H - 0.5, -0.5)
61
+ plt.show()
62
+
63
+
64
+ def plot_mask3D(mask=None,
65
+ title="",
66
+ point_coords=None,
67
+ figsize=1500,
68
+ point_marker_size=8,
69
+ interactive=True):
70
+ '''
71
+ Simple plotting tool to show intermediate mask predictions and points
72
+ where PointRend is applied.
73
+
74
+ Args:
75
+ mask (Tensor): mask prediction of shape DxHxW
76
+ title (str): title for the plot
77
+ point_coords ((Tensor, Tensor, Tensor)): x and y and z point coordinates
78
+ figsize (int): size of the figure to plot
79
+ point_marker_size (int): marker size for points
80
+ '''
81
+ import trimesh
82
+ import vtkplotter
83
+ from skimage import measure
84
+
85
+ vp = vtkplotter.Plotter(title=title, size=(figsize, figsize))
86
+ vis_list = []
87
+
88
+ if mask is not None:
89
+ mask = mask.detach().to("cpu").numpy()
90
+ mask = mask.transpose(2, 1, 0)
91
+
92
+ # marching cube to find surface
93
+ verts, faces, normals, values = measure.marching_cubes_lewiner(
94
+ mask, 0.5, gradient_direction='ascent')
95
+
96
+ # create a mesh
97
+ mesh = trimesh.Trimesh(verts, faces)
98
+ mesh.visual.face_colors = [200, 200, 250, 100]
99
+ vis_list.append(mesh)
100
+
101
+ if point_coords is not None:
102
+ point_coords = torch.stack(point_coords, 1).to("cpu").numpy()
103
+
104
+ # import numpy as np
105
+ # select_x = np.logical_and(point_coords[:, 0] >= 16, point_coords[:, 0] <= 112)
106
+ # select_y = np.logical_and(point_coords[:, 1] >= 48, point_coords[:, 1] <= 272)
107
+ # select_z = np.logical_and(point_coords[:, 2] >= 16, point_coords[:, 2] <= 112)
108
+ # select = np.logical_and(np.logical_and(select_x, select_y), select_z)
109
+ # point_coords = point_coords[select, :]
110
+
111
+ pc = vtkplotter.Points(point_coords, r=point_marker_size, c='red')
112
+ vis_list.append(pc)
113
+
114
+ vp.show(*vis_list,
115
+ bg="white",
116
+ axes=1,
117
+ interactive=interactive,
118
+ azimuth=30,
119
+ elevation=30)
120
+
121
+
122
+ def create_grid3D(min, max, steps):
123
+ if type(min) is int:
124
+ min = (min, min, min) # (x, y, z)
125
+ if type(max) is int:
126
+ max = (max, max, max) # (x, y)
127
+ if type(steps) is int:
128
+ steps = (steps, steps, steps) # (x, y, z)
129
+ arrangeX = torch.linspace(min[0], max[0], steps[0]).long()
130
+ arrangeY = torch.linspace(min[1], max[1], steps[1]).long()
131
+ arrangeZ = torch.linspace(min[2], max[2], steps[2]).long()
132
+ gridD, girdH, gridW = torch.meshgrid([arrangeZ, arrangeY, arrangeX])
133
+ coords = torch.stack([gridW, girdH,
134
+ gridD]) # [2, steps[0], steps[1], steps[2]]
135
+ coords = coords.view(3, -1).t() # [N, 3]
136
+ return coords
137
+
138
+
139
+ def create_grid2D(min, max, steps):
140
+ if type(min) is int:
141
+ min = (min, min) # (x, y)
142
+ if type(max) is int:
143
+ max = (max, max) # (x, y)
144
+ if type(steps) is int:
145
+ steps = (steps, steps) # (x, y)
146
+ arrangeX = torch.linspace(min[0], max[0], steps[0]).long()
147
+ arrangeY = torch.linspace(min[1], max[1], steps[1]).long()
148
+ girdH, gridW = torch.meshgrid([arrangeY, arrangeX])
149
+ coords = torch.stack([gridW, girdH]) # [2, steps[0], steps[1]]
150
+ coords = coords.view(2, -1).t() # [N, 2]
151
+ return coords
152
+
153
+
154
+ class SmoothConv2D(nn.Module):
155
+ def __init__(self, in_channels, out_channels, kernel_size=3):
156
+ super().__init__()
157
+ assert kernel_size % 2 == 1, "kernel_size for smooth_conv must be odd: {3, 5, ...}"
158
+ self.padding = (kernel_size - 1) // 2
159
+
160
+ weight = torch.ones(
161
+ (in_channels, out_channels, kernel_size, kernel_size),
162
+ dtype=torch.float32) / (kernel_size**2)
163
+ self.register_buffer('weight', weight)
164
+
165
+ def forward(self, input):
166
+ return F.conv2d(input, self.weight, padding=self.padding)
167
+
168
+
169
+ class SmoothConv3D(nn.Module):
170
+ def __init__(self, in_channels, out_channels, kernel_size=3):
171
+ super().__init__()
172
+ assert kernel_size % 2 == 1, "kernel_size for smooth_conv must be odd: {3, 5, ...}"
173
+ self.padding = (kernel_size - 1) // 2
174
+
175
+ weight = torch.ones(
176
+ (in_channels, out_channels, kernel_size, kernel_size, kernel_size),
177
+ dtype=torch.float32) / (kernel_size**3)
178
+ self.register_buffer('weight', weight)
179
+
180
+ def forward(self, input):
181
+ return F.conv3d(input, self.weight, padding=self.padding)
182
+
183
+
184
+ def build_smooth_conv3D(in_channels=1,
185
+ out_channels=1,
186
+ kernel_size=3,
187
+ padding=1):
188
+ smooth_conv = torch.nn.Conv3d(in_channels=in_channels,
189
+ out_channels=out_channels,
190
+ kernel_size=kernel_size,
191
+ padding=padding)
192
+ smooth_conv.weight.data = torch.ones(
193
+ (in_channels, out_channels, kernel_size, kernel_size, kernel_size),
194
+ dtype=torch.float32) / (kernel_size**3)
195
+ smooth_conv.bias.data = torch.zeros(out_channels)
196
+ return smooth_conv
197
+
198
+
199
+ def build_smooth_conv2D(in_channels=1,
200
+ out_channels=1,
201
+ kernel_size=3,
202
+ padding=1):
203
+ smooth_conv = torch.nn.Conv2d(in_channels=in_channels,
204
+ out_channels=out_channels,
205
+ kernel_size=kernel_size,
206
+ padding=padding)
207
+ smooth_conv.weight.data = torch.ones(
208
+ (in_channels, out_channels, kernel_size, kernel_size),
209
+ dtype=torch.float32) / (kernel_size**2)
210
+ smooth_conv.bias.data = torch.zeros(out_channels)
211
+ return smooth_conv
212
+
213
+
214
+ def get_uncertain_point_coords_on_grid3D(uncertainty_map, num_points,
215
+ **kwargs):
216
+ """
217
+ Find `num_points` most uncertain points from `uncertainty_map` grid.
218
+ Args:
219
+ uncertainty_map (Tensor): A tensor of shape (N, 1, H, W, D) that contains uncertainty
220
+ values for a set of points on a regular H x W x D grid.
221
+ num_points (int): The number of points P to select.
222
+ Returns:
223
+ point_indices (Tensor): A tensor of shape (N, P) that contains indices from
224
+ [0, H x W x D) of the most uncertain points.
225
+ point_coords (Tensor): A tensor of shape (N, P, 3) that contains [0, 1] x [0, 1] normalized
226
+ coordinates of the most uncertain points from the H x W x D grid.
227
+ """
228
+ R, _, D, H, W = uncertainty_map.shape
229
+ # h_step = 1.0 / float(H)
230
+ # w_step = 1.0 / float(W)
231
+ # d_step = 1.0 / float(D)
232
+
233
+ num_points = min(D * H * W, num_points)
234
+ point_scores, point_indices = torch.topk(uncertainty_map.view(
235
+ R, D * H * W),
236
+ k=num_points,
237
+ dim=1)
238
+ point_coords = torch.zeros(R,
239
+ num_points,
240
+ 3,
241
+ dtype=torch.float,
242
+ device=uncertainty_map.device)
243
+ # point_coords[:, :, 0] = h_step / 2.0 + (point_indices // (W * D)).to(torch.float) * h_step
244
+ # point_coords[:, :, 1] = w_step / 2.0 + (point_indices % (W * D) // D).to(torch.float) * w_step
245
+ # point_coords[:, :, 2] = d_step / 2.0 + (point_indices % D).to(torch.float) * d_step
246
+ point_coords[:, :, 0] = (point_indices % W).to(torch.float) # x
247
+ point_coords[:, :, 1] = (point_indices % (H * W) // W).to(torch.float) # y
248
+ point_coords[:, :, 2] = (point_indices // (H * W)).to(torch.float) # z
249
+ print(f"resolution {D} x {H} x {W}", point_scores.min(),
250
+ point_scores.max())
251
+ return point_indices, point_coords
252
+
253
+
254
+ def get_uncertain_point_coords_on_grid3D_faster(uncertainty_map, num_points,
255
+ clip_min):
256
+ """
257
+ Find `num_points` most uncertain points from `uncertainty_map` grid.
258
+ Args:
259
+ uncertainty_map (Tensor): A tensor of shape (N, 1, H, W, D) that contains uncertainty
260
+ values for a set of points on a regular H x W x D grid.
261
+ num_points (int): The number of points P to select.
262
+ Returns:
263
+ point_indices (Tensor): A tensor of shape (N, P) that contains indices from
264
+ [0, H x W x D) of the most uncertain points.
265
+ point_coords (Tensor): A tensor of shape (N, P, 3) that contains [0, 1] x [0, 1] normalized
266
+ coordinates of the most uncertain points from the H x W x D grid.
267
+ """
268
+ R, _, D, H, W = uncertainty_map.shape
269
+ # h_step = 1.0 / float(H)
270
+ # w_step = 1.0 / float(W)
271
+ # d_step = 1.0 / float(D)
272
+
273
+ assert R == 1, "batchsize > 1 is not implemented!"
274
+ uncertainty_map = uncertainty_map.view(D * H * W)
275
+ indices = (uncertainty_map >= clip_min).nonzero().squeeze(1)
276
+ num_points = min(num_points, indices.size(0))
277
+ point_scores, point_indices = torch.topk(uncertainty_map[indices],
278
+ k=num_points,
279
+ dim=0)
280
+ point_indices = indices[point_indices].unsqueeze(0)
281
+
282
+ point_coords = torch.zeros(R,
283
+ num_points,
284
+ 3,
285
+ dtype=torch.float,
286
+ device=uncertainty_map.device)
287
+ # point_coords[:, :, 0] = h_step / 2.0 + (point_indices // (W * D)).to(torch.float) * h_step
288
+ # point_coords[:, :, 1] = w_step / 2.0 + (point_indices % (W * D) // D).to(torch.float) * w_step
289
+ # point_coords[:, :, 2] = d_step / 2.0 + (point_indices % D).to(torch.float) * d_step
290
+ point_coords[:, :, 0] = (point_indices % W).to(torch.float) # x
291
+ point_coords[:, :, 1] = (point_indices % (H * W) // W).to(torch.float) # y
292
+ point_coords[:, :, 2] = (point_indices // (H * W)).to(torch.float) # z
293
+ # print (f"resolution {D} x {H} x {W}", point_scores.min(), point_scores.max())
294
+ return point_indices, point_coords
295
+
296
+
297
+ def get_uncertain_point_coords_on_grid2D(uncertainty_map, num_points,
298
+ **kwargs):
299
+ """
300
+ Find `num_points` most uncertain points from `uncertainty_map` grid.
301
+ Args:
302
+ uncertainty_map (Tensor): A tensor of shape (N, 1, H, W) that contains uncertainty
303
+ values for a set of points on a regular H x W grid.
304
+ num_points (int): The number of points P to select.
305
+ Returns:
306
+ point_indices (Tensor): A tensor of shape (N, P) that contains indices from
307
+ [0, H x W) of the most uncertain points.
308
+ point_coords (Tensor): A tensor of shape (N, P, 2) that contains [0, 1] x [0, 1] normalized
309
+ coordinates of the most uncertain points from the H x W grid.
310
+ """
311
+ R, _, H, W = uncertainty_map.shape
312
+ # h_step = 1.0 / float(H)
313
+ # w_step = 1.0 / float(W)
314
+
315
+ num_points = min(H * W, num_points)
316
+ point_scores, point_indices = torch.topk(uncertainty_map.view(R, H * W),
317
+ k=num_points,
318
+ dim=1)
319
+ point_coords = torch.zeros(R,
320
+ num_points,
321
+ 2,
322
+ dtype=torch.long,
323
+ device=uncertainty_map.device)
324
+ # point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step
325
+ # point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step
326
+ point_coords[:, :, 0] = (point_indices % W).to(torch.long)
327
+ point_coords[:, :, 1] = (point_indices // W).to(torch.long)
328
+ # print (point_scores.min(), point_scores.max())
329
+ return point_indices, point_coords
330
+
331
+
332
+ def get_uncertain_point_coords_on_grid2D_faster(uncertainty_map, num_points,
333
+ clip_min):
334
+ """
335
+ Find `num_points` most uncertain points from `uncertainty_map` grid.
336
+ Args:
337
+ uncertainty_map (Tensor): A tensor of shape (N, 1, H, W) that contains uncertainty
338
+ values for a set of points on a regular H x W grid.
339
+ num_points (int): The number of points P to select.
340
+ Returns:
341
+ point_indices (Tensor): A tensor of shape (N, P) that contains indices from
342
+ [0, H x W) of the most uncertain points.
343
+ point_coords (Tensor): A tensor of shape (N, P, 2) that contains [0, 1] x [0, 1] normalized
344
+ coordinates of the most uncertain points from the H x W grid.
345
+ """
346
+ R, _, H, W = uncertainty_map.shape
347
+ # h_step = 1.0 / float(H)
348
+ # w_step = 1.0 / float(W)
349
+
350
+ assert R == 1, "batchsize > 1 is not implemented!"
351
+ uncertainty_map = uncertainty_map.view(H * W)
352
+ indices = (uncertainty_map >= clip_min).nonzero().squeeze(1)
353
+ num_points = min(num_points, indices.size(0))
354
+ point_scores, point_indices = torch.topk(uncertainty_map[indices],
355
+ k=num_points,
356
+ dim=0)
357
+ point_indices = indices[point_indices].unsqueeze(0)
358
+
359
+ point_coords = torch.zeros(R,
360
+ num_points,
361
+ 2,
362
+ dtype=torch.long,
363
+ device=uncertainty_map.device)
364
+ # point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step
365
+ # point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step
366
+ point_coords[:, :, 0] = (point_indices % W).to(torch.long)
367
+ point_coords[:, :, 1] = (point_indices // W).to(torch.long)
368
+ # print (point_scores.min(), point_scores.max())
369
+ return point_indices, point_coords
370
+
371
+
372
+ def calculate_uncertainty(logits, classes=None, balance_value=0.5):
373
+ """
374
+ We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the
375
+ foreground class in `classes`.
376
+ Args:
377
+ logits (Tensor): A tensor of shape (R, C, ...) or (R, 1, ...) for class-specific or
378
+ class-agnostic, where R is the total number of predicted masks in all images and C is
379
+ the number of foreground classes. The values are logits.
380
+ classes (list): A list of length R that contains either predicted of ground truth class
381
+ for eash predicted mask.
382
+ Returns:
383
+ scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with
384
+ the most uncertain locations having the highest uncertainty score.
385
+ """
386
+ if logits.shape[1] == 1:
387
+ gt_class_logits = logits
388
+ else:
389
+ gt_class_logits = logits[
390
+ torch.arange(logits.shape[0], device=logits.device),
391
+ classes].unsqueeze(1)
392
+ return -torch.abs(gt_class_logits - balance_value)
lib/common/smpl_vert_segmentation.json ADDED
The diff for this file is too large to render. See raw diff
 
lib/common/train_util.py ADDED
@@ -0,0 +1,599 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: ps-license@tuebingen.mpg.de
17
+
18
+ import yaml
19
+ import os.path as osp
20
+ import torch
21
+ import numpy as np
22
+ import torch.nn.functional as F
23
+ from ..dataset.mesh_util import *
24
+ from ..net.geometry import orthogonal
25
+ from pytorch3d.renderer.mesh import rasterize_meshes
26
+ from .render_utils import Pytorch3dRasterizer
27
+ from pytorch3d.structures import Meshes
28
+ import cv2
29
+ from PIL import Image
30
+ from tqdm import tqdm
31
+ import os
32
+ from termcolor import colored
33
+
34
+
35
+
36
+
37
+ def reshape_sample_tensor(sample_tensor, num_views):
38
+ if num_views == 1:
39
+ return sample_tensor
40
+ # Need to repeat sample_tensor along the batch dim num_views times
41
+ sample_tensor = sample_tensor.unsqueeze(dim=1)
42
+ sample_tensor = sample_tensor.repeat(1, num_views, 1, 1)
43
+ sample_tensor = sample_tensor.view(
44
+ sample_tensor.shape[0] * sample_tensor.shape[1],
45
+ sample_tensor.shape[2], sample_tensor.shape[3])
46
+ return sample_tensor
47
+
48
+
49
+ def gen_mesh_eval(opt, net, cuda, data, resolution=None):
50
+ resolution = opt.resolution if resolution is None else resolution
51
+ image_tensor = data['img'].to(device=cuda)
52
+ calib_tensor = data['calib'].to(device=cuda)
53
+
54
+ net.filter(image_tensor)
55
+
56
+ b_min = data['b_min']
57
+ b_max = data['b_max']
58
+ try:
59
+ verts, faces, _, _ = reconstruction_faster(net,
60
+ cuda,
61
+ calib_tensor,
62
+ resolution,
63
+ b_min,
64
+ b_max,
65
+ use_octree=False)
66
+
67
+ except Exception as e:
68
+ print(e)
69
+ print('Can not create marching cubes at this time.')
70
+ verts, faces = None, None
71
+ return verts, faces
72
+
73
+
74
+ def gen_mesh(opt, net, cuda, data, save_path, resolution=None):
75
+ resolution = opt.resolution if resolution is None else resolution
76
+ image_tensor = data['img'].to(device=cuda)
77
+ calib_tensor = data['calib'].to(device=cuda)
78
+
79
+ net.filter(image_tensor)
80
+
81
+ b_min = data['b_min']
82
+ b_max = data['b_max']
83
+ try:
84
+ save_img_path = save_path[:-4] + '.png'
85
+ save_img_list = []
86
+ for v in range(image_tensor.shape[0]):
87
+ save_img = (np.transpose(image_tensor[v].detach().cpu().numpy(),
88
+ (1, 2, 0)) * 0.5 +
89
+ 0.5)[:, :, ::-1] * 255.0
90
+ save_img_list.append(save_img)
91
+ save_img = np.concatenate(save_img_list, axis=1)
92
+ Image.fromarray(np.uint8(save_img[:, :, ::-1])).save(save_img_path)
93
+
94
+ verts, faces, _, _ = reconstruction_faster(net, cuda, calib_tensor,
95
+ resolution, b_min, b_max)
96
+ verts_tensor = torch.from_numpy(
97
+ verts.T).unsqueeze(0).to(device=cuda).float()
98
+ xyz_tensor = net.projection(verts_tensor, calib_tensor[:1])
99
+ uv = xyz_tensor[:, :2, :]
100
+ color = netG.index(image_tensor[:1], uv).detach().cpu().numpy()[0].T
101
+ color = color * 0.5 + 0.5
102
+ save_obj_mesh_with_color(save_path, verts, faces, color)
103
+ except Exception as e:
104
+ print(e)
105
+ print('Can not create marching cubes at this time.')
106
+ verts, faces, color = None, None, None
107
+ return verts, faces, color
108
+
109
+
110
+ def gen_mesh_color(opt, netG, netC, cuda, data, save_path, use_octree=True):
111
+ image_tensor = data['img'].to(device=cuda)
112
+ calib_tensor = data['calib'].to(device=cuda)
113
+
114
+ netG.filter(image_tensor)
115
+ netC.filter(image_tensor)
116
+ netC.attach(netG.get_im_feat())
117
+
118
+ b_min = data['b_min']
119
+ b_max = data['b_max']
120
+ try:
121
+ save_img_path = save_path[:-4] + '.png'
122
+ save_img_list = []
123
+ for v in range(image_tensor.shape[0]):
124
+ save_img = (np.transpose(image_tensor[v].detach().cpu().numpy(),
125
+ (1, 2, 0)) * 0.5 +
126
+ 0.5)[:, :, ::-1] * 255.0
127
+ save_img_list.append(save_img)
128
+ save_img = np.concatenate(save_img_list, axis=1)
129
+ Image.fromarray(np.uint8(save_img[:, :, ::-1])).save(save_img_path)
130
+
131
+ verts, faces, _, _ = reconstruction_faster(netG,
132
+ cuda,
133
+ calib_tensor,
134
+ opt.resolution,
135
+ b_min,
136
+ b_max,
137
+ use_octree=use_octree)
138
+
139
+ # Now Getting colors
140
+ verts_tensor = torch.from_numpy(
141
+ verts.T).unsqueeze(0).to(device=cuda).float()
142
+ verts_tensor = reshape_sample_tensor(verts_tensor, opt.num_views)
143
+ color = np.zeros(verts.shape)
144
+ interval = 10000
145
+ for i in range(len(color) // interval):
146
+ left = i * interval
147
+ right = i * interval + interval
148
+ if i == len(color) // interval - 1:
149
+ right = -1
150
+ netC.query(verts_tensor[:, :, left:right], calib_tensor)
151
+ rgb = netC.get_preds()[0].detach().cpu().numpy() * 0.5 + 0.5
152
+ color[left:right] = rgb.T
153
+
154
+ save_obj_mesh_with_color(save_path, verts, faces, color)
155
+ except Exception as e:
156
+ print(e)
157
+ print('Can not create marching cubes at this time.')
158
+ verts, faces, color = None, None, None
159
+ return verts, faces, color
160
+
161
+
162
+ def adjust_learning_rate(optimizer, epoch, lr, schedule, gamma):
163
+ """Sets the learning rate to the initial LR decayed by schedule"""
164
+ if epoch in schedule:
165
+ lr *= gamma
166
+ for param_group in optimizer.param_groups:
167
+ param_group['lr'] = lr
168
+ return lr
169
+
170
+
171
+ def compute_acc(pred, gt, thresh=0.5):
172
+ '''
173
+ return:
174
+ IOU, precision, and recall
175
+ '''
176
+ with torch.no_grad():
177
+ vol_pred = pred > thresh
178
+ vol_gt = gt > thresh
179
+
180
+ union = vol_pred | vol_gt
181
+ inter = vol_pred & vol_gt
182
+
183
+ true_pos = inter.sum().float()
184
+
185
+ union = union.sum().float()
186
+ if union == 0:
187
+ union = 1
188
+ vol_pred = vol_pred.sum().float()
189
+ if vol_pred == 0:
190
+ vol_pred = 1
191
+ vol_gt = vol_gt.sum().float()
192
+ if vol_gt == 0:
193
+ vol_gt = 1
194
+ return true_pos / union, true_pos / vol_pred, true_pos / vol_gt
195
+
196
+
197
+ # def calc_metrics(opt, net, cuda, dataset, num_tests,
198
+ # resolution=128, sampled_points=1000, use_kaolin=True):
199
+ # if num_tests > len(dataset):
200
+ # num_tests = len(dataset)
201
+ # with torch.no_grad():
202
+ # chamfer_arr, p2s_arr = [], []
203
+ # for idx in tqdm(range(num_tests)):
204
+ # data = dataset[idx * len(dataset) // num_tests]
205
+
206
+ # verts, faces = gen_mesh_eval(opt, net, cuda, data, resolution)
207
+ # if verts is None:
208
+ # continue
209
+
210
+ # mesh_gt = trimesh.load(data['mesh_path'])
211
+ # mesh_gt = mesh_gt.split(only_watertight=False)
212
+ # comp_num = [mesh.vertices.shape[0] for mesh in mesh_gt]
213
+ # mesh_gt = mesh_gt[comp_num.index(max(comp_num))]
214
+
215
+ # mesh_pred = trimesh.Trimesh(verts, faces)
216
+
217
+ # gt_surface_pts, _ = trimesh.sample.sample_surface_even(
218
+ # mesh_gt, sampled_points)
219
+ # pred_surface_pts, _ = trimesh.sample.sample_surface_even(
220
+ # mesh_pred, sampled_points)
221
+
222
+ # if use_kaolin and has_kaolin:
223
+ # kal_mesh_gt = kal.rep.TriangleMesh.from_tensors(
224
+ # torch.tensor(mesh_gt.vertices).float().to(device=cuda),
225
+ # torch.tensor(mesh_gt.faces).long().to(device=cuda))
226
+ # kal_mesh_pred = kal.rep.TriangleMesh.from_tensors(
227
+ # torch.tensor(mesh_pred.vertices).float().to(device=cuda),
228
+ # torch.tensor(mesh_pred.faces).long().to(device=cuda))
229
+
230
+ # kal_distance_0 = kal.metrics.mesh.point_to_surface(
231
+ # torch.tensor(pred_surface_pts).float().to(device=cuda), kal_mesh_gt)
232
+ # kal_distance_1 = kal.metrics.mesh.point_to_surface(
233
+ # torch.tensor(gt_surface_pts).float().to(device=cuda), kal_mesh_pred)
234
+
235
+ # dist_gt_pred = torch.sqrt(kal_distance_0).cpu().numpy()
236
+ # dist_pred_gt = torch.sqrt(kal_distance_1).cpu().numpy()
237
+ # else:
238
+ # try:
239
+ # _, dist_pred_gt, _ = trimesh.proximity.closest_point(mesh_pred, gt_surface_pts)
240
+ # _, dist_gt_pred, _ = trimesh.proximity.closest_point(mesh_gt, pred_surface_pts)
241
+ # except Exception as e:
242
+ # print (e)
243
+ # continue
244
+
245
+ # chamfer_dist = 0.5 * (dist_pred_gt.mean() + dist_gt_pred.mean())
246
+ # p2s_dist = dist_pred_gt.mean()
247
+
248
+ # chamfer_arr.append(chamfer_dist)
249
+ # p2s_arr.append(p2s_dist)
250
+
251
+ # return np.average(chamfer_arr), np.average(p2s_arr)
252
+
253
+
254
+ def calc_error(opt, net, cuda, dataset, num_tests):
255
+ if num_tests > len(dataset):
256
+ num_tests = len(dataset)
257
+ with torch.no_grad():
258
+ erorr_arr, IOU_arr, prec_arr, recall_arr = [], [], [], []
259
+ for idx in tqdm(range(num_tests)):
260
+ data = dataset[idx * len(dataset) // num_tests]
261
+ # retrieve the data
262
+ image_tensor = data['img'].to(device=cuda)
263
+ calib_tensor = data['calib'].to(device=cuda)
264
+ sample_tensor = data['samples'].to(device=cuda).unsqueeze(0)
265
+ if opt.num_views > 1:
266
+ sample_tensor = reshape_sample_tensor(sample_tensor,
267
+ opt.num_views)
268
+ label_tensor = data['labels'].to(device=cuda).unsqueeze(0)
269
+
270
+ res, error = net.forward(image_tensor,
271
+ sample_tensor,
272
+ calib_tensor,
273
+ labels=label_tensor)
274
+
275
+ IOU, prec, recall = compute_acc(res, label_tensor)
276
+
277
+ # print(
278
+ # '{0}/{1} | Error: {2:06f} IOU: {3:06f} prec: {4:06f} recall: {5:06f}'
279
+ # .format(idx, num_tests, error.item(), IOU.item(), prec.item(), recall.item()))
280
+ erorr_arr.append(error.item())
281
+ IOU_arr.append(IOU.item())
282
+ prec_arr.append(prec.item())
283
+ recall_arr.append(recall.item())
284
+
285
+ return np.average(erorr_arr), np.average(IOU_arr), np.average(
286
+ prec_arr), np.average(recall_arr)
287
+
288
+
289
+ def calc_error_color(opt, netG, netC, cuda, dataset, num_tests):
290
+ if num_tests > len(dataset):
291
+ num_tests = len(dataset)
292
+ with torch.no_grad():
293
+ error_color_arr = []
294
+
295
+ for idx in tqdm(range(num_tests)):
296
+ data = dataset[idx * len(dataset) // num_tests]
297
+ # retrieve the data
298
+ image_tensor = data['img'].to(device=cuda)
299
+ calib_tensor = data['calib'].to(device=cuda)
300
+ color_sample_tensor = data['color_samples'].to(
301
+ device=cuda).unsqueeze(0)
302
+
303
+ if opt.num_views > 1:
304
+ color_sample_tensor = reshape_sample_tensor(
305
+ color_sample_tensor, opt.num_views)
306
+
307
+ rgb_tensor = data['rgbs'].to(device=cuda).unsqueeze(0)
308
+
309
+ netG.filter(image_tensor)
310
+ _, errorC = netC.forward(image_tensor,
311
+ netG.get_im_feat(),
312
+ color_sample_tensor,
313
+ calib_tensor,
314
+ labels=rgb_tensor)
315
+
316
+ # print('{0}/{1} | Error inout: {2:06f} | Error color: {3:06f}'
317
+ # .format(idx, num_tests, errorG.item(), errorC.item()))
318
+ error_color_arr.append(errorC.item())
319
+
320
+ return np.average(error_color_arr)
321
+
322
+
323
+ # pytorch lightning training related fucntions
324
+
325
+
326
+ def query_func(opt, netG, features, points, proj_matrix=None):
327
+ '''
328
+ - points: size of (bz, N, 3)
329
+ - proj_matrix: size of (bz, 4, 4)
330
+ return: size of (bz, 1, N)
331
+ '''
332
+ assert len(points) == 1
333
+ samples = points.repeat(opt.num_views, 1, 1)
334
+ samples = samples.permute(0, 2, 1) # [bz, 3, N]
335
+
336
+ # view specific query
337
+ if proj_matrix is not None:
338
+ samples = orthogonal(samples, proj_matrix)
339
+
340
+ calib_tensor = torch.stack([torch.eye(4).float()], dim=0).type_as(samples)
341
+
342
+ preds = netG.query(features=features,
343
+ points=samples,
344
+ calibs=calib_tensor,
345
+ regressor=netG.if_regressor)
346
+
347
+ if type(preds) is list:
348
+ preds = preds[0]
349
+
350
+ return preds
351
+
352
+
353
+ def isin(ar1, ar2):
354
+ return (ar1[..., None] == ar2).any(-1)
355
+
356
+
357
+ def in1d(ar1, ar2):
358
+ mask = ar2.new_zeros((max(ar1.max(), ar2.max()) + 1, ), dtype=torch.bool)
359
+ mask[ar2.unique()] = True
360
+ return mask[ar1]
361
+
362
+
363
+ def get_visibility(xy, z, faces):
364
+ """get the visibility of vertices
365
+
366
+ Args:
367
+ xy (torch.tensor): [N,2]
368
+ z (torch.tensor): [N,1]
369
+ faces (torch.tensor): [N,3]
370
+ size (int): resolution of rendered image
371
+ """
372
+
373
+ xyz = torch.cat((xy, -z), dim=1)
374
+ xyz = (xyz + 1.0) / 2.0
375
+ faces = faces.long()
376
+
377
+ rasterizer = Pytorch3dRasterizer(image_size=2**12)
378
+ meshes_screen = Meshes(verts=xyz[None, ...], faces=faces[None, ...])
379
+ raster_settings = rasterizer.raster_settings
380
+
381
+ pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
382
+ meshes_screen,
383
+ image_size=raster_settings.image_size,
384
+ blur_radius=raster_settings.blur_radius,
385
+ faces_per_pixel=raster_settings.faces_per_pixel,
386
+ bin_size=raster_settings.bin_size,
387
+ max_faces_per_bin=raster_settings.max_faces_per_bin,
388
+ perspective_correct=raster_settings.perspective_correct,
389
+ cull_backfaces=raster_settings.cull_backfaces,
390
+ )
391
+
392
+ vis_vertices_id = torch.unique(faces[torch.unique(pix_to_face), :])
393
+ vis_mask = torch.zeros(size=(z.shape[0], 1))
394
+ vis_mask[vis_vertices_id] = 1.0
395
+
396
+ # print("------------------------\n")
397
+ # print(f"keep points : {vis_mask.sum()/len(vis_mask)}")
398
+
399
+ return vis_mask
400
+
401
+
402
+ def batch_mean(res, key):
403
+ # recursive mean for multilevel dicts
404
+ return torch.stack([
405
+ x[key] if isinstance(x, dict) else batch_mean(x, key) for x in res
406
+ ]).mean()
407
+
408
+
409
+ def tf_log_convert(log_dict):
410
+ new_log_dict = log_dict.copy()
411
+ for k, v in log_dict.items():
412
+ new_log_dict[k.replace("_", "/")] = v
413
+ del new_log_dict[k]
414
+
415
+ return new_log_dict
416
+
417
+
418
+ def bar_log_convert(log_dict, name=None, rot=None):
419
+ from decimal import Decimal
420
+
421
+ new_log_dict = {}
422
+
423
+ if name is not None:
424
+ new_log_dict['name'] = name[0]
425
+ if rot is not None:
426
+ new_log_dict['rot'] = rot[0]
427
+
428
+ for k, v in log_dict.items():
429
+ color = "yellow"
430
+ if 'loss' in k:
431
+ color = "red"
432
+ k = k.replace("loss", "L")
433
+ elif 'acc' in k:
434
+ color = "green"
435
+ k = k.replace("acc", "A")
436
+ elif 'iou' in k:
437
+ color = "green"
438
+ k = k.replace("iou", "I")
439
+ elif 'prec' in k:
440
+ color = "green"
441
+ k = k.replace("prec", "P")
442
+ elif 'recall' in k:
443
+ color = "green"
444
+ k = k.replace("recall", "R")
445
+
446
+ if 'lr' not in k:
447
+ new_log_dict[colored(k.split("_")[1],
448
+ color)] = colored(f"{v:.3f}", color)
449
+ else:
450
+ new_log_dict[colored(k.split("_")[1],
451
+ color)] = colored(f"{Decimal(str(v)):.1E}",
452
+ color)
453
+
454
+ if 'loss' in new_log_dict.keys():
455
+ del new_log_dict['loss']
456
+
457
+ return new_log_dict
458
+
459
+
460
+ def accumulate(outputs, rot_num, split):
461
+
462
+ hparam_log_dict = {}
463
+
464
+ metrics = outputs[0].keys()
465
+ datasets = split.keys()
466
+
467
+ for dataset in datasets:
468
+ for metric in metrics:
469
+ keyword = f"hparam/{dataset}-{metric}"
470
+ if keyword not in hparam_log_dict.keys():
471
+ hparam_log_dict[keyword] = 0
472
+ for idx in range(split[dataset][0] * rot_num,
473
+ split[dataset][1] * rot_num):
474
+ hparam_log_dict[keyword] += outputs[idx][metric]
475
+ hparam_log_dict[keyword] /= (split[dataset][1] -
476
+ split[dataset][0]) * rot_num
477
+
478
+ print(colored(hparam_log_dict, "green"))
479
+
480
+ return hparam_log_dict
481
+
482
+
483
+ def calc_error_N(outputs, targets):
484
+ """calculate the error of normal (IGR)
485
+
486
+ Args:
487
+ outputs (torch.tensor): [B, 3, N]
488
+ target (torch.tensor): [B, N, 3]
489
+
490
+ # manifold loss and grad_loss in IGR paper
491
+ grad_loss = ((nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean()
492
+ normals_loss = ((mnfld_grad - normals).abs()).norm(2, dim=1).mean()
493
+
494
+ Returns:
495
+ torch.tensor: error of valid normals on the surface
496
+ """
497
+ # outputs = torch.tanh(-outputs.permute(0,2,1).reshape(-1,3))
498
+ outputs = -outputs.permute(0, 2, 1).reshape(-1, 1)
499
+ targets = targets.reshape(-1, 3)[:, 2:3]
500
+ with_normals = targets.sum(dim=1).abs() > 0.0
501
+
502
+ # eikonal loss
503
+ grad_loss = ((outputs[with_normals].norm(2, dim=-1) - 1)**2).mean()
504
+ # normals loss
505
+ normal_loss = (outputs - targets)[with_normals].abs().norm(2, dim=1).mean()
506
+
507
+ return grad_loss * 0.0 + normal_loss
508
+
509
+
510
+ def calc_knn_acc(preds, carn_verts, labels, pick_num):
511
+ """calculate knn accuracy
512
+
513
+ Args:
514
+ preds (torch.tensor): [B, 3, N]
515
+ carn_verts (torch.tensor): [SMPLX_V_num, 3]
516
+ labels (torch.tensor): [B, N_knn, N]
517
+ """
518
+ N_knn_full = labels.shape[1]
519
+ preds = preds.permute(0, 2, 1).reshape(-1, 3)
520
+ labels = labels.permute(0, 2, 1).reshape(-1, N_knn_full) # [BxN, num_knn]
521
+ labels = labels[:, :pick_num]
522
+
523
+ dist = torch.cdist(preds, carn_verts, p=2) # [BxN, SMPL_V_num]
524
+ knn = dist.topk(k=pick_num, dim=1, largest=False)[1] # [BxN, num_knn]
525
+ cat_mat = torch.sort(torch.cat((knn, labels), dim=1))[0]
526
+ bool_col = torch.zeros_like(cat_mat)[:, 0]
527
+ for i in range(pick_num * 2 - 1):
528
+ bool_col += cat_mat[:, i] == cat_mat[:, i + 1]
529
+ acc = (bool_col > 0).sum() / len(bool_col)
530
+
531
+ return acc
532
+
533
+
534
+ def calc_acc_seg(output, target, num_multiseg):
535
+ from pytorch_lightning.metrics import Accuracy
536
+ return Accuracy()(output.reshape(-1, num_multiseg).cpu(),
537
+ target.flatten().cpu())
538
+
539
+
540
+ def add_watermark(imgs, titles):
541
+
542
+ # Write some Text
543
+
544
+ font = cv2.FONT_HERSHEY_SIMPLEX
545
+ bottomLeftCornerOfText = (350, 50)
546
+ bottomRightCornerOfText = (800, 50)
547
+ fontScale = 1
548
+ fontColor = (1.0, 1.0, 1.0)
549
+ lineType = 2
550
+
551
+ for i in range(len(imgs)):
552
+
553
+ title = titles[i + 1]
554
+ cv2.putText(imgs[i], title, bottomLeftCornerOfText, font, fontScale,
555
+ fontColor, lineType)
556
+
557
+ if i == 0:
558
+ cv2.putText(imgs[i], str(titles[i][0]), bottomRightCornerOfText,
559
+ font, fontScale, fontColor, lineType)
560
+
561
+ result = np.concatenate(imgs, axis=0).transpose(2, 0, 1)
562
+
563
+ return result
564
+
565
+
566
+ def make_test_gif(img_dir):
567
+
568
+ if img_dir is not None and len(os.listdir(img_dir)) > 0:
569
+ for dataset in os.listdir(img_dir):
570
+ for subject in sorted(os.listdir(osp.join(img_dir, dataset))):
571
+ img_lst = []
572
+ im1 = None
573
+ for file in sorted(
574
+ os.listdir(osp.join(img_dir, dataset, subject))):
575
+ if file[-3:] not in ['obj', 'gif']:
576
+ img_path = os.path.join(img_dir, dataset, subject,
577
+ file)
578
+ if im1 == None:
579
+ im1 = Image.open(img_path)
580
+ else:
581
+ img_lst.append(Image.open(img_path))
582
+
583
+ print(os.path.join(img_dir, dataset, subject, "out.gif"))
584
+ im1.save(os.path.join(img_dir, dataset, subject, "out.gif"),
585
+ save_all=True,
586
+ append_images=img_lst,
587
+ duration=500,
588
+ loop=0)
589
+
590
+
591
+ def export_cfg(logger, cfg):
592
+
593
+ cfg_export_file = osp.join(logger.save_dir, logger.name,
594
+ f"version_{logger.version}", "cfg.yaml")
595
+
596
+ if not osp.exists(cfg_export_file):
597
+ os.makedirs(osp.dirname(cfg_export_file), exist_ok=True)
598
+ with open(cfg_export_file, "w+") as file:
599
+ _ = yaml.dump(cfg, file)