barc_gradio / src /smal_pytorch /renderer /differentiable_renderer.py
Nadine Rueegg
initial commit for barc
7629b39
raw history blame
No virus
13.5 kB
# part of the code from
# https://github.com/benjiebob/SMALify/blob/master/smal_fitter/p3d_renderer.py
import torch
import torch.nn.functional as F
from scipy.io import loadmat
import numpy as np
# import config
import pytorch3d
from pytorch3d.structures import Meshes
from pytorch3d.renderer import (
PerspectiveCameras, look_at_view_transform, look_at_rotation,
RasterizationSettings, MeshRenderer, MeshRasterizer, BlendParams,
PointLights, HardPhongShader, SoftSilhouetteShader, Materials, Textures,
DirectionalLights
)
from pytorch3d.renderer import TexturesVertex, SoftPhongShader
from pytorch3d.io import load_objs_as_meshes
MESH_COLOR_0 = [0, 172, 223]
MESH_COLOR_1 = [172, 223, 0]
'''
Explanation of the shift between projection results from opendr and pytorch3d:
(0, 0, ?) will be projected to 127.5 (pytorch3d) instead of 128 (opendr)
imagine you have an image of size 4:
middle of the first pixel is 0
middle of the last pixel is 3
=> middle of the imgae would be 1.5 and not 2!
so in order to go from pytorch3d predictions to opendr we would calculate: p_odr = p_p3d * (128/127.5)
To reproject points (p3d) by hand according to this pytorch3d renderer we would do the following steps:
1.) build camera matrix
K = np.array([[flength, 0, c_x],
[0, flength, c_y],
[0, 0, 1]], np.float)
2.) we don't need to add extrinsics, as the mesh comes with translation (which is
added within smal_pytorch). all 3d points are already in the camera coordinate system.
-> projection reduces to p2d_proj = K*p3d
3.) convert to pytorch3d conventions (0 in the middle of the first pixel)
p2d_proj_pytorch3d = p2d_proj / image_size * (image_size-1.)
renderer.py - project_points_p3d: shows an example of what is described above, but
same focal length for the whole batch
'''
class SilhRenderer(torch.nn.Module):
def __init__(self, image_size, adapt_R_wldo=False):
super(SilhRenderer, self).__init__()
# see: https://pytorch3d.org/files/fit_textured_mesh.py, line 315
# adapt_R=True is True for all my experiments
# image_size: one number, integer
# -----
# set mesh color
self.register_buffer('mesh_color_0', torch.FloatTensor(MESH_COLOR_0))
self.register_buffer('mesh_color_1', torch.FloatTensor(MESH_COLOR_1))
# prepare extrinsics, which in our case don't change
R = torch.Tensor(np.eye(3)).float()[None, :, :]
T = torch.Tensor(np.zeros((1, 3))).float()
if adapt_R_wldo:
R[0, 0, 0] = -1
else: # used for all my own experiments
R[0, 0, 0] = -1
R[0, 1, 1] = -1
self.register_buffer('R', R)
self.register_buffer('T', T)
# prepare that part of the intrinsics which does not change either
# principal_point_prep = torch.Tensor([self.image_size / 2., self.image_size / 2.]).float()[None, :].float().to(device)
# image_size_prep = torch.Tensor([self.image_size, self.image_size]).float()[None, :].float().to(device)
self.img_size_scalar = image_size
self.register_buffer('image_size', torch.Tensor([image_size, image_size]).float()[None, :].float())
self.register_buffer('principal_point', torch.Tensor([image_size / 2., image_size / 2.]).float()[None, :].float())
# Rasterization settings for differentiable rendering, where the blur_radius
# initialization is based on Liu et al, 'Soft Rasterizer: A Differentiable
# Renderer for Image-based 3D Reasoning', ICCV 2019
self.blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
self.raster_settings_soft = RasterizationSettings(
image_size=image_size, # 128
blur_radius=np.log(1. / 1e-4 - 1.)*self.blend_params.sigma,
faces_per_pixel=100) #50,
# Renderer for Image-based 3D Reasoning', body part segmentation
self.blend_params_parts = BlendParams(sigma=2*1e-4, gamma=1e-4)
self.raster_settings_soft_parts = RasterizationSettings(
image_size=image_size, # 128
blur_radius=np.log(1. / 1e-4 - 1.)*self.blend_params_parts.sigma,
faces_per_pixel=60) #50,
# settings for visualization renderer
self.raster_settings_vis = RasterizationSettings(
image_size=image_size,
blur_radius=0.0,
faces_per_pixel=1)
def _get_cam(self, focal_lengths):
device = focal_lengths.device
bs = focal_lengths.shape[0]
if pytorch3d.__version__ == '0.2.5':
cameras = PerspectiveCameras(device=device,
focal_length=focal_lengths.repeat((1, 2)),
principal_point=self.principal_point.repeat((bs, 1)),
R=self.R.repeat((bs, 1, 1)), T=self.T.repeat((bs, 1)),
image_size=self.image_size.repeat((bs, 1)))
elif pytorch3d.__version__ == '0.6.1':
cameras = PerspectiveCameras(device=device, in_ndc=False,
focal_length=focal_lengths.repeat((1, 2)),
principal_point=self.principal_point.repeat((bs, 1)),
R=self.R.repeat((bs, 1, 1)), T=self.T.repeat((bs, 1)),
image_size=self.image_size.repeat((bs, 1)))
else:
print('this part depends on the version of pytorch3d, code was developed with 0.2.5')
raise ValueError
return cameras
def _get_visualization_from_mesh(self, mesh, cameras, lights=None):
# color renderer for visualization
with torch.no_grad():
device = mesh.device
# renderer for visualization
if lights is None:
lights = PointLights(device=device, location=[[0.0, 0.0, 3.0]])
vis_renderer = MeshRenderer(
rasterizer=MeshRasterizer(
cameras=cameras,
raster_settings=self.raster_settings_vis),
shader=HardPhongShader(
device=device,
cameras=cameras,
lights=lights))
# render image:
visualization = vis_renderer(mesh).permute(0, 3, 1, 2)[:, :3, :, :]
return visualization
def calculate_vertex_visibility(self, vertices, faces, focal_lengths, soft=False):
tex = torch.ones_like(vertices) * self.mesh_color_0 # (1, V, 3)
textures = Textures(verts_rgb=tex)
mesh = Meshes(verts=vertices, faces=faces, textures=textures)
cameras = self._get_cam(focal_lengths)
# NEW: use the rasterizer to check vertex visibility
# see: https://github.com/facebookresearch/pytorch3d/issues/126
# Get a rasterizer
if soft:
rasterizer = MeshRasterizer(cameras=cameras,
raster_settings=self.raster_settings_soft)
else:
rasterizer = MeshRasterizer(cameras=cameras,
raster_settings=self.raster_settings_vis)
# Get the output from rasterization
fragments = rasterizer(mesh)
# pix_to_face is of shape (N, H, W, 1)
pix_to_face = fragments.pix_to_face
# (F, 3) where F is the total number of faces across all the meshes in the batch
packed_faces = mesh.faces_packed()
# (V, 3) where V is the total number of verts across all the meshes in the batch
packed_verts = mesh.verts_packed()
vertex_visibility_map = torch.zeros(packed_verts.shape[0]) # (V,)
# Indices of unique visible faces
visible_faces = pix_to_face.unique() # [0] # (num_visible_faces )
# Get Indices of unique visible verts using the vertex indices in the faces
visible_verts_idx = packed_faces[visible_faces] # (num_visible_faces, 3)
unique_visible_verts_idx = torch.unique(visible_verts_idx) # (num_visible_verts, )
# Update visibility indicator to 1 for all visible vertices
vertex_visibility_map[unique_visible_verts_idx] = 1.0
# since all meshes have the same amount of vertices, we can reshape the result
bs = vertices.shape[0]
vertex_visibility_map_resh = vertex_visibility_map.reshape((bs, -1))
return pix_to_face, vertex_visibility_map_resh
def get_torch_meshes(self, vertices, faces, color=0):
# create pytorch mesh
if color == 0:
mesh_color = self.mesh_color_0
else:
mesh_color = self.mesh_color_1
tex = torch.ones_like(vertices) * mesh_color # (1, V, 3)
textures = Textures(verts_rgb=tex)
mesh = Meshes(verts=vertices, faces=faces, textures=textures)
return mesh
def get_visualization_nograd(self, vertices, faces, focal_lengths, color=0):
# vertices: torch.Size([bs, 3889, 3])
# faces: torch.Size([bs, 7774, 3]), int
# focal_lengths: torch.Size([bs, 1])
device = vertices.device
# create cameras
cameras = self._get_cam(focal_lengths)
# create pytorch mesh
if color == 0:
mesh_color = self.mesh_color_0 # blue
elif color == 1:
mesh_color = self.mesh_color_1
elif color == 2:
MESH_COLOR_2 = [240, 250, 240] # white
mesh_color = torch.FloatTensor(MESH_COLOR_2).to(device)
elif color == 3:
# MESH_COLOR_3 = [223, 0, 172] # pink
# MESH_COLOR_3 = [245, 245, 220] # beige
MESH_COLOR_3 = [166, 173, 164]
mesh_color = torch.FloatTensor(MESH_COLOR_3).to(device)
else:
MESH_COLOR_2 = [240, 250, 240]
mesh_color = torch.FloatTensor(MESH_COLOR_2).to(device)
tex = torch.ones_like(vertices) * mesh_color # (1, V, 3)
textures = Textures(verts_rgb=tex)
mesh = Meshes(verts=vertices, faces=faces, textures=textures)
# render mesh (no gradients)
# lights = PointLights(device=device, location=[[0.0, 0.0, 3.0]])
# lights = PointLights(device=device, location=[[2.0, 2.0, -2.0]])
lights = DirectionalLights(device=device, direction=[[0.0, -5.0, -10.0]])
visualization = self._get_visualization_from_mesh(mesh, cameras, lights=lights)
return visualization
def project_points(self, points, focal_lengths=None, cameras=None):
# points: torch.Size([bs, n_points, 3])
# either focal_lengths or cameras is needed:
# focal_lenghts: torch.Size([bs, 1])
# cameras: pytorch camera, for example PerspectiveCameras()
bs = points.shape[0]
device = points.device
screen_size = self.image_size.repeat((bs, 1))
if cameras is None:
cameras = self._get_cam(focal_lengths)
if pytorch3d.__version__ == '0.2.5':
proj_points_orig = cameras.transform_points_screen(points, screen_size)[:, :, [1, 0]] # used in the original virtuel environment (for cvpr BARC submission)
elif pytorch3d.__version__ == '0.6.1':
proj_points_orig = cameras.transform_points_screen(points)[:, :, [1, 0]]
else:
print('this part depends on the version of pytorch3d, code was developed with 0.2.5')
raise ValueError
# flip, otherwise the 1st and 2nd row are exchanged compared to the ground truth
proj_points = torch.flip(proj_points_orig, [2])
# --- project points 'manually'
# j_proj = project_points_p3d(image_size, focal_length, points, device)
return proj_points
def forward(self, vertices, points, faces, focal_lengths, color=None):
# vertices: torch.Size([bs, 3889, 3])
# points: torch.Size([bs, n_points, 3]) (or None)
# faces: torch.Size([bs, 7774, 3]), int
# focal_lengths: torch.Size([bs, 1])
# color: if None we don't render a visualization, else it should
# either be 0 or 1
# ---> important: results are around 0.5 pixels off compared to chumpy!
# have a look at renderer.py for an explanation
# create cameras
cameras = self._get_cam(focal_lengths)
# create pytorch mesh
if color is None or color == 0:
mesh_color = self.mesh_color_0
else:
mesh_color = self.mesh_color_1
tex = torch.ones_like(vertices) * mesh_color # (1, V, 3)
textures = Textures(verts_rgb=tex)
mesh = Meshes(verts=vertices, faces=faces, textures=textures)
# silhouette renderer
renderer_silh = MeshRenderer(
rasterizer=MeshRasterizer(
cameras=cameras,
raster_settings=self.raster_settings_soft),
shader=SoftSilhouetteShader(blend_params=self.blend_params))
# project silhouette
silh_images = renderer_silh(mesh)[..., -1].unsqueeze(1)
# project points
if points is None:
proj_points = None
else:
proj_points = self.project_points(points=points, cameras=cameras)
if color is not None:
# color renderer for visualization (no gradients)
visualization = self._get_visualization_from_mesh(mesh, cameras)
return silh_images, proj_points, visualization
else:
return silh_images, proj_points