from typing import * import torch import nvdiffrast.torch as dr from . import utils, transforms, mesh from ._helpers import batched __all__ = [ 'RastContext', 'rasterize_triangle_faces', 'warp_image_by_depth', 'warp_image_by_forward_flow', ] class RastContext: """ Create a rasterization context. Nothing but a wrapper of nvdiffrast.torch.RasterizeCudaContext or nvdiffrast.torch.RasterizeGLContext. """ def __init__(self, nvd_ctx: Union[dr.RasterizeCudaContext, dr.RasterizeGLContext] = None, *, backend: Literal['cuda', 'gl'] = 'gl', device: Union[str, torch.device] = None): import nvdiffrast.torch as dr if nvd_ctx is not None: self.nvd_ctx = nvd_ctx return if backend == 'gl': self.nvd_ctx = dr.RasterizeGLContext(device=device) elif backend == 'cuda': self.nvd_ctx = dr.RasterizeCudaContext(device=device) else: raise ValueError(f'Unknown backend: {backend}') def rasterize_triangle_faces( ctx: RastContext, vertices: torch.Tensor, faces: torch.Tensor, attr: torch.Tensor, width: int, height: int, model: torch.Tensor = None, view: torch.Tensor = None, projection: torch.Tensor = None, antialiasing: Union[bool, List[int]] = True, diff_attrs: Union[None, List[int]] = None, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ Rasterize a mesh with vertex attributes. Args: ctx (GLContext): rasterizer context vertices (np.ndarray): (B, N, 2 or 3 or 4) faces (torch.Tensor): (T, 3) attr (torch.Tensor): (B, N, C) width (int): width of the output image height (int): height of the output image model (torch.Tensor, optional): ([B,] 4, 4) model matrix. Defaults to None (identity). view (torch.Tensor, optional): ([B,] 4, 4) view matrix. Defaults to None (identity). projection (torch.Tensor, optional): ([B,] 4, 4) projection matrix. Defaults to None (identity). antialiasing (Union[bool, List[int]], optional): whether to perform antialiasing. Defaults to True. If a list of indices is provided, only those channels will be antialiased. diff_attrs (Union[None, List[int]], optional): indices of attributes to compute screen-space derivatives. Defaults to None. Returns: image: (torch.Tensor): (B, C, H, W) depth: (torch.Tensor): (B, H, W) screen space depth, ranging from 0 (near) to 1. (far) NOTE: Empty pixels will have depth 1., i.e. far plane. """ assert vertices.ndim == 3 assert faces.ndim == 2 if vertices.shape[-1] == 2: vertices = torch.cat([vertices, torch.zeros_like(vertices[..., :1]), torch.ones_like(vertices[..., :1])], dim=-1) elif vertices.shape[-1] == 3: vertices = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1) elif vertices.shape[-1] == 4: pass else: raise ValueError(f'Wrong shape of vertices: {vertices.shape}') mvp = projection if projection is not None else torch.eye(4).to(vertices) if view is not None: mvp = mvp @ view if model is not None: mvp = mvp @ model pos_clip = vertices @ mvp.transpose(-1, -2) faces = faces.contiguous() attr = attr.contiguous() rast_out, rast_db = dr.rasterize(ctx.nvd_ctx, pos_clip, faces, resolution=[height, width], grad_db=True) image, image_dr = dr.interpolate(attr, rast_out, faces, rast_db, diff_attrs=diff_attrs) if antialiasing == True: image = dr.antialias(image, rast_out, pos_clip, faces) elif isinstance(antialiasing, list): aa_image = dr.antialias(image[..., antialiasing], rast_out, pos_clip, faces) image[..., antialiasing] = aa_image image = image.flip(1).permute(0, 3, 1, 2) depth = rast_out[..., 2].flip(1) depth = (depth * 0.5 + 0.5) * (depth > 0).float() + (depth == 0).float() if diff_attrs is not None: image_dr = image_dr.flip(1).permute(0, 3, 1, 2) return image, depth, image_dr return image, depth def texture( ctx: RastContext, uv: torch.Tensor, uv_da: torch.Tensor, texture: torch.Tensor, ) -> torch.Tensor: dr.texture(ctx.nvd_ctx, uv, texture) def warp_image_by_depth( ctx: RastContext, depth: torch.FloatTensor, image: torch.FloatTensor = None, mask: torch.BoolTensor = None, width: int = None, height: int = None, *, extrinsics_src: torch.FloatTensor = None, extrinsics_tgt: torch.FloatTensor = None, intrinsics_src: torch.FloatTensor = None, intrinsics_tgt: torch.FloatTensor = None, near: float = 0.1, far: float = 100.0, antialiasing: bool = True, backslash: bool = False, padding: int = 0, return_uv: bool = False, return_dr: bool = False, ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.BoolTensor, Optional[torch.FloatTensor], Optional[torch.FloatTensor]]: """ Warp image by depth. NOTE: if batch size is 1, image mesh will be triangulated aware of the depth, yielding less distorted results. Otherwise, image mesh will be triangulated simply for batch rendering. Args: ctx (Union[dr.RasterizeCudaContext, dr.RasterizeGLContext]): rasterization context depth (torch.Tensor): (B, H, W) linear depth image (torch.Tensor): (B, C, H, W). None to use image space uv. Defaults to None. width (int, optional): width of the output image. None to use the same as depth. Defaults to None. height (int, optional): height of the output image. Defaults the same as depth.. extrinsics_src (torch.Tensor, optional): (B, 4, 4) extrinsics matrix for source. None to use identity. Defaults to None. extrinsics_tgt (torch.Tensor, optional): (B, 4, 4) extrinsics matrix for target. None to use identity. Defaults to None. intrinsics_src (torch.Tensor, optional): (B, 3, 3) intrinsics matrix for source. None to use the same as target. Defaults to None. intrinsics_tgt (torch.Tensor, optional): (B, 3, 3) intrinsics matrix for target. None to use the same as source. Defaults to None. near (float, optional): near plane. Defaults to 0.1. far (float, optional): far plane. Defaults to 100.0. antialiasing (bool, optional): whether to perform antialiasing. Defaults to True. backslash (bool, optional): whether to use backslash triangulation. Defaults to False. padding (int, optional): padding of the image. Defaults to 0. return_uv (bool, optional): whether to return the uv. Defaults to False. return_dr (bool, optional): whether to return the image-space derivatives of uv. Defaults to False. Returns: image: (torch.FloatTensor): (B, C, H, W) rendered image depth: (torch.FloatTensor): (B, H, W) linear depth, ranging from 0 to inf mask: (torch.BoolTensor): (B, H, W) mask of valid pixels uv: (torch.FloatTensor): (B, 2, H, W) image-space uv dr: (torch.FloatTensor): (B, 4, H, W) image-space derivatives of uv """ assert depth.ndim == 3 batch_size = depth.shape[0] if width is None: width = depth.shape[-1] if height is None: height = depth.shape[-2] if image is not None: assert image.shape[-2:] == depth.shape[-2:], f'Shape of image {image.shape} does not match shape of depth {depth.shape}' if extrinsics_src is None: extrinsics_src = torch.eye(4).to(depth) if extrinsics_tgt is None: extrinsics_tgt = torch.eye(4).to(depth) if intrinsics_src is None: intrinsics_src = intrinsics_tgt if intrinsics_tgt is None: intrinsics_tgt = intrinsics_src assert all(x is not None for x in [extrinsics_src, extrinsics_tgt, intrinsics_src, intrinsics_tgt]), "Make sure you have provided all the necessary camera parameters." view_tgt = transforms.extrinsics_to_view(extrinsics_tgt) perspective_tgt = transforms.intrinsics_to_perspective(intrinsics_tgt, near=near, far=far) if padding > 0: uv, faces = utils.image_mesh(width=width+2, height=height+2) uv = (uv - 1 / (width + 2)) * ((width + 2) / width) uv_ = uv.clone().reshape(height+2, width+2, 2) uv_[0, :, 1] -= padding / height uv_[-1, :, 1] += padding / height uv_[:, 0, 0] -= padding / width uv_[:, -1, 0] += padding / width uv_ = uv_.reshape(-1, 2) depth = torch.nn.functional.pad(depth, [1, 1, 1, 1], mode='replicate') if image is not None: image = torch.nn.functional.pad(image, [1, 1, 1, 1], mode='replicate') uv, uv_, faces = uv.to(depth.device), uv_.to(depth.device), faces.to(depth.device) pts = transforms.unproject_cv( uv_, depth.flatten(-2, -1), extrinsics_src, intrinsics_src, ) else: uv, faces = utils.image_mesh(width=depth.shape[-1], height=depth.shape[-2]) if mask is not None: depth = torch.where(mask, depth, torch.tensor(far, dtype=depth.dtype, device=depth.device)) uv, faces = uv.to(depth.device), faces.to(depth.device) pts = transforms.unproject_cv( uv, depth.flatten(-2, -1), extrinsics_src, intrinsics_src, ) # triangulate if batch_size == 1: faces = mesh.triangulate(faces, vertices=pts[0]) else: faces = mesh.triangulate(faces, backslash=backslash) # rasterize attributes diff_attrs = None if image is not None: attr = image.permute(0, 2, 3, 1).flatten(1, 2) if return_dr or return_uv: if return_dr: diff_attrs = [image.shape[1], image.shape[1]+1] if return_uv and antialiasing: antialiasing = list(range(image.shape[1])) attr = torch.cat([attr, uv.expand(batch_size, -1, -1)], dim=-1) else: attr = uv.expand(batch_size, -1, -1) if antialiasing: print("\033[93mWarning: you are performing antialiasing on uv. This may cause artifacts.\033[0m") if return_uv: return_uv = False print("\033[93mWarning: image is None, return_uv is ignored.\033[0m") if return_dr: diff_attrs = [0, 1] if mask is not None: attr = torch.cat([attr, mask.float().flatten(1, 2).unsqueeze(-1)], dim=-1) rast = rasterize_triangle_faces( ctx, pts, faces, attr, width, height, view=view_tgt, perspective=perspective_tgt, antialiasing=antialiasing, diff_attrs=diff_attrs, ) if return_dr: output_image, screen_depth, output_dr = rast else: output_image, screen_depth = rast output_mask = screen_depth < 1.0 if mask is not None: output_image, rast_mask = output_image[..., :-1, :, :], output_image[..., -1, :, :] output_mask &= (rast_mask > 0.9999).reshape(-1, height, width) if (return_dr or return_uv) and image is not None: output_image, output_uv = output_image[..., :-2, :, :], output_image[..., -2:, :, :] output_depth = transforms.depth_buffer_to_linear(screen_depth, near=near, far=far) * output_mask output_image = output_image * output_mask.unsqueeze(1) outs = [output_image, output_depth, output_mask] if return_uv: outs.append(output_uv) if return_dr: outs.append(output_dr) return tuple(outs) def warp_image_by_forward_flow( ctx: RastContext, image: torch.FloatTensor, flow: torch.FloatTensor, depth: torch.FloatTensor = None, *, antialiasing: bool = True, backslash: bool = False, ) -> Tuple[torch.FloatTensor, torch.BoolTensor]: """ Warp image by forward flow. NOTE: if batch size is 1, image mesh will be triangulated aware of the depth, yielding less distorted results. Otherwise, image mesh will be triangulated simply for batch rendering. Args: ctx (Union[dr.RasterizeCudaContext, dr.RasterizeGLContext]): rasterization context image (torch.Tensor): (B, C, H, W) image flow (torch.Tensor): (B, 2, H, W) forward flow depth (torch.Tensor, optional): (B, H, W) linear depth. If None, will use the same for all pixels. Defaults to None. antialiasing (bool, optional): whether to perform antialiasing. Defaults to True. backslash (bool, optional): whether to use backslash triangulation. Defaults to False. Returns: image: (torch.FloatTensor): (B, C, H, W) rendered image mask: (torch.BoolTensor): (B, H, W) mask of valid pixels """ assert image.ndim == 4, f'Wrong shape of image: {image.shape}' batch_size, _, height, width = image.shape if depth is None: depth = torch.ones_like(flow[:, 0]) extrinsics = torch.eye(4).to(image) fov = torch.deg2rad(torch.tensor([45.0], device=image.device)) intrinsics = transforms.intrinsics_from_fov(fov, width, height, normalize=True)[0] view = transforms.extrinsics_to_view(extrinsics) perspective = transforms.intrinsics_to_perspective(intrinsics, near=0.1, far=100) uv, faces = utils.image_mesh(width=width, height=height) uv, faces = uv.to(image.device), faces.to(image.device) uv = uv + flow.permute(0, 2, 3, 1).flatten(1, 2) pts = transforms.unproject_cv( uv, depth.flatten(-2, -1), extrinsics, intrinsics, ) # triangulate if batch_size == 1: faces = mesh.triangulate(faces, vertices=pts[0]) else: faces = mesh.triangulate(faces, backslash=backslash) # rasterize attributes attr = image.permute(0, 2, 3, 1).flatten(1, 2) rast = rasterize_triangle_faces( ctx, pts, faces, attr, width, height, view=view, perspective=perspective, antialiasing=antialiasing, ) output_image, screen_depth = rast output_mask = screen_depth < 1.0 output_image = output_image * output_mask.unsqueeze(1) outs = [output_image, output_mask] return tuple(outs)