Ruicheng commited on
Commit
c9d074f
1 Parent(s): 40ec4f9

update edge removal

Browse files
app.py CHANGED
@@ -60,7 +60,8 @@ def run(image: np.ndarray, remove_edge: bool = True, max_size: int = 800):
60
  points, depth, mask = output['points'], output['depth'], output['mask']
61
 
62
  if remove_edge:
63
- mask = mask & ~utils3d.numpy.depth_edge(depth, mask=mask, rtol=0.02)
 
64
 
65
  faces, vertices, vertex_colors, vertex_uvs = utils3d.numpy.image_mesh(
66
  points,
 
60
  points, depth, mask = output['points'], output['depth'], output['mask']
61
 
62
  if remove_edge:
63
+ normals, normals_mask = utils3d.numpy.points_to_normals(points, mask=mask)
64
+ mask=mask & ~(utils3d.numpy.depth_edge(depth, rtol=0.03, mask=mask) & utils3d.numpy.normals_edge(normals, tol=5, mask=normals_mask))
65
 
66
  faces, vertices, vertex_colors, vertex_uvs = utils3d.numpy.image_mesh(
67
  points,
moge/model/moge_model.py CHANGED
@@ -15,7 +15,7 @@ import torch.version
15
  import utils3d
16
  from huggingface_hub import hf_hub_download
17
 
18
- from ..utils.geometry_torch import image_plane_uv, point_map_to_depth, gaussian_blur_2d
19
  from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing
20
  from ..utils.tools import timeit
21
 
@@ -121,7 +121,7 @@ class Head(nn.Module):
121
  # (patch_h, patch_w) -> (patch_h * 2, patch_w * 2) -> (patch_h * 4, patch_w * 4) -> (patch_h * 8, patch_w * 8)
122
  for i, block in enumerate(self.upsample_blocks):
123
  # UV coordinates is for awareness of image aspect ratio
124
- uv = image_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device)
125
  uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1)
126
  x = torch.cat([x, uv], dim=1)
127
  for layer in block:
@@ -129,7 +129,7 @@ class Head(nn.Module):
129
 
130
  # (patch_h * 8, patch_w * 8) -> (img_h, img_w)
131
  x = F.interpolate(x, (img_h, img_w), mode="bilinear", align_corners=False)
132
- uv = image_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device)
133
  uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1)
134
  x = torch.cat([x, uv], dim=1)
135
 
@@ -301,6 +301,7 @@ class MoGeModel(nn.Module):
301
  force_projection: bool = True,
302
  resolution_level: int = 9,
303
  apply_mask: bool = True,
 
304
  ) -> Dict[str, torch.Tensor]:
305
  """
306
  User-friendly inference function
@@ -308,7 +309,9 @@ class MoGeModel(nn.Module):
308
  ### Parameters
309
  - `image`: input image tensor of shape (B, 3, H, W) or (3, H, W)
310
  - `resolution_level`: the resolution level to use for the output point map in 0-9. Default: 9 (highest)
311
- - `interpolation_mode`: interpolation mode for the output points map. Default: 'bilinear'.
 
 
312
 
313
  ### Returns
314
 
@@ -325,6 +328,7 @@ class MoGeModel(nn.Module):
325
 
326
  original_height, original_width = image.shape[-2:]
327
  area = original_height * original_width
 
328
 
329
  min_area, max_area = self.trained_area_range
330
  expected_area = min_area + (max_area - min_area) * (resolution_level / 9)
@@ -336,15 +340,24 @@ class MoGeModel(nn.Module):
336
  output = self.forward(image)
337
  points, mask = output['points'], output.get('mask', None)
338
 
339
- # Get camera-origin-centered point map
340
- depth, fov_x, fov_y, z_shift = point_map_to_depth(points, None if mask is None else mask > 0.5)
341
- intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov_x, fov_y)
 
 
 
 
 
 
 
 
 
342
 
343
- # If projection constraint is forces, recompute the point map using the actual depth map
344
  if force_projection:
345
  points = utils3d.torch.unproject_cv(utils3d.torch.image_uv(width=expected_width, height=expected_height, dtype=points.dtype, device=points.device), depth, extrinsics=None, intrinsics=intrinsics[..., None, :, :])
346
  else:
347
- points = points + torch.stack([torch.zeros_like(z_shift), torch.zeros_like(z_shift), z_shift], dim=-1)[..., None, None, :]
348
 
349
  # Resize the output to the original resolution
350
  if expected_area != area:
@@ -373,4 +386,4 @@ class MoGeModel(nn.Module):
373
  if self.output_mask:
374
  return_dict['mask'] = mask > 0.5
375
 
376
- return return_dict
 
15
  import utils3d
16
  from huggingface_hub import hf_hub_download
17
 
18
+ from ..utils.geometry_torch import normalized_view_plane_uv, recover_focal_shift, gaussian_blur_2d
19
  from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing
20
  from ..utils.tools import timeit
21
 
 
121
  # (patch_h, patch_w) -> (patch_h * 2, patch_w * 2) -> (patch_h * 4, patch_w * 4) -> (patch_h * 8, patch_w * 8)
122
  for i, block in enumerate(self.upsample_blocks):
123
  # UV coordinates is for awareness of image aspect ratio
124
+ uv = normalized_view_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device)
125
  uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1)
126
  x = torch.cat([x, uv], dim=1)
127
  for layer in block:
 
129
 
130
  # (patch_h * 8, patch_w * 8) -> (img_h, img_w)
131
  x = F.interpolate(x, (img_h, img_w), mode="bilinear", align_corners=False)
132
+ uv = normalized_view_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device)
133
  uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1)
134
  x = torch.cat([x, uv], dim=1)
135
 
 
301
  force_projection: bool = True,
302
  resolution_level: int = 9,
303
  apply_mask: bool = True,
304
+ fov_x: Union[Number, torch.Tensor] = None
305
  ) -> Dict[str, torch.Tensor]:
306
  """
307
  User-friendly inference function
 
309
  ### Parameters
310
  - `image`: input image tensor of shape (B, 3, H, W) or (3, H, W)
311
  - `resolution_level`: the resolution level to use for the output point map in 0-9. Default: 9 (highest)
312
+ - `force_projection`: if True, the output point map will be computed using the actual depth map. Default: True
313
+ - `apply_mask`: if True, the output point map will be masked using the predicted mask. Default: True
314
+ - `fov_x`: the horizontal camera FoV in degrees. If None, it will be inferred from the predicted point map. Default: None
315
 
316
  ### Returns
317
 
 
328
 
329
  original_height, original_width = image.shape[-2:]
330
  area = original_height * original_width
331
+ aspect_ratio = original_width / original_height
332
 
333
  min_area, max_area = self.trained_area_range
