Spaces:
Sleeping
Sleeping
File size: 6,342 Bytes
99a05f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import torch
import cv2
import numpy as np
from torch.utils.data import Dataset
from torchvision.transforms import Normalize
from common import constants
def mask_split(img, num_parts):
if not len(img.shape) == 2:
img = img[:, :, 0]
mask = np.zeros((img.shape[0], img.shape[1], num_parts))
for i in np.unique(img):
mask[:, :, i] = np.where(img == i, 1., 0.)
return np.transpose(mask, (2, 0, 1))
class BaseDataset(Dataset):
def __init__(self, dataset, mode, model_type='smpl', normalize=False):
self.dataset = dataset
self.mode = mode
print(f'Loading dataset: {constants.DATASET_FILES[mode][dataset]} for mode: {mode}')
self.data = np.load(constants.DATASET_FILES[mode][dataset], allow_pickle=True)
self.images = self.data['imgname']
# get 3d contact labels, if available
try:
self.contact_labels_3d = self.data['contact_label']
# make a has_contact_3d numpy array which contains 1 if contact labels are no empty and 0 otherwise
self.has_contact_3d = np.array([1 if len(x) > 0 else 0 for x in self.contact_labels_3d])
except KeyError:
self.has_contact_3d = np.zeros(len(self.images))
# get 2d polygon contact labels, if available
try:
self.polygon_contacts_2d = self.data['polygon_2d_contact']
self.has_polygon_contact_2d = np.ones(len(self.images))
except KeyError:
self.has_polygon_contact_2d = np.zeros(len(self.images))
# Get camera parameters - only intrinsics for now
try:
self.cam_k = self.data['cam_k']
except KeyError:
self.cam_k = np.zeros((len(self.images), 3, 3))
self.sem_masks = self.data['scene_seg']
self.part_masks = self.data['part_seg']
# Get gt SMPL parameters, if available
try:
self.pose = self.data['pose'].astype(float)
self.betas = self.data['shape'].astype(float)
self.transl = self.data['transl'].astype(float)
if 'has_smpl' in self.data:
self.has_smpl = self.data['has_smpl']
else:
self.has_smpl = np.ones(len(self.images))
self.is_smplx = np.ones(len(self.images)) if model_type == 'smplx' else np.zeros(len(self.images))
except KeyError:
self.has_smpl = np.zeros(len(self.images))
self.is_smplx = np.zeros(len(self.images))
if model_type == 'smpl':
self.n_vertices = 6890
elif model_type == 'smplx':
self.n_vertices = 10475
else:
raise NotImplementedError
self.normalize = normalize
self.normalize_img = Normalize(mean=constants.IMG_NORM_MEAN, std=constants.IMG_NORM_STD)
def __getitem__(self, index):
item = {}
# Load image
img_path = self.images[index]
try:
img = cv2.imread(img_path)
img_h, img_w, _ = img.shape
img = cv2.resize(img, (256, 256), cv2.INTER_CUBIC)
img = img.transpose(2, 0, 1) / 255.0
except:
print('Img: ', img_path)
img_scale_factor = np.array([256 / img_w, 256 / img_h])
# Get SMPL parameters, if available
if self.has_smpl[index]:
pose = self.pose[index].copy()
betas = self.betas[index].copy()
transl = self.transl[index].copy()
else:
pose = np.zeros(72)
betas = np.zeros(10)
transl = np.zeros(3)
# Load vertex_contact
if self.has_contact_3d[index]:
contact_label_3d = self.contact_labels_3d[index]
else:
contact_label_3d = np.zeros(self.n_vertices)
sem_mask_path = self.sem_masks[index]
try:
sem_mask = cv2.imread(sem_mask_path)
sem_mask = cv2.resize(sem_mask, (256, 256), cv2.INTER_CUBIC)
sem_mask = mask_split(sem_mask, 133)
except:
print('Scene seg: ', sem_mask_path)
try:
part_mask_path = self.part_masks[index]
part_mask = cv2.imread(part_mask_path)
part_mask = cv2.resize(part_mask, (256, 256), cv2.INTER_CUBIC)
part_mask = mask_split(part_mask, 26)
except:
print('Part seg: ', part_mask_path)
try:
if self.has_polygon_contact_2d[index]:
polygon_contact_2d_path = self.polygon_contacts_2d[index]
polygon_contact_2d = cv2.imread(polygon_contact_2d_path)
polygon_contact_2d = cv2.resize(polygon_contact_2d, (256, 256), cv2.INTER_NEAREST)
# binarize the part mask
polygon_contact_2d = np.where(polygon_contact_2d > 0, 1, 0)
else:
polygon_contact_2d = np.zeros((256, 256, 3))
except:
print('2D polygon contact: ', polygon_contact_2d_path)
if self.normalize:
img = torch.tensor(img, dtype=torch.float32)
item['img'] = self.normalize_img(img)
else:
item['img'] = torch.tensor(img, dtype=torch.float32)
if self.is_smplx[index]:
# Add 6 zeros to the end of the pose vector to match with smpl
pose = np.concatenate((pose, np.zeros(6)))
item['img_path'] = img_path
item['pose'] = torch.tensor(pose, dtype=torch.float32)
item['betas'] = torch.tensor(betas, dtype=torch.float32)
item['transl'] = torch.tensor(transl, dtype=torch.float32)
item['cam_k'] = self.cam_k[index]
item['img_scale_factor'] = torch.tensor(img_scale_factor, dtype=torch.float32)
item['contact_label_3d'] = torch.tensor(contact_label_3d, dtype=torch.float32)
item['sem_mask'] = torch.tensor(sem_mask, dtype=torch.float32)
item['part_mask'] = torch.tensor(part_mask, dtype=torch.float32)
item['polygon_contact_2d'] = torch.tensor(polygon_contact_2d, dtype=torch.float32)
item['has_smpl'] = self.has_smpl[index]
item['is_smplx'] = self.is_smplx[index]
item['has_contact_3d'] = self.has_contact_3d[index]
item['has_polygon_contact_2d'] = self.has_polygon_contact_2d[index]
return item
def __len__(self):
return len(self.images)
|