mueller-franzes's picture
init
f85e212
raw
history blame
2.56 kB
import pytorch_lightning as pl
import torch
from torch.utils.data.dataloader import DataLoader
import torch.multiprocessing as mp
from torch.utils.data.sampler import WeightedRandomSampler, RandomSampler
class SimpleDataModule(pl.LightningDataModule):
def __init__(self,
ds_train: object,
ds_val:object =None,
ds_test:object =None,
batch_size: int = 1,
num_workers: int = mp.cpu_count(),
seed: int = 0,
pin_memory: bool = False,
weights: list = None
):
super().__init__()
self.hyperparameters = {**locals()}
self.hyperparameters.pop('__class__')
self.hyperparameters.pop('self')
self.ds_train = ds_train
self.ds_val = ds_val
self.ds_test = ds_test
self.batch_size = batch_size
self.num_workers = num_workers
self.seed = seed
self.pin_memory = pin_memory
self.weights = weights
def train_dataloader(self):
generator = torch.Generator()
generator.manual_seed(self.seed)
if self.weights is not None:
sampler = WeightedRandomSampler(self.weights, len(self.weights), generator=generator)
else:
sampler = RandomSampler(self.ds_train, replacement=False, generator=generator)
return DataLoader(self.ds_train, batch_size=self.batch_size, num_workers=self.num_workers,
sampler=sampler, generator=generator, drop_last=True, pin_memory=self.pin_memory)
def val_dataloader(self):
generator = torch.Generator()
generator.manual_seed(self.seed)
if self.ds_val is not None:
return DataLoader(self.ds_val, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False,
generator=generator, drop_last=False, pin_memory=self.pin_memory)
else:
raise AssertionError("A validation set was not initialized.")
def test_dataloader(self):
generator = torch.Generator()
generator.manual_seed(self.seed)
if self.ds_test is not None:
return DataLoader(self.ds_test, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False,
generator = generator, drop_last=False, pin_memory=self.pin_memory)
else:
raise AssertionError("A test test set was not initialized.")