334
  expected_area = min_area + (max_area - min_area) * (resolution_level / 9)
 
340
  output = self.forward(image)
341
  points, mask = output['points'], output.get('mask', None)
342
 
343
+ # Get camera-space point map. (Focal here is the focal length relative to half the image diagonal)
344
+ if fov_x is None:
345
+ focal, shift = recover_focal_shift(points, None if mask is None else mask > 0.5)
346
+ else:
347
+ focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 / torch.tan(torch.deg2rad(torch.as_tensor(fov_x, device=points.device, dtype=points.dtype) / 2))
348
+ if focal.ndim == 0:
349
+ focal = focal[None].expand(points.shape[0])
350
+ _, shift = recover_focal_shift(points, None if mask is None else mask > 0.5, focal=focal)
351
+ fx = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 / aspect_ratio
352
+ fy = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5
353
+ intrinsics = utils3d.torch.intrinsics_from_focal_center(fx, fy, 0.5, 0.5)
354
+ depth = points[..., 2] + shift[..., None, None]
355
 
356
+ # If projection constraint is forced, recompute the point map using the actual depth map
357
  if force_projection:
358
  points = utils3d.torch.unproject_cv(utils3d.torch.image_uv(width=expected_width, height=expected_height, dtype=points.dtype, device=points.device), depth, extrinsics=None, intrinsics=intrinsics[..., None, :, :])
359
  else:
360
+ points = points + torch.stack([torch.zeros_like(shift), torch.zeros_like(shift), shift], dim=-1)[..., None, None, :]
361
 
362
  # Resize the output to the original resolution
363
  if expected_area != area:
 
386
  if self.output_mask:
387
  return_dict['mask'] = mask > 0.5
388
 
389
+ return return_dict
moge/utils/geometry_numpy.py CHANGED
@@ -23,7 +23,7 @@ def harmonic_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tu
23
  return 1 / (weighted_mean_numpy(1 / (x + eps), w, axis=axis, keepdims=keepdims, eps=eps) + eps)
24
 
25
 
26
- def image_plane_uv_numpy(width: int, height: int, aspect_ratio: float = None, dtype: np.dtype = np.float32) -> np.ndarray:
27
  "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
28
  if aspect_ratio is None:
29
  aspect_ratio = width / height
@@ -52,7 +52,27 @@ def intrinsics_to_fov_numpy(intrinsics: np.ndarray) -> Tuple[np.ndarray, np.ndar
52
  return fov_x, fov_y
53
 
54
 
55
- def solve_optimal_shift_focal(uv: np.ndarray, xyz: np.ndarray, ransac_iters: int = None, ransac_hypothetical_size: float = 0.1, ransac_threshold: float = 0.1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift and focal"
57
  from scipy.optimize import least_squares
58
  uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1)
@@ -63,44 +83,39 @@ def solve_optimal_shift_focal(uv: np.ndarray, xyz: np.ndarray, ransac_iters: int
63
  err = (f * xy_proj - uv).ravel()
64
  return err
65
 
66
- initial_shift = 0 #-z.min(keepdims=True) + 1.0
67
-
68
- if ransac_iters is None:
69
- solution = least_squares(partial(fn, uv, xy, z), x0=initial_shift, ftol=1e-3, method='lm')
70
- optim_shift = solution['x'].squeeze().astype(np.float32)
71
- else:
72
- best_err, best_shift = np.inf, None
73
- for _ in range(ransac_iters):
74
- maybe_inliers = np.random.choice(len(z), size=int(ransac_hypothetical_size * len(z)), replace=False)
75
- solution = least_squares(partial(fn, uv[maybe_inliers], xy[maybe_inliers], z[maybe_inliers]), x0=initial_shift, ftol=1e-3, method='lm')
76
- maybe_shift = solution['x'].squeeze().astype(np.float32)
77
- confirmed_inliers = np.linalg.norm(fn(uv, xy, z, maybe_shift).reshape(-1, 2), axis=-1) < ransac_threshold
78
- if confirmed_inliers.sum() > 10:
79
- solution = least_squares(partial(fn, uv[confirmed_inliers], xy[confirmed_inliers], z[confirmed_inliers]), x0=maybe_shift, ftol=1e-3, method='lm')
80
- better_shift = solution['x'].squeeze().astype(np.float32)
81
- else:
82
- better_shift = maybe_shift
83
- err = np.linalg.norm(fn(uv, xy, z, better_shift).reshape(-1, 2), axis=-1).clip(max=ransac_threshold).mean()
84
- if err < best_err:
85
- best_err, best_shift = err, better_shift
86
- initial_shift = best_shift
87
-
88
- optim_shift = best_shift
89
 
90
  xy_proj = xy / (z + optim_shift)[: , None]
91
- optim_focal = (xy_proj * uv).sum() / (xy_proj * xy_proj).sum()
92
 
93
  return optim_shift, optim_focal
94
 
95
 
96
- def point_map_to_depth_numpy(points: np.ndarray, mask: np.ndarray = None, downsample_size: Tuple[int, int] = (64, 64)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  import cv2
98
  assert points.shape[-1] == 3, "Points should (H, W, 3)"
99
 
100
  height, width = points.shape[-3], points.shape[-2]
101
  diagonal = (height ** 2 + width ** 2) ** 0.5
102
 
103
- uv = image_plane_uv_numpy(width=width, height=height)
104
 
105
  if mask is None:
106
  points_lr = cv2.resize(points, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 3)
@@ -112,13 +127,12 @@ def point_map_to_depth_numpy(points: np.ndarray, mask: np.ndarray = None, downsa
112
  if points_lr.size == 0:
113
  return np.zeros((height, width)), 0, 0, 0
114
 
115
- optim_shift, optim_focal = solve_optimal_shift_focal(uv_lr, points_lr, ransac_iters=None)
 
 
 
116
 
117
- fov_x = 2 * np.arctan(width / diagonal / optim_focal)
118
- fov_y = 2 * np.arctan(height / diagonal / optim_focal)
119
-
120
- depth = points[:, :, 2] + optim_shift
121
- return depth, fov_x, fov_y, optim_shift
122
 
123
 
124
  def mask_aware_nearest_resize_numpy(mask: np.ndarray, target_width: int, target_height: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
 
23
  return 1 / (weighted_mean_numpy(1 / (x + eps), w, axis=axis, keepdims=keepdims, eps=eps) + eps)
24
 
25
 
26
+ def normalized_view_plane_uv_numpy(width: int, height: int, aspect_ratio: float = None, dtype: np.dtype = np.float32) -> np.ndarray:
27
  "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
28
  if aspect_ratio is None:
29
  aspect_ratio = width / height
 
52
  return fov_x, fov_y
53
 
54
 
55
+ def point_map_to_depth_legacy_numpy(points: np.ndarray):
56
+ height, width = points.shape[-3:-1]
57
+ diagonal = (height ** 2 + width ** 2) ** 0.5
58
+ uv = normalized_view_plane_uv_numpy(width, height, dtype=points.dtype) # (H, W, 2)
59
+ _, uv = np.broadcast_arrays(points[..., :2], uv)
60
+
61
+ # Solve least squares problem
62
+ b = (uv * points[..., 2:]).reshape(*points.shape[:-3], -1) # (..., H * W * 2)
63
+ A = np.stack([points[..., :2], -uv], axis=-1).reshape(*points.shape[:-3], -1, 2) # (..., H * W * 2, 2)
64
+
65
+ M = A.swapaxes(-2, -1) @ A
66
+ solution = (np.linalg.inv(M + 1e-6 * np.eye(2)) @ (A.swapaxes(-2, -1) @ b[..., None])).squeeze(-1)
67
+ focal, shift = solution
68
+
69
+ depth = points[..., 2] + shift[..., None, None]
70
+ fov_x = np.arctan(width / diagonal / focal) * 2
71
+ fov_y = np.arctan(height / diagonal / focal) * 2
72
+ return depth, fov_x, fov_y, shift
73
+
74
+
75
+ def solve_optimal_focal_shift(uv: np.ndarray, xyz: np.ndarray):
76
  "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift and focal"
77
  from scipy.optimize import least_squares
78
  uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1)
 
83
  err = (f * xy_proj - uv).ravel()
84
  return err
85
 
86
+ solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method='lm')
87
+ optim_shift = solution['x'].squeeze().astype(np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  xy_proj = xy / (z + optim_shift)[: , None]
90
+ optim_focal = (xy_proj * uv).sum() / np.square(xy_proj).sum()
91
 
92
  return optim_shift, optim_focal
93
 
94
 
95
+ def solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray, focal: float):
96
+ "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift"
97
+ from scipy.optimize import least_squares
98
+ uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1)
99
+
100
+ def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray):
101
+ xy_proj = xy/ (z + shift)[: , None]
102
+ err = (focal * xy_proj - uv).ravel()
103
+ return err
104
+
105
+ solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method='lm')
106
+ optim_shift = solution['x'].squeeze().astype(np.float32)
107
+
108
+ return optim_shift
109
+
110
+
111
+ def recover_focal_shift_numpy(points: np.ndarray, mask: np.ndarray = None, focal: float = None, downsample_size: Tuple[int, int] = (64, 64)):
112
  import cv2
