IF3D / preprocessing /preprocessing_utils.py
leobcc's picture
vid2avatar baseline
6325697
"""This module contains simple helper functions and classes for preprocessing """
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch3d.renderer import (
SfMPerspectiveCameras,
RasterizationSettings,
MeshRenderer,
MeshRasterizer,
SoftPhongShader,
PointLights,
)
from pytorch3d.structures import Meshes
from pytorch3d.renderer.mesh import Textures
DEFAULT_DTYPE = torch.float32
INVALID_TRANS=np.ones(3)*-1
def smpl_to_pose(model_type='smplx', use_hands=True, use_face=True,
use_face_contour=False, openpose_format='coco25'):
''' Returns the indices of the permutation that maps OpenPose to SMPL
Parameters
----------
model_type: str, optional
The type of SMPL-like model that is used. The default mapping
returned is for the SMPLX model
use_hands: bool, optional
Flag for adding to the returned permutation the mapping for the
hand keypoints. Defaults to True
use_face: bool, optional
Flag for adding to the returned permutation the mapping for the
face keypoints. Defaults to True
use_face_contour: bool, optional
Flag for appending the facial contour keypoints. Defaults to False
openpose_format: bool, optional
The output format of OpenPose. For now only COCO-25 and COCO-19 is
supported. Defaults to 'coco25'
'''
if openpose_format.lower() == 'coco25':
if model_type == 'smpl':
return np.array([24, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, 1, 4,
7, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
dtype=np.int32)
elif model_type == 'smplh':
body_mapping = np.array([52, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5,
8, 1, 4, 7, 53, 54, 55, 56, 57, 58, 59,
60, 61, 62], dtype=np.int32)
mapping = [body_mapping]
if use_hands:
lhand_mapping = np.array([20, 34, 35, 36, 63, 22, 23, 24, 64,
25, 26, 27, 65, 31, 32, 33, 66, 28,
29, 30, 67], dtype=np.int32)
rhand_mapping = np.array([21, 49, 50, 51, 68, 37, 38, 39, 69,
40, 41, 42, 70, 46, 47, 48, 71, 43,
44, 45, 72], dtype=np.int32)
mapping += [lhand_mapping, rhand_mapping]
return np.concatenate(mapping)
# SMPLX
elif model_type == 'smplx':
body_mapping = np.array([55, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5,
8, 1, 4, 7, 56, 57, 58, 59, 60, 61, 62,
63, 64, 65], dtype=np.int32)
mapping = [body_mapping]
if use_hands:
lhand_mapping = np.array([20, 37, 38, 39, 66, 25, 26, 27,
67, 28, 29, 30, 68, 34, 35, 36, 69,
31, 32, 33, 70], dtype=np.int32)
rhand_mapping = np.array([21, 52, 53, 54, 71, 40, 41, 42, 72,
43, 44, 45, 73, 49, 50, 51, 74, 46,
47, 48, 75], dtype=np.int32)
mapping += [lhand_mapping, rhand_mapping]
if use_face:
# end_idx = 127 + 17 * use_face_contour
face_mapping = np.arange(76, 127 + 17 * use_face_contour,
dtype=np.int32)
mapping += [face_mapping]
return np.concatenate(mapping)
else:
raise ValueError('Unknown model type: {}'.format(model_type))
elif openpose_format == 'coco19':
if model_type == 'smpl':
return np.array([24, 12, 17, 19, 21, 16, 18, 20, 2, 5, 8,
1, 4, 7, 25, 26, 27, 28],
dtype=np.int32)
elif model_type == 'smpl_neutral':
return np.array([14, 12, 8, 7, 6, 9, 10, 11, 2, 1, 0, 3, 4, 5, 16, 15,18, 17,],
dtype=np.int32)
elif model_type == 'smplh':
body_mapping = np.array([52, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5,
8, 1, 4, 7, 53, 54, 55, 56],
dtype=np.int32)
mapping = [body_mapping]
if use_hands:
lhand_mapping = np.array([20, 34, 35, 36, 57, 22, 23, 24, 58,
25, 26, 27, 59, 31, 32, 33, 60, 28,
29, 30, 61], dtype=np.int32)
rhand_mapping = np.array([21, 49, 50, 51, 62, 37, 38, 39, 63,
40, 41, 42, 64, 46, 47, 48, 65, 43,
44, 45, 66], dtype=np.int32)
mapping += [lhand_mapping, rhand_mapping]
return np.concatenate(mapping)
# SMPLX
elif model_type == 'smplx':
body_mapping = np.array([55, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5,
8, 1, 4, 7, 56, 57, 58, 59],
dtype=np.int32)
mapping = [body_mapping]
if use_hands:
lhand_mapping = np.array([20, 37, 38, 39, 60, 25, 26, 27,
61, 28, 29, 30, 62, 34, 35, 36, 63,
31, 32, 33, 64], dtype=np.int32)
rhand_mapping = np.array([21, 52, 53, 54, 65, 40, 41, 42, 66,
43, 44, 45, 67, 49, 50, 51, 68, 46,
47, 48, 69], dtype=np.int32)
mapping += [lhand_mapping, rhand_mapping]
if use_face:
face_mapping = np.arange(70, 70 + 51 +
17 * use_face_contour,
dtype=np.int32)
mapping += [face_mapping]
return np.concatenate(mapping)
else:
raise ValueError('Unknown model type: {}'.format(model_type))
elif openpose_format == 'h36':
if model_type == 'smpl':
return np.array([2,5,8,1,4,7,12,24,16,18,20,17,19,21],dtype=np.int32)
elif model_type == 'smpl_neutral':
#return np.array([2,1,0,3,4,5,12,13,9,10,11,8,7,6], dtype=np.int32)
return [6, 5, 4, 1, 2, 3, 16, 15, 14, 11, 12, 13, 8, 10]
else:
raise ValueError('Unknown joint format: {}'.format(openpose_format))
def render_trimesh(renderer,mesh,R,T, mode='np'):
verts = torch.tensor(mesh.vertices).cuda().float()[None]
faces = torch.tensor(mesh.faces).cuda()[None]
colors = torch.tensor(mesh.visual.vertex_colors).float().cuda()[None,...,:3]/255
renderer.set_camera(R,T)
image = renderer.render_mesh_recon(verts, faces, colors=colors, mode=mode)[0]
image = (255*image).data.cpu().numpy().astype(np.uint8)
return image
def estimate_translation_cv2(joints_3d, joints_2d, focal_length=600, img_size=np.array([512.,512.]), proj_mat=None, cam_dist=None):
if proj_mat is None:
camK = np.eye(3)
camK[0,0], camK[1,1] = focal_length, focal_length
camK[:2,2] = img_size//2
else:
camK = proj_mat
_, _, tvec,inliers = cv2.solvePnPRansac(joints_3d, joints_2d, camK, cam_dist,\
flags=cv2.SOLVEPNP_EPNP,reprojectionError=20,iterationsCount=100)
if inliers is None:
return INVALID_TRANS
else:
tra_pred = tvec[:,0]
return tra_pred
class JointMapper(nn.Module):
def __init__(self, joint_maps=None):
super(JointMapper, self).__init__()
if joint_maps is None:
self.joint_maps = joint_maps
else:
self.register_buffer('joint_maps',
torch.tensor(joint_maps, dtype=torch.long))
def forward(self, joints, **kwargs):
if self.joint_maps is None:
return joints
else:
return torch.index_select(joints, 1, self.joint_maps)
def transform_mat(R, t):
''' Creates a batch of transformation matrices
Args:
- R: Bx3x3 array of a batch of rotation matrices
- t: Bx3x1 array of a batch of translation vectors
Returns:
- T: Bx4x4 Transformation matrix
'''
# No padding left or right, only add an extra row
return torch.cat([F.pad(R, [0, 0, 0, 1]),
F.pad(t, [0, 0, 0, 1], value=1)], dim=2)
# transform SMPL such that the target camera extrinsic will be met
def transform_smpl(curr_extrinsic, target_extrinsic, smpl_pose, smpl_trans, T_hip):
R_root = cv2.Rodrigues(smpl_pose[:3])[0]
transf_global_ori = np.linalg.inv(target_extrinsic[:3,:3]) @ curr_extrinsic[:3,:3] @ R_root
target_extrinsic[:3, -1] = curr_extrinsic[:3,:3] @ (smpl_trans + T_hip) + curr_extrinsic[:3, -1] - smpl_trans - target_extrinsic[:3,:3] @ T_hip
smpl_pose[:3] = cv2.Rodrigues(transf_global_ori)[0].reshape(3)
smpl_trans = np.linalg.inv(target_extrinsic[:3,:3]) @ smpl_trans # we assume
return target_extrinsic, smpl_pose, smpl_trans
class GMoF(nn.Module):
def __init__(self, rho=1):
super(GMoF, self).__init__()
self.rho = rho
def extra_repr(self):
return 'rho = {}'.format(self.rho)
def forward(self, residual):
squared_res = residual ** 2
dist = torch.div(squared_res, squared_res + self.rho ** 2)
return self.rho ** 2 * dist
class PerspectiveCamera(nn.Module):
FOCAL_LENGTH = 50*128
def __init__(self, rotation=None, translation=None,
focal_length_x=None, focal_length_y=None,
batch_size=1,
center=None, dtype=torch.float32):
super(PerspectiveCamera, self).__init__()
self.batch_size = batch_size
self.dtype = dtype
# Make a buffer so that PyTorch does not complain when creating
# the camera matrix
self.register_buffer('zero',
torch.zeros([batch_size], dtype=dtype))
if focal_length_x is None or type(focal_length_x) == float:
focal_length_x = torch.full(
[batch_size],
self.FOCAL_LENGTH if focal_length_x is None else
focal_length_x,
dtype=dtype)
if focal_length_y is None or type(focal_length_y) == float:
focal_length_y = torch.full(
[batch_size],
self.FOCAL_LENGTH if focal_length_y is None else
focal_length_y,
dtype=dtype)
self.register_buffer('focal_length_x', focal_length_x)
self.register_buffer('focal_length_y', focal_length_y)
if center is None:
center = torch.zeros([batch_size, 2], dtype=dtype)
self.register_buffer('center', center)
if rotation is None:
rotation = torch.eye(
3, dtype=dtype).unsqueeze(dim=0).repeat(batch_size, 1, 1)
rotation = nn.Parameter(rotation, requires_grad=False)
self.register_parameter('rotation', rotation)
if translation is None:
translation = torch.zeros([batch_size, 3], dtype=dtype)
translation = nn.Parameter(translation,
requires_grad=True)
self.register_parameter('translation', translation)
def forward(self, points):
device = points.device
with torch.no_grad():
camera_mat = torch.zeros([self.batch_size, 2, 2],
dtype=self.dtype, device=points.device)
camera_mat[:, 0, 0] = self.focal_length_x
camera_mat[:, 1, 1] = self.focal_length_y
camera_transform = transform_mat(self.rotation,
self.translation.unsqueeze(dim=-1))
homog_coord = torch.ones(list(points.shape)[:-1] + [1],
dtype=points.dtype,
device=device)
# Convert the points to homogeneous coordinates
points_h = torch.cat([points, homog_coord], dim=-1)
projected_points = torch.einsum('bki,bji->bjk',
[camera_transform, points_h])
img_points = torch.div(projected_points[:, :, :2],
projected_points[:, :, 2].unsqueeze(dim=-1))
img_points = torch.einsum('bki,bji->bjk', [camera_mat, img_points]) \
+ self.center.unsqueeze(dim=1)
return img_points
class Renderer():
def __init__(self, principal_point=None, img_size=None, cam_intrinsic = None):
super().__init__()
self.device = torch.device("cuda:0")
torch.cuda.set_device(self.device)
self.cam_intrinsic = cam_intrinsic
self.image_size = img_size
self.render_img_size = np.max(img_size)
principal_point = [-(self.cam_intrinsic[0,2]-self.image_size[1]/2.)/(self.image_size[1]/2.), -(self.cam_intrinsic[1,2]-self.image_size[0]/2.)/(self.image_size[0]/2.)]
self.principal_point = torch.tensor(principal_point, device=self.device).unsqueeze(0)
self.cam_R = torch.from_numpy(np.array([[-1., 0., 0.],
[0., -1., 0.],
[0., 0., 1.]])).cuda().float().unsqueeze(0)
self.cam_T = torch.zeros((1,3)).cuda().float()
half_max_length = max(self.cam_intrinsic[0:2,2])
self.focal_length = torch.tensor([(self.cam_intrinsic[0,0]/half_max_length).astype(np.float32), \
(self.cam_intrinsic[1,1]/half_max_length).astype(np.float32)]).unsqueeze(0)
self.cameras = SfMPerspectiveCameras(focal_length=self.focal_length, principal_point=self.principal_point, R=self.cam_R, T=self.cam_T, device=self.device)
self.lights = PointLights(device=self.device,location=[[0.0, 0.0, 0.0]], ambient_color=((1,1,1),),diffuse_color=((0,0,0),),specular_color=((0,0,0),))
self.raster_settings = RasterizationSettings(image_size=self.render_img_size, faces_per_pixel=10, blur_radius=0, max_faces_per_bin=30000)
self.rasterizer = MeshRasterizer(cameras=self.cameras, raster_settings=self.raster_settings)
self.shader = SoftPhongShader(device=self.device, cameras=self.cameras, lights=self.lights)
self.renderer = MeshRenderer(rasterizer=self.rasterizer, shader=self.shader)
def set_camera(self, R, T):
self.cam_R = R
self.cam_T = T
self.cam_R[:, :2, :] *= -1.0
self.cam_T[:, :2] *= -1.0
self.cam_R = torch.transpose(self.cam_R,1,2)
self.cameras = SfMPerspectiveCameras(focal_length=self.focal_length, principal_point=self.principal_point, R=self.cam_R, T=self.cam_T, device=self.device)
self.rasterizer = MeshRasterizer(cameras=self.cameras, raster_settings=self.raster_settings)
self.shader = SoftPhongShader(device=self.device, cameras=self.cameras, lights=self.lights)
self.renderer = MeshRenderer(rasterizer=self.rasterizer, shader=self.shader)
def render_mesh_recon(self, verts, faces, R=None, T=None, colors=None, mode='npat'):
'''
mode: normal, phong, texture
'''
with torch.no_grad():
mesh = Meshes(verts, faces)
normals = torch.stack(mesh.verts_normals_list())
front_light = -torch.tensor([0,0,-1]).float().to(verts.device)
shades = (normals * front_light.view(1,1,3)).sum(-1).clamp(min=0).unsqueeze(-1).expand(-1,-1,3)
results = []
# shading
if 'p' in mode:
mesh_shading = Meshes(verts, faces, textures=Textures(verts_rgb=shades))
image_phong = self.renderer(mesh_shading)
results.append(image_phong)
# normal
if 'n' in mode:
normals_vis = normals* 0.5 + 0.5
normals_vis = normals_vis[:,:,[2,1,0]]
mesh_normal = Meshes(verts, faces, textures=Textures(verts_rgb=normals_vis))
image_normal = self.renderer(mesh_normal)
results.append(image_normal)
return torch.cat(results, axis=1)