| from dataclasses import fields |
| from typing import Callable |
| from torch.utils.data import Dataset, ConcatDataset |
| import bisect |
|
|
| from ..misc.step_tracker import StepTracker |
| from .types import Stage |
| from .view_sampler import get_view_sampler |
| from .dataset_dl3dv import DatasetDL3DV, DatasetDL3DVCfgWrapper |
| from .dataset_scannetpp import DatasetScannetpp, DatasetScannetppCfgWrapper |
| from .dataset_co3d import DatasetCo3d, DatasetCo3dCfgWrapper |
|
|
| DATASETS: dict[str, Dataset] = { |
| "co3d": DatasetCo3d, |
| "scannetpp": DatasetScannetpp, |
| "dl3dv": DatasetDL3DV, |
| } |
|
|
| DatasetCfgWrapper = DatasetDL3DVCfgWrapper | DatasetScannetppCfgWrapper | DatasetCo3dCfgWrapper |
|
|
| class TestDatasetWarpper(Dataset): |
| def __init__(self, dataset: Dataset): |
| self.dataset = dataset |
|
|
| def __getitem__(self, idx): |
|
|
| return self.dataset[(idx, self.dataset.view_sampler.num_context_views, self.dataset.cfg.input_image_shape[1] // 14)] |
| |
| def __len__(self): |
| return len(self.dataset) |
|
|
| |
| |
| class CustomConcatDataset(ConcatDataset): |
|
|
| def __getitem__(self, idx_tuple): |
|
|
| if isinstance(idx_tuple, list): |
| idx_tuple = idx_tuple[0] |
|
|
| idx = idx_tuple[0] |
| if idx < 0: |
| if -idx > len(self): |
| raise ValueError("absolute value of index should not exceed dataset length") |
| idx = len(self) + idx |
| dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) |
| if dataset_idx == 0: |
| sample_idx = idx |
| else: |
| sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] |
| return self.datasets[dataset_idx][(sample_idx, idx_tuple[1], idx_tuple[2])] |
|
|
|
|
| def get_dataset( |
| cfgs: list[DatasetCfgWrapper], |
| stage: Stage, |
| step_tracker: StepTracker | None, |
| dataset_shim: Callable[[Dataset, str], Dataset] |
| ) -> list[Dataset]: |
| datasets = [] |
| if stage != "test": |
| if stage == "val": |
| cfgs = [cfgs[0]] |
| for cfg in cfgs: |
| (field,) = fields(type(cfg)) |
| cfg = getattr(cfg, field.name) |
| |
| view_sampler = get_view_sampler( |
| cfg.view_sampler, |
| stage, |
| cfg.overfit_to_scene is not None, |
| cfg.cameras_are_circular, |
| step_tracker, |
| ) |
| dataset = DATASETS[cfg.name](cfg, stage, view_sampler) |
| dataset = dataset_shim(dataset, stage) |
| datasets.append(dataset) |
|
|
| return CustomConcatDataset(datasets), datasets |
| elif stage == "test": |
| assert len(cfgs) == 1 |
| cfg = cfgs[0] |
| (field,) = fields(type(cfg)) |
| cfg = getattr(cfg, field.name) |
| |
| view_sampler = get_view_sampler( |
| cfg.view_sampler, |
| stage, |
| cfg.overfit_to_scene is not None, |
| cfg.cameras_are_circular, |
| step_tracker, |
| ) |
| dataset = DATASETS[cfg.name](cfg, stage, view_sampler) |
| dataset = dataset_shim(dataset, stage) |
|
|
| return TestDatasetWarpper(dataset) |
| else: |
| NotImplementedError(f"Stage {stage} is not supported") |
|
|