File size: 2,021 Bytes
fd01725 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
from typing import Optional
from omegaconf import DictConfig
import pytorch_lightning as L
import torch.utils.data as torchdata
from .torch import collate, worker_init_fn
def get_dataset(name):
if name == "mapillary":
from .mapillary.data_module import MapillaryDataModule
return MapillaryDataModule
elif name == "nuscenes":
from .nuscenes.data_module import NuScenesData
return NuScenesData
elif name == "kitti":
from .kitti.data_module import BEVKitti360Data
return BEVKitti360Data
else:
raise NotImplementedError(f"Dataset {name} not implemented.")
class GenericDataModule(L.LightningDataModule):
def __init__(self, cfg: DictConfig):
super().__init__()
self.cfg = cfg
self.data_module = get_dataset(cfg.name)(cfg)
def prepare_data(self) -> None:
self.data_module.prepare_data()
def setup(self, stage: Optional[str] = None):
self.data_module.setup(stage)
def dataloader(
self,
stage: str,
shuffle: bool = False,
num_workers: int = None,
sampler: Optional[torchdata.Sampler] = None,
):
dataset = self.data_module.dataset(stage)
cfg = self.cfg["loading"][stage]
num_workers = cfg["num_workers"] if num_workers is None else num_workers
loader = torchdata.DataLoader(
dataset,
batch_size=cfg["batch_size"],
num_workers=num_workers,
shuffle=shuffle or (stage == "train"),
pin_memory=True,
persistent_workers=num_workers > 0,
worker_init_fn=worker_init_fn,
collate_fn=collate,
sampler=sampler,
)
return loader
def train_dataloader(self, **kwargs):
return self.dataloader("train", **kwargs)
def val_dataloader(self, **kwargs):
return self.dataloader("val", **kwargs)
def test_dataloader(self, **kwargs):
return self.dataloader("test", **kwargs) |