Spaces:
Runtime error
Runtime error
from mmpl.registry import DATASETS | |
import lightning.pytorch as pl | |
from torch.utils.data import DataLoader | |
from .builder import build_dataset | |
from mmengine.registry import FUNCTIONS | |
from functools import partial | |
def get_collate_fn(dataloader_cfg): | |
collate_fn_cfg = dataloader_cfg.pop('collate_fn', dict(type='pseudo_collate')) | |
collate_fn_type = collate_fn_cfg.pop('type') | |
collate_fn = FUNCTIONS.get(collate_fn_type) | |
collate_fn = partial(collate_fn, **collate_fn_cfg) # type: ignore | |
return collate_fn | |
class PLDataModule(pl.LightningDataModule): | |
def __init__(self, | |
train_loader=None, | |
val_loader=None, | |
test_loader=None, | |
predict_loader=None, | |
**kwargs | |
): | |
super().__init__() | |
self.train_loader = train_loader | |
self.val_loader = val_loader | |
self.test_loader = test_loader | |
self.predict_loader = predict_loader | |
self.train_dataset = None | |
self.val_dataset = None | |
self.test_dataset = None | |
self.predict_dataset = None | |
def prepare_data(self): | |
pass | |
def setup(self, stage: str): | |
if stage == "fit": | |
dataset_cfg = self.train_loader.pop('dataset') | |
self.train_dataset = build_dataset(dataset_cfg) | |
if self.val_loader is not None: | |
dataset_cfg = self.val_loader.pop('dataset') | |
self.val_dataset = build_dataset(dataset_cfg) | |
if stage == "val": | |
if self.val_loader is not None: | |
dataset_cfg = self.val_loader.pop('dataset') | |
self.val_dataset = build_dataset(dataset_cfg) | |
if stage == "test": | |
if self.test_loader is not None: | |
dataset_cfg = self.test_loader.pop('dataset') | |
self.test_dataset = build_dataset(dataset_cfg) | |
if stage == "predict": | |
if self.predict_loader is not None: | |
dataset_cfg = self.predict_loader.pop('dataset') | |
self.predict_dataset = build_dataset(dataset_cfg) | |
def train_dataloader(self): | |
collate_fn = get_collate_fn(self.train_loader) | |
return DataLoader(self.train_dataset, collate_fn=collate_fn, **self.train_loader) | |
def val_dataloader(self): | |
collate_fn = get_collate_fn(self.val_loader) | |
return DataLoader(self.val_dataset, collate_fn=collate_fn, **self.val_loader) | |
def test_dataloader(self): | |
collate_fn = get_collate_fn(self.test_loader) | |
return DataLoader(self.test_dataset, collate_fn=collate_fn, **self.test_loader) | |
def predict_dataloader(self): | |
collate_fn = get_collate_fn(self.predict_loader) | |
return DataLoader(self.predict_dataset, collate_fn=collate_fn, **self.predict_loader) | |