113
  assert points.shape[-1] == 3, "Points should (H, W, 3)"
114
 
115
  height, width = points.shape[-3], points.shape[-2]
116
  diagonal = (height ** 2 + width ** 2) ** 0.5
117
 
118
+ uv = normalized_view_plane_uv_numpy(width=width, height=height)
119
 
120
  if mask is None:
121
  points_lr = cv2.resize(points, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 3)
 
127
  if points_lr.size == 0:
128
  return np.zeros((height, width)), 0, 0, 0
129
 
130
+ if focal is None:
131
+ focal, shift = solve_optimal_focal_shift(uv_lr, points_lr)
132
+ else:
133
+ shift = solve_optimal_shift(uv_lr, points_lr, focal)
134
 
135
+ return focal, shift
 
 
 
 
136
 
137
 
138
  def mask_aware_nearest_resize_numpy(mask: np.ndarray, target_width: int, target_height: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
moge/utils/geometry_torch.py CHANGED
@@ -10,7 +10,7 @@ import torch.types
10
  import utils3d
11
 
12
  from .tools import timeit
13
- from .geometry_numpy import solve_optimal_shift_focal
14
 
15
 
16
  def weighted_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
@@ -37,7 +37,7 @@ def geometric_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torc
37
  return weighted_mean(x.add(eps).log(), w, dim=dim, keepdim=keepdim, eps=eps).exp()
38
 
39
 
40
- def image_plane_uv(width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None) -> torch.Tensor:
41
  "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
42
  if aspect_ratio is None:
43
  aspect_ratio = width / height
@@ -61,23 +61,6 @@ def gaussian_blur_2d(input: torch.Tensor, kernel_size: int, sigma: float) -> tor
61
  return input
62
 
63
 
64
- def split_batch_fwd(fn: Callable, chunk_size: int, *args, **kwargs):
65
- batch_size = next(x for x in (*args, *kwargs.values()) if isinstance(x, torch.Tensor)).shape[0]
66
- n_chunks = batch_size // chunk_size + (batch_size % chunk_size > 0)
67
- splited_args = tuple(arg.split(chunk_size, dim=0) if isinstance(arg, torch.Tensor) else [arg] * n_chunks for arg in args)
68
- splited_kwargs = {k: [v.split(chunk_size, dim=0) if isinstance(v, torch.Tensor) else [v] * n_chunks] for k, v in kwargs.items()}
69
- results = []
70
- for i in range(n_chunks):
71
- chunk_args = tuple(arg[i] for arg in splited_args)
72
- chunk_kwargs = {k: v[i] for k, v in splited_kwargs.items()}
73
- results.append(fn(*chunk_args, **chunk_kwargs))
74
-
75
- if isinstance(results[0], tuple):
76
- return tuple(torch.cat(r, dim=0) for r in zip(*results))
77
- else:
78
- return torch.cat(results, dim=0)
79
-
80
-
81
  def focal_to_fov(focal: torch.Tensor):
82
  return 2 * torch.atan(0.5 / focal)
83
 
@@ -104,7 +87,7 @@ def intrinsics_to_fov(intrinsics: torch.Tensor):
104
  def point_map_to_depth_legacy(points: torch.Tensor):
105
  height, width = points.shape[-3:-1]
106
  diagonal = (height ** 2 + width ** 2) ** 0.5
107
- uv = image_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2)
108
 
109
  # Solve least squares problem
110
  b = (uv * points[..., 2:]).flatten(-3, -1) # (..., H * W * 2)
@@ -120,7 +103,13 @@ def point_map_to_depth_legacy(points: torch.Tensor):
120
  return depth, fov_x, fov_y, shift
121
 
122
 
123
- def point_map_to_depth(points: torch.Tensor, mask: torch.Tensor = None, downsample_size: Tuple[int, int] = (64, 64)):
 
 
 
 
 
 
124
  """
125
  Recover the depth map and FoV from a point map with unknown z shift and focal.
126
 
@@ -131,13 +120,13 @@ def point_map_to_depth(points: torch.Tensor, mask: torch.Tensor = None, downsamp
131
 
132
  ### Parameters:
133
  - `points: torch.Tensor` of shape (..., H, W, 3)
 
 
134
  - `downsample_size: Tuple[int, int]` in (height, width), the size of the downsampled map. Downsampling produces approximate solution and is efficient for large maps.
135
 
136
  ### Returns:
137
- - `depth: torch.Tensor` of shape (..., H, W)
138
- - `fov_x: torch.Tensor` of shape (...)
139
- - `fov_y: torch.Tensor` of shape (...)
140
- - `shift: torch.Tensor` of shape (...), the z shift, making `depth = points[..., 2] + shift`
141
  """
