"""This script is the differentiable renderer for Deep3DFaceRecon_pytorch Attention, antialiasing step is missing in current version. """ import pytorch3d.ops import torch import torch.nn.functional as F import kornia from kornia.geometry.camera import pixel2cam import numpy as np from typing import List from scipy.io import loadmat from torch import nn from pytorch3d.structures import Meshes from pytorch3d.renderer import ( look_at_view_transform, FoVPerspectiveCameras, DirectionalLights, RasterizationSettings, MeshRenderer, MeshRasterizer, SoftPhongShader, TexturesUV, ) # def ndc_projection(x=0.1, n=1.0, f=50.0): # return np.array([[n/x, 0, 0, 0], # [ 0, n/-x, 0, 0], # [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], # [ 0, 0, -1, 0]]).astype(np.float32) class MeshRenderer(nn.Module): def __init__(self, rasterize_fov, znear=0.1, zfar=10, rasterize_size=224): super(MeshRenderer, self).__init__() # x = np.tan(np.deg2rad(rasterize_fov * 0.5)) * znear # self.ndc_proj = torch.tensor(ndc_projection(x=x, n=znear, f=zfar)).matmul( # torch.diag(torch.tensor([1., -1, -1, 1]))) self.rasterize_size = rasterize_size self.fov = rasterize_fov self.znear = znear self.zfar = zfar self.rasterizer = None def forward(self, vertex, tri, feat=None): """ Return: mask -- torch.tensor, size (B, 1, H, W) depth -- torch.tensor, size (B, 1, H, W) features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None Parameters: vertex -- torch.tensor, size (B, N, 3) tri -- torch.tensor, size (B, M, 3) or (M, 3), triangles feat(optional) -- torch.tensor, size (B, N ,C), features """ device = vertex.device rsize = int(self.rasterize_size) # ndc_proj = self.ndc_proj.to(device) # trans to homogeneous coordinates of 3d vertices, the direction of y is the same as v if vertex.shape[-1] == 3: vertex = torch.cat([vertex, torch.ones([*vertex.shape[:2], 1]).to(device)], dim=-1) vertex[..., 0] = -vertex[..., 0] # vertex_ndc = vertex @ ndc_proj.t() if self.rasterizer is None: self.rasterizer = MeshRasterizer() print("create rasterizer on device cuda:%d"%device.index) # ranges = None # if isinstance(tri, List) or len(tri.shape) == 3: # vum = vertex_ndc.shape[1] # fnum = torch.tensor([f.shape[0] for f in tri]).unsqueeze(1).to(device) # fstartidx = torch.cumsum(fnum, dim=0) - fnum # ranges = torch.cat([fstartidx, fnum], axis=1).type(torch.int32).cpu() # for i in range(tri.shape[0]): # tri[i] = tri[i] + i*vum # vertex_ndc = torch.cat(vertex_ndc, dim=0) # tri = torch.cat(tri, dim=0) # for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3] tri = tri.type(torch.int32).contiguous() # rasterize cameras = FoVPerspectiveCameras( device=device, fov=self.fov, znear=self.znear, zfar=self.zfar, ) raster_settings = RasterizationSettings( image_size=rsize ) # print(vertex.shape, tri.shape) mesh = Meshes(vertex.contiguous()[...,:3], tri.unsqueeze(0).repeat((vertex.shape[0],1,1))) fragments = self.rasterizer(mesh, cameras = cameras, raster_settings = raster_settings) rast_out = fragments.pix_to_face.squeeze(-1) depth = fragments.zbuf # render depth depth = depth.permute(0, 3, 1, 2) mask = (rast_out > 0).float().unsqueeze(1) depth = mask * depth image = None if feat is not None: attributes = feat.reshape(-1,3)[mesh.faces_packed()] image = pytorch3d.ops.interpolate_face_attributes(fragments.pix_to_face, fragments.bary_coords, attributes) # print(image.shape) image = image.squeeze(-2).permute(0, 3, 1, 2) image = mask * image return mask, depth, image