Spaces:
Sleeping
Sleeping
import torch | |
import cv2 | |
import numpy as np | |
from torch.utils.data import Dataset | |
from torchvision.transforms import Normalize | |
from common import constants | |
def mask_split(img, num_parts): | |
if not len(img.shape) == 2: | |
img = img[:, :, 0] | |
mask = np.zeros((img.shape[0], img.shape[1], num_parts)) | |
for i in np.unique(img): | |
mask[:, :, i] = np.where(img == i, 1., 0.) | |
return np.transpose(mask, (2, 0, 1)) | |
class BaseDataset(Dataset): | |
def __init__(self, dataset, mode, model_type='smpl', normalize=False): | |
self.dataset = dataset | |
self.mode = mode | |
print(f'Loading dataset: {constants.DATASET_FILES[mode][dataset]} for mode: {mode}') | |
self.data = np.load(constants.DATASET_FILES[mode][dataset], allow_pickle=True) | |
self.images = self.data['imgname'] | |
# get 3d contact labels, if available | |
try: | |
self.contact_labels_3d = self.data['contact_label'] | |
# make a has_contact_3d numpy array which contains 1 if contact labels are no empty and 0 otherwise | |
self.has_contact_3d = np.array([1 if len(x) > 0 else 0 for x in self.contact_labels_3d]) | |
except KeyError: | |
self.has_contact_3d = np.zeros(len(self.images)) | |
# get 2d polygon contact labels, if available | |
try: | |
self.polygon_contacts_2d = self.data['polygon_2d_contact'] | |
self.has_polygon_contact_2d = np.ones(len(self.images)) | |
except KeyError: | |
self.has_polygon_contact_2d = np.zeros(len(self.images)) | |
# Get camera parameters - only intrinsics for now | |
try: | |
self.cam_k = self.data['cam_k'] | |
except KeyError: | |
self.cam_k = np.zeros((len(self.images), 3, 3)) | |
self.sem_masks = self.data['scene_seg'] | |
self.part_masks = self.data['part_seg'] | |
# Get gt SMPL parameters, if available | |
try: | |
self.pose = self.data['pose'].astype(float) | |
self.betas = self.data['shape'].astype(float) | |
self.transl = self.data['transl'].astype(float) | |
if 'has_smpl' in self.data: | |
self.has_smpl = self.data['has_smpl'] | |
else: | |
self.has_smpl = np.ones(len(self.images)) | |
self.is_smplx = np.ones(len(self.images)) if model_type == 'smplx' else np.zeros(len(self.images)) | |
except KeyError: | |
self.has_smpl = np.zeros(len(self.images)) | |
self.is_smplx = np.zeros(len(self.images)) | |
if model_type == 'smpl': | |
self.n_vertices = 6890 | |
elif model_type == 'smplx': | |
self.n_vertices = 10475 | |
else: | |
raise NotImplementedError | |
self.normalize = normalize | |
self.normalize_img = Normalize(mean=constants.IMG_NORM_MEAN, std=constants.IMG_NORM_STD) | |
def __getitem__(self, index): | |
item = {} | |
# Load image | |
img_path = self.images[index] | |
try: | |
img = cv2.imread(img_path) | |
img_h, img_w, _ = img.shape | |
img = cv2.resize(img, (256, 256), cv2.INTER_CUBIC) | |
img = img.transpose(2, 0, 1) / 255.0 | |
except: | |
print('Img: ', img_path) | |
img_scale_factor = np.array([256 / img_w, 256 / img_h]) | |
# Get SMPL parameters, if available | |
if self.has_smpl[index]: | |
pose = self.pose[index].copy() | |
betas = self.betas[index].copy() | |
transl = self.transl[index].copy() | |
else: | |
pose = np.zeros(72) | |
betas = np.zeros(10) | |
transl = np.zeros(3) | |
# Load vertex_contact | |
if self.has_contact_3d[index]: | |
contact_label_3d = self.contact_labels_3d[index] | |
else: | |
contact_label_3d = np.zeros(self.n_vertices) | |
sem_mask_path = self.sem_masks[index] | |
try: | |
sem_mask = cv2.imread(sem_mask_path) | |
sem_mask = cv2.resize(sem_mask, (256, 256), cv2.INTER_CUBIC) | |
sem_mask = mask_split(sem_mask, 133) | |
except: | |
print('Scene seg: ', sem_mask_path) | |
try: | |
part_mask_path = self.part_masks[index] | |
part_mask = cv2.imread(part_mask_path) | |
part_mask = cv2.resize(part_mask, (256, 256), cv2.INTER_CUBIC) | |
part_mask = mask_split(part_mask, 26) | |
except: | |
print('Part seg: ', part_mask_path) | |
try: | |
if self.has_polygon_contact_2d[index]: | |
polygon_contact_2d_path = self.polygon_contacts_2d[index] | |
polygon_contact_2d = cv2.imread(polygon_contact_2d_path) | |
polygon_contact_2d = cv2.resize(polygon_contact_2d, (256, 256), cv2.INTER_NEAREST) | |
# binarize the part mask | |
polygon_contact_2d = np.where(polygon_contact_2d > 0, 1, 0) | |
else: | |
polygon_contact_2d = np.zeros((256, 256, 3)) | |
except: | |
print('2D polygon contact: ', polygon_contact_2d_path) | |
if self.normalize: | |
img = torch.tensor(img, dtype=torch.float32) | |
item['img'] = self.normalize_img(img) | |
else: | |
item['img'] = torch.tensor(img, dtype=torch.float32) | |
if self.is_smplx[index]: | |
# Add 6 zeros to the end of the pose vector to match with smpl | |
pose = np.concatenate((pose, np.zeros(6))) | |
item['img_path'] = img_path | |
item['pose'] = torch.tensor(pose, dtype=torch.float32) | |
item['betas'] = torch.tensor(betas, dtype=torch.float32) | |
item['transl'] = torch.tensor(transl, dtype=torch.float32) | |
item['cam_k'] = self.cam_k[index] | |
item['img_scale_factor'] = torch.tensor(img_scale_factor, dtype=torch.float32) | |
item['contact_label_3d'] = torch.tensor(contact_label_3d, dtype=torch.float32) | |
item['sem_mask'] = torch.tensor(sem_mask, dtype=torch.float32) | |
item['part_mask'] = torch.tensor(part_mask, dtype=torch.float32) | |
item['polygon_contact_2d'] = torch.tensor(polygon_contact_2d, dtype=torch.float32) | |
item['has_smpl'] = self.has_smpl[index] | |
item['is_smplx'] = self.is_smplx[index] | |
item['has_contact_3d'] = self.has_contact_3d[index] | |
item['has_polygon_contact_2d'] = self.has_polygon_contact_2d[index] | |
return item | |
def __len__(self): | |
return len(self.images) | |