table-extraction / components /data_module.py
jurgendn's picture
[Demo Gradio] Add to Huggingface Space
4c41a36
raw
history blame
No virus
2.71 kB
from typing import Callable, List, Optional, Union
import torch
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset
class SampleDataset(Dataset):
def __init__(self,
x: Union[List, torch.Tensor],
y: Union[List, torch.Tensor],
transforms: Optional[Callable] = None) -> None:
super(SampleDataset, self).__init__()
self.x = x
self.y = y
if transforms is None:
# Replace None with some default transforms
# If image, could be an Resize and ToTensor
self.transforms = lambda x: x
else:
self.transforms = transforms
def __len__(self):
return len(self.x)
def __getitem__(self, index: int):
x = self.x[index]
y = self.y[index]
x = self.transforms(x)
return x, y
class SampleDataModule(LightningDataModule):
def __init__(self,
x: Union[List, torch.Tensor],
y: Union[List, torch.Tensor],
transforms: Optional[Callable] = None,
val_ratio: float = 0,
batch_size: int = 32) -> None:
super(SampleDataModule, self).__init__()
assert 0 <= val_ratio < 1
assert isinstance(batch_size, int)
self.x = x
self.y = y
self.transforms = transforms
self.val_ratio = val_ratio
self.batch_size = batch_size
self.setup()
self.prepare_data()
def setup(self, stage: Optional[str] = None) -> None:
pass
def prepare_data(self) -> None:
n_samples: int = len(self.x)
train_size: int = n_samples - int(n_samples * self.val_ratio)
self.train_dataset = SampleDataset(x=self.x[:train_size],
y=self.y[:train_size],
transforms=self.transforms)
if train_size < n_samples:
self.val_dataset = SampleDataset(x=self.x[train_size:],
y=self.y[train_size:],
transforms=self.transforms)
else:
self.val_dataset = SampleDataset(x=self.x[-self.batch_size:],
y=self.y[-self.batch_size:],
transforms=self.transforms)
def train_dataloader(self) -> DataLoader:
return DataLoader(dataset=self.train_dataset,
batch_size=self.batch_size)
def val_dataloader(self) -> DataLoader:
return DataLoader(dataset=self.val_dataset, batch_size=self.batch_size)