# from https://gitlab.tuebingen.mpg.de/mkocabas/projects/-/blob/master/pare/pare/utils/diff_renderer.py import torch import numpy as np import torch.nn as nn from pytorch3d.renderer import ( PerspectiveCameras, RasterizationSettings, DirectionalLights, BlendParams, HardFlatShader, MeshRasterizer, TexturesVertex, TexturesAtlas ) from pytorch3d.structures import Meshes from .image_utils import get_default_camera from .smpl_uv import get_tenet_texture class MeshRendererWithDepth(nn.Module): """ A class for rendering a batch of heterogeneous meshes. The class should be initialized with a rasterizer and shader class which each have a forward function. """ def __init__(self, rasterizer, shader): super().__init__() self.rasterizer = rasterizer self.shader = shader def forward(self, meshes_world, **kwargs) -> torch.Tensor: """ Render a batch of images from a batch of meshes by rasterizing and then shading. NOTE: If the blur radius for rasterization is > 0.0, some pixels can have one or more barycentric coordinates lying outside the range [0, 1]. For a pixel with out of bounds barycentric coordinates with respect to a face f, clipping is required before interpolating the texture uv coordinates and z buffer so that the colors and depths are limited to the range for the corresponding face. """ fragments = self.rasterizer(meshes_world, **kwargs) images = self.shader(fragments, meshes_world, **kwargs) mask = (fragments.zbuf > -1).float() zbuf = fragments.zbuf.view(images.shape[0], -1) # print(images.shape, zbuf.shape) depth = (zbuf - zbuf.min(-1, keepdims=True).values) / \ (zbuf.max(-1, keepdims=True).values - zbuf.min(-1, keepdims=True).values) depth = depth.reshape(*images.shape[:3] + (1,)) images = torch.cat([images[:, :, :, :3], mask, depth], dim=-1) return images class DifferentiableRenderer(nn.Module): def __init__( self, img_h, img_w, focal_length, device='cuda', background_color=(0.0, 0.0, 0.0), texture_mode='smplpix', vertex_colors=None, face_textures=None, smpl_faces=None, is_train=False, is_cam_batch=False, ): super(DifferentiableRenderer, self).__init__() self.x = 'a' self.img_h = img_h self.img_w = img_w self.device = device self.focal_length = focal_length K, R = get_default_camera(focal_length, img_h, img_w, is_cam_batch=is_cam_batch) K, R = K.to(device), R.to(device) # T = torch.tensor([[0, 0, 2.5 * self.focal_length / max(self.img_h, self.img_w)]]).to(device) if is_cam_batch: T = torch.zeros((K.shape[0], 3)).to(device) else: T = torch.tensor([[0.0, 0.0, 0.0]]).to(device) self.background_color = background_color self.renderer = None smpl_faces = smpl_faces if texture_mode == 'smplpix': face_colors = get_tenet_texture(mode=texture_mode).to(device).float() vertex_colors = torch.from_numpy( np.load(f'data/smpl/{texture_mode}_vertex_colors.npy')[:,:3] ).unsqueeze(0).to(device).float() if texture_mode == 'partseg': vertex_colors = vertex_colors[..., :3].unsqueeze(0).to(device) face_colors = face_textures.to(device) if texture_mode == 'deco': vertex_colors = vertex_colors[..., :3].to(device) face_colors = face_textures.to(device) self.register_buffer('K', K) self.register_buffer('R', R) self.register_buffer('T', T) self.register_buffer('face_colors', face_colors) self.register_buffer('vertex_colors', vertex_colors) self.register_buffer('smpl_faces', smpl_faces) self.set_requires_grad(is_train) def set_requires_grad(self, val=False): self.K.requires_grad_(val) self.R.requires_grad_(val) self.T.requires_grad_(val) self.face_colors.requires_grad_(val) self.vertex_colors.requires_grad_(val) # check if smpl_faces is a FloatTensor as requires_grad_ is not defined for LongTensor if isinstance(self.smpl_faces, torch.FloatTensor): self.smpl_faces.requires_grad_(val) def forward(self, vertices, faces=None, R=None, T=None): raise NotImplementedError class Pytorch3D(DifferentiableRenderer): def __init__( self, img_h, img_w, focal_length, device='cuda', background_color=(0.0, 0.0, 0.0), texture_mode='smplpix', vertex_colors=None, face_textures=None, smpl_faces=None, model_type='smpl', is_train=False, is_cam_batch=False, ): super(Pytorch3D, self).__init__( img_h, img_w, focal_length, device=device, background_color=background_color, texture_mode=texture_mode, vertex_colors=vertex_colors, face_textures=face_textures, smpl_faces=smpl_faces, is_train=is_train, is_cam_batch=is_cam_batch, ) # this R converts the camera from pyrender NDC to # OpenGL coordinate frame. It is basicall R(180, X) x R(180, Y) # I manually defined it here for convenience self.R = self.R @ torch.tensor( [[[ -1.0, 0.0, 0.0], [ 0.0, -1.0, 0.0], [ 0.0, 0.0, 1.0]]], dtype=self.R.dtype, device=self.R.device, ) if is_cam_batch: focal_length = self.focal_length else: focal_length = self.focal_length[None, :] principal_point = ((self.img_w // 2, self.img_h // 2),) image_size = ((self.img_h, self.img_w),) cameras = PerspectiveCameras( device=self.device, focal_length=focal_length, principal_point=principal_point, R=self.R, T=self.T, in_ndc=False, image_size=image_size, ) for param in cameras.parameters(): param.requires_grad_(False) raster_settings = RasterizationSettings( image_size=(self.img_h, self.img_w), blur_radius=0.0, max_faces_per_bin=20000, faces_per_pixel=1, ) lights = DirectionalLights( device=self.device, ambient_color=((1.0, 1.0, 1.0),), diffuse_color=((0.0, 0.0, 0.0),), specular_color=((0.0, 0.0, 0.0),), direction=((0, 1, 0),), ) blend_params = BlendParams(background_color=self.background_color) shader = HardFlatShader(device=self.device, cameras=cameras, blend_params=blend_params, lights=lights) self.textures = TexturesVertex(verts_features=self.vertex_colors) self.renderer = MeshRendererWithDepth( rasterizer=MeshRasterizer( cameras=cameras, raster_settings=raster_settings ), shader=shader, ) def forward(self, vertices, faces=None, R=None, T=None, face_atlas=None): batch_size = vertices.shape[0] if faces is None: faces = self.smpl_faces.expand(batch_size, -1, -1) if R is None: R = self.R.expand(batch_size, -1, -1) if T is None: T = self.T.expand(batch_size, -1) # convert camera translation to pytorch3d coordinate frame T = torch.bmm(R, T.unsqueeze(-1)).squeeze(-1) vertex_textures = TexturesVertex( verts_features=self.vertex_colors.expand(batch_size, -1, -1) ) # face_textures needed because vertex_texture cause interpolation at boundaries if face_atlas: face_textures = TexturesAtlas(atlas=face_atlas) else: face_textures = TexturesAtlas(atlas=self.face_colors) # we may need to rotate the mesh meshes = Meshes(verts=vertices, faces=faces, textures=face_textures) images = self.renderer(meshes, R=R, T=T) images = images.permute(0, 3, 1, 2) return images class NeuralMeshRenderer(DifferentiableRenderer): def __init__(self, *args, **kwargs): import neural_renderer as nr super(NeuralMeshRenderer, self).__init__(*args, **kwargs) self.neural_renderer = nr.Renderer( dist_coeffs=None, orig_size=self.img_size, image_size=self.img_size, light_intensity_ambient=1, light_intensity_directional=0, anti_aliasing=False, ) def forward(self, vertices, faces=None, R=None, T=None): batch_size = vertices.shape[0] if faces is None: faces = self.smpl_faces.expand(batch_size, -1, -1) if R is None: R = self.R.expand(batch_size, -1, -1) if T is None: T = self.T.expand(batch_size, -1) rgb, depth, mask = self.neural_renderer( vertices, faces, textures=self.face_colors.expand(batch_size, -1, -1, -1, -1, -1), K=self.K.expand(batch_size, -1, -1), R=R, t=T.unsqueeze(1), ) return torch.cat([rgb, depth.unsqueeze(1), mask.unsqueeze(1)], dim=1)