from typing import Optional from omegaconf import DictConfig import pytorch_lightning as L import torch.utils.data as torchdata from .torch import collate, worker_init_fn def get_dataset(name): if name == "mapillary": from .mapillary.data_module import MapillaryDataModule return MapillaryDataModule elif name == "nuscenes": from .nuscenes.data_module import NuScenesData return NuScenesData elif name == "kitti": from .kitti.data_module import BEVKitti360Data return BEVKitti360Data else: raise NotImplementedError(f"Dataset {name} not implemented.") class GenericDataModule(L.LightningDataModule): def __init__(self, cfg: DictConfig): super().__init__() self.cfg = cfg self.data_module = get_dataset(cfg.name)(cfg) def prepare_data(self) -> None: self.data_module.prepare_data() def setup(self, stage: Optional[str] = None): self.data_module.setup(stage) def dataloader( self, stage: str, shuffle: bool = False, num_workers: int = None, sampler: Optional[torchdata.Sampler] = None, ): dataset = self.data_module.dataset(stage) cfg = self.cfg["loading"][stage] num_workers = cfg["num_workers"] if num_workers is None else num_workers loader = torchdata.DataLoader( dataset, batch_size=cfg["batch_size"], num_workers=num_workers, shuffle=shuffle or (stage == "train"), pin_memory=True, persistent_workers=num_workers > 0, worker_init_fn=worker_init_fn, collate_fn=collate, sampler=sampler, ) return loader def train_dataloader(self, **kwargs): return self.dataloader("train", **kwargs) def val_dataloader(self, **kwargs): return self.dataloader("val", **kwargs) def test_dataloader(self, **kwargs): return self.dataloader("test", **kwargs)