DECO / data /base_dataset.py
ac5113's picture
added files
99a05f0
import torch
import cv2
import numpy as np
from torch.utils.data import Dataset
from torchvision.transforms import Normalize
from common import constants
def mask_split(img, num_parts):
if not len(img.shape) == 2:
img = img[:, :, 0]
mask = np.zeros((img.shape[0], img.shape[1], num_parts))
for i in np.unique(img):
mask[:, :, i] = np.where(img == i, 1., 0.)
return np.transpose(mask, (2, 0, 1))
class BaseDataset(Dataset):
def __init__(self, dataset, mode, model_type='smpl', normalize=False):
self.dataset = dataset
self.mode = mode
print(f'Loading dataset: {constants.DATASET_FILES[mode][dataset]} for mode: {mode}')
self.data = np.load(constants.DATASET_FILES[mode][dataset], allow_pickle=True)
self.images = self.data['imgname']
# get 3d contact labels, if available
try:
self.contact_labels_3d = self.data['contact_label']
# make a has_contact_3d numpy array which contains 1 if contact labels are no empty and 0 otherwise
self.has_contact_3d = np.array([1 if len(x) > 0 else 0 for x in self.contact_labels_3d])
except KeyError:
self.has_contact_3d = np.zeros(len(self.images))
# get 2d polygon contact labels, if available
try:
self.polygon_contacts_2d = self.data['polygon_2d_contact']
self.has_polygon_contact_2d = np.ones(len(self.images))
except KeyError:
self.has_polygon_contact_2d = np.zeros(len(self.images))
# Get camera parameters - only intrinsics for now
try:
self.cam_k = self.data['cam_k']
except KeyError:
self.cam_k = np.zeros((len(self.images), 3, 3))
self.sem_masks = self.data['scene_seg']
self.part_masks = self.data['part_seg']
# Get gt SMPL parameters, if available
try:
self.pose = self.data['pose'].astype(float)
self.betas = self.data['shape'].astype(float)
self.transl = self.data['transl'].astype(float)
if 'has_smpl' in self.data:
self.has_smpl = self.data['has_smpl']
else:
self.has_smpl = np.ones(len(self.images))
self.is_smplx = np.ones(len(self.images)) if model_type == 'smplx' else np.zeros(len(self.images))
except KeyError:
self.has_smpl = np.zeros(len(self.images))
self.is_smplx = np.zeros(len(self.images))
if model_type == 'smpl':
self.n_vertices = 6890
elif model_type == 'smplx':
self.n_vertices = 10475
else:
raise NotImplementedError
self.normalize = normalize
self.normalize_img = Normalize(mean=constants.IMG_NORM_MEAN, std=constants.IMG_NORM_STD)
def __getitem__(self, index):
item = {}
# Load image
img_path = self.images[index]
try:
img = cv2.imread(img_path)
img_h, img_w, _ = img.shape
img = cv2.resize(img, (256, 256), cv2.INTER_CUBIC)
img = img.transpose(2, 0, 1) / 255.0
except:
print('Img: ', img_path)
img_scale_factor = np.array([256 / img_w, 256 / img_h])
# Get SMPL parameters, if available
if self.has_smpl[index]:
pose = self.pose[index].copy()
betas = self.betas[index].copy()
transl = self.transl[index].copy()
else:
pose = np.zeros(72)
betas = np.zeros(10)
transl = np.zeros(3)
# Load vertex_contact
if self.has_contact_3d[index]:
contact_label_3d = self.contact_labels_3d[index]
else:
contact_label_3d = np.zeros(self.n_vertices)
sem_mask_path = self.sem_masks[index]
try:
sem_mask = cv2.imread(sem_mask_path)
sem_mask = cv2.resize(sem_mask, (256, 256), cv2.INTER_CUBIC)
sem_mask = mask_split(sem_mask, 133)
except:
print('Scene seg: ', sem_mask_path)
try:
part_mask_path = self.part_masks[index]
part_mask = cv2.imread(part_mask_path)
part_mask = cv2.resize(part_mask, (256, 256), cv2.INTER_CUBIC)
part_mask = mask_split(part_mask, 26)
except:
print('Part seg: ', part_mask_path)
try:
if self.has_polygon_contact_2d[index]:
polygon_contact_2d_path = self.polygon_contacts_2d[index]
polygon_contact_2d = cv2.imread(polygon_contact_2d_path)
polygon_contact_2d = cv2.resize(polygon_contact_2d, (256, 256), cv2.INTER_NEAREST)
# binarize the part mask
polygon_contact_2d = np.where(polygon_contact_2d > 0, 1, 0)
else:
polygon_contact_2d = np.zeros((256, 256, 3))
except:
print('2D polygon contact: ', polygon_contact_2d_path)
if self.normalize:
img = torch.tensor(img, dtype=torch.float32)
item['img'] = self.normalize_img(img)
else:
item['img'] = torch.tensor(img, dtype=torch.float32)
if self.is_smplx[index]:
# Add 6 zeros to the end of the pose vector to match with smpl
pose = np.concatenate((pose, np.zeros(6)))
item['img_path'] = img_path
item['pose'] = torch.tensor(pose, dtype=torch.float32)
item['betas'] = torch.tensor(betas, dtype=torch.float32)
item['transl'] = torch.tensor(transl, dtype=torch.float32)
item['cam_k'] = self.cam_k[index]
item['img_scale_factor'] = torch.tensor(img_scale_factor, dtype=torch.float32)
item['contact_label_3d'] = torch.tensor(contact_label_3d, dtype=torch.float32)
item['sem_mask'] = torch.tensor(sem_mask, dtype=torch.float32)
item['part_mask'] = torch.tensor(part_mask, dtype=torch.float32)
item['polygon_contact_2d'] = torch.tensor(polygon_contact_2d, dtype=torch.float32)
item['has_smpl'] = self.has_smpl[index]
item['is_smplx'] = self.is_smplx[index]
item['has_contact_3d'] = self.has_contact_3d[index]
item['has_polygon_contact_2d'] = self.has_polygon_contact_2d[index]
return item
def __len__(self):
return len(self.images)