Spaces:
Sleeping
Sleeping
# from https://gitlab.tuebingen.mpg.de/mkocabas/projects/-/blob/master/pare/pare/utils/diff_renderer.py | |
import torch | |
import numpy as np | |
import torch.nn as nn | |
from pytorch3d.renderer import ( | |
PerspectiveCameras, | |
RasterizationSettings, | |
DirectionalLights, | |
BlendParams, | |
HardFlatShader, | |
MeshRasterizer, | |
TexturesVertex, | |
TexturesAtlas | |
) | |
from pytorch3d.structures import Meshes | |
from .image_utils import get_default_camera | |
from .smpl_uv import get_tenet_texture | |
class MeshRendererWithDepth(nn.Module): | |
""" | |
A class for rendering a batch of heterogeneous meshes. The class should | |
be initialized with a rasterizer and shader class which each have a forward | |
function. | |
""" | |
def __init__(self, rasterizer, shader): | |
super().__init__() | |
self.rasterizer = rasterizer | |
self.shader = shader | |
def forward(self, meshes_world, **kwargs) -> torch.Tensor: | |
""" | |
Render a batch of images from a batch of meshes by rasterizing and then | |
shading. | |
NOTE: If the blur radius for rasterization is > 0.0, some pixels can | |
have one or more barycentric coordinates lying outside the range [0, 1]. | |
For a pixel with out of bounds barycentric coordinates with respect to a | |
face f, clipping is required before interpolating the texture uv | |
coordinates and z buffer so that the colors and depths are limited to | |
the range for the corresponding face. | |
""" | |
fragments = self.rasterizer(meshes_world, **kwargs) | |
images = self.shader(fragments, meshes_world, **kwargs) | |
mask = (fragments.zbuf > -1).float() | |
zbuf = fragments.zbuf.view(images.shape[0], -1) | |
# print(images.shape, zbuf.shape) | |
depth = (zbuf - zbuf.min(-1, keepdims=True).values) / \ | |
(zbuf.max(-1, keepdims=True).values - zbuf.min(-1, keepdims=True).values) | |
depth = depth.reshape(*images.shape[:3] + (1,)) | |
images = torch.cat([images[:, :, :, :3], mask, depth], dim=-1) | |
return images | |
class DifferentiableRenderer(nn.Module): | |
def __init__( | |
self, | |
img_h, | |
img_w, | |
focal_length, | |
device='cuda', | |
background_color=(0.0, 0.0, 0.0), | |
texture_mode='smplpix', | |
vertex_colors=None, | |
face_textures=None, | |
smpl_faces=None, | |
is_train=False, | |
is_cam_batch=False, | |
): | |
super(DifferentiableRenderer, self).__init__() | |
self.x = 'a' | |
self.img_h = img_h | |
self.img_w = img_w | |
self.device = device | |
self.focal_length = focal_length | |
K, R = get_default_camera(focal_length, img_h, img_w, is_cam_batch=is_cam_batch) | |
K, R = K.to(device), R.to(device) | |
# T = torch.tensor([[0, 0, 2.5 * self.focal_length / max(self.img_h, self.img_w)]]).to(device) | |
if is_cam_batch: | |
T = torch.zeros((K.shape[0], 3)).to(device) | |
else: | |
T = torch.tensor([[0.0, 0.0, 0.0]]).to(device) | |
self.background_color = background_color | |
self.renderer = None | |
smpl_faces = smpl_faces | |
if texture_mode == 'smplpix': | |
face_colors = get_tenet_texture(mode=texture_mode).to(device).float() | |
vertex_colors = torch.from_numpy( | |
np.load(f'data/smpl/{texture_mode}_vertex_colors.npy')[:,:3] | |
).unsqueeze(0).to(device).float() | |
if texture_mode == 'partseg': | |
vertex_colors = vertex_colors[..., :3].unsqueeze(0).to(device) | |
face_colors = face_textures.to(device) | |
if texture_mode == 'deco': | |
vertex_colors = vertex_colors[..., :3].to(device) | |
face_colors = face_textures.to(device) | |
self.register_buffer('K', K) | |
self.register_buffer('R', R) | |
self.register_buffer('T', T) | |
self.register_buffer('face_colors', face_colors) | |
self.register_buffer('vertex_colors', vertex_colors) | |
self.register_buffer('smpl_faces', smpl_faces) | |
self.set_requires_grad(is_train) | |
def set_requires_grad(self, val=False): | |
self.K.requires_grad_(val) | |
self.R.requires_grad_(val) | |
self.T.requires_grad_(val) | |
self.face_colors.requires_grad_(val) | |
self.vertex_colors.requires_grad_(val) | |
# check if smpl_faces is a FloatTensor as requires_grad_ is not defined for LongTensor | |
if isinstance(self.smpl_faces, torch.FloatTensor): | |
self.smpl_faces.requires_grad_(val) | |
def forward(self, vertices, faces=None, R=None, T=None): | |
raise NotImplementedError | |
class Pytorch3D(DifferentiableRenderer): | |
def __init__( | |
self, | |
img_h, | |
img_w, | |
focal_length, | |
device='cuda', | |
background_color=(0.0, 0.0, 0.0), | |
texture_mode='smplpix', | |
vertex_colors=None, | |
face_textures=None, | |
smpl_faces=None, | |
model_type='smpl', | |
is_train=False, | |
is_cam_batch=False, | |
): | |
super(Pytorch3D, self).__init__( | |
img_h, | |
img_w, | |
focal_length, | |
device=device, | |
background_color=background_color, | |
texture_mode=texture_mode, | |
vertex_colors=vertex_colors, | |
face_textures=face_textures, | |
smpl_faces=smpl_faces, | |
is_train=is_train, | |
is_cam_batch=is_cam_batch, | |
) | |
# this R converts the camera from pyrender NDC to | |
# OpenGL coordinate frame. It is basicall R(180, X) x R(180, Y) | |
# I manually defined it here for convenience | |
self.R = self.R @ torch.tensor( | |
[[[ -1.0, 0.0, 0.0], | |
[ 0.0, -1.0, 0.0], | |
[ 0.0, 0.0, 1.0]]], | |
dtype=self.R.dtype, device=self.R.device, | |
) | |
if is_cam_batch: | |
focal_length = self.focal_length | |
else: | |
focal_length = self.focal_length[None, :] | |
principal_point = ((self.img_w // 2, self.img_h // 2),) | |
image_size = ((self.img_h, self.img_w),) | |
cameras = PerspectiveCameras( | |
device=self.device, | |
focal_length=focal_length, | |
principal_point=principal_point, | |
R=self.R, | |
T=self.T, | |
in_ndc=False, | |
image_size=image_size, | |
) | |
for param in cameras.parameters(): | |
param.requires_grad_(False) | |
raster_settings = RasterizationSettings( | |
image_size=(self.img_h, self.img_w), | |
blur_radius=0.0, | |
max_faces_per_bin=20000, | |
faces_per_pixel=1, | |
) | |
lights = DirectionalLights( | |
device=self.device, | |
ambient_color=((1.0, 1.0, 1.0),), | |
diffuse_color=((0.0, 0.0, 0.0),), | |
specular_color=((0.0, 0.0, 0.0),), | |
direction=((0, 1, 0),), | |
) | |
blend_params = BlendParams(background_color=self.background_color) | |
shader = HardFlatShader(device=self.device, | |
cameras=cameras, | |
blend_params=blend_params, | |
lights=lights) | |
self.textures = TexturesVertex(verts_features=self.vertex_colors) | |
self.renderer = MeshRendererWithDepth( | |
rasterizer=MeshRasterizer( | |
cameras=cameras, | |
raster_settings=raster_settings | |
), | |
shader=shader, | |
) | |
def forward(self, vertices, faces=None, R=None, T=None, face_atlas=None): | |
batch_size = vertices.shape[0] | |
if faces is None: | |
faces = self.smpl_faces.expand(batch_size, -1, -1) | |
if R is None: | |
R = self.R.expand(batch_size, -1, -1) | |
if T is None: | |
T = self.T.expand(batch_size, -1) | |
# convert camera translation to pytorch3d coordinate frame | |
T = torch.bmm(R, T.unsqueeze(-1)).squeeze(-1) | |
vertex_textures = TexturesVertex( | |
verts_features=self.vertex_colors.expand(batch_size, -1, -1) | |
) | |
# face_textures needed because vertex_texture cause interpolation at boundaries | |
if face_atlas: | |
face_textures = TexturesAtlas(atlas=face_atlas) | |
else: | |
face_textures = TexturesAtlas(atlas=self.face_colors) | |
# we may need to rotate the mesh | |
meshes = Meshes(verts=vertices, faces=faces, textures=face_textures) | |
images = self.renderer(meshes, R=R, T=T) | |
images = images.permute(0, 3, 1, 2) | |
return images | |
class NeuralMeshRenderer(DifferentiableRenderer): | |
def __init__(self, *args, **kwargs): | |
import neural_renderer as nr | |
super(NeuralMeshRenderer, self).__init__(*args, **kwargs) | |
self.neural_renderer = nr.Renderer( | |
dist_coeffs=None, | |
orig_size=self.img_size, | |
image_size=self.img_size, | |
light_intensity_ambient=1, | |
light_intensity_directional=0, | |
anti_aliasing=False, | |
) | |
def forward(self, vertices, faces=None, R=None, T=None): | |
batch_size = vertices.shape[0] | |
if faces is None: | |
faces = self.smpl_faces.expand(batch_size, -1, -1) | |
if R is None: | |
R = self.R.expand(batch_size, -1, -1) | |
if T is None: | |
T = self.T.expand(batch_size, -1) | |
rgb, depth, mask = self.neural_renderer( | |
vertices, | |
faces, | |
textures=self.face_colors.expand(batch_size, -1, -1, -1, -1, -1), | |
K=self.K.expand(batch_size, -1, -1), | |
R=R, | |
t=T.unsqueeze(1), | |
) | |
return torch.cat([rgb, depth.unsqueeze(1), mask.unsqueeze(1)], dim=1) |