Spaces:
Running
Running
from os import path as osp | |
from typing import Dict | |
from unicodedata import name | |
import numpy as np | |
import torch | |
import torch.utils as utils | |
from numpy.linalg import inv | |
from src.utils.dataset import ( | |
read_scannet_gray, | |
read_scannet_depth, | |
read_scannet_pose, | |
read_scannet_intrinsic, | |
) | |
class ScanNetDataset(utils.data.Dataset): | |
def __init__( | |
self, | |
root_dir, | |
npz_path, | |
intrinsic_path, | |
mode="train", | |
min_overlap_score=0.4, | |
augment_fn=None, | |
pose_dir=None, | |
**kwargs, | |
): | |
"""Manage one scene of ScanNet Dataset. | |
Args: | |
root_dir (str): ScanNet root directory that contains scene folders. | |
npz_path (str): {scene_id}.npz path. This contains image pair information of a scene. | |
intrinsic_path (str): path to depth-camera intrinsic file. | |
mode (str): options are ['train', 'val', 'test']. | |
augment_fn (callable, optional): augments images with pre-defined visual effects. | |
pose_dir (str): ScanNet root directory that contains all poses. | |
(we use a separate (optional) pose_dir since we store images and poses separately.) | |
""" | |
super().__init__() | |
self.root_dir = root_dir | |
self.pose_dir = pose_dir if pose_dir is not None else root_dir | |
self.mode = mode | |
# prepare data_names, intrinsics and extrinsics(T) | |
with np.load(npz_path) as data: | |
self.data_names = data["name"] | |
if "score" in data.keys() and mode not in ["val" or "test"]: | |
kept_mask = data["score"] > min_overlap_score | |
self.data_names = self.data_names[kept_mask] | |
self.intrinsics = dict(np.load(intrinsic_path)) | |
# for training LoFTR | |
self.augment_fn = augment_fn if mode == "train" else None | |
def __len__(self): | |
return len(self.data_names) | |
def _read_abs_pose(self, scene_name, name): | |
pth = osp.join(self.pose_dir, scene_name, "pose", f"{name}.txt") | |
return read_scannet_pose(pth) | |
def _compute_rel_pose(self, scene_name, name0, name1): | |
pose0 = self._read_abs_pose(scene_name, name0) | |
pose1 = self._read_abs_pose(scene_name, name1) | |
return np.matmul(pose1, inv(pose0)) # (4, 4) | |
def __getitem__(self, idx): | |
data_name = self.data_names[idx] | |
scene_name, scene_sub_name, stem_name_0, stem_name_1 = data_name | |
scene_name = f"scene{scene_name:04d}_{scene_sub_name:02d}" | |
# read the grayscale image which will be resized to (1, 480, 640) | |
img_name0 = osp.join(self.root_dir, scene_name, "color", f"{stem_name_0}.jpg") | |
img_name1 = osp.join(self.root_dir, scene_name, "color", f"{stem_name_1}.jpg") | |
# TODO: Support augmentation & handle seeds for each worker correctly. | |
image0 = read_scannet_gray(img_name0, resize=(640, 480), augment_fn=None) | |
# augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) | |
image1 = read_scannet_gray(img_name1, resize=(640, 480), augment_fn=None) | |
# augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) | |
# read the depthmap which is stored as (480, 640) | |
if self.mode in ["train", "val"]: | |
depth0 = read_scannet_depth( | |
osp.join(self.root_dir, scene_name, "depth", f"{stem_name_0}.png") | |
) | |
depth1 = read_scannet_depth( | |
osp.join(self.root_dir, scene_name, "depth", f"{stem_name_1}.png") | |
) | |
else: | |
depth0 = depth1 = torch.tensor([]) | |
# read the intrinsic of depthmap | |
K_0 = K_1 = torch.tensor( | |
self.intrinsics[scene_name].copy(), dtype=torch.float | |
).reshape(3, 3) | |
# read and compute relative poses | |
T_0to1 = torch.tensor( | |
self._compute_rel_pose(scene_name, stem_name_0, stem_name_1), | |
dtype=torch.float32, | |
) | |
T_1to0 = T_0to1.inverse() | |
data = { | |
"image0": image0, # (1, h, w) | |
"depth0": depth0, # (h, w) | |
"image1": image1, | |
"depth1": depth1, | |
"T_0to1": T_0to1, # (4, 4) | |
"T_1to0": T_1to0, | |
"K0": K_0, # (3, 3) | |
"K1": K_1, | |
"dataset_name": "ScanNet", | |
"scene_id": scene_name, | |
"pair_id": idx, | |
"pair_names": ( | |
osp.join(scene_name, "color", f"{stem_name_0}.jpg"), | |
osp.join(scene_name, "color", f"{stem_name_1}.jpg"), | |
), | |
} | |
return data | |