142
  shape = points.shape
143
  height, width = points.shape[-3], points.shape[-2]
@@ -145,7 +134,8 @@ def point_map_to_depth(points: torch.Tensor, mask: torch.Tensor = None, downsamp
145
 
146
  points = points.reshape(-1, *shape[-3:])
147
  mask = None if mask is None else mask.reshape(-1, *shape[-3:-1])
148
- uv = image_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2)
 
149
 
150
  points_lr = F.interpolate(points.permute(0, 3, 1, 2), downsample_size, mode='nearest').permute(0, 2, 3, 1)
151
  uv_lr = F.interpolate(uv.unsqueeze(0).permute(0, 3, 1, 2), downsample_size, mode='nearest').squeeze(0).permute(1, 2, 0)
@@ -153,26 +143,26 @@ def point_map_to_depth(points: torch.Tensor, mask: torch.Tensor = None, downsamp
153
 
154
  uv_lr_np = uv_lr.cpu().numpy()
155
  points_lr_np = points_lr.detach().cpu().numpy()
 
156
  mask_lr_np = None if mask is None else mask_lr.cpu().numpy()
157
  optim_shift, optim_focal = [], []
158
  for i in range(points.shape[0]):
159
  points_lr_i_np = points_lr_np[i] if mask is None else points_lr_np[i][mask_lr_np[i]]
160
  uv_lr_i_np = uv_lr_np if mask is None else uv_lr_np[mask_lr_np[i]]
161
- optim_shift_i, optim_focal_i = solve_optimal_shift_focal(uv_lr_i_np, points_lr_i_np, ransac_iters=None)
 
 
 
 
162
  optim_shift.append(float(optim_shift_i))
163
- optim_focal.append(float(optim_focal_i))
164
- optim_shift = torch.tensor(optim_shift, device=points.device, dtype=points.dtype)
165
- optim_focal = torch.tensor(optim_focal, device=points.device, dtype=points.dtype)
166
 
167
- fov_x = 2 * torch.atan(width / diagonal / optim_focal)
168
- fov_y = 2 * torch.atan(height / diagonal / optim_focal)
169
-
170
- depth = (points[..., 2] + optim_shift[:, None, None]).reshape(shape[:-1])
171
- fov_x = fov_x.reshape(shape[:-3])
172
- fov_y = fov_y.reshape(shape[:-3])
173
- optim_shift = optim_shift.reshape(shape[:-3])
174
 
175
- return depth, fov_x, fov_y, optim_shift
176
 
177
 
178
  def mask_aware_nearest_resize(mask: torch.BoolTensor, target_width: int, target_height: int) -> Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]:
@@ -227,5 +217,3 @@ def mask_aware_nearest_resize(mask: torch.BoolTensor, target_width: int, target_
227
  batch_indices = [torch.arange(n, device=device).reshape([1] * i + [n] + [1] * (mask.dim() - i - 1)) for i, n in enumerate(mask.shape[:-2])]
228
 
229
  return (*batch_indices, nearest_i, nearest_j), target_mask
230
-
231
-
 
10
  import utils3d
11
 
12
  from .tools import timeit
13
+ from .geometry_numpy import solve_optimal_focal_shift, solve_optimal_shift
14
 
15
 
16
  def weighted_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
 
37
  return weighted_mean(x.add(eps).log(), w, dim=dim, keepdim=keepdim, eps=eps).exp()
38
 
39
 
40
+ def normalized_view_plane_uv(width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None) -> torch.Tensor:
41
  "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
42
  if aspect_ratio is None:
43
  aspect_ratio = width / height
 
61
  return input
62
 
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def focal_to_fov(focal: torch.Tensor):
65
  return 2 * torch.atan(0.5 / focal)
66
 
 
87
  def point_map_to_depth_legacy(points: torch.Tensor):
88
  height, width = points.shape[-3:-1]
89
  diagonal = (height ** 2 + width ** 2) ** 0.5
90
+ uv = normalized_view_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2)
91
 
92
  # Solve least squares problem
93
  b = (uv * points[..., 2:]).flatten(-3, -1) # (..., H * W * 2)
 
103
  return depth, fov_x, fov_y, shift
104
 
105
 
106
+ def view_plane_uv_to_focal(uv: torch.Tensor):
107
+ normed_uv = normalized_view_plane_uv(width=uv.shape[-2], height=uv.shape[-3], device=uv.device, dtype=uv.dtype)
108
+ focal = (uv * normed_uv).sum() / uv.square().sum().add(1e-12)
109
+ return focal
110
+
111
+
112
+ def recover_focal_shift(points: torch.Tensor, mask: torch.Tensor = None, focal: torch.Tensor = None, downsample_size: Tuple[int, int] = (64, 64)):
113
  """
114
  Recover the depth map and FoV from a point map with unknown z shift and focal.
115
 
 
120
 
121
  ### Parameters:
122
  - `points: torch.Tensor` of shape (..., H, W, 3)
123
+ - `mask: torch.Tensor` of shape (..., H, W). Optional.
124
+ - `focal: torch.Tensor` of shape (...). Optional.
125
  - `downsample_size: Tuple[int, int]` in (height, width), the size of the downsampled map. Downsampling produces approximate solution and is efficient for large maps.
126
 
127
  ### Returns:
128
+ - `focal`: torch.Tensor of shape (...) the estimated focal length, relative to the half diagonal of the map
129
+ - `shift`: torch.Tensor of shape (...) Z-axis shift to translate the point map to camera space
 
 
130
  """
131
  shape = points.shape
132
  height, width = points.shape[-3], points.shape[-2]
 
134
 
135
  points = points.reshape(-1, *shape[-3:])
136
  mask = None if mask is None else mask.reshape(-1, *shape[-3:-1])
137
+ focal = focal.reshape(-1) if focal is not None else None
138
+ uv = normalized_view_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2)
139
 
140
  points_lr = F.interpolate(points.permute(0, 3, 1, 2), downsample_size, mode='nearest').permute(0, 2, 3, 1)
141
  uv_lr = F.interpolate(uv.unsqueeze(0).permute(0, 3, 1, 2), downsample_size, mode='nearest').squeeze(0).permute(1, 2, 0)
 
143
 
144
  uv_lr_np = uv_lr.cpu().numpy()
