Spaces:
Running
on
Zero
Running
on
Zero
""" | |
ScanNet20 / ScanNet200 / ScanNet Data Efficient Dataset | |
Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) | |
Please cite our work if the code is helpful to you. | |
""" | |
import os | |
import glob | |
import numpy as np | |
import torch | |
from copy import deepcopy | |
from torch.utils.data import Dataset | |
from collections.abc import Sequence | |
from pointcept.utils.logger import get_root_logger | |
from pointcept.utils.cache import shared_dict | |
from .builder import DATASETS | |
from .defaults import DefaultDataset | |
from .transform import Compose, TRANSFORMS | |
from .preprocessing.scannet.meta_data.scannet200_constants import ( | |
VALID_CLASS_IDS_20, | |
VALID_CLASS_IDS_200, | |
) | |
class ScanNetDataset(DefaultDataset): | |
VALID_ASSETS = [ | |
"coord", | |
"color", | |
"normal", | |
"segment20", | |
"instance", | |
] | |
class2id = np.array(VALID_CLASS_IDS_20) | |
def __init__( | |
self, | |
lr_file=None, | |
la_file=None, | |
**kwargs, | |
): | |
self.lr = np.loadtxt(lr_file, dtype=str) if lr_file is not None else None | |
self.la = torch.load(la_file) if la_file is not None else None | |
super().__init__(**kwargs) | |
def get_data_list(self): | |
if self.lr is None: | |
data_list = super().get_data_list() | |
else: | |
data_list = [ | |
os.path.join(self.data_root, "train", name) for name in self.lr | |
] | |
return data_list | |
def get_data(self, idx): | |
data_path = self.data_list[idx % len(self.data_list)] | |
name = self.get_data_name(idx) | |
if self.cache: | |
cache_name = f"pointcept-{name}" | |
return shared_dict(cache_name) | |
data_dict = {} | |
assets = os.listdir(data_path) | |
for asset in assets: | |
if not asset.endswith(".npy"): | |
continue | |
if asset[:-4] not in self.VALID_ASSETS: | |
continue | |
data_dict[asset[:-4]] = np.load(os.path.join(data_path, asset)) | |
data_dict["name"] = name | |
data_dict["coord"] = data_dict["coord"].astype(np.float32) | |
data_dict["color"] = data_dict["color"].astype(np.float32) | |
data_dict["normal"] = data_dict["normal"].astype(np.float32) | |
if "segment20" in data_dict.keys(): | |
data_dict["segment"] = ( | |
data_dict.pop("segment20").reshape([-1]).astype(np.int32) | |
) | |
elif "segment200" in data_dict.keys(): | |
data_dict["segment"] = ( | |
data_dict.pop("segment200").reshape([-1]).astype(np.int32) | |
) | |
else: | |
data_dict["segment"] = ( | |
np.ones(data_dict["coord"].shape[0], dtype=np.int32) * -1 | |
) | |
if "instance" in data_dict.keys(): | |
data_dict["instance"] = ( | |
data_dict.pop("instance").reshape([-1]).astype(np.int32) | |
) | |
else: | |
data_dict["instance"] = ( | |
np.ones(data_dict["coord"].shape[0], dtype=np.int32) * -1 | |
) | |
if self.la: | |
sampled_index = self.la[self.get_data_name(idx)] | |
mask = np.ones_like(data_dict["segment"], dtype=bool) | |
mask[sampled_index] = False | |
data_dict["segment"][mask] = self.ignore_index | |
data_dict["sampled_index"] = sampled_index | |
return data_dict | |
class ScanNet200Dataset(ScanNetDataset): | |
VALID_ASSETS = [ | |
"coord", | |
"color", | |
"normal", | |
"segment200", | |
"instance", | |
] | |
class2id = np.array(VALID_CLASS_IDS_200) | |