import torch import torch.nn as nn from common import constants from models.smpl import SMPL from smplx import SMPLX import pickle as pkl import numpy as np from utils.mesh_utils import save_results_mesh from utils.diff_renderer import Pytorch3D import os import cv2 class sem_loss_function(nn.Module): def __init__(self): super(sem_loss_function, self).__init__() self.ce = nn.BCELoss() def forward(self, y_true, y_pred): loss = self.ce(y_pred, y_true) return loss class class_loss_function(nn.Module): def __init__(self): super(class_loss_function, self).__init__() self.ce_loss = nn.BCELoss() # self.ce_loss = nn.MultiLabelSoftMarginLoss() # self.ce_loss = nn.MultiLabelMarginLoss() def forward(self, y_true, y_pred, valid_mask): # y_true = torch.squeeze(y_true, 1).long() # y_true = torch.squeeze(y_true, 1) # y_pred = torch.squeeze(y_pred, 1) bs = y_true.shape[0] if bs != 1: y_pred = y_pred[valid_mask == 1] y_true = y_true[valid_mask == 1] if len(y_pred) > 0: return self.ce_loss(y_pred, y_true) else: return torch.tensor(0.0).to(y_pred.device) class pixel_anchoring_function(nn.Module): def __init__(self, model_type, device='cuda'): super(pixel_anchoring_function, self).__init__() self.device = device self.model_type = model_type if self.model_type == 'smplx': # load mapping from smpl vertices to smplx vertices mapping_pkl = os.path.join(constants.CONTACT_MAPPING_PATH, "smpl_to_smplx.pkl") with open(mapping_pkl, 'rb') as f: smpl_to_smplx_mapping = pkl.load(f) smpl_to_smplx_mapping = smpl_to_smplx_mapping["matrix"] self.smpl_to_smplx_mapping = torch.from_numpy(smpl_to_smplx_mapping).float().to(self.device) # Setup the SMPL model if self.model_type == 'smpl': self.n_vertices = 6890 self.body_model = SMPL(constants.SMPL_MODEL_DIR).to(self.device) if self.model_type == 'smplx': self.n_vertices = 10475 self.body_model = SMPLX(constants.SMPLX_MODEL_DIR, num_betas=10, use_pca=False).to(self.device) self.body_faces = torch.LongTensor(self.body_model.faces.astype(np.int32)).to(self.device) self.ce_loss = nn.BCELoss() def get_posed_mesh(self, body_params, debug=False): betas = body_params['betas'] pose = body_params['pose'] transl = body_params['transl'] # extra smplx params extra_args = {'jaw_pose': torch.zeros((betas.shape[0], 3)).float().to(self.device), 'leye_pose': torch.zeros((betas.shape[0], 3)).float().to(self.device), 'reye_pose': torch.zeros((betas.shape[0], 3)).float().to(self.device), 'expression': torch.zeros((betas.shape[0], 10)).float().to(self.device), 'left_hand_pose': torch.zeros((betas.shape[0], 45)).float().to(self.device), 'right_hand_pose': torch.zeros((betas.shape[0], 45)).float().to(self.device)} smpl_output = self.body_model(betas=betas, body_pose=pose[:, 3:], global_orient=pose[:, :3], pose2rot=True, transl=transl, **extra_args) smpl_verts = smpl_output.vertices smpl_joints = smpl_output.joints if debug: for mesh_i in range(smpl_verts.shape[0]): out_dir = 'temp_meshes' os.makedirs(out_dir, exist_ok=True) out_file = os.path.join(out_dir, f'temp_mesh_{mesh_i:04d}.obj') save_results_mesh(smpl_verts[mesh_i], self.body_model.faces, out_file) return smpl_verts, smpl_joints def render_batch(self, smpl_verts, cam_k, img_scale_factor, vertex_colors=None, face_textures=None, debug=False): bs = smpl_verts.shape[0] # Incorporate resizing factor into the camera img_w = 256 # TODO: Remove hardcoding img_h = 256 # TODO: Remove hardcoding focal_length_x = cam_k[:, 0, 0] * img_scale_factor[:, 0] focal_length_y = cam_k[:, 1, 1] * img_scale_factor[:, 1] # convert to float for pytorch3d focal_length_x, focal_length_y = focal_length_x.float(), focal_length_y.float() # concatenate focal length focal_length = torch.stack([focal_length_x, focal_length_y], dim=1) # Setup renderer renderer = Pytorch3D(img_h=img_h, img_w=img_w, focal_length=focal_length, smpl_faces=self.body_faces, texture_mode='deco', vertex_colors=vertex_colors, face_textures=face_textures, is_train=True, is_cam_batch=True) front_view = renderer(smpl_verts) if debug: # visualize the front view as images in a temp_image folder for i in range(bs): front_view_rgb = front_view[i, :3, :, :].permute(1, 2, 0).detach().cpu() front_view_mask = front_view[i, 3, :, :].detach().cpu() out_dir = 'temp_images' os.makedirs(out_dir, exist_ok=True) out_file_rgb = os.path.join(out_dir, f'{i:04d}_rgb.png') out_file_mask = os.path.join(out_dir, f'{i:04d}_mask.png') cv2.imwrite(out_file_rgb, front_view_rgb.numpy()*255) cv2.imwrite(out_file_mask, front_view_mask.numpy()*255) return front_view def paint_contact(self, pred_contact): """ Paints the contact vertices on the SMPL mesh Args: pred_contact: prbabilities of contact vertices Returns: pred_rgb: RGB colors for the contact vertices """ bs = pred_contact.shape[0] # initialize black and while colors colors = torch.tensor([[0, 0, 0], [1, 1, 1]]).float().to(self.device) colors = torch.unsqueeze(colors, 0).expand(bs, -1, -1) # add another dimension to the contact probabilities for inverse probabilities pred_contact = torch.unsqueeze(pred_contact, 2) pred_contact = torch.cat((1 - pred_contact, pred_contact), 2) # get pred_rgb colors pred_vert_rgb = torch.bmm(pred_contact, colors) pred_face_rgb = pred_vert_rgb[:, self.body_faces, :][:, :, 0, :] # take the first vertex color pred_face_texture = torch.zeros((bs, self.body_faces.shape[0], 1, 1, 3), dtype=torch.float32).to(self.device) pred_face_texture[:, :, 0, 0, :] = pred_face_rgb return pred_vert_rgb, pred_face_texture def forward(self, pred_contact, body_params, cam_k, img_scale_factor, gt_contact_polygon, valid_mask): """ Takes predicted contact labels (probabilities), transfers them to the posed mesh and renders to the image. Loss is computed between the rendered contact and the ground truth polygons from HOT. Args: pred_contact: predicted contact labels (probabilities) body_params: SMPL parameters in camera coords cam_k: camera intrinsics gt_contact_polygon: ground truth polygons from HOT """ # convert pred_contact to smplx bs = pred_contact.shape[0] if self.model_type == 'smplx': smpl_to_smplx_mapping = self.smpl_to_smplx_mapping[None].expand(bs, -1, -1) pred_contact = torch.bmm(smpl_to_smplx_mapping, pred_contact[..., None]) pred_contact = pred_contact.squeeze() # get the posed mesh smpl_verts, smpl_joints = self.get_posed_mesh(body_params) # paint the contact vertices on the mesh vertex_colors, face_textures = self.paint_contact(pred_contact) # render the mesh front_view = self.render_batch(smpl_verts, cam_k, img_scale_factor, vertex_colors, face_textures) front_view_rgb = front_view[:, :3, :, :].permute(0, 2, 3, 1) front_view_mask = front_view[:, 3, :, :] # compute segmentation loss between rendered contact mask and ground truth contact mask front_view_rgb = front_view_rgb[valid_mask == 1] gt_contact_polygon = gt_contact_polygon[valid_mask == 1] loss = self.ce_loss(front_view_rgb, gt_contact_polygon) return loss, front_view_rgb, front_view_mask