File size: 2,831 Bytes
1c3eb47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from mmpl.registry import DATASETS
import lightning.pytorch as pl
from torch.utils.data import DataLoader
from .builder import build_dataset
from mmengine.registry import FUNCTIONS
from functools import partial


def get_collate_fn(dataloader_cfg):
    collate_fn_cfg = dataloader_cfg.pop('collate_fn', dict(type='pseudo_collate'))
    collate_fn_type = collate_fn_cfg.pop('type')
    collate_fn = FUNCTIONS.get(collate_fn_type)
    collate_fn = partial(collate_fn, **collate_fn_cfg)  # type: ignore
    return collate_fn


@DATASETS.register_module()
class PLDataModule(pl.LightningDataModule):
    def __init__(self,
                 train_loader=None,
                 val_loader=None,
                 test_loader=None,
                 predict_loader=None,
                 **kwargs
                 ):
        super().__init__()
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.predict_loader = predict_loader
        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None
        self.predict_dataset = None

    def prepare_data(self):
        pass

    def setup(self, stage: str):
        if stage == "fit":
            dataset_cfg = self.train_loader.pop('dataset')
            self.train_dataset = build_dataset(dataset_cfg)
            if self.val_loader is not None:
                dataset_cfg = self.val_loader.pop('dataset')
                self.val_dataset = build_dataset(dataset_cfg)
        if stage == "val":
            if self.val_loader is not None:
                dataset_cfg = self.val_loader.pop('dataset')
                self.val_dataset = build_dataset(dataset_cfg)
        if stage == "test":
            if self.test_loader is not None:
                dataset_cfg = self.test_loader.pop('dataset')
                self.test_dataset = build_dataset(dataset_cfg)
        if stage == "predict":
            if self.predict_loader is not None:
                dataset_cfg = self.predict_loader.pop('dataset')
                self.predict_dataset = build_dataset(dataset_cfg)

    def train_dataloader(self):
        collate_fn = get_collate_fn(self.train_loader)
        return DataLoader(self.train_dataset, collate_fn=collate_fn, **self.train_loader)

    def val_dataloader(self):
        collate_fn = get_collate_fn(self.val_loader)
        return DataLoader(self.val_dataset, collate_fn=collate_fn, **self.val_loader)

    def test_dataloader(self):
        collate_fn = get_collate_fn(self.test_loader)
        return DataLoader(self.test_dataset, collate_fn=collate_fn, **self.test_loader)

    def predict_dataloader(self):
        collate_fn = get_collate_fn(self.predict_loader)
        return DataLoader(self.predict_dataset, collate_fn=collate_fn, **self.predict_loader)