|
from __future__ import division, print_function |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from torch.autograd import Function |
|
|
|
import voxelize_cuda |
|
|
|
|
|
class VoxelizationFunction(Function): |
|
""" |
|
Definition of differentiable voxelization function |
|
Currently implemented only for cuda Tensors |
|
""" |
|
@staticmethod |
|
def forward(ctx, smpl_vertices, smpl_face_center, smpl_face_normal, |
|
smpl_vertex_code, smpl_face_code, smpl_tetrahedrons, |
|
volume_res, sigma, smooth_kernel_size): |
|
""" |
|
forward pass |
|
Output format: (batch_size, z_dims, y_dims, x_dims, channel_num) |
|
""" |
|
assert (smpl_vertices.size()[1] == smpl_vertex_code.size()[1]) |
|
assert (smpl_face_center.size()[1] == smpl_face_normal.size()[1]) |
|
assert (smpl_face_center.size()[1] == smpl_face_code.size()[1]) |
|
ctx.batch_size = smpl_vertices.size()[0] |
|
ctx.volume_res = volume_res |
|
ctx.sigma = sigma |
|
ctx.smooth_kernel_size = smooth_kernel_size |
|
ctx.smpl_vertex_num = smpl_vertices.size()[1] |
|
ctx.device = smpl_vertices.device |
|
|
|
smpl_vertices = smpl_vertices.contiguous() |
|
smpl_face_center = smpl_face_center.contiguous() |
|
smpl_face_normal = smpl_face_normal.contiguous() |
|
smpl_vertex_code = smpl_vertex_code.contiguous() |
|
smpl_face_code = smpl_face_code.contiguous() |
|
smpl_tetrahedrons = smpl_tetrahedrons.contiguous() |
|
|
|
occ_volume = torch.cuda.FloatTensor(ctx.batch_size, ctx.volume_res, |
|
ctx.volume_res, |
|
ctx.volume_res).fill_(0.0) |
|
semantic_volume = torch.cuda.FloatTensor(ctx.batch_size, |
|
ctx.volume_res, |
|
ctx.volume_res, |
|
ctx.volume_res, 3).fill_(0.0) |
|
weight_sum_volume = torch.cuda.FloatTensor(ctx.batch_size, |
|
ctx.volume_res, |
|
ctx.volume_res, |
|
ctx.volume_res).fill_(1e-3) |
|
|
|
|
|
|
|
|
|
|
|
occ_volume, semantic_volume, weight_sum_volume = voxelize_cuda.forward_semantic_voxelization( |
|
smpl_vertices, smpl_vertex_code, smpl_tetrahedrons, occ_volume, |
|
semantic_volume, weight_sum_volume, sigma) |
|
|
|
return semantic_volume |
|
|
|
|
|
class Voxelization(nn.Module): |
|
""" |
|
Wrapper around the autograd function VoxelizationFunction |
|
""" |
|
|
|
def __init__(self, smpl_vertex_code, smpl_face_code, smpl_face_indices, |
|
smpl_tetraderon_indices, volume_res, sigma, |
|
smooth_kernel_size, batch_size, device): |
|
super(Voxelization, self).__init__() |
|
assert (len(smpl_face_indices.shape) == 2) |
|
assert (len(smpl_tetraderon_indices.shape) == 2) |
|
assert (smpl_face_indices.shape[1] == 3) |
|
assert (smpl_tetraderon_indices.shape[1] == 4) |
|
|
|
self.volume_res = volume_res |
|
self.sigma = sigma |
|
self.smooth_kernel_size = smooth_kernel_size |
|
self.batch_size = batch_size |
|
self.device = device |
|
|
|
self.smpl_vertex_code = smpl_vertex_code |
|
self.smpl_face_code = smpl_face_code |
|
self.smpl_face_indices = smpl_face_indices |
|
self.smpl_tetraderon_indices = smpl_tetraderon_indices |
|
|
|
def update_param(self, batch_size, smpl_tetra): |
|
|
|
self.batch_size = batch_size |
|
self.smpl_tetraderon_indices = smpl_tetra |
|
|
|
smpl_vertex_code_batch = np.tile(self.smpl_vertex_code, |
|
(self.batch_size, 1, 1)) |
|
smpl_face_code_batch = np.tile(self.smpl_face_code, |
|
(self.batch_size, 1, 1)) |
|
smpl_face_indices_batch = np.tile(self.smpl_face_indices, |
|
(self.batch_size, 1, 1)) |
|
smpl_tetraderon_indices_batch = np.tile(self.smpl_tetraderon_indices, |
|
(self.batch_size, 1, 1)) |
|
|
|
smpl_vertex_code_batch = torch.from_numpy( |
|
smpl_vertex_code_batch).contiguous().to(self.device) |
|
smpl_face_code_batch = torch.from_numpy( |
|
smpl_face_code_batch).contiguous().to(self.device) |
|
smpl_face_indices_batch = torch.from_numpy( |
|
smpl_face_indices_batch).contiguous().to(self.device) |
|
smpl_tetraderon_indices_batch = torch.from_numpy( |
|
smpl_tetraderon_indices_batch).contiguous().to(self.device) |
|
|
|
self.register_buffer('smpl_vertex_code_batch', smpl_vertex_code_batch) |
|
self.register_buffer('smpl_face_code_batch', smpl_face_code_batch) |
|
self.register_buffer('smpl_face_indices_batch', |
|
smpl_face_indices_batch) |
|
self.register_buffer('smpl_tetraderon_indices_batch', |
|
smpl_tetraderon_indices_batch) |
|
|
|
def forward(self, smpl_vertices): |
|
""" |
|
Generate semantic volumes from SMPL vertices |
|
""" |
|
assert (smpl_vertices.size()[0] == self.batch_size) |
|
self.check_input(smpl_vertices) |
|
smpl_faces = self.vertices_to_faces(smpl_vertices) |
|
smpl_tetrahedrons = self.vertices_to_tetrahedrons(smpl_vertices) |
|
smpl_face_center = self.calc_face_centers(smpl_faces) |
|
smpl_face_normal = self.calc_face_normals(smpl_faces) |
|
smpl_surface_vertex_num = self.smpl_vertex_code_batch.size()[1] |
|
smpl_vertices_surface = smpl_vertices[:, :smpl_surface_vertex_num, :] |
|
vol = VoxelizationFunction.apply(smpl_vertices_surface, |
|
smpl_face_center, smpl_face_normal, |
|
self.smpl_vertex_code_batch, |
|
self.smpl_face_code_batch, |
|
smpl_tetrahedrons, self.volume_res, |
|
self.sigma, self.smooth_kernel_size) |
|
return vol.permute((0, 4, 1, 2, 3)) |
|
|
|
def vertices_to_faces(self, vertices): |
|
assert (vertices.ndimension() == 3) |
|
bs, nv = vertices.shape[:2] |
|
device = vertices.device |
|
face = self.smpl_face_indices_batch + ( |
|
torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None] |
|
vertices_ = vertices.reshape((bs * nv, 3)) |
|
return vertices_[face.long()] |
|
|
|
def vertices_to_tetrahedrons(self, vertices): |
|
assert (vertices.ndimension() == 3) |
|
bs, nv = vertices.shape[:2] |
|
device = vertices.device |
|
tets = self.smpl_tetraderon_indices_batch + ( |
|
torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None] |
|
vertices_ = vertices.reshape((bs * nv, 3)) |
|
return vertices_[tets.long()] |
|
|
|
def calc_face_centers(self, face_verts): |
|
assert len(face_verts.shape) == 4 |
|
assert face_verts.shape[2] == 3 |
|
assert face_verts.shape[3] == 3 |
|
bs, nf = face_verts.shape[:2] |
|
face_centers = (face_verts[:, :, 0, :] + face_verts[:, :, 1, :] + |
|
face_verts[:, :, 2, :]) / 3.0 |
|
face_centers = face_centers.reshape((bs, nf, 3)) |
|
return face_centers |
|
|
|
def calc_face_normals(self, face_verts): |
|
assert len(face_verts.shape) == 4 |
|
assert face_verts.shape[2] == 3 |
|
assert face_verts.shape[3] == 3 |
|
bs, nf = face_verts.shape[:2] |
|
face_verts = face_verts.reshape((bs * nf, 3, 3)) |
|
v10 = face_verts[:, 0] - face_verts[:, 1] |
|
v12 = face_verts[:, 2] - face_verts[:, 1] |
|
normals = F.normalize(torch.cross(v10, v12), eps=1e-5) |
|
normals = normals.reshape((bs, nf, 3)) |
|
return normals |
|
|
|
def check_input(self, x): |
|
if x.device == 'cpu': |
|
raise TypeError('Voxelization module supports only cuda tensors') |
|
if x.type() != 'torch.cuda.FloatTensor': |
|
raise TypeError( |
|
'Voxelization module supports only float32 tensors') |
|
|