Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import math | |
| from collections import abc | |
| from loguru import logger | |
| from torch.utils.data.dataset import Dataset | |
| from tqdm import tqdm | |
| from os import path as osp | |
| from pathlib import Path | |
| from joblib import Parallel, delayed | |
| import pytorch_lightning as pl | |
| from torch import distributed as dist | |
| from torch.utils.data import ( | |
| Dataset, | |
| DataLoader, | |
| ConcatDataset, | |
| DistributedSampler, | |
| RandomSampler, | |
| dataloader | |
| ) | |
| from src.utils.augment import build_augmentor | |
| from src.utils.dataloader import get_local_split | |
| from src.utils.misc import tqdm_joblib | |
| from src.utils import comm | |
| from src.datasets.megadepth import MegaDepthDataset | |
| from src.datasets.scannet import ScanNetDataset | |
| from src.datasets.sampler import RandomConcatSampler | |
| class MultiSceneDataModule(pl.LightningDataModule): | |
| """ | |
| For distributed training, each training process is assgined | |
| only a part of the training scenes to reduce memory overhead. | |
| """ | |
| def __init__(self, args, config): | |
| super().__init__() | |
| # 1. data config | |
| # Train and Val should from the same data source | |
| self.trainval_data_source = config.DATASET.TRAINVAL_DATA_SOURCE | |
| self.test_data_source = config.DATASET.TEST_DATA_SOURCE | |
| # training and validating | |
| self.train_data_root = config.DATASET.TRAIN_DATA_ROOT | |
| self.train_pose_root = config.DATASET.TRAIN_POSE_ROOT # (optional) | |
| self.train_npz_root = config.DATASET.TRAIN_NPZ_ROOT | |
| self.train_list_path = config.DATASET.TRAIN_LIST_PATH | |
| self.train_intrinsic_path = config.DATASET.TRAIN_INTRINSIC_PATH | |
| self.val_data_root = config.DATASET.VAL_DATA_ROOT | |
| self.val_pose_root = config.DATASET.VAL_POSE_ROOT # (optional) | |
| self.val_npz_root = config.DATASET.VAL_NPZ_ROOT | |
| self.val_list_path = config.DATASET.VAL_LIST_PATH | |
| self.val_intrinsic_path = config.DATASET.VAL_INTRINSIC_PATH | |
| # testing | |
| self.test_data_root = config.DATASET.TEST_DATA_ROOT | |
| self.test_pose_root = config.DATASET.TEST_POSE_ROOT # (optional) | |
| self.test_npz_root = config.DATASET.TEST_NPZ_ROOT | |
| self.test_list_path = config.DATASET.TEST_LIST_PATH | |
| self.test_intrinsic_path = config.DATASET.TEST_INTRINSIC_PATH | |
| # 2. dataset config | |
| # general options | |
| self.min_overlap_score_test = config.DATASET.MIN_OVERLAP_SCORE_TEST # 0.4, omit data with overlap_score < min_overlap_score | |
| self.min_overlap_score_train = config.DATASET.MIN_OVERLAP_SCORE_TRAIN | |
| self.augment_fn = build_augmentor(config.DATASET.AUGMENTATION_TYPE) # None, options: [None, 'dark', 'mobile'] | |
| # MegaDepth options | |
| self.mgdpt_img_resize = config.DATASET.MGDPT_IMG_RESIZE # 840 | |
| self.mgdpt_img_pad = config.DATASET.MGDPT_IMG_PAD # True | |
| self.mgdpt_depth_pad = config.DATASET.MGDPT_DEPTH_PAD # True | |
| self.mgdpt_df = config.DATASET.MGDPT_DF # 8 | |
| self.coarse_scale = 1 / config.MODEL.RESOLUTION[0] # 0.125. for training loftr. | |
| # 3.loader parameters | |
| self.train_loader_params = { | |
| 'batch_size': args.batch_size, | |
| 'num_workers': args.num_workers, | |
| 'pin_memory': getattr(args, 'pin_memory', True) | |
| } | |
| self.val_loader_params = { | |
| 'batch_size': 1, | |
| 'shuffle': False, | |
| 'num_workers': args.num_workers, | |
| 'pin_memory': getattr(args, 'pin_memory', True) | |
| } | |
| self.test_loader_params = { | |
| 'batch_size': 1, | |
| 'shuffle': False, | |
| 'num_workers': args.num_workers, | |
| 'pin_memory': True | |
| } | |
| # 4. sampler | |
| self.data_sampler = config.TRAINER.DATA_SAMPLER | |
| self.n_samples_per_subset = config.TRAINER.N_SAMPLES_PER_SUBSET | |
| self.subset_replacement = config.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT | |
| self.shuffle = config.TRAINER.SB_SUBSET_SHUFFLE | |
| self.repeat = config.TRAINER.SB_REPEAT | |
| # (optional) RandomSampler for debugging | |
| # misc configurations | |
| self.parallel_load_data = getattr(args, 'parallel_load_data', False) | |
| self.seed = config.TRAINER.SEED # 66 | |
| def setup(self, stage=None): | |
| """ | |
| Setup train / val / test dataset. This method will be called by PL automatically. | |
| Args: | |
| stage (str): 'fit' in training phase, and 'test' in testing phase. | |
| """ | |
| assert stage in ['fit', 'test'], "stage must be either fit or test" | |
| try: | |
| self.world_size = dist.get_world_size() | |
| self.rank = dist.get_rank() | |
| logger.info(f"[rank:{self.rank}] world_size: {self.world_size}") | |
| except AssertionError as ae: | |
| self.world_size = 1 | |
| self.rank = 0 | |
| logger.warning(str(ae) + " (set wolrd_size=1 and rank=0)") | |
| if stage == 'fit': | |
| self.train_dataset = self._setup_dataset( | |
| self.train_data_root, | |
| self.train_npz_root, | |
| self.train_list_path, | |
| self.train_intrinsic_path, | |
| mode='train', | |
| min_overlap_score=self.min_overlap_score_train, | |
| pose_dir=self.train_pose_root) | |
| # setup multiple (optional) validation subsets | |
| if isinstance(self.val_list_path, (list, tuple)): | |
| self.val_dataset = [] | |
| if not isinstance(self.val_npz_root, (list, tuple)): | |
| self.val_npz_root = [self.val_npz_root for _ in range(len(self.val_list_path))] | |
| for npz_list, npz_root in zip(self.val_list_path, self.val_npz_root): | |
| self.val_dataset.append(self._setup_dataset( | |
| self.val_data_root, | |
| npz_root, | |
| npz_list, | |
| self.val_intrinsic_path, | |
| mode='val', | |
| min_overlap_score=self.min_overlap_score_test, | |
| pose_dir=self.val_pose_root)) | |
| else: | |
| self.val_dataset = self._setup_dataset( | |
| self.val_data_root, | |
| self.val_npz_root, | |
| self.val_list_path, | |
| self.val_intrinsic_path, | |
| mode='val', | |
| min_overlap_score=self.min_overlap_score_test, | |
| pose_dir=self.val_pose_root) | |
| logger.info(f'[rank:{self.rank}] Train & Val Dataset loaded!') | |
| else: # stage == 'test | |
| self.test_dataset = self._setup_dataset( | |
| self.test_data_root, | |
| self.test_npz_root, | |
| self.test_list_path, | |
| self.test_intrinsic_path, | |
| mode='test', | |
| min_overlap_score=self.min_overlap_score_test, | |
| pose_dir=self.test_pose_root) | |
| logger.info(f'[rank:{self.rank}]: Test Dataset loaded!') | |
| def _setup_dataset(self, | |
| data_root, | |
| split_npz_root, | |
| scene_list_path, | |
| intri_path, | |
| mode='train', | |
| min_overlap_score=0., | |
| pose_dir=None): | |
| """ Setup train / val / test set""" | |
| with open(scene_list_path, 'r') as f: | |
| npz_names = [name.split()[0] for name in f.readlines()] | |
| if mode == 'train': | |
| local_npz_names = get_local_split(npz_names, self.world_size, self.rank, self.seed) | |
| else: | |
| local_npz_names = npz_names | |
| logger.info(f'[rank {self.rank}]: {len(local_npz_names)} scene(s) assigned.') | |
| dataset_builder = self._build_concat_dataset_parallel \ | |
| if self.parallel_load_data \ | |
| else self._build_concat_dataset | |
| return dataset_builder(data_root, local_npz_names, split_npz_root, intri_path, | |
| mode=mode, min_overlap_score=min_overlap_score, pose_dir=pose_dir) | |
| def _build_concat_dataset( | |
| self, | |
| data_root, | |
| npz_names, | |
| npz_dir, | |
| intrinsic_path, | |
| mode, | |
| min_overlap_score=0., | |
| pose_dir=None | |
| ): | |
| datasets = [] | |
| augment_fn = self.augment_fn if mode == 'train' else None | |
| data_source = self.trainval_data_source if mode in ['train', 'val'] else self.test_data_source | |
| if str(data_source).lower() == 'megadepth': | |
| npz_names = [f'{n}.npz' for n in npz_names] | |
| for npz_name in tqdm(npz_names, | |
| desc=f'[rank:{self.rank}] loading {mode} datasets', | |
| disable=int(self.rank) != 0): | |
| # `ScanNetDataset`/`MegaDepthDataset` load all data from npz_path when initialized, which might take time. | |
| npz_path = osp.join(npz_dir, npz_name) | |
| if data_source == 'ScanNet': | |
| datasets.append( | |
| ScanNetDataset(data_root, | |
| npz_path, | |
| intrinsic_path, | |
| mode=mode, | |
| min_overlap_score=min_overlap_score, | |
| augment_fn=augment_fn, | |
| pose_dir=pose_dir)) | |
| elif data_source == 'MegaDepth': | |
| datasets.append( | |
| MegaDepthDataset(data_root, | |
| npz_path, | |
| mode=mode, | |
| min_overlap_score=min_overlap_score, | |
| img_resize=self.mgdpt_img_resize, | |
| df=self.mgdpt_df, | |
| img_padding=self.mgdpt_img_pad, | |
| depth_padding=self.mgdpt_depth_pad, | |
| augment_fn=augment_fn, | |
| coarse_scale=self.coarse_scale)) | |
| else: | |
| raise NotImplementedError() | |
| return ConcatDataset(datasets) | |
| def _build_concat_dataset_parallel( | |
| self, | |
| data_root, | |
| npz_names, | |
| npz_dir, | |
| intrinsic_path, | |
| mode, | |
| min_overlap_score=0., | |
| pose_dir=None, | |
| ): | |
| augment_fn = self.augment_fn if mode == 'train' else None | |
| data_source = self.trainval_data_source if mode in ['train', 'val'] else self.test_data_source | |
| if str(data_source).lower() == 'megadepth': | |
| npz_names = [f'{n}.npz' for n in npz_names] | |
| with tqdm_joblib(tqdm(desc=f'[rank:{self.rank}] loading {mode} datasets', | |
| total=len(npz_names), disable=int(self.rank) != 0)): | |
| if data_source == 'ScanNet': | |
| datasets = Parallel(n_jobs=math.floor(len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()))( | |
| delayed(lambda x: _build_dataset( | |
| ScanNetDataset, | |
| data_root, | |
| osp.join(npz_dir, x), | |
| intrinsic_path, | |
| mode=mode, | |
| min_overlap_score=min_overlap_score, | |
| augment_fn=augment_fn, | |
| pose_dir=pose_dir))(name) | |
| for name in npz_names) | |
| elif data_source == 'MegaDepth': | |
| # TODO: _pickle.PicklingError: Could not pickle the task to send it to the workers. | |
| raise NotImplementedError() | |
| datasets = Parallel(n_jobs=math.floor(len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()))( | |
| delayed(lambda x: _build_dataset( | |
| MegaDepthDataset, | |
| data_root, | |
| osp.join(npz_dir, x), | |
| mode=mode, | |
| min_overlap_score=min_overlap_score, | |
| img_resize=self.mgdpt_img_resize, | |
| df=self.mgdpt_df, | |
| img_padding=self.mgdpt_img_pad, | |
| depth_padding=self.mgdpt_depth_pad, | |
| augment_fn=augment_fn, | |
| coarse_scale=self.coarse_scale))(name) | |
| for name in npz_names) | |
| else: | |
| raise ValueError(f'Unknown dataset: {data_source}') | |
| return ConcatDataset(datasets) | |
| def train_dataloader(self): | |
| """ Build training dataloader for ScanNet / MegaDepth. """ | |
| assert self.data_sampler in ['scene_balance'] | |
| logger.info(f'[rank:{self.rank}/{self.world_size}]: Train Sampler and DataLoader re-init (should not re-init between epochs!).') | |
| if self.data_sampler == 'scene_balance': | |
| sampler = RandomConcatSampler(self.train_dataset, | |
| self.n_samples_per_subset, | |
| self.subset_replacement, | |
| self.shuffle, self.repeat, self.seed) | |
| else: | |
| sampler = None | |
| dataloader = DataLoader(self.train_dataset, sampler=sampler, **self.train_loader_params) | |
| return dataloader | |
| def val_dataloader(self): | |
| """ Build validation dataloader for ScanNet / MegaDepth. """ | |
| logger.info(f'[rank:{self.rank}/{self.world_size}]: Val Sampler and DataLoader re-init.') | |
| if not isinstance(self.val_dataset, abc.Sequence): | |
| sampler = DistributedSampler(self.val_dataset, shuffle=False) | |
| return DataLoader(self.val_dataset, sampler=sampler, **self.val_loader_params) | |
| else: | |
| dataloaders = [] | |
| for dataset in self.val_dataset: | |
| sampler = DistributedSampler(dataset, shuffle=False) | |
| dataloaders.append(DataLoader(dataset, sampler=sampler, **self.val_loader_params)) | |
| return dataloaders | |
| def test_dataloader(self, *args, **kwargs): | |
| logger.info(f'[rank:{self.rank}/{self.world_size}]: Test Sampler and DataLoader re-init.') | |
| sampler = DistributedSampler(self.test_dataset, shuffle=False) | |
| return DataLoader(self.test_dataset, sampler=sampler, **self.test_loader_params) | |
| def _build_dataset(dataset: Dataset, *args, **kwargs): | |
| return dataset(*args, **kwargs) | |