Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,313 Bytes
4893ce0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
"""
ModelNet40 Dataset
get sampled point clouds of ModelNet40 (XYZ and normal from mesh, 10k points per shape)
at "https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip"
Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
Please cite our work if the code is helpful to you.
"""
import os
import numpy as np
import pointops
import torch
from torch.utils.data import Dataset
from copy import deepcopy
from pointcept.utils.logger import get_root_logger
from .builder import DATASETS
from .transform import Compose
@DATASETS.register_module()
class ModelNetDataset(Dataset):
def __init__(
self,
split="train",
data_root="data/modelnet40",
class_names=None,
transform=None,
num_points=8192,
uniform_sampling=True,
save_record=True,
test_mode=False,
test_cfg=None,
loop=1,
):
super().__init__()
self.data_root = data_root
self.class_names = dict(zip(class_names, range(len(class_names))))
self.split = split
self.num_point = num_points
self.uniform_sampling = uniform_sampling
self.transform = Compose(transform)
self.loop = (
loop if not test_mode else 1
) # force make loop = 1 while in test mode
self.test_mode = test_mode
self.test_cfg = test_cfg if test_mode else None
if test_mode:
self.post_transform = Compose(self.test_cfg.post_transform)
self.aug_transform = [Compose(aug) for aug in self.test_cfg.aug_transform]
self.data_list = self.get_data_list()
logger = get_root_logger()
logger.info(
"Totally {} x {} samples in {} set.".format(
len(self.data_list), self.loop, split
)
)
# check, prepare record
record_name = f"modelnet40_{self.split}"
if num_points is not None:
record_name += f"_{num_points}points"
if uniform_sampling:
record_name += "_uniform"
record_path = os.path.join(self.data_root, f"{record_name}.pth")
if os.path.isfile(record_path):
logger.info(f"Loading record: {record_name} ...")
self.data = torch.load(record_path)
else:
logger.info(f"Preparing record: {record_name} ...")
self.data = {}
for idx in range(len(self.data_list)):
data_name = self.data_list[idx]
logger.info(f"Parsing data [{idx}/{len(self.data_list)}]: {data_name}")
self.data[data_name] = self.get_data(idx)
if save_record:
torch.save(self.data, record_path)
def get_data(self, idx):
data_idx = idx % len(self.data_list)
data_name = self.data_list[data_idx]
if data_name in self.data.keys():
return self.data[data_name]
else:
data_shape = "_".join(data_name.split("_")[0:-1])
data_path = os.path.join(
self.data_root, data_shape, self.data_list[data_idx] + ".txt"
)
data = np.loadtxt(data_path, delimiter=",").astype(np.float32)
if self.num_point is not None:
if self.uniform_sampling:
with torch.no_grad():
mask = pointops.farthest_point_sampling(
torch.tensor(data).float().cuda(),
torch.tensor([len(data)]).long().cuda(),
torch.tensor([self.num_point]).long().cuda(),
)
data = data[mask.cpu()]
else:
data = data[: self.num_point]
coord, normal = data[:, 0:3], data[:, 3:6]
category = np.array([self.class_names[data_shape]])
return dict(coord=coord, normal=normal, category=category)
def get_data_list(self):
assert isinstance(self.split, str)
split_path = os.path.join(
self.data_root, "modelnet40_{}.txt".format(self.split)
)
data_list = np.loadtxt(split_path, dtype="str")
return data_list
def get_data_name(self, idx):
data_idx = idx % len(self.data_list)
return self.data_list[data_idx]
def __getitem__(self, idx):
if self.test_mode:
return self.prepare_test_data(idx)
else:
return self.prepare_train_data(idx)
def __len__(self):
return len(self.data_list) * self.loop
def prepare_train_data(self, idx):
data_dict = self.get_data(idx)
data_dict = self.transform(data_dict)
return data_dict
def prepare_test_data(self, idx):
assert idx < len(self.data_list)
data_dict = self.get_data(idx)
category = data_dict.pop("category")
data_dict = self.transform(data_dict)
data_dict_list = []
for aug in self.aug_transform:
data_dict_list.append(aug(deepcopy(data_dict)))
for i in range(len(data_dict_list)):
data_dict_list[i] = self.post_transform(data_dict_list[i])
data_dict = dict(
voting_list=data_dict_list,
category=category,
name=self.get_data_name(idx),
)
return data_dict
|