Spaces:
Running
on
Zero
Running
on
Zero
""" | |
ArkitScenes Dataset | |
Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) | |
Please cite our work if the code is helpful to you. | |
""" | |
import os | |
import glob | |
import numpy as np | |
import torch | |
from copy import deepcopy | |
from torch.utils.data import Dataset | |
from pointcept.utils.logger import get_root_logger | |
from .builder import DATASETS | |
from .transform import Compose, TRANSFORMS | |
from .preprocessing.scannet.meta_data.scannet200_constants import VALID_CLASS_IDS_200 | |
class ArkitScenesDataset(Dataset): | |
def __init__( | |
self, | |
split="Training", | |
data_root="data/ARKitScenesMesh", | |
transform=None, | |
test_mode=False, | |
test_cfg=None, | |
loop=1, | |
): | |
super(ArkitScenesDataset, self).__init__() | |
self.data_root = data_root | |
self.split = split | |
self.transform = Compose(transform) | |
self.loop = ( | |
loop if not test_mode else 1 | |
) # force make loop = 1 while in test mode | |
self.test_mode = test_mode | |
self.test_cfg = test_cfg if test_mode else None | |
self.class2id = np.array(VALID_CLASS_IDS_200) | |
if test_mode: | |
self.test_voxelize = TRANSFORMS.build(self.test_cfg.voxelize) | |
self.test_crop = TRANSFORMS.build(self.test_cfg.crop) | |
self.post_transform = Compose(self.test_cfg.post_transform) | |
self.aug_transform = [Compose(aug) for aug in self.test_cfg.aug_transform] | |
self.data_list = self.get_data_list() | |
logger = get_root_logger() | |
logger.info( | |
"Totally {} x {} samples in {} set.".format( | |
len(self.data_list), self.loop, split | |
) | |
) | |
def get_data_list(self): | |
if isinstance(self.split, str): | |
data_list = glob.glob(os.path.join(self.data_root, self.split, "*.pth")) | |
elif isinstance(self.split, list): | |
data_list = [] | |
for split in self.split: | |
data_list += glob.glob(os.path.join(self.data_root, split, "*.pth")) | |
else: | |
raise NotImplementedError | |
return data_list | |
def get_data(self, idx): | |
data = torch.load(self.data_list[idx % len(self.data_list)]) | |
coord = data["coord"] | |
color = data["color"] | |
normal = data["normal"] | |
segment = np.zeros(coord.shape[0]) | |
data_dict = dict(coord=coord, normal=normal, color=color, segment=segment) | |
return data_dict | |
def get_data_name(self, idx): | |
data_idx = self.data_idx[idx % len(self.data_idx)] | |
return os.path.basename(self.data_list[data_idx]).split(".")[0] | |
def prepare_train_data(self, idx): | |
# load data | |
data_dict = self.get_data(idx) | |
data_dict = self.transform(data_dict) | |
return data_dict | |
def prepare_test_data(self, idx): | |
# load data | |
data_dict = self.get_data(idx) | |
segment = data_dict.pop("segment") | |
data_dict = self.transform(data_dict) | |
data_dict_list = [] | |
for aug in self.aug_transform: | |
data_dict_list.append(aug(deepcopy(data_dict))) | |
input_dict_list = [] | |
for data in data_dict_list: | |
data_part_list = self.test_voxelize(data) | |
for data_part in data_part_list: | |
data_part_list = self.test_crop(data_part) | |
input_dict_list += data_part_list | |
for i in range(len(input_dict_list)): | |
input_dict_list[i] = self.post_transform(input_dict_list[i]) | |
return input_dict_list, segment | |
def __getitem__(self, idx): | |
if self.test_mode: | |
return self.prepare_test_data(idx) | |
else: | |
return self.prepare_train_data(idx) | |
def __len__(self): | |
return len(self.data_list) * self.loop | |