import torch import cv2 import numpy as np from torch.utils.data import Dataset from torchvision.transforms import Normalize from common import constants def mask_split(img, num_parts): if not len(img.shape) == 2: img = img[:, :, 0] mask = np.zeros((img.shape[0], img.shape[1], num_parts)) for i in np.unique(img): mask[:, :, i] = np.where(img == i, 1., 0.) return np.transpose(mask, (2, 0, 1)) class BaseDataset(Dataset): def __init__(self, dataset, mode, model_type='smpl', normalize=False): self.dataset = dataset self.mode = mode print(f'Loading dataset: {constants.DATASET_FILES[mode][dataset]} for mode: {mode}') self.data = np.load(constants.DATASET_FILES[mode][dataset], allow_pickle=True) self.images = self.data['imgname'] # get 3d contact labels, if available try: self.contact_labels_3d = self.data['contact_label'] # make a has_contact_3d numpy array which contains 1 if contact labels are no empty and 0 otherwise self.has_contact_3d = np.array([1 if len(x) > 0 else 0 for x in self.contact_labels_3d]) except KeyError: self.has_contact_3d = np.zeros(len(self.images)) # get 2d polygon contact labels, if available try: self.polygon_contacts_2d = self.data['polygon_2d_contact'] self.has_polygon_contact_2d = np.ones(len(self.images)) except KeyError: self.has_polygon_contact_2d = np.zeros(len(self.images)) # Get camera parameters - only intrinsics for now try: self.cam_k = self.data['cam_k'] except KeyError: self.cam_k = np.zeros((len(self.images), 3, 3)) self.sem_masks = self.data['scene_seg'] self.part_masks = self.data['part_seg'] # Get gt SMPL parameters, if available try: self.pose = self.data['pose'].astype(float) self.betas = self.data['shape'].astype(float) self.transl = self.data['transl'].astype(float) if 'has_smpl' in self.data: self.has_smpl = self.data['has_smpl'] else: self.has_smpl = np.ones(len(self.images)) self.is_smplx = np.ones(len(self.images)) if model_type == 'smplx' else np.zeros(len(self.images)) except KeyError: self.has_smpl = np.zeros(len(self.images)) self.is_smplx = np.zeros(len(self.images)) if model_type == 'smpl': self.n_vertices = 6890 elif model_type == 'smplx': self.n_vertices = 10475 else: raise NotImplementedError self.normalize = normalize self.normalize_img = Normalize(mean=constants.IMG_NORM_MEAN, std=constants.IMG_NORM_STD) def __getitem__(self, index): item = {} # Load image img_path = self.images[index] try: img = cv2.imread(img_path) img_h, img_w, _ = img.shape img = cv2.resize(img, (256, 256), cv2.INTER_CUBIC) img = img.transpose(2, 0, 1) / 255.0 except: print('Img: ', img_path) img_scale_factor = np.array([256 / img_w, 256 / img_h]) # Get SMPL parameters, if available if self.has_smpl[index]: pose = self.pose[index].copy() betas = self.betas[index].copy() transl = self.transl[index].copy() else: pose = np.zeros(72) betas = np.zeros(10) transl = np.zeros(3) # Load vertex_contact if self.has_contact_3d[index]: contact_label_3d = self.contact_labels_3d[index] else: contact_label_3d = np.zeros(self.n_vertices) sem_mask_path = self.sem_masks[index] try: sem_mask = cv2.imread(sem_mask_path) sem_mask = cv2.resize(sem_mask, (256, 256), cv2.INTER_CUBIC) sem_mask = mask_split(sem_mask, 133) except: print('Scene seg: ', sem_mask_path) try: part_mask_path = self.part_masks[index] part_mask = cv2.imread(part_mask_path) part_mask = cv2.resize(part_mask, (256, 256), cv2.INTER_CUBIC) part_mask = mask_split(part_mask, 26) except: print('Part seg: ', part_mask_path) try: if self.has_polygon_contact_2d[index]: polygon_contact_2d_path = self.polygon_contacts_2d[index] polygon_contact_2d = cv2.imread(polygon_contact_2d_path) polygon_contact_2d = cv2.resize(polygon_contact_2d, (256, 256), cv2.INTER_NEAREST) # binarize the part mask polygon_contact_2d = np.where(polygon_contact_2d > 0, 1, 0) else: polygon_contact_2d = np.zeros((256, 256, 3)) except: print('2D polygon contact: ', polygon_contact_2d_path) if self.normalize: img = torch.tensor(img, dtype=torch.float32) item['img'] = self.normalize_img(img) else: item['img'] = torch.tensor(img, dtype=torch.float32) if self.is_smplx[index]: # Add 6 zeros to the end of the pose vector to match with smpl pose = np.concatenate((pose, np.zeros(6))) item['img_path'] = img_path item['pose'] = torch.tensor(pose, dtype=torch.float32) item['betas'] = torch.tensor(betas, dtype=torch.float32) item['transl'] = torch.tensor(transl, dtype=torch.float32) item['cam_k'] = self.cam_k[index] item['img_scale_factor'] = torch.tensor(img_scale_factor, dtype=torch.float32) item['contact_label_3d'] = torch.tensor(contact_label_3d, dtype=torch.float32) item['sem_mask'] = torch.tensor(sem_mask, dtype=torch.float32) item['part_mask'] = torch.tensor(part_mask, dtype=torch.float32) item['polygon_contact_2d'] = torch.tensor(polygon_contact_2d, dtype=torch.float32) item['has_smpl'] = self.has_smpl[index] item['is_smplx'] = self.is_smplx[index] item['has_contact_3d'] = self.has_contact_3d[index] item['has_polygon_contact_2d'] = self.has_polygon_contact_2d[index] return item def __len__(self): return len(self.images)