145
  points_lr_np = points_lr.detach().cpu().numpy()
146
+ focal_np = focal.cpu().numpy() if focal is not None else None
147
  mask_lr_np = None if mask is None else mask_lr.cpu().numpy()
148
  optim_shift, optim_focal = [], []
149
  for i in range(points.shape[0]):
150
  points_lr_i_np = points_lr_np[i] if mask is None else points_lr_np[i][mask_lr_np[i]]
151
  uv_lr_i_np = uv_lr_np if mask is None else uv_lr_np[mask_lr_np[i]]
152
+ if focal is None:
153
+ optim_shift_i, optim_focal_i = solve_optimal_focal_shift(uv_lr_i_np, points_lr_i_np)
154
+ optim_focal.append(float(optim_focal_i))
155
+ else:
156
+ optim_shift_i = solve_optimal_shift(uv_lr_i_np, points_lr_i_np, focal_np[i])
157
  optim_shift.append(float(optim_shift_i))
158
+ optim_shift = torch.tensor(optim_shift, device=points.device, dtype=points.dtype).reshape(shape[:-3])
 
 
159
 
160
+ if focal is None:
161
+ optim_focal = torch.tensor(optim_focal, device=points.device, dtype=points.dtype).reshape(shape[:-3])
162
+ else:
163
+ optim_focal = focal.reshape(shape[:-3])
 
 
 
164
 
165
+ return optim_focal, optim_shift
166
 
167
 
168
  def mask_aware_nearest_resize(mask: torch.BoolTensor, target_width: int, target_height: int) -> Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]:
 
217
  batch_indices = [torch.arange(n, device=device).reshape([1] * i + [n] + [1] * (mask.dim() - i - 1)) for i, n in enumerate(mask.shape[:-2])]
218
 
219
  return (*batch_indices, nearest_i, nearest_j), target_mask
 
 
