RSPrompter / mmpl /datasets /pl_datamodule.py
KyanChen's picture
Upload 159 files
1c3eb47
raw
history blame
2.83 kB
from mmpl.registry import DATASETS
import lightning.pytorch as pl
from torch.utils.data import DataLoader
from .builder import build_dataset
from mmengine.registry import FUNCTIONS
from functools import partial
def get_collate_fn(dataloader_cfg):
collate_fn_cfg = dataloader_cfg.pop('collate_fn', dict(type='pseudo_collate'))
collate_fn_type = collate_fn_cfg.pop('type')
collate_fn = FUNCTIONS.get(collate_fn_type)
collate_fn = partial(collate_fn, **collate_fn_cfg) # type: ignore
return collate_fn
@DATASETS.register_module()
class PLDataModule(pl.LightningDataModule):
def __init__(self,
train_loader=None,
val_loader=None,
test_loader=None,
predict_loader=None,
**kwargs
):
super().__init__()
self.train_loader = train_loader
self.val_loader = val_loader
self.test_loader = test_loader
self.predict_loader = predict_loader
self.train_dataset = None
self.val_dataset = None
self.test_dataset = None
self.predict_dataset = None
def prepare_data(self):
pass
def setup(self, stage: str):
if stage == "fit":
dataset_cfg = self.train_loader.pop('dataset')
self.train_dataset = build_dataset(dataset_cfg)
if self.val_loader is not None:
dataset_cfg = self.val_loader.pop('dataset')
self.val_dataset = build_dataset(dataset_cfg)
if stage == "val":
if self.val_loader is not None:
dataset_cfg = self.val_loader.pop('dataset')
self.val_dataset = build_dataset(dataset_cfg)
if stage == "test":
if self.test_loader is not None:
dataset_cfg = self.test_loader.pop('dataset')
self.test_dataset = build_dataset(dataset_cfg)
if stage == "predict":
if self.predict_loader is not None:
dataset_cfg = self.predict_loader.pop('dataset')
self.predict_dataset = build_dataset(dataset_cfg)
def train_dataloader(self):
collate_fn = get_collate_fn(self.train_loader)
return DataLoader(self.train_dataset, collate_fn=collate_fn, **self.train_loader)
def val_dataloader(self):
collate_fn = get_collate_fn(self.val_loader)
return DataLoader(self.val_dataset, collate_fn=collate_fn, **self.val_loader)
def test_dataloader(self):
collate_fn = get_collate_fn(self.test_loader)
return DataLoader(self.test_dataset, collate_fn=collate_fn, **self.test_loader)
def predict_dataloader(self):
collate_fn = get_collate_fn(self.predict_loader)
return DataLoader(self.predict_dataset, collate_fn=collate_fn, **self.predict_loader)