from typing import Optional import numpy as np import torch as th import torch.nn.functional as F import torch.nn as nn from sklearn.neighbors import KDTree import logging logger = logging.getLogger(__name__) # NOTE: we need pytorch3d primarily for UV rasterization things from pytorch3d.renderer.mesh.rasterize_meshes import rasterize_meshes from pytorch3d.structures import Meshes from typing import Union, Optional, Tuple import trimesh from trimesh import Trimesh from trimesh.triangles import points_to_barycentric try: # pyre-fixme[21]: Could not find module `igl`. from igl import point_mesh_squared_distance # @manual # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. def closest_point(mesh, points): """Helper function that mimics trimesh.proximity.closest_point but uses IGL for faster queries.""" v = mesh.vertices vi = mesh.faces dist, face_idxs, p = point_mesh_squared_distance(points, v, vi) return p, dist, face_idxs except ImportError: from trimesh.proximity import closest_point def closest_point_barycentrics(v, vi, points): """Given a 3D mesh and a set of query points, return closest point barycentrics Args: v: np.array (float) [N, 3] mesh vertices vi: np.array (int) [N, 3] mesh triangle indices points: np.array (float) [M, 3] query points Returns: Tuple[approx, barys, interp_idxs, face_idxs] approx: [M, 3] approximated (closest) points on the mesh barys: [M, 3] barycentric weights that produce "approx" interp_idxs: [M, 3] vertex indices for barycentric interpolation face_idxs: [M] face indices for barycentric interpolation. interp_idxs = vi[face_idxs] """ mesh = Trimesh(vertices=v, faces=vi, process=False) p, _, face_idxs = closest_point(mesh, points) p = p.reshape((points.shape[0], 3)) face_idxs = face_idxs.reshape((points.shape[0],)) barys = points_to_barycentric(mesh.triangles[face_idxs], p) b0, b1, b2 = np.split(barys, 3, axis=1) interp_idxs = vi[face_idxs] v0 = v[interp_idxs[:, 0]] v1 = v[interp_idxs[:, 1]] v2 = v[interp_idxs[:, 2]] approx = b0 * v0 + b1 * v1 + b2 * v2 return approx, barys, interp_idxs, face_idxs def make_uv_face_index( vt: th.Tensor, vti: th.Tensor, uv_shape: Union[Tuple[int, int], int], flip_uv: bool = True, device: Optional[Union[str, th.device]] = None, ): """Compute a UV-space face index map identifying which mesh face contains each texel. For texels with no assigned triangle, the index will be -1.""" if isinstance(uv_shape, int): uv_shape = (uv_shape, uv_shape) uv_max_shape_ind = uv_shape.index(max(uv_shape)) uv_min_shape_ind = uv_shape.index(min(uv_shape)) uv_ratio = uv_shape[uv_max_shape_ind] / uv_shape[uv_min_shape_ind] if device is not None: if isinstance(device, str): dev = th.device(device) else: dev = device assert dev.type == "cuda" else: dev = th.device("cuda") vt = 1.0 - vt.clone() if flip_uv: vt = vt.clone() vt[:, 1] = 1 - vt[:, 1] vt_pix = 2.0 * vt.to(dev) - 1.0 vt_pix = th.cat([vt_pix, th.ones_like(vt_pix[:, 0:1])], dim=1) vt_pix[:, uv_min_shape_ind] *= uv_ratio meshes = Meshes(vt_pix[np.newaxis], vti[np.newaxis].to(dev)) with th.no_grad(): face_index, _, _, _ = rasterize_meshes( meshes, uv_shape, faces_per_pixel=1, z_clip_value=0.0, bin_size=0 ) face_index = face_index[0, ..., 0] return face_index def make_uv_vert_index( vt: th.Tensor, vi: th.Tensor, vti: th.Tensor, uv_shape: Union[Tuple[int, int], int], flip_uv: bool = True, ): """Compute a UV-space vertex index map identifying which mesh vertices comprise the triangle containing each texel. For texels with no assigned triangle, all indices will be -1. """ face_index_map = make_uv_face_index(vt, vti, uv_shape, flip_uv) vert_index_map = vi[face_index_map.clamp(min=0)] vert_index_map[face_index_map < 0] = -1 return vert_index_map.long() def bary_coords(points: th.Tensor, triangles: th.Tensor, eps: float = 1.0e-6): """Computes barycentric coordinates for a set of 2D query points given coordintes for the 3 vertices of the enclosing triangle for each point.""" x = points[:, 0] - triangles[2, :, 0] x1 = triangles[0, :, 0] - triangles[2, :, 0] x2 = triangles[1, :, 0] - triangles[2, :, 0] y = points[:, 1] - triangles[2, :, 1] y1 = triangles[0, :, 1] - triangles[2, :, 1] y2 = triangles[1, :, 1] - triangles[2, :, 1] denom = y2 * x1 - y1 * x2 n0 = y2 * x - x2 * y n1 = x1 * y - y1 * x # Small epsilon to prevent divide-by-zero error. denom = th.where(denom >= 0, denom.clamp(min=eps), denom.clamp(max=-eps)) bary_0 = n0 / denom bary_1 = n1 / denom bary_2 = 1.0 - bary_0 - bary_1 return th.stack((bary_0, bary_1, bary_2)) def make_uv_barys( vt: th.Tensor, vti: th.Tensor, uv_shape: Union[Tuple[int, int], int], flip_uv: bool = True, ): """Compute a UV-space barycentric map where each texel contains barycentric coordinates for that texel within its enclosing UV triangle. For texels with no assigned triangle, all 3 barycentric coordinates will be 0. """ if isinstance(uv_shape, int): uv_shape = (uv_shape, uv_shape) if flip_uv: # Flip here because texture coordinates in some of our topo files are # stored in OpenGL convention with Y=0 on the bottom of the texture # unlike numpy/torch arrays/tensors. vt = vt.clone() vt[:, 1] = 1 - vt[:, 1] face_index_map = make_uv_face_index(vt, vti, uv_shape, flip_uv=False) vti_map = vti.long()[face_index_map.clamp(min=0)] uv_max_shape_ind = uv_shape.index(max(uv_shape)) uv_min_shape_ind = uv_shape.index(min(uv_shape)) uv_ratio = uv_shape[uv_max_shape_ind] / uv_shape[uv_min_shape_ind] vt = vt.clone() vt = vt * 2 - 1 vt[:, uv_min_shape_ind] *= uv_ratio uv_tri_uvs = vt[vti_map].permute(2, 0, 1, 3) uv_grid = th.meshgrid( th.linspace(0.5, uv_shape[0] - 0.5, uv_shape[0]) / uv_shape[0], th.linspace(0.5, uv_shape[1] - 0.5, uv_shape[1]) / uv_shape[1], ) uv_grid = th.stack(uv_grid[::-1], dim=2).to(uv_tri_uvs) uv_grid = uv_grid * 2 - 1 uv_grid[..., uv_min_shape_ind] *= uv_ratio bary_map = bary_coords(uv_grid.view(-1, 2), uv_tri_uvs.view(3, -1, 2)) bary_map = bary_map.permute(1, 0).view(uv_shape[0], uv_shape[1], 3) bary_map[face_index_map < 0] = 0 return face_index_map, bary_map def index_image_impaint( index_image: th.Tensor, bary_image: Optional[th.Tensor] = None, distance_threshold=100.0, ): # getting the mask around the indexes? if len(index_image.shape) == 3: valid_index = (index_image != -1).any(dim=-1) elif len(index_image.shape) == 2: valid_index = index_image != -1 else: raise ValueError("`index_image` should be a [H,W] or [H,W,C] image") invalid_index = ~valid_index device = index_image.device valid_ij = th.stack(th.where(valid_index), dim=-1) invalid_ij = th.stack(th.where(invalid_index), dim=-1) lookup_valid = KDTree(valid_ij.cpu().numpy()) dists, idxs = lookup_valid.query(invalid_ij.cpu()) # TODO: try average? idxs = th.as_tensor(idxs, device=device)[..., 0] dists = th.as_tensor(dists, device=device)[..., 0] dist_mask = dists < distance_threshold invalid_border = th.zeros_like(invalid_index) invalid_border[invalid_index] = dist_mask invalid_src_ij = valid_ij[idxs][dist_mask] invalid_dst_ij = invalid_ij[dist_mask] index_image_imp = index_image.clone() index_image_imp[invalid_dst_ij[:, 0], invalid_dst_ij[:, 1]] = index_image[ invalid_src_ij[:, 0], invalid_src_ij[:, 1] ] if bary_image is not None: bary_image_imp = bary_image.clone() bary_image_imp[invalid_dst_ij[:, 0], invalid_dst_ij[:, 1]] = bary_image[ invalid_src_ij[:, 0], invalid_src_ij[:, 1] ] return index_image_imp, bary_image_imp return index_image_imp class GeometryModule(nn.Module): def __init__( self, v, vi, vt, vti, uv_size, v2uv: Optional[th.Tensor] = None, flip_uv=False, impaint=False, impaint_threshold=100.0, ): super().__init__() self.register_buffer("v", th.as_tensor(v)) self.register_buffer("vi", th.as_tensor(vi)) self.register_buffer("vt", th.as_tensor(vt)) self.register_buffer("vti", th.as_tensor(vti)) if v2uv is not None: self.register_buffer("v2uv", th.as_tensor(v2uv, dtype=th.int64)) # TODO: should we just pass topology here? # self.n_verts = v2uv.shape[0] self.n_verts = vi.max() + 1 self.uv_size = uv_size # TODO: can't we just index face_index? index_image = make_uv_vert_index( self.vt, self.vi, self.vti, uv_shape=uv_size, flip_uv=flip_uv ).cpu() face_index, bary_image = make_uv_barys( self.vt, self.vti, uv_shape=uv_size, flip_uv=flip_uv ) if impaint: if min(uv_size) >= 1024: logger.info( "impainting index image might take a while for sizes >= 1024" ) index_image, bary_image = index_image_impaint( index_image, bary_image, impaint_threshold ) # TODO: we can avoid doing this 2x face_index = index_image_impaint( face_index, distance_threshold=impaint_threshold ) self.register_buffer("index_image", index_image.cpu()) self.register_buffer("bary_image", bary_image.cpu()) self.register_buffer("face_index_image", face_index.cpu()) def render_index_images(self, uv_size, flip_uv=False, impaint=False): index_image = make_uv_vert_index( self.vt, self.vi, self.vti, uv_shape=uv_size, flip_uv=flip_uv ) face_image, bary_image = make_uv_barys( self.vt, self.vti, uv_shape=uv_size, flip_uv=flip_uv ) if impaint: index_image, bary_image = index_image_impaint( index_image, bary_image, ) return index_image, face_image, bary_image def vn(self, verts): return vert_normals(verts, self.vi[np.newaxis].to(th.long)) def to_uv(self, values): return values_to_uv(values, self.index_image, self.bary_image) def from_uv(self, values_uv): # TODO: we need to sample this return sample_uv(values_uv, self.vt, self.v2uv.to(th.long)) def rand_sample_3d_uv(self, count, uv_img): """ Sample a set of 3D points on the surface of mesh, return corresponding interpolated values in UV space. Args: count - num of 3D points to be sampled uv_img - the image in uv space to be sampled, e.g., texture """ _mesh = Trimesh(vertices=self.v.detach().cpu().numpy(), faces=self.vi.detach().cpu().numpy(), process=False) points, _ = trimesh.sample.sample_surface(_mesh, count) return self.sample_uv_from_3dpts(points, uv_img) def sample_uv_from_3dpts(self, points, uv_img): num_pts = points.shape[0] approx, barys, interp_idxs, face_idxs = closest_point_barycentrics(self.v.detach().cpu().numpy(), self.vi.detach().cpu().numpy(), points) interp_uv_coords = self.vt[interp_idxs, :] # [N, 3, 2] # do bary interp first to get interp_uv_coord in high-reso uv space target_uv_coords = th.sum(interp_uv_coords * th.from_numpy(barys)[..., None], dim=1).float() # then directly sample from uv space sampled_values = sample_uv(values_uv=uv_img.permute(2, 0, 1)[None, ...], uv_coords=target_uv_coords) # [1, count, c] approx_values = sampled_values[0].reshape(num_pts, uv_img.shape[2]) return approx_values.numpy(), points def vert_sample_uv(self, uv_img): count = self.v.shape[0] points = self.v.detach().cpu().numpy() approx_values, _ = self.sample_uv_from_3dpts(points, uv_img) return approx_values def sample_uv( values_uv, uv_coords, v2uv: Optional[th.Tensor] = None, mode: str = "bilinear", align_corners: bool = True, flip_uvs: bool = False, ): batch_size = values_uv.shape[0] if flip_uvs: uv_coords = uv_coords.clone() uv_coords[:, 1] = 1.0 - uv_coords[:, 1] # uv_coords_norm is [1, N, 1, 2] afterwards uv_coords_norm = (uv_coords * 2.0 - 1.0)[np.newaxis, :, np.newaxis].expand( batch_size, -1, -1, -1 ) # uv_shape = values_uv.shape[-2:] # uv_max_shape_ind = uv_shape.index(max(uv_shape)) # uv_min_shape_ind = uv_shape.index(min(uv_shape)) # uv_ratio = uv_shape[uv_max_shape_ind] / uv_shape[uv_min_shape_ind] # uv_coords_norm[..., uv_min_shape_ind] *= uv_ratio values = ( F.grid_sample(values_uv, uv_coords_norm, align_corners=align_corners, mode=mode) .squeeze(-1) .permute((0, 2, 1)) ) if v2uv is not None: values_duplicate = values[:, v2uv] values = values_duplicate.mean(2) return values def values_to_uv(values, index_img, bary_img): uv_size = index_img.shape index_mask = th.all(index_img != -1, dim=-1) idxs_flat = index_img[index_mask].to(th.int64) bary_flat = bary_img[index_mask].to(th.float32) # NOTE: here we assume values_flat = th.sum(values[:, idxs_flat].permute(0, 3, 1, 2) * bary_flat, dim=-1) values_uv = th.zeros( values.shape[0], values.shape[-1], uv_size[0], uv_size[1], dtype=values.dtype, device=values.device, ) values_uv[:, :, index_mask] = values_flat return values_uv def face_normals(v, vi, eps: float = 1e-5): pts = v[:, vi] v0 = pts[:, :, 1] - pts[:, :, 0] v1 = pts[:, :, 2] - pts[:, :, 0] n = th.cross(v0, v1, dim=-1) norm = th.norm(n, dim=-1, keepdim=True) norm[norm < eps] = 1 n /= norm return n def vert_normals(v, vi, eps: float = 1.0e-5): fnorms = face_normals(v, vi) fnorms = fnorms[:, :, None].expand(-1, -1, 3, -1).reshape(fnorms.shape[0], -1, 3) vi_flat = vi.view(1, -1).expand(v.shape[0], -1) vnorms = th.zeros_like(v) for j in range(3): vnorms[..., j].scatter_add_(1, vi_flat, fnorms[..., j]) norm = th.norm(vnorms, dim=-1, keepdim=True) norm[norm < eps] = 1 vnorms /= norm return vnorms def compute_view_cos(verts, faces, camera_pos): vn = F.normalize(vert_normals(verts, faces), dim=-1) v2c = F.normalize(verts - camera_pos[:, np.newaxis], dim=-1) return th.einsum("bnd,bnd->bn", vn, v2c) def compute_tbn(geom, vt, vi, vti): """Computes tangent, bitangent, and normal vectors given a mesh. Args: geom: [N, n_verts, 3] th.Tensor Vertex positions. vt: [n_uv_coords, 2] th.Tensor UV coordinates. vi: [..., 3] th.Tensor Face vertex indices. vti: [..., 3] th.Tensor Face UV indices. Returns: [..., 3] th.Tensors for T, B, N. """ v0 = geom[:, vi[..., 0]] v1 = geom[:, vi[..., 1]] v2 = geom[:, vi[..., 2]] vt0 = vt[vti[..., 0]] vt1 = vt[vti[..., 1]] vt2 = vt[vti[..., 2]] v01 = v1 - v0 v02 = v2 - v0 vt01 = vt1 - vt0 vt02 = vt2 - vt0 f = 1.0 / ( vt01[None, ..., 0] * vt02[None, ..., 1] - vt01[None, ..., 1] * vt02[None, ..., 0] ) tangent = f[..., None] * th.stack( [ v01[..., 0] * vt02[None, ..., 1] - v02[..., 0] * vt01[None, ..., 1], v01[..., 1] * vt02[None, ..., 1] - v02[..., 1] * vt01[None, ..., 1], v01[..., 2] * vt02[None, ..., 1] - v02[..., 2] * vt01[None, ..., 1], ], dim=-1, ) tangent = F.normalize(tangent, dim=-1) normal = F.normalize(th.cross(v01, v02, dim=3), dim=-1) bitangent = F.normalize(th.cross(tangent, normal, dim=3), dim=-1) return tangent, bitangent, normal def compute_v2uv(n_verts, vi, vti, n_max=4): """Computes mapping from vertex indices to texture indices. Args: vi: [F, 3], triangles vti: [F, 3], texture triangles n_max: int, max number of texture locations Returns: [n_verts, n_max], texture indices """ v2uv_dict = {} for i_v, i_uv in zip(vi.reshape(-1), vti.reshape(-1)): v2uv_dict.setdefault(i_v, set()).add(i_uv) assert len(v2uv_dict) == n_verts v2uv = np.zeros((n_verts, n_max), dtype=np.int32) for i in range(n_verts): vals = sorted(list(v2uv_dict[i])) v2uv[i, :] = vals[0] v2uv[i, : len(vals)] = np.array(vals) return v2uv def compute_neighbours(n_verts, vi, n_max_values=10): """Computes first-ring neighbours given vertices and faces.""" n_vi = vi.shape[0] adj = {i: set() for i in range(n_verts)} for i in range(n_vi): for idx in vi[i]: adj[idx] |= set(vi[i]) - set([idx]) nbs_idxs = np.tile(np.arange(n_verts)[:, np.newaxis], (1, n_max_values)) nbs_weights = np.zeros((n_verts, n_max_values), dtype=np.float32) for idx in range(n_verts): n_values = min(len(adj[idx]), n_max_values) nbs_idxs[idx, :n_values] = np.array(list(adj[idx]))[:n_values] nbs_weights[idx, :n_values] = -1.0 / n_values return nbs_idxs, nbs_weights def make_postex(v, idxim, barim): return ( barim[None, :, :, 0, None] * v[:, idxim[:, :, 0]] + barim[None, :, :, 1, None] * v[:, idxim[:, :, 1]] + barim[None, :, :, 2, None] * v[:, idxim[:, :, 2]] ).permute(0, 3, 1, 2) def matrix_to_axisangle(r): th = th.arccos(0.5 * (r[..., 0, 0] + r[..., 1, 1] + r[..., 2, 2] - 1.0))[..., None] vec = ( 0.5 * th.stack( [ r[..., 2, 1] - r[..., 1, 2], r[..., 0, 2] - r[..., 2, 0], r[..., 1, 0] - r[..., 0, 1], ], dim=-1, ) / th.sin(th) ) return th, vec def axisangle_to_matrix(rvec): theta = th.sqrt(1e-5 + th.sum(rvec**2, dim=-1)) rvec = rvec / theta[..., None] costh = th.cos(theta) sinth = th.sin(theta) return th.stack( ( th.stack( ( rvec[..., 0] ** 2 + (1.0 - rvec[..., 0] ** 2) * costh, rvec[..., 0] * rvec[..., 1] * (1.0 - costh) - rvec[..., 2] * sinth, rvec[..., 0] * rvec[..., 2] * (1.0 - costh) + rvec[..., 1] * sinth, ), dim=-1, ), th.stack( ( rvec[..., 0] * rvec[..., 1] * (1.0 - costh) + rvec[..., 2] * sinth, rvec[..., 1] ** 2 + (1.0 - rvec[..., 1] ** 2) * costh, rvec[..., 1] * rvec[..., 2] * (1.0 - costh) - rvec[..., 0] * sinth, ), dim=-1, ), th.stack( ( rvec[..., 0] * rvec[..., 2] * (1.0 - costh) - rvec[..., 1] * sinth, rvec[..., 1] * rvec[..., 2] * (1.0 - costh) + rvec[..., 0] * sinth, rvec[..., 2] ** 2 + (1.0 - rvec[..., 2] ** 2) * costh, ), dim=-1, ), ), dim=-2, ) def rotation_interp(r0, r1, alpha): r0a = r0.view(-1, 3, 3) r1a = r1.view(-1, 3, 3) r = th.bmm(r0a.permute(0, 2, 1), r1a).view_as(r0) th, rvec = matrix_to_axisangle(r) rvec = rvec * (alpha * th) r = axisangle_to_matrix(rvec) return th.bmm(r0a, r.view(-1, 3, 3)).view_as(r0) def convert_camera_parameters(Rt, K): R = Rt[:, :3, :3] t = -R.permute(0, 2, 1).bmm(Rt[:, :3, 3].unsqueeze(2)).squeeze(2) return dict( campos=t, camrot=R, focal=K[:, :2, :2], princpt=K[:, :2, 2], ) def project_points_multi(p, Rt, K, normalize=False, size=None): """Project a set of 3D points into multiple cameras with a pinhole model. Args: p: [B, N, 3], input 3D points in world coordinates Rt: [B, NC, 3, 4], extrinsics (where NC is the number of cameras to project to) K: [B, NC, 3, 3], intrinsics normalize: bool, whether to normalize coordinates to [-1.0, 1.0] Returns: tuple: - [B, NC, N, 2] - projected points - [B, NC, N] - their """ B, N = p.shape[:2] NC = Rt.shape[1] Rt = Rt.reshape(B * NC, 3, 4) K = K.reshape(B * NC, 3, 3) # [B, N, 3] -> [B * NC, N, 3] p = p[:, np.newaxis].expand(-1, NC, -1, -1).reshape(B * NC, -1, 3) p_cam = p @ Rt[:, :3, :3].transpose(-2, -1) + Rt[:, :3, 3][:, np.newaxis] p_pix = p_cam @ K.transpose(-2, -1) p_depth = p_pix[:, :, 2:] p_pix = (p_pix[..., :2] / p_depth).reshape(B, NC, N, 2) p_depth = p_depth.reshape(B, NC, N) if normalize: assert size is not None h, w = size p_pix = ( 2.0 * p_pix / th.as_tensor([w, h], dtype=th.float32, device=p.device) - 1.0 ) return p_pix, p_depth