import os import random from PIL import Image import cv2 import h5py import numpy as np import torch from torch.utils.data import Dataset, DataLoader, ConcatDataset import torchvision.transforms.functional as tvf import kornia.augmentation as K import os.path as osp import matplotlib.pyplot as plt from dkm.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops from dkm.utils.transforms import GeometricSequential from tqdm import tqdm class ScanNetScene: def __init__( self, data_root, scene_info, ht=384, wt=512, min_overlap=0.0, shake_t=0, rot_prob=0.0, ) -> None: self.scene_root = osp.join(data_root, "scans", "scans_train") self.data_names = scene_info["name"] self.overlaps = scene_info["score"] # Only sample 10s valid = (self.data_names[:, -2:] % 10).sum(axis=-1) == 0 self.overlaps = self.overlaps[valid] self.data_names = self.data_names[valid] if len(self.data_names) > 10000: pairinds = np.random.choice( np.arange(0, len(self.data_names)), 10000, replace=False ) self.data_names = self.data_names[pairinds] self.overlaps = self.overlaps[pairinds] self.im_transform_ops = get_tuple_transform_ops(resize=(ht, wt), normalize=True) self.depth_transform_ops = get_depth_tuple_transform_ops( resize=(ht, wt), normalize=False ) self.wt, self.ht = wt, ht self.shake_t = shake_t self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob)) def load_im(self, im_ref, crop=None): im = Image.open(im_ref) return im def load_depth(self, depth_ref, crop=None): depth = cv2.imread(str(depth_ref), cv2.IMREAD_UNCHANGED) depth = depth / 1000 depth = torch.from_numpy(depth).float() # (h, w) return depth def __len__(self): return len(self.data_names) def scale_intrinsic(self, K, wi, hi): sx, sy = self.wt / wi, self.ht / hi sK = torch.tensor([[sx, 0, 0], [0, sy, 0], [0, 0, 1]]) return sK @ K def read_scannet_pose(self, path): """Read ScanNet's Camera2World pose and transform it to World2Camera. Returns: pose_w2c (np.ndarray): (4, 4) """ cam2world = np.loadtxt(path, delimiter=" ") world2cam = np.linalg.inv(cam2world) return world2cam def read_scannet_intrinsic(self, path): """Read ScanNet's intrinsic matrix and return the 3x3 matrix.""" intrinsic = np.loadtxt(path, delimiter=" ") return intrinsic[:-1, :-1] def __getitem__(self, pair_idx): # read intrinsics of original size data_name = self.data_names[pair_idx] scene_name, scene_sub_name, stem_name_1, stem_name_2 = data_name scene_name = f"scene{scene_name:04d}_{scene_sub_name:02d}" # read the intrinsic of depthmap K1 = K2 = self.read_scannet_intrinsic( osp.join(self.scene_root, scene_name, "intrinsic", "intrinsic_color.txt") ) # the depth K is not the same, but doesnt really matter # read and compute relative poses T1 = self.read_scannet_pose( osp.join(self.scene_root, scene_name, "pose", f"{stem_name_1}.txt") ) T2 = self.read_scannet_pose( osp.join(self.scene_root, scene_name, "pose", f"{stem_name_2}.txt") ) T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[ :4, :4 ] # (4, 4) # Load positive pair data im_src_ref = os.path.join( self.scene_root, scene_name, "color", f"{stem_name_1}.jpg" ) im_pos_ref = os.path.join( self.scene_root, scene_name, "color", f"{stem_name_2}.jpg" ) depth_src_ref = os.path.join( self.scene_root, scene_name, "depth", f"{stem_name_1}.png" ) depth_pos_ref = os.path.join( self.scene_root, scene_name, "depth", f"{stem_name_2}.png" ) im_src = self.load_im(im_src_ref) im_pos = self.load_im(im_pos_ref) depth_src = self.load_depth(depth_src_ref) depth_pos = self.load_depth(depth_pos_ref) # Recompute camera intrinsic matrix due to the resize K1 = self.scale_intrinsic(K1, im_src.width, im_src.height) K2 = self.scale_intrinsic(K2, im_pos.width, im_pos.height) # Process images im_src, im_pos = self.im_transform_ops((im_src, im_pos)) depth_src, depth_pos = self.depth_transform_ops( (depth_src[None, None], depth_pos[None, None]) ) data_dict = { "query": im_src, "support": im_pos, "query_depth": depth_src[0, 0], "support_depth": depth_pos[0, 0], "K1": K1, "K2": K2, "T_1to2": T_1to2, } return data_dict class ScanNetBuilder: def __init__(self, data_root="data/scannet") -> None: self.data_root = data_root self.scene_info_root = os.path.join(data_root, "scannet_indices") self.all_scenes = os.listdir(self.scene_info_root) def build_scenes(self, split="train", min_overlap=0.0, **kwargs): # Note: split doesn't matter here as we always use same scannet_train scenes scene_names = self.all_scenes scenes = [] for scene_name in tqdm(scene_names): scene_info = np.load( os.path.join(self.scene_info_root, scene_name), allow_pickle=True ) scenes.append( ScanNetScene( self.data_root, scene_info, min_overlap=min_overlap, **kwargs ) ) return scenes def weight_scenes(self, concat_dataset, alpha=0.5): ns = [] for d in concat_dataset.datasets: ns.append(len(d)) ws = torch.cat([torch.ones(n) / n**alpha for n in ns]) return ws if __name__ == "__main__": mega_test = ConcatDataset( ScanNetBuilder("data/scannet").build_scenes(split="train") ) mega_test[0]