File size: 2,710 Bytes
4c41a36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
from typing import Callable, List, Optional, Union
import torch
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset
class SampleDataset(Dataset):
def __init__(self,
x: Union[List, torch.Tensor],
y: Union[List, torch.Tensor],
transforms: Optional[Callable] = None) -> None:
super(SampleDataset, self).__init__()
self.x = x
self.y = y
if transforms is None:
# Replace None with some default transforms
# If image, could be an Resize and ToTensor
self.transforms = lambda x: x
else:
self.transforms = transforms
def __len__(self):
return len(self.x)
def __getitem__(self, index: int):
x = self.x[index]
y = self.y[index]
x = self.transforms(x)
return x, y
class SampleDataModule(LightningDataModule):
def __init__(self,
x: Union[List, torch.Tensor],
y: Union[List, torch.Tensor],
transforms: Optional[Callable] = None,
val_ratio: float = 0,
batch_size: int = 32) -> None:
super(SampleDataModule, self).__init__()
assert 0 <= val_ratio < 1
assert isinstance(batch_size, int)
self.x = x
self.y = y
self.transforms = transforms
self.val_ratio = val_ratio
self.batch_size = batch_size
self.setup()
self.prepare_data()
def setup(self, stage: Optional[str] = None) -> None:
pass
def prepare_data(self) -> None:
n_samples: int = len(self.x)
train_size: int = n_samples - int(n_samples * self.val_ratio)
self.train_dataset = SampleDataset(x=self.x[:train_size],
y=self.y[:train_size],
transforms=self.transforms)
if train_size < n_samples:
self.val_dataset = SampleDataset(x=self.x[train_size:],
y=self.y[train_size:],
transforms=self.transforms)
else:
self.val_dataset = SampleDataset(x=self.x[-self.batch_size:],
y=self.y[-self.batch_size:],
transforms=self.transforms)
def train_dataloader(self) -> DataLoader:
return DataLoader(dataset=self.train_dataset,
batch_size=self.batch_size)
def val_dataloader(self) -> DataLoader:
return DataLoader(dataset=self.val_dataset, batch_size=self.batch_size)
|