moge/utils/io.py CHANGED
@@ -345,3 +345,47 @@ def read_rgbxyz(file: Union[IO, str, Path]) -> Tuple[np.ndarray, np.ndarray, np.
345
  mask = np.ones(image.shape[:2], dtype=bool)
346
 
347
  return image, points, mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  mask = np.ones(image.shape[:2], dtype=bool)
346
 
347
  return image, points, mask
348
+
349
+
350
+ def save_glb(
351
+ save_path: Union[str, os.PathLike],
352
+ vertices: np.ndarray,
353
+ faces: np.ndarray,
354
+ vertex_uvs: np.ndarray,
355
+ texture: np.ndarray,
356
+ ):
357
+ import trimesh
358
+ import trimesh.visual
359
+ from PIL import Image
360
+
361
+ trimesh.Trimesh(
362
+ vertices=vertices,
363
+ faces=faces,
364
+ visual = trimesh.visual.texture.TextureVisuals(
365
+ uv=vertex_uvs,
366
+ material=trimesh.visual.material.PBRMaterial(
367
+ baseColorTexture=Image.fromarray(texture),
368
+ metallicFactor=0.5,
369
+ roughnessFactor=1.0
370
+ )
371
+ ),
372
+ process=False
373
+ ).export(save_path)
374
+
375
+
376
+ def save_ply(
377
+ save_path: Union[str, os.PathLike],
378
+ vertices: np.ndarray,
379
+ faces: np.ndarray,
380
+ vertex_colors: np.ndarray,
381
+ ):
382
+ import trimesh
383
+ import trimesh.visual
384
+ from PIL import Image
385
+
386
+ trimesh.Trimesh(
387
+ vertices=vertices,
388
+ faces=faces,
389
+ vertex_colors=vertex_colors,
390
+ process=False
391
+ ).export(save_path)
utils3d/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # utils3d
2
+
3
+ This is a collection of utility functions for 3D computer vision tasks copied from https://github.com/EasternJournalist/utils3d.
utils3d/io/__init__.py CHANGED
@@ -1,4 +1,3 @@
1
- from .wavefront_obj import *
2
  from .colmap import *
3
  from .ply import *
4
- from .glb import *
 
1
+ from .obj import *
2
  from .colmap import *
3
  from .ply import *
 
utils3d/io/colmap.py CHANGED
@@ -33,7 +33,7 @@ def write_extrinsics_as_colmap(file: Union[str, Path], extrinsics: np.ndarray, i
33
  with open(file, 'w') as fp:
34
  print("# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME", file=fp)
35
  for i, (quat, t, name, camera_id) in enumerate(zip(quats.tolist(), trans.tolist(), image_names, camera_ids)):
36
- # Colmap has wxyz order while scipy.spatial.transform.Rotation has xyzw order. Haha, wcnm.
37
  qx, qy, qz, qw = quat
38
  tx, ty, tz = t
39
  print(f'{i + 1} {qw:f} {qx:f} {qy:f} {qz:f} {tx:f} {ty:f} {tz:f} {camera_id:d} {name}', file=fp)
 
33
  with open(file, 'w') as fp:
34
  print("# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME", file=fp)
35
  for i, (quat, t, name, camera_id) in enumerate(zip(quats.tolist(), trans.tolist(), image_names, camera_ids)):
36
+ # Colmap has wxyz order while scipy.spatial.transform.Rotation has xyzw order.
37
  qx, qy, qz, qw = quat
38
  tx, ty, tz = t
39
  print(f'{i + 1} {qw:f} {qx:f} {qy:f} {qz:f} {tx:f} {ty:f} {tz:f} {camera_id:d} {name}', file=fp)
utils3d/io/obj.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import TextIOWrapper
2
+ from typing import Dict, Any, Union, Iterable
3
+ import numpy as np
4
+ from pathlib import Path
5
+
6
+ __all__ = [
7
+ 'read_obj',
8
+ 'write_obj',
9
+ 'simple_write_obj'
10
+ ]
11
+
12
+ def read_obj(
13
+ file : Union[str, Path, TextIOWrapper],
14
+ encoding: Union[str, None] = None,
15
+ ignore_unknown: bool = False
16
+ ):
17
+ """
18
+ Read wavefront .obj file, without preprocessing.
19
+
20
+ Why bothering having this read_obj() while we already have other libraries like `trimesh`?
21
+ This function read the raw format from .obj file and keeps the order of vertices and faces,
22
+ while trimesh which involves modification like merge/split vertices, which could break the orders of vertices and faces,
23
+ Those libraries are commonly aiming at geometry processing and rendering supporting various formats.
24
+ If you want mesh geometry processing, you may turn to `trimesh` for more features.
25
+
26
+ ### Parameters
27
+ `file` (str, Path, TextIOWrapper): filepath or file object
28
+ encoding (str, optional):
29
+
30
+ ### Returns
31
+ obj (dict): A dict containing .obj components
32
+ {
33
+ 'mtllib': [],
34
+ 'v': [[0,1, 0.2, 1.0], [1.2, 0.0, 0.0], ...],
35
+ 'vt': [[0.5, 0.5], ...],
36
+ 'vn': [[0., 0.7, 0.7], [0., -0.7, 0.7], ...],
37
+ 'f': [[0, 1, 2], [2, 3, 4],...],
38
+ 'usemtl': [{'name': 'mtl1', 'f': 7}]
39
+ }
40
+ """
41
+ if hasattr(file,'read'):
42
+ lines = file.read().splitlines()
43
+ else:
44
+ with open(file, 'r', encoding=encoding) as fp:
45
+ lines = fp.read().splitlines()
46
+ mtllib = []
47
+ v, vt, vn, vp = [], [], [], [] # Vertex coordinates, Vertex texture coordinate, Vertex normal, Vertex parameter
48
+ f, ft, fn = [], [], [] # Face indices, Face texture indices, Face normal indices
49
+ o = []
50
+ s = []
51
+ usemtl = []
52
+
53
+ def pad(l: list, n: Any):
54
+ return l + [n] * (3 - len(l))
55
+
56
+ for i, line in enumerate(lines):
57
+ sq = line.strip().split()
58
+ if len(sq) == 0:
59
+ continue
60
+ if sq[0] == 'v':
61
+ assert 4 <= len(sq) <= 5, f'Invalid format of line {i}: {line}'
62
+ v.append([float(e) for e in sq[1:]][:3])
63
+ elif sq[0] == 'vt':
64
+ assert 3 <= len(sq) <= 4, f'Invalid format of line {i}: {line}'
65
+ vt.append([float(e) for e in sq[1:]][:2])
66
+ elif sq[0] == 'vn':
67
+ assert len(sq) == 4, f'Invalid format of line {i}: {line}'
68
+ vn.append([float(e) for e in sq[1:]])
69
+ elif sq[0] == 'vp':
70
+ assert 2 <= len(sq) <= 4, f'Invalid format of line {i}: {line}'
71
+ vp.append(pad([float(e) for e in sq[1:]], 0))
72
+ elif sq[0] == 'f':
73
+ spliting = [pad([int(j) - 1 for j in e.split('/')], -1) for e in sq[1:]]
74
+ f.append([e[0] for e in spliting])
75
+ ft.append([e[1] for e in spliting])
76
+ fn.append([e[2] for e in spliting])
77
+ elif sq[0] == 'usemtl':
78
+ assert len(sq) == 2
79
+ usemtl.append((sq[1], len(f)))
80
+ elif sq[0] == 'o':
81
+ assert len(sq) == 2
82
+ o.append((sq[1], len(f)))
83
+ elif sq[0] == 's':
84
+ s.append((sq[1], len(f)))
85
+ elif sq[0] == 'mtllib':
86
+ assert len(sq) == 2
87
+ mtllib.append(sq[1])
88
+ elif sq[0][0] == '#':
89
+ continue
90
+ else:
91
+ if not ignore_unknown:
92
+ raise Exception(f'Unknown keyword {sq[0]}')
93
+
94
+ min_poly_vertices = min(len(f) for f in f)
95
+ max_poly_vertices = max(len(f) for f in f)
96
+
97
+ return {
98
+ 'mtllib': mtllib,
99
+ 'v': np.array(v, dtype=np.float32),
100
+ 'vt': np.array(vt, dtype=np.float32),
101
+ 'vn': np.array(vn, dtype=np.float32),
102
+ 'vp': np.array(vp, dtype=np.float32),
103
+ 'f': np.array(f, dtype=np.int32) if min_poly_vertices == max_poly_vertices else f,
104
+ 'ft': np.array(ft, dtype=np.int32) if min_poly_vertices == max_poly_vertices else ft,
105
+ 'fn': np.array(fn, dtype=np.int32) if min_poly_vertices == max_poly_vertices else fn,
106
+ 'o': o,
107
+ 's': s,
108
+ 'usemtl': usemtl,
109
+ }
110
+
111
+
112
+ def write_obj(
113
+ file: Union[str, Path],
114
+ obj: Dict[str, Any],
115
+ encoding: Union[str, None] = None
116
+ ):
117
+ with open(file, 'w', encoding=encoding) as fp:
118
+ for k in ['v', 'vt', 'vn', 'vp']:
119
+ if k not in obj:
120
+ continue
121
+ for v in obj[k]:
122
+ print(k, *map(float, v), file=fp)
123
+ for f in obj['f']:
124
+ print('f', *((str('/').join(map(int, i)) if isinstance(int(i), Iterable) else i) for i in f), file=fp)
125
+
126
+
127
+ def simple_write_obj(
128
+ file: Union[str, Path],
129
+ vertices: np.ndarray,
130
+ faces: np.ndarray,
131
+ encoding: Union[str, None] = None
132
+ ):
133
+ """
134
+ Write wavefront .obj file, without preprocessing.
135
+
136
+ Args:
137
+ vertices (np.ndarray): [N, 3]
138
+ faces (np.ndarray): [T, 3]
139
+ file (Any): filepath
140
+ encoding (str, optional):
141
+ """
142
+ with open(file, 'w', encoding=encoding) as fp:
143
+ for v in vertices:
144
+ print('v', *map(float, v), file=fp)
145
+ for f in faces:
146
+ print('f', *map(int, f + 1), file=fp)
utils3d/numpy/__init__.py CHANGED
@@ -37,6 +37,7 @@ __modules_all__ = {
37
  'max_pool_2d',
38
  'max_pool_nd',
39
  'depth_edge',
 
40
  'depth_aliasing',
41
  'interpolate',
42
  'image_scrcoord',
@@ -45,10 +46,11 @@ __modules_all__ = {
45
  'image_pixel',
46
  'image_mesh',
47
  'image_mesh_from_depth',
48
- 'depth_to_normal',
49
- 'point_to_normal',
50
  'chessboard',
51
  'cube',
 
52
  'square',
53
  'camera_frustum',
54
  ],
 
37
  'max_pool_2d',
38
  'max_pool_nd',
39
  'depth_edge',
40
+ 'normals_edge',
41
  'depth_aliasing',
42
  'interpolate',
43
  'image_scrcoord',
 
46
  'image_pixel',
47
  'image_mesh',
48
  'image_mesh_from_depth',
49
+ 'depth_to_normals',
50
+ 'points_to_normals',
51
  'chessboard',
52
  'cube',
53
+ 'icosahedron',
54
  'square',
55
  'camera_frustum',
56
  ],
utils3d/numpy/rasterization.py CHANGED
@@ -460,10 +460,8 @@ def test():
460
  faces,
461
  attr,
462
  512, 512,
463
- view=view,
464
- projection=perspective,
465
  cull_backface=True,
466
- ssaa=1,
467
  return_depth=True,
468
  )
469
  import cv2
 
460
  faces,
461
  attr,
462
  512, 512,
463
+ transform=perspective @ view,
 
464
  cull_backface=True,
 
465
  return_depth=True,
466
  )
