AnySplat / src /dataset /data_module.py
alexnasa's picture
Upload 243 files
2568013 verified
import random
from dataclasses import dataclass
from typing import Callable
import numpy as np
import torch
from lightning.pytorch import LightningDataModule
from torch import Generator, nn
from torch.utils.data import DataLoader, Dataset, IterableDataset
from src.dataset import *
from src.global_cfg import get_cfg
from ..misc.step_tracker import StepTracker
from ..misc.utils import get_world_size, get_rank
from . import DatasetCfgWrapper, get_dataset
from .types import DataShim, Stage
from .data_sampler import BatchedRandomSampler, MixedBatchSampler, custom_collate_fn
from .validation_wrapper import ValidationWrapper
def get_data_shim(encoder: nn.Module) -> DataShim:
"""Get functions that modify the batch. It's sometimes necessary to modify batches
outside the data loader because GPU computations are required to modify the batch or
because the modification depends on something outside the data loader.
"""
shims: list[DataShim] = []
if hasattr(encoder, "get_data_shim"):
shims.append(encoder.get_data_shim())
def combined_shim(batch):
for shim in shims:
batch = shim(batch)
return batch
return combined_shim
# the training ratio of datasets (example)
prob_mapping = {DatasetScannetpp: 0.5,
DatasetDL3DV: 0.5,
DatasetCo3d: 0.5}
@dataclass
class DataLoaderStageCfg:
batch_size: int
num_workers: int
persistent_workers: bool
seed: int | None
@dataclass
class DataLoaderCfg:
train: DataLoaderStageCfg
test: DataLoaderStageCfg
val: DataLoaderStageCfg
DatasetShim = Callable[[Dataset, Stage], Dataset]
def worker_init_fn(worker_id: int) -> None:
random.seed(int(torch.utils.data.get_worker_info().seed) % (2**32 - 1))
np.random.seed(int(torch.utils.data.get_worker_info().seed) % (2**32 - 1))
class DataModule(LightningDataModule):
dataset_cfgs: list[DatasetCfgWrapper]
data_loader_cfg: DataLoaderCfg
step_tracker: StepTracker | None
dataset_shim: DatasetShim
global_rank: int
def __init__(
self,
dataset_cfgs: list[DatasetCfgWrapper],
data_loader_cfg: DataLoaderCfg,
step_tracker: StepTracker | None = None,
dataset_shim: DatasetShim = lambda dataset, _: dataset,
global_rank: int = 0,
) -> None:
super().__init__()
self.dataset_cfgs = dataset_cfgs
self.data_loader_cfg = data_loader_cfg
self.step_tracker = step_tracker
self.dataset_shim = dataset_shim
self.global_rank = global_rank
def get_persistent(self, loader_cfg: DataLoaderStageCfg) -> bool | None:
return None if loader_cfg.num_workers == 0 else loader_cfg.persistent_workers
def get_generator(self, loader_cfg: DataLoaderStageCfg) -> torch.Generator | None:
if loader_cfg.seed is None:
return None
generator = Generator()
generator.manual_seed(loader_cfg.seed + self.global_rank)
self.generator = generator
return self.generator
def train_dataloader(self):
dataset, datasets_ls = get_dataset(self.dataset_cfgs, "train", self.step_tracker, self.dataset_shim)
world_size = get_world_size()
rank = get_rank()
# breakpoint()
prob_ls = [prob_mapping[type(dataset)] for dataset in datasets_ls]
# we assume all the dataset share the same num_context_views
if len(datasets_ls) > 1:
prob = prob_ls
context_num_views = [dataset.cfg.view_sampler.num_context_views for dataset in datasets_ls]
else:
prob = None
dataset_key = next(iter(get_cfg()["dataset"]))
dataset_cfg = get_cfg()["dataset"][dataset_key]
context_num_views = dataset_cfg['view_sampler']['num_context_views']
sampler = MixedBatchSampler(datasets_ls,
batch_size=self.data_loader_cfg.train.batch_size, # Not used here!
num_context_views=context_num_views,
world_size=world_size,
rank=rank,
prob=prob,
generator=self.get_generator(self.data_loader_cfg.train))
sampler.set_epoch(0)
self.train_loader = DataLoader(
dataset,
# self.data_loader_cfg.train.batch_size,
# shuffle=not isinstance(dataset, IterableDataset),
batch_sampler=sampler,
num_workers=self.data_loader_cfg.train.num_workers,
generator=self.generator,
worker_init_fn=worker_init_fn,
# collate_fn=custom_collate_fn,
persistent_workers=self.get_persistent(self.data_loader_cfg.train),
)
# breakpoint()
# Set epoch for train and validation loaders (if applicable)
if hasattr(self.train_loader, "dataset") and hasattr(self.train_loader.dataset, "set_epoch"):
print("Training: Set Epoch in DataModule")
self.train_loader.dataset.set_epoch(0)
if hasattr(self.train_loader, "sampler") and hasattr(self.train_loader.sampler, "set_epoch"):
print("Training: Set Epoch in DataModule")
self.train_loader.sampler.set_epoch(0)
return self.train_loader
def val_dataloader(self):
dataset, datasets_ls = get_dataset(self.dataset_cfgs, "val", self.step_tracker, self.dataset_shim)
world_size = get_world_size()
rank = get_rank()
# here, we random select one dataset for val
dataset_key = next(iter(get_cfg()["dataset"]))
dataset_cfg = get_cfg()["dataset"][dataset_key]
if len(datasets_ls) > 1:
prob = [0.5] * len(datasets_ls)
else:
prob = None
sampler = MixedBatchSampler(datasets_ls,
batch_size=self.data_loader_cfg.train.batch_size,
num_context_views=dataset_cfg['view_sampler']['num_context_views'],
world_size=world_size,
rank=rank,
prob=prob,
generator=self.get_generator(self.data_loader_cfg.train))
sampler.set_epoch(0)
self.val_loader = DataLoader(
dataset,
self.data_loader_cfg.val.batch_size,
num_workers=self.data_loader_cfg.val.num_workers,
sampler=sampler,
generator=self.get_generator(self.data_loader_cfg.val),
worker_init_fn=worker_init_fn,
persistent_workers=self.get_persistent(self.data_loader_cfg.val),
)
if hasattr(self.val_loader, "dataset") and hasattr(self.val_loader.dataset, "set_epoch"):
print("Validation: Set Epoch in DataModule")
self.val_loader.dataset.set_epoch(0)
if hasattr(self.val_loader, "sampler") and hasattr(self.val_loader.sampler, "set_epoch"):
print("Validation: Set Epoch in DataModule")
self.val_loader.sampler.set_epoch(0)
return self.val_loader
def test_dataloader(self):
dataset = get_dataset(self.dataset_cfgs, "test", self.step_tracker, self.dataset_shim)
data_loader = DataLoader(
dataset,
self.data_loader_cfg.test.batch_size,
num_workers=self.data_loader_cfg.test.num_workers,
generator=self.get_generator(self.data_loader_cfg.test),
worker_init_fn=worker_init_fn,
persistent_workers=self.get_persistent(self.data_loader_cfg.test),
)
return data_loader