| | import torch |
| | import numpy as np |
| | import math |
| | import torch.nn as nn |
| |
|
| | from pytorch3d.structures import Meshes |
| | from pytorch3d.io import load_obj |
| | from pytorch3d.renderer.mesh import rasterize_meshes |
| | from pytorch3d.ops import mesh_face_areas_normals |
| |
|
| | |
| |
|
| | def gen_tritex(vt: np.ndarray, vi: np.ndarray, vti: np.ndarray, texsize: int): |
| | """ |
| | Copied from MVP |
| | Create 3 texture maps containing the vertex indices, texture vertex |
| | indices, and barycentric coordinates |
| | |
| | Parameters |
| | ---------- |
| | vt: uv coordinates of texels |
| | vi: triangle list mapping into vertex positions |
| | vti: triangle list mapping into texel coordinates |
| | texsize: Size of the generated maps |
| | """ |
| | |
| | vt = vt[:, :2] |
| |
|
| | vt = np.array(vt, dtype=np.float32) |
| | vi = np.array(vi, dtype=np.int32) |
| | vti = np.array(vti, dtype=np.int32) |
| | ntris = vi.shape[0] |
| |
|
| | texu, texv = np.meshgrid( |
| | (np.arange(texsize) + 0.5) / texsize, |
| | (np.arange(texsize) + 0.5) / texsize) |
| | texuv = np.stack((texu, texv), axis=-1) |
| |
|
| | vt = vt[vti] |
| |
|
| | viim = np.zeros((texsize, texsize, 3), dtype=np.int32) |
| | vtiim = np.zeros((texsize, texsize, 3), dtype=np.int32) |
| | baryim = np.zeros((texsize, texsize, 3), dtype=np.float32) |
| |
|
| | for i in list(range(ntris))[::-1]: |
| | bbox = ( |
| | max(0, int(min(vt[i, 0, 0], min(vt[i, 1, 0], vt[i, 2, 0])) * texsize) - 1), |
| | min(texsize, int(max(vt[i, 0, 0], max(vt[i, 1, 0], vt[i, 2, 0])) * texsize) + 2), |
| | max(0, int(min(vt[i, 0, 1], min(vt[i, 1, 1], vt[i, 2, 1])) * texsize) - 1), |
| | min(texsize, int(max(vt[i, 0, 1], max(vt[i, 1, 1], vt[i, 2, 1])) * texsize) + 2)) |
| | v0 = vt[None, None, i, 1, :] - vt[None, None, i, 0, :] |
| | v1 = vt[None, None, i, 2, :] - vt[None, None, i, 0, :] |
| | v2 = texuv[bbox[2]:bbox[3], bbox[0]:bbox[1], :] - vt[None, None, i, 0, :] |
| | d00 = np.sum(v0 * v0, axis=-1) |
| | d01 = np.sum(v0 * v1, axis=-1) |
| | d11 = np.sum(v1 * v1, axis=-1) |
| | d20 = np.sum(v2 * v0, axis=-1) |
| | d21 = np.sum(v2 * v1, axis=-1) |
| | denom = d00 * d11 - d01 * d01 |
| |
|
| | if denom != 0.: |
| | baryv = (d11 * d20 - d01 * d21) / denom |
| | baryw = (d00 * d21 - d01 * d20) / denom |
| | baryu = 1. - baryv - baryw |
| |
|
| | baryim[bbox[2]:bbox[3], bbox[0]:bbox[1], :] = np.where( |
| | ((baryu >= 0.) & (baryv >= 0.) & (baryw >= 0.))[:, :, None], |
| | np.stack((baryu, baryv, baryw), axis=-1), |
| | baryim[bbox[2]:bbox[3], bbox[0]:bbox[1], :]) |
| | viim[bbox[2]:bbox[3], bbox[0]:bbox[1], :] = np.where( |
| | ((baryu >= 0.) & (baryv >= 0.) & (baryw >= 0.))[:, :, None], |
| | np.stack((vi[i, 0], vi[i, 1], vi[i, 2]), axis=-1), |
| | viim[bbox[2]:bbox[3], bbox[0]:bbox[1], :]) |
| | vtiim[bbox[2]:bbox[3], bbox[0]:bbox[1], :] = np.where( |
| | ((baryu >= 0.) & (baryv >= 0.) & (baryw >= 0.))[:, :, None], |
| | np.stack((vti[i, 0], vti[i, 1], vti[i, 2]), axis=-1), |
| | vtiim[bbox[2]:bbox[3], bbox[0]:bbox[1], :]) |
| |
|
| | return torch.LongTensor(viim), torch.Tensor(vtiim), torch.Tensor(baryim) |
| |
|
| |
|
| | |
| | class Pytorch3dRasterizer(nn.Module): |
| | def __init__(self, image_size=224): |
| | """ |
| | use fixed raster_settings for rendering faces |
| | """ |
| | super().__init__() |
| | raster_settings = { |
| | 'image_size': image_size, |
| | 'blur_radius': 0.0, |
| | 'faces_per_pixel': 1, |
| | 'bin_size': None, |
| | 'max_faces_per_bin': None, |
| | 'perspective_correct': False, |
| | 'cull_backfaces': True |
| | } |
| | |
| | self.raster_settings = raster_settings |
| |
|
| | def forward(self, vertices, faces, h=None, w=None): |
| | fixed_vertices = vertices.clone() |
| | fixed_vertices[...,:2] = -fixed_vertices[...,:2] |
| | raster_settings = self.raster_settings |
| | if h is None and w is None: |
| | image_size = raster_settings['image_size'] |
| | else: |
| | image_size = [h, w] |
| | if h>w: |
| | fixed_vertices[..., 1] = fixed_vertices[..., 1]*h/w |
| | else: |
| | fixed_vertices[..., 0] = fixed_vertices[..., 0]*w/h |
| | |
| | meshes_screen = Meshes(verts=fixed_vertices.float(), faces=faces.long()) |
| | pix_to_face, zbuf, bary_coords, dists = rasterize_meshes( |
| | meshes_screen, |
| | image_size=image_size, |
| | blur_radius=raster_settings['blur_radius'], |
| | faces_per_pixel=raster_settings['faces_per_pixel'], |
| | bin_size=raster_settings['bin_size'], |
| | max_faces_per_bin=raster_settings['max_faces_per_bin'], |
| | perspective_correct=raster_settings['perspective_correct'], |
| | cull_backfaces=raster_settings['cull_backfaces'] |
| | ) |
| |
|
| | return pix_to_face, bary_coords |
| | |
| | |
| |
|
| | |
| | def face_vertices(vertices, faces): |
| | """ |
| | Indexing the coordinates of the three vertices on each face. |
| | |
| | Args: |
| | vertices: [bs, V, 3] |
| | faces: [bs, F, 3] |
| | |
| | Return: |
| | face_to_vertices: [bs, F, 3, 3] |
| | """ |
| | assert (vertices.ndimension() == 3) |
| | assert (faces.ndimension() == 3) |
| | |
| | assert (vertices.shape[2] == 3) |
| | assert (faces.shape[2] == 3) |
| |
|
| | bs, nv = vertices.shape[:2] |
| | bs, nf = faces.shape[:2] |
| | device = vertices.device |
| | faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None] |
| | vertices = vertices.reshape((bs * nv, 3)) |
| | |
| | return vertices[faces.long()] |
| |
|
| | def uniform_sampling_barycoords( |
| | num_points: int, |
| | tex_coord: torch.Tensor, |
| | uv_faces: torch.Tensor, |
| | d_size: float=1.0, |
| | strict: bool=False, |
| | use_mask: bool=True, |
| | ): |
| | """ |
| | Uniformly sampling barycentric coordinates using the rasterizer. |
| | |
| | Args: |
| | num_points: int sampling points number |
| | tex_coord: [5150, 2] UV coords for each vert |
| | uv_faces: [F,3] UV faces to UV coords index |
| | d_size: const to control sampling points number |
| | use_mask: use mask to mask valid points |
| | Returns: |
| | face_index [num_points] save which face each bary_coords belongs to |
| | bary_coords [num_points, 3] |
| | """ |
| | |
| | uv_size = int(math.sqrt(num_points) * d_size) |
| | uv_rasterizer = Pytorch3dRasterizer(uv_size) |
| |
|
| | tex_coord = tex_coord[None, ...] |
| | uv_faces = uv_faces[None, ...] |
| |
|
| | tex_coord_ = torch.cat([tex_coord, tex_coord[:,:,0:1]*0.+1.], -1) |
| | tex_coord_ = tex_coord_ * 2 - 1 |
| | tex_coord_[...,1] = - tex_coord_[...,1] |
| |
|
| | pix_to_face, bary_coords = uv_rasterizer(tex_coord_.expand(1, -1, -1), uv_faces.expand(1, -1, -1)) |
| | mask = (pix_to_face == -1) |
| |
|
| | if use_mask: |
| | face_index = pix_to_face[~mask] |
| | bary_coords = bary_coords[~mask] |
| | else: |
| | return pix_to_face, bary_coords |
| |
|
| | cur_n = face_index.shape[0] |
| |
|
| | |
| | if strict: |
| | if cur_n < num_points: |
| | pad_size = num_points - cur_n |
| | new_face_index = face_index[torch.randint(0, cur_n, (pad_size,))] |
| | new_bary_coords = torch.rand((pad_size, 3), device=bary_coords.device) |
| | new_bary_coords = new_bary_coords / new_bary_coords.sum(dim=-1, keepdim=True) |
| | face_index = torch.cat([face_index, new_face_index], dim=0) |
| | bary_coords = torch.cat([bary_coords, new_bary_coords], dim=0) |
| | elif cur_n > num_points: |
| | face_index = face_index[:num_points] |
| | bary_coords = bary_coords[:num_points] |
| |
|
| | return face_index, bary_coords |
| |
|
| | def random_sampling_barycoords( |
| | num_points: int, |
| | vertices: torch.Tensor, |
| | faces: torch.Tensor |
| | ): |
| | """ |
| | Randomly sampling barycentric coordinates using the rasterizer. |
| | |
| | Args: |
| | num_points: int sampling points number |
| | vertices: [V, 3] |
| | faces: [F,3] |
| | Returns: |
| | face_index [num_points] save which face each bary_coords belongs to |
| | bary_coords [num_points, 3] |
| | """ |
| |
|
| | areas, _ = mesh_face_areas_normals(vertices.squeeze(0), faces) |
| |
|
| | g1 = torch.Generator(device=vertices.device) |
| | g1.manual_seed(0) |
| |
|
| | face_index = areas.multinomial( |
| | num_points, replacement=True, generator=g1 |
| | ) |
| |
|
| | uvw = torch.rand((face_index.shape[0], 3), device=vertices.device) |
| | bary_coords = uvw / uvw.sum(dim=-1, keepdim=True) |
| |
|
| | return face_index, bary_coords |
| |
|
| | def reweight_verts_by_barycoords( |
| | verts: torch.Tensor, |
| | faces: torch.Tensor, |
| | face_index: torch.Tensor, |
| | bary_coords: torch.Tensor, |
| | ): |
| | """ |
| | Reweights the vertices based on the barycentric coordinates for each face. |
| | |
| | Args: |
| | verts: [bs, V, 3]. |
| | faces: [F, 3] |
| | face_index: [N]. |
| | bary_coords: [N, 3]. |
| | |
| | Returns: |
| | Reweighted vertex positions of shape [bs, N, 3]. |
| | """ |
| | |
| | |
| | B = verts.shape[0] |
| |
|
| | face_verts = face_vertices(verts, faces.expand(B, -1, -1)) |
| | |
| | N = face_index.shape[0] |
| | face_index_3 = face_index.view(1, N, 1, 1).expand(B, N, 3, 3) |
| | position_vals = face_verts.gather(1, face_index_3) |
| | |
| | position_vals = (bary_coords[..., None] * position_vals).sum(dim = -2) |
| |
|
| | return position_vals |
| |
|
| | def reweight_uvcoords_by_barycoords( |
| | uvcoords: torch.Tensor, |
| | uvfaces: torch.Tensor, |
| | face_index: torch.Tensor, |
| | bary_coords: torch.Tensor, |
| | ): |
| | """ |
| | Reweights the UV coordinates based on the barycentric coordinates for each face. |
| | |
| | Args: |
| | uvcoords: [bs, V', 2]. |
| | uvfaces: [F, 3]. |
| | face_index: [N]. |
| | bary_coords: [N, 3]. |
| | |
| | Returns: |
| | Reweighted UV coordinates, shape [bs, N, 2]. |
| | """ |
| |
|
| | |
| | num_v = uvcoords.shape[0] |
| | uvcoords = torch.cat([uvcoords, torch.ones((num_v, 1)).to(uvcoords.device)], dim=1) |
| | |
| | uvcoords = uvcoords[None, ...] |
| | face_verts = face_vertices(uvcoords, uvfaces.expand(1, -1, -1)) |
| | |
| | N = face_index.shape[0] |
| | face_index_3 = face_index.view(1, N, 1, 1).expand(1, N, 3, 3) |
| | position_vals = face_verts.gather(1, face_index_3) |
| | |
| | position_vals = (bary_coords[..., None] * position_vals).sum(dim = -2) |
| |
|
| | return position_vals |
| |
|
| | |
| | def get_shell_verts_from_base( |
| | template_verts: torch.Tensor, |
| | template_faces: torch.Tensor, |
| | offset_len: float, |
| | num_shells: int, |
| | deflat = False, |
| | ): |
| | """ |
| | Generates shell vertices by offsetting the original mesh's vertices along their normals. |
| | |
| | Args: |
| | template_verts: [bs, V, 3]. |
| | template_faces: [F, 3]. |
| | offset_len: Positive number specifying the offset length for generating shells. |
| | num_shells: The number of shells to generate. |
| | deflat: If True, performs a deflation process. Defaults to False. |
| | |
| | Returns: |
| | shell verts: [bs, num_shells, n, 3] |
| | """ |
| | out_offset_len = offset_len |
| |
|
| | if deflat: |
| | in_offset_len = offset_len |
| |
|
| | batch_size = template_verts.shape[0] |
| | mesh = Meshes( |
| | verts=template_verts, faces=template_faces[None].repeat(batch_size, 1, 1) |
| | ) |
| | |
| | vertex_normal = mesh.verts_normals_padded() |
| | |
| |
|
| | if deflat: |
| | n_inflated_shells = num_shells//2 + 1 |
| | else: |
| | n_inflated_shells = num_shells |
| | |
| | linscale = torch.linspace( |
| | out_offset_len, |
| | 0, |
| | n_inflated_shells, |
| | device=template_verts.device, |
| | dtype=template_verts.dtype, |
| | ) |
| | offset = linscale.reshape(1,n_inflated_shells, 1, 1) * vertex_normal[:, None] |
| | |
| | if deflat: |
| | linscale = torch.linspace(0, -in_offset_len, num_shells - n_inflated_shells + 1, device=template_verts.device, dtype=template_verts.dtype)[1:] |
| | offset_in = linscale.reshape(1, -1, 1, 1) * vertex_normal[:, None] |
| | offset = torch.cat([offset, offset_in], dim=1) |
| |
|
| | verts = template_verts[:, None] + offset |
| | assert verts.isfinite().all() |
| | return verts |