Spaces:
Sleeping
Sleeping
File size: 8,847 Bytes
b807ddb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
import torch
import torch.nn as nn
from common import constants
from models.smpl import SMPL
from smplx import SMPLX
import pickle as pkl
import numpy as np
from utils.mesh_utils import save_results_mesh
from utils.diff_renderer import Pytorch3D
import os
import cv2
class sem_loss_function(nn.Module):
def __init__(self):
super(sem_loss_function, self).__init__()
self.ce = nn.BCELoss()
def forward(self, y_true, y_pred):
loss = self.ce(y_pred, y_true)
return loss
class class_loss_function(nn.Module):
def __init__(self):
super(class_loss_function, self).__init__()
self.ce_loss = nn.BCELoss()
# self.ce_loss = nn.MultiLabelSoftMarginLoss()
# self.ce_loss = nn.MultiLabelMarginLoss()
def forward(self, y_true, y_pred, valid_mask):
# y_true = torch.squeeze(y_true, 1).long()
# y_true = torch.squeeze(y_true, 1)
# y_pred = torch.squeeze(y_pred, 1)
bs = y_true.shape[0]
if bs != 1:
y_pred = y_pred[valid_mask == 1]
y_true = y_true[valid_mask == 1]
if len(y_pred) > 0:
return self.ce_loss(y_pred, y_true)
else:
return torch.tensor(0.0).to(y_pred.device)
class pixel_anchoring_function(nn.Module):
def __init__(self, model_type, device='cuda'):
super(pixel_anchoring_function, self).__init__()
self.device = device
self.model_type = model_type
if self.model_type == 'smplx':
# load mapping from smpl vertices to smplx vertices
mapping_pkl = os.path.join(constants.CONTACT_MAPPING_PATH, "smpl_to_smplx.pkl")
with open(mapping_pkl, 'rb') as f:
smpl_to_smplx_mapping = pkl.load(f)
smpl_to_smplx_mapping = smpl_to_smplx_mapping["matrix"]
self.smpl_to_smplx_mapping = torch.from_numpy(smpl_to_smplx_mapping).float().to(self.device)
# Setup the SMPL model
if self.model_type == 'smpl':
self.n_vertices = 6890
self.body_model = SMPL(constants.SMPL_MODEL_DIR).to(self.device)
if self.model_type == 'smplx':
self.n_vertices = 10475
self.body_model = SMPLX(constants.SMPLX_MODEL_DIR,
num_betas=10,
use_pca=False).to(self.device)
self.body_faces = torch.LongTensor(self.body_model.faces.astype(np.int32)).to(self.device)
self.ce_loss = nn.BCELoss()
def get_posed_mesh(self, body_params, debug=False):
betas = body_params['betas']
pose = body_params['pose']
transl = body_params['transl']
# extra smplx params
extra_args = {'jaw_pose': torch.zeros((betas.shape[0], 3)).float().to(self.device),
'leye_pose': torch.zeros((betas.shape[0], 3)).float().to(self.device),
'reye_pose': torch.zeros((betas.shape[0], 3)).float().to(self.device),
'expression': torch.zeros((betas.shape[0], 10)).float().to(self.device),
'left_hand_pose': torch.zeros((betas.shape[0], 45)).float().to(self.device),
'right_hand_pose': torch.zeros((betas.shape[0], 45)).float().to(self.device)}
smpl_output = self.body_model(betas=betas,
body_pose=pose[:, 3:],
global_orient=pose[:, :3],
pose2rot=True,
transl=transl,
**extra_args)
smpl_verts = smpl_output.vertices
smpl_joints = smpl_output.joints
if debug:
for mesh_i in range(smpl_verts.shape[0]):
out_dir = 'temp_meshes'
os.makedirs(out_dir, exist_ok=True)
out_file = os.path.join(out_dir, f'temp_mesh_{mesh_i:04d}.obj')
save_results_mesh(smpl_verts[mesh_i], self.body_model.faces, out_file)
return smpl_verts, smpl_joints
def render_batch(self, smpl_verts, cam_k, img_scale_factor, vertex_colors=None, face_textures=None, debug=False):
bs = smpl_verts.shape[0]
# Incorporate resizing factor into the camera
img_w = 256 # TODO: Remove hardcoding
img_h = 256 # TODO: Remove hardcoding
focal_length_x = cam_k[:, 0, 0] * img_scale_factor[:, 0]
focal_length_y = cam_k[:, 1, 1] * img_scale_factor[:, 1]
# convert to float for pytorch3d
focal_length_x, focal_length_y = focal_length_x.float(), focal_length_y.float()
# concatenate focal length
focal_length = torch.stack([focal_length_x, focal_length_y], dim=1)
# Setup renderer
renderer = Pytorch3D(img_h=img_h,
img_w=img_w,
focal_length=focal_length,
smpl_faces=self.body_faces,
texture_mode='deco',
vertex_colors=vertex_colors,
face_textures=face_textures,
is_train=True,
is_cam_batch=True)
front_view = renderer(smpl_verts)
if debug:
# visualize the front view as images in a temp_image folder
for i in range(bs):
front_view_rgb = front_view[i, :3, :, :].permute(1, 2, 0).detach().cpu()
front_view_mask = front_view[i, 3, :, :].detach().cpu()
out_dir = 'temp_images'
os.makedirs(out_dir, exist_ok=True)
out_file_rgb = os.path.join(out_dir, f'{i:04d}_rgb.png')
out_file_mask = os.path.join(out_dir, f'{i:04d}_mask.png')
cv2.imwrite(out_file_rgb, front_view_rgb.numpy()*255)
cv2.imwrite(out_file_mask, front_view_mask.numpy()*255)
return front_view
def paint_contact(self, pred_contact):
"""
Paints the contact vertices on the SMPL mesh
Args:
pred_contact: prbabilities of contact vertices
Returns:
pred_rgb: RGB colors for the contact vertices
"""
bs = pred_contact.shape[0]
# initialize black and while colors
colors = torch.tensor([[0, 0, 0], [1, 1, 1]]).float().to(self.device)
colors = torch.unsqueeze(colors, 0).expand(bs, -1, -1)
# add another dimension to the contact probabilities for inverse probabilities
pred_contact = torch.unsqueeze(pred_contact, 2)
pred_contact = torch.cat((1 - pred_contact, pred_contact), 2)
# get pred_rgb colors
pred_vert_rgb = torch.bmm(pred_contact, colors)
pred_face_rgb = pred_vert_rgb[:, self.body_faces, :][:, :, 0, :] # take the first vertex color
pred_face_texture = torch.zeros((bs, self.body_faces.shape[0], 1, 1, 3), dtype=torch.float32).to(self.device)
pred_face_texture[:, :, 0, 0, :] = pred_face_rgb
return pred_vert_rgb, pred_face_texture
def forward(self, pred_contact, body_params, cam_k, img_scale_factor, gt_contact_polygon, valid_mask):
"""
Takes predicted contact labels (probabilities), transfers them to the posed mesh and
renders to the image. Loss is computed between the rendered contact and the ground truth
polygons from HOT.
Args:
pred_contact: predicted contact labels (probabilities)
body_params: SMPL parameters in camera coords
cam_k: camera intrinsics
gt_contact_polygon: ground truth polygons from HOT
"""
# convert pred_contact to smplx
bs = pred_contact.shape[0]
if self.model_type == 'smplx':
smpl_to_smplx_mapping = self.smpl_to_smplx_mapping[None].expand(bs, -1, -1)
pred_contact = torch.bmm(smpl_to_smplx_mapping, pred_contact[..., None])
pred_contact = pred_contact.squeeze()
# get the posed mesh
smpl_verts, smpl_joints = self.get_posed_mesh(body_params)
# paint the contact vertices on the mesh
vertex_colors, face_textures = self.paint_contact(pred_contact)
# render the mesh
front_view = self.render_batch(smpl_verts, cam_k, img_scale_factor, vertex_colors, face_textures)
front_view_rgb = front_view[:, :3, :, :].permute(0, 2, 3, 1)
front_view_mask = front_view[:, 3, :, :]
# compute segmentation loss between rendered contact mask and ground truth contact mask
front_view_rgb = front_view_rgb[valid_mask == 1]
gt_contact_polygon = gt_contact_polygon[valid_mask == 1]
loss = self.ce_loss(front_view_rgb, gt_contact_polygon)
return loss, front_view_rgb, front_view_mask
|