DECO / data /base_dataset.py
ac5113's picture
added files
99a05f0
raw
history blame
6.34 kB
import torch
import cv2
import numpy as np
from torch.utils.data import Dataset
from torchvision.transforms import Normalize
from common import constants
def mask_split(img, num_parts):
if not len(img.shape) == 2:
img = img[:, :, 0]
mask = np.zeros((img.shape[0], img.shape[1], num_parts))
for i in np.unique(img):
mask[:, :, i] = np.where(img == i, 1., 0.)
return np.transpose(mask, (2, 0, 1))
class BaseDataset(Dataset):
def __init__(self, dataset, mode, model_type='smpl', normalize=False):
self.dataset = dataset
self.mode = mode
print(f'Loading dataset: {constants.DATASET_FILES[mode][dataset]} for mode: {mode}')
self.data = np.load(constants.DATASET_FILES[mode][dataset], allow_pickle=True)
self.images = self.data['imgname']
# get 3d contact labels, if available
try:
self.contact_labels_3d = self.data['contact_label']
# make a has_contact_3d numpy array which contains 1 if contact labels are no empty and 0 otherwise
self.has_contact_3d = np.array([1 if len(x) > 0 else 0 for x in self.contact_labels_3d])
except KeyError:
self.has_contact_3d = np.zeros(len(self.images))
# get 2d polygon contact labels, if available
try:
self.polygon_contacts_2d = self.data['polygon_2d_contact']
self.has_polygon_contact_2d = np.ones(len(self.images))
except KeyError:
self.has_polygon_contact_2d = np.zeros(len(self.images))
# Get camera parameters - only intrinsics for now
try:
self.cam_k = self.data['cam_k']
except KeyError:
self.cam_k = np.zeros((len(self.images), 3, 3))
self.sem_masks = self.data['scene_seg']
self.part_masks = self.data['part_seg']
# Get gt SMPL parameters, if available
try:
self.pose = self.data['pose'].astype(float)
self.betas = self.data['shape'].astype(float)
self.transl = self.data['transl'].astype(float)
if 'has_smpl' in self.data:
self.has_smpl = self.data['has_smpl']
else:
self.has_smpl = np.ones(len(self.images))
self.is_smplx = np.ones(len(self.images)) if model_type == 'smplx' else np.zeros(len(self.images))
except KeyError:
self.has_smpl = np.zeros(len(self.images))
self.is_smplx = np.zeros(len(self.images))
if model_type == 'smpl':
self.n_vertices = 6890
elif model_type == 'smplx':
self.n_vertices = 10475
else:
raise NotImplementedError
self.normalize = normalize
self.normalize_img = Normalize(mean=constants.IMG_NORM_MEAN, std=constants.IMG_NORM_STD)
def __getitem__(self, index):
item = {}
# Load image
img_path = self.images[index]
try:
img = cv2.imread(img_path)
img_h, img_w, _ = img.shape
img = cv2.resize(img, (256, 256), cv2.INTER_CUBIC)
img = img.transpose(2, 0, 1) / 255.0
except:
print('Img: ', img_path)
img_scale_factor = np.array([256 / img_w, 256 / img_h])
# Get SMPL parameters, if available
if self.has_smpl[index]:
pose = self.pose[index].copy()
betas = self.betas[index].copy()
transl = self.transl[index].copy()
else:
pose = np.zeros(72)
betas = np.zeros(10)
transl = np.zeros(3)
# Load vertex_contact
if self.has_contact_3d[index]:
contact_label_3d = self.contact_labels_3d[index]
else:
contact_label_3d = np.zeros(self.n_vertices)
sem_mask_path = self.sem_masks[index]
try:
sem_mask = cv2.imread(sem_mask_path)
sem_mask = cv2.resize(sem_mask, (256, 256), cv2.INTER_CUBIC)
sem_mask = mask_split(sem_mask, 133)
except:
print('Scene seg: ', sem_mask_path)
try:
part_mask_path = self.part_masks[index]
part_mask = cv2.imread(part_mask_path)
part_mask = cv2.resize(part_mask, (256, 256), cv2.INTER_CUBIC)
part_mask = mask_split(part_mask, 26)
except:
print('Part seg: ', part_mask_path)
try:
if self.has_polygon_contact_2d[index]:
polygon_contact_2d_path = self.polygon_contacts_2d[index]
polygon_contact_2d = cv2.imread(polygon_contact_2d_path)
polygon_contact_2d = cv2.resize(polygon_contact_2d, (256, 256), cv2.INTER_NEAREST)
# binarize the part mask
polygon_contact_2d = np.where(polygon_contact_2d > 0, 1, 0)
else:
polygon_contact_2d = np.zeros((256, 256, 3))
except:
print('2D polygon contact: ', polygon_contact_2d_path)
if self.normalize:
img = torch.tensor(img, dtype=torch.float32)
item['img'] = self.normalize_img(img)
else:
item['img'] = torch.tensor(img, dtype=torch.float32)
if self.is_smplx[index]:
# Add 6 zeros to the end of the pose vector to match with smpl
pose = np.concatenate((pose, np.zeros(6)))
item['img_path'] = img_path
item['pose'] = torch.tensor(pose, dtype=torch.float32)
item['betas'] = torch.tensor(betas, dtype=torch.float32)
item['transl'] = torch.tensor(transl, dtype=torch.float32)
item['cam_k'] = self.cam_k[index]
item['img_scale_factor'] = torch.tensor(img_scale_factor, dtype=torch.float32)
item['contact_label_3d'] = torch.tensor(contact_label_3d, dtype=torch.float32)
item['sem_mask'] = torch.tensor(sem_mask, dtype=torch.float32)
item['part_mask'] = torch.tensor(part_mask, dtype=torch.float32)
item['polygon_contact_2d'] = torch.tensor(polygon_contact_2d, dtype=torch.float32)
item['has_smpl'] = self.has_smpl[index]
item['is_smplx'] = self.is_smplx[index]
item['has_contact_3d'] = self.has_contact_3d[index]
item['has_polygon_contact_2d'] = self.has_polygon_contact_2d[index]
return item
def __len__(self):
return len(self.images)