467
  import cv2
utils3d/numpy/transforms.py CHANGED
@@ -474,7 +474,7 @@ def uv_to_pixel(
474
  Returns:
475
  (np.ndarray): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1)
476
  """
477
- pixel = uv * np.stack([width, height], axis=-1) - 0.5
478
  return pixel
479
 
480
 
@@ -645,7 +645,7 @@ def unproject_gl(
645
  @batched(2,1,2,2)
646
  def unproject_cv(
647
  uv_coord: np.ndarray,
648
- depth: np.ndarray,
649
  extrinsics: np.ndarray = None,
650
  intrinsics: np.ndarray = None
651
  ) -> np.ndarray:
@@ -665,7 +665,8 @@ def unproject_cv(
665
  assert intrinsics is not None, "intrinsics matrix is required"
666
  points = np.concatenate([uv_coord, np.ones_like(uv_coord[..., :1])], axis=-1)
667
  points = points @ np.linalg.inv(intrinsics).swapaxes(-1, -2)
668
- points = points * depth[..., None]
 
669
  if extrinsics is not None:
670
  points = np.concatenate([points, np.ones_like(points[..., :1])], axis=-1)
671
  points = (points @ np.linalg.inv(extrinsics).swapaxes(-1, -2))[..., :3]
 
474
  Returns:
475
  (np.ndarray): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1)
476
  """
477
+ pixel = uv * np.stack([width, height], axis=-1).astype(uv.dtype) - 0.5
478
  return pixel
479
 
480
 
 
645
  @batched(2,1,2,2)
646
  def unproject_cv(
647
  uv_coord: np.ndarray,
648
+ depth: np.ndarray = None,
649
  extrinsics: np.ndarray = None,
650
  intrinsics: np.ndarray = None
651
  ) -> np.ndarray:
 
665
  assert intrinsics is not None, "intrinsics matrix is required"
666
  points = np.concatenate([uv_coord, np.ones_like(uv_coord[..., :1])], axis=-1)
667
  points = points @ np.linalg.inv(intrinsics).swapaxes(-1, -2)
668
+ if depth is not None:
669
+ points = points * depth[..., None]
670
  if extrinsics is not None:
671
  points = np.concatenate([points, np.ones_like(points[..., :1])], axis=-1)
672
  points = (points @ np.linalg.inv(extrinsics).swapaxes(-1, -2))[..., :3]
utils3d/numpy/utils.py CHANGED
@@ -1,6 +1,8 @@
1
  import numpy as np
2
  from typing import *
3
  from numbers import Number
 
 
4
 
5
  from ._helpers import batched
6
  from . import transforms
@@ -14,6 +16,7 @@ __all__ = [
14
  'max_pool_2d',
15
  'max_pool_nd',
16
  'depth_edge',
 
17
  'depth_aliasing',
18
  'interpolate',
19
  'image_scrcoord',
@@ -22,16 +25,29 @@ __all__ = [
22
  'image_pixel',
23
  'image_mesh',
24
  'image_mesh_from_depth',
25
- 'depth_to_normal',
26
- 'point_to_normal',
27
  'chessboard',
28
  'cube',
 
29
  'square',
30
  'camera_frustum',
31
  'to4x4'
32
  ]
33
 
34
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def sliding_window_1d(x: np.ndarray, window_size: int, stride: int, axis: int = -1):
36
  """
37
  Return x view of the input array with x sliding window of the given kernel size and stride.
@@ -97,9 +113,10 @@ def max_pool_2d(x: np.ndarray, kernel_size: Union[int, Tuple[int, int]], stride:
97
  return max_pool_nd(x, kernel_size, stride, padding, axis)
98
 
99
 
 
100
  def depth_edge(depth: np.ndarray, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: np.ndarray = None) -> np.ndarray:
101
  """
102
- Compute the edge mask of x depth map. The edge is defined as the pixels whose neighbors have x large difference in depth.
103
 
104
  Args:
105
  depth (np.ndarray): shape (..., height, width), linear depth map
@@ -117,11 +134,15 @@ def depth_edge(depth: np.ndarray, atol: float = None, rtol: float = None, kernel
117
  edge = np.zeros_like(depth, dtype=bool)
118
  if atol is not None:
119
  edge |= diff > atol
120
- if rtol is not None:
121
- edge |= diff / depth > rtol
 
 
 
122
  return edge
123
 
124
 
 
125
  def depth_aliasing(depth: np.ndarray, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: np.ndarray = None) -> np.ndarray:
126
  """
127
  Compute the map that indicates the aliasing of x depth map. The aliasing is defined as the pixels which neither close to the maximum nor the minimum of its neighbors.
@@ -148,7 +169,46 @@ def depth_aliasing(depth: np.ndarray, atol: float = None, rtol: float = None, ke
148
  edge |= diff / depth > rtol
149
  return edge
150
 
151
- def point_to_normal(point: np.ndarray, mask: np.ndarray = None) -> np.ndarray:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  """
153
  Calculate normal map from point map. Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system.
154
 
@@ -189,12 +249,14 @@ def point_to_normal(point: np.ndarray, mask: np.ndarray = None) -> np.ndarray:
189
  normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-12)
190
 
191
  if has_mask:
192
- return normal, valid.any(axis=0)
 
 
193
  else:
194
  return normal
195
 
196
 
197
- def depth_to_normal(depth: np.ndarray, intrinsics: np.ndarray, mask: np.ndarray = None) -> np.ndarray:
198
  """
199
  Calculate normal map from depth map. Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system.
200
 
@@ -213,7 +275,7 @@ def depth_to_normal(depth: np.ndarray, intrinsics: np.ndarray, mask: np.ndarray
213
  uv = image_uv(width=width, height=height, dtype=np.float32)
214
  pts = transforms.unproject_cv(uv, depth, intrinsics=intrinsics, extrinsics=None)
215
 
216
- return point_to_normal(pts, mask)
217
 
218
  def interpolate(bary: np.ndarray, tri_id: np.ndarray, attr: np.ndarray, faces: np.ndarray) -> np.ndarray:
219
  """Interpolate with given barycentric coordinates and triangle indices
@@ -560,3 +622,18 @@ def camera_frustum(extrinsics: np.ndarray, intrinsics: np.ndarray, depth: float
560
  ], dtype=np.int32)
561
  return vertices, edges, faces
562
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
  from typing import *
3
  from numbers import Number
4
+ import warnings
5
+ import functools
6
 
7
  from ._helpers import batched
