Spaces:
Build error
Build error
File size: 3,676 Bytes
4409449 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import pytorch_lightning as pl
from torch.utils.data import DataLoader
class BASEDataModule(pl.LightningDataModule):
def __init__(self, collate_fn):
super().__init__()
self.dataloader_options = {"collate_fn": collate_fn}
self.persistent_workers = True
self.is_mm = False
self._train_dataset = None
self._val_dataset = None
self._test_dataset = None
def get_sample_set(self, overrides={}):
sample_params = self.hparams.copy()
sample_params.update(overrides)
return self.DatasetEval(**sample_params)
@property
def train_dataset(self):
if self._train_dataset is None:
self._train_dataset = self.Dataset(split=self.cfg.TRAIN.SPLIT,
**self.hparams)
return self._train_dataset
@property
def val_dataset(self):
if self._val_dataset is None:
params = self.hparams.copy()
params['code_path'] = None
params['split'] = self.cfg.EVAL.SPLIT
self._val_dataset = self.DatasetEval(**params)
return self._val_dataset
@property
def test_dataset(self):
if self._test_dataset is None:
# self._test_dataset = self.DatasetEval(split=self.cfg.TEST.SPLIT,
# **self.hparams)
params = self.hparams.copy()
params['code_path'] = None
params['split'] = self.cfg.TEST.SPLIT
self._test_dataset = self.DatasetEval( **params)
return self._test_dataset
def setup(self, stage=None):
# Use the getter the first time to load the data
if stage in (None, "fit"):
_ = self.train_dataset
_ = self.val_dataset
if stage in (None, "test"):
_ = self.test_dataset
def train_dataloader(self):
dataloader_options = self.dataloader_options.copy()
dataloader_options["batch_size"] = self.cfg.TRAIN.BATCH_SIZE
dataloader_options["num_workers"] = self.cfg.TRAIN.NUM_WORKERS
return DataLoader(
self.train_dataset,
shuffle=False,
persistent_workers=True,
**dataloader_options,
)
def predict_dataloader(self):
dataloader_options = self.dataloader_options.copy()
dataloader_options[
"batch_size"] = 1 if self.is_mm else self.cfg.TEST.BATCH_SIZE
dataloader_options["num_workers"] = self.cfg.TEST.NUM_WORKERS
dataloader_options["shuffle"] = False
return DataLoader(
self.test_dataset,
persistent_workers=True,
**dataloader_options,
)
def val_dataloader(self):
# overrides batch_size and num_workers
dataloader_options = self.dataloader_options.copy()
dataloader_options["batch_size"] = self.cfg.EVAL.BATCH_SIZE
dataloader_options["num_workers"] = self.cfg.EVAL.NUM_WORKERS
dataloader_options["shuffle"] = False
return DataLoader(
self.val_dataset,
persistent_workers=True,
**dataloader_options,
)
def test_dataloader(self):
# overrides batch_size and num_workers
dataloader_options = self.dataloader_options.copy()
dataloader_options[
"batch_size"] = 1 if self.is_mm else self.cfg.TEST.BATCH_SIZE
dataloader_options["num_workers"] = self.cfg.TEST.NUM_WORKERS
dataloader_options["shuffle"] = False
return DataLoader(
self.test_dataset,
persistent_workers=True,
**dataloader_options,
)
|