8
  from . import transforms
 
16
  'max_pool_2d',
17
  'max_pool_nd',
18
  'depth_edge',
19
+ 'normals_edge',
20
  'depth_aliasing',
21
  'interpolate',
22
  'image_scrcoord',
 
25
  'image_pixel',
26
  'image_mesh',
27
  'image_mesh_from_depth',
28
+ 'points_to_normals',
29
+ 'points_to_normals',
30
  'chessboard',
31
  'cube',
32
+ 'icosahedron',
33
  'square',
34
  'camera_frustum',
35
  'to4x4'
36
  ]
37
 
38
 
39
+ def no_runtime_warnings(fn):
40
+ """
41
+ Disable runtime warnings in numpy.
42
+ """
43
+ @functools.wraps(fn)
44
+ def wrapper(*args, **kwargs):
45
+ with warnings.catch_warnings():
46
+ warnings.simplefilter("ignore")
47
+ return fn(*args, **kwargs)
48
+ return wrapper
49
+
50
+
51
  def sliding_window_1d(x: np.ndarray, window_size: int, stride: int, axis: int = -1):
52
  """
53
  Return x view of the input array with x sliding window of the given kernel size and stride.
 
113
  return max_pool_nd(x, kernel_size, stride, padding, axis)
114
 
115
 
116
+ @no_runtime_warnings
117
  def depth_edge(depth: np.ndarray, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: np.ndarray = None) -> np.ndarray:
118
  """
119
+ Compute the edge mask from depth map. The edge is defined as the pixels whose neighbors have large difference in depth.
120
 
121
  Args:
122
  depth (np.ndarray): shape (..., height, width), linear depth map
 
134
  edge = np.zeros_like(depth, dtype=bool)
135
  if atol is not None:
136
  edge |= diff > atol
137
+
138
+ with warnings.catch_warnings():
139
+ warnings.simplefilter("ignore", category=RuntimeWarning)
140
+ if rtol is not None:
141
+ edge |= diff / depth > rtol
142
  return edge
143
 
144
 
145
+ @no_runtime_warnings
146
  def depth_aliasing(depth: np.ndarray, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: np.ndarray = None) -> np.ndarray:
147
  """
148
  Compute the map that indicates the aliasing of x depth map. The aliasing is defined as the pixels which neither close to the maximum nor the minimum of its neighbors.
 
169
  edge |= diff / depth > rtol
170
  return edge
171
 
172
+
173
+ @no_runtime_warnings
174
+ def normals_edge(normals: np.ndarray, tol: float, kernel_size: int = 3, mask: np.ndarray = None) -> np.ndarray:
175
+ """
176
+ Compute the edge mask from normal map.
177
+
178
+ Args:
179
+ normal (np.ndarray): shape (..., height, width, 3), normal map
180
+ tol (float): tolerance in degrees
181
+
182
+ Returns:
183
+ edge (np.ndarray): shape (..., height, width) of dtype torch.bool
184
+ """
185
+ assert normals.ndim >= 3 and normals.shape[-1] == 3, "normal should be of shape (..., height, width, 3)"
186
+ normals = normals / (np.linalg.norm(normals, axis=-1, keepdims=True) + 1e-12)
187
+
188
+ padding = kernel_size // 2
189
+ normals_window = sliding_window_2d(
190
+ np.pad(normals, (*([(0, 0)] * (normals.ndim - 3)), (padding, padding), (padding, padding), (0, 0)), mode='edge'),
191
+ window_size=kernel_size,
192
+ stride=1,
193
+ axis=(-3, -2)
194
+ )
195
+ if mask is None:
196
+ angle_diff = np.acos((normals[..., None, None] * normals_window).sum(axis=-3)).max(axis=(-2, -1))
197
+ else:
198
+ mask_window = sliding_window_2d(
199
+ np.pad(mask, (*([(0, 0)] * (mask.ndim - 3)), (padding, padding), (padding, padding)), mode='edge'),
200
+ window_size=kernel_size,
201
+ stride=1,
202
+ axis=(-3, -2)
203
+ )
204
+ angle_diff = np.where(mask_window, np.acos((normals[..., None, None] * normals_window).sum(axis=-3)), 0).max(axis=(-2, -1))
205
+
206
+ angle_diff = max_pool_2d(angle_diff, kernel_size, stride=1, padding=kernel_size // 2)
207
+ edge = angle_diff > np.deg2rad(tol)
208
+ return edge
209
+
210
+ @no_runtime_warnings
211
+ def points_to_normals(point: np.ndarray, mask: np.ndarray = None) -> np.ndarray:
212
  """
213
  Calculate normal map from point map. Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system.
214
 
 
249
  normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-12)
250
 
251
  if has_mask:
252
+ normal_mask = valid.any(axis=0)
253
+ normal = np.where(normal_mask[..., None], normal, 0)
254
+ return normal, normal_mask
255
  else:
256
  return normal
257
 
258
 
259
+ def depth_to_normals(depth: np.ndarray, intrinsics: np.ndarray, mask: np.ndarray = None) -> np.ndarray:
260
  """
261
  Calculate normal map from depth map. Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system.
262
 
 
275
  uv = image_uv(width=width, height=height, dtype=np.float32)
276
  pts = transforms.unproject_cv(uv, depth, intrinsics=intrinsics, extrinsics=None)
277
 
278
+ return points_to_normals(pts, mask)
279
 
280
  def interpolate(bary: np.ndarray, tri_id: np.ndarray, attr: np.ndarray, faces: np.ndarray) -> np.ndarray:
281
  """Interpolate with given barycentric coordinates and triangle indices
 
622
  ], dtype=np.int32)
623
  return vertices, edges, faces
624
 
625
+
626
+ def icosahedron():
627
+ A = (1 + 5 ** 0.5) / 2
628
+ vertices = np.array([
629
+ [0, 1, A], [0, -1, A], [0, 1, -A], [0, -1, -A],
630
+ [1, A, 0], [-1, A, 0], [1, -A, 0], [-1, -A, 0],
631
+ [A, 0, 1], [A, 0, -1], [-A, 0, 1], [-A, 0, -1]
632
+ ], dtype=np.float32)
633
+ faces = np.array([
634
+ [0, 1, 8], [0, 8, 4], [0, 4, 5], [0, 5, 10], [0, 10, 1],
635
+ [3, 2, 9], [3, 9, 6], [3, 6, 7], [3, 7, 11], [3, 11, 2],
636
+ [1, 6, 8], [8, 9, 4], [4, 2, 5], [5, 11, 10], [10, 7, 1],
637
+ [2, 4, 9], [9, 8, 6], [6, 1, 7], [7, 10, 11], [11, 5, 2]
638
+ ], dtype=np.int32)
639
+ return vertices, faces