File size: 5,316 Bytes
d0f68bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from torch.utils.data import Dataset, DataLoader, Subset
from robust_detection.data_utils.rcnn_data_utils import *
import pytorch_lightning as pl
import robust_detection.transforms as T

DATA_FOLDER = os.path.join(os.path.dirname(__file__))
def get_transform():
        transforms = []
        transforms.append(T.ToTensor())
        return T.Compose(transforms)

class Objects_Smiles(pl.LightningDataModule):
    def __init__(self, data_path, **kwargs):
                super().__init__()
                self.batch_size = 1
                self.num_workers = 4
                self.data_path = data_path
                self.transforms = get_transform()
                self.base_class = Objects_Detection_Predictor_Dataset
    def prepare_data(self):
                dataset = self.base_class(os.path.join(DATA_FOLDER, self.data_path), self.transforms)
                self.train = dataset
                self.test  = dataset
                self.val   = dataset
                
                self.test_ood = dataset

    def train_dataloader(self):
        return DataLoader(
            self.train,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            drop_last=False,
            pin_memory=True,
            collate_fn=collate_tuple
        )

    def val_dataloader(self):
        return DataLoader(
            self.val,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            drop_last=False,
            pin_memory=True,
            collate_fn=collate_tuple
        )

    def test_dataloader(self):
        return DataLoader(
            self.test,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            drop_last=False,
            pin_memory=True,
            collate_fn=collate_tuple
        )

    def test_ood_dataloader(self, shuffle=False):
        return DataLoader(
            self.test_ood,
            batch_size=self.batch_size,
            shuffle=shuffle,
            num_workers=self.num_workers,
            drop_last=False,
            pin_memory=True,
            collate_fn=collate_tuple
        )
   
    @classmethod
    def add_dataset_specific_args(cls, parent):
        import argparse
        parser = argparse.ArgumentParser(parents=[parent], add_help=False)
        parser.add_argument('--data_path', type=str,
                            default="mnist/alldigits/")
        return parser


class Objects_fold_Smiles(pl.LightningDataModule):
    def __init__(self, data_path, fold, **kwargs):
                super().__init__()
                self.batch_size = 1
                self.num_workers = 4
                self.data_path = data_path
                self.fold = fold
                self.transforms = get_transform()
               # self.base_class = Objects_Detection_Predictor_Dataset
                self.base_class = Objects_Detection_Dataset
    def prepare_data(self):
                dataset = self.base_class(os.path.join(DATA_FOLDER, self.data_path), self.transforms)
                if self.fold > -1:
                   train_idx = np.load(os.path.join(DATA_FOLDER, f"{self.data_path}", "../folds", str(self.fold), "train_idx.npy"))
                   self.train = Subset(dataset, train_idx)
                   val_idx = np.load(os.path.join(DATA_FOLDER, f"{self.data_path}", "../folds", str(self.fold), "val_idx.npy"))

                   self.val = Subset(dataset, val_idx)
                else:
                   self.train = dataset
                   self.val   = dataset
                self.test = self.val
                self.test_ood = self.test

    def train_dataloader(self):
        return DataLoader(
            self.train,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            drop_last=False,
            pin_memory=True,
            collate_fn=collate_tuple
        )

    def val_dataloader(self):
        return DataLoader(
            self.val,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            drop_last=False,
            pin_memory=True,
            collate_fn=collate_tuple
        )

    def test_dataloader(self):
        return DataLoader(
            self.test,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            drop_last=False,
            pin_memory=True,
            collate_fn=collate_tuple
        )
    
    def test_ood_dataloader(self, shuffle=False):
        return DataLoader(
            self.test_ood,
            batch_size=self.batch_size,
            shuffle=shuffle,
            num_workers=self.num_workers,
            drop_last=False,
            pin_memory=True,
            collate_fn=collate_tuple
        )

    @classmethod
    def add_dataset_specific_args(cls, parent):
        import argparse
        parser = argparse.ArgumentParser(parents=[parent], add_help=False)
        parser.add_argument('--data_path', type=str,
                            default="mnist/alldigits/")
        parser.add_argument('--fold', type=int,
                            default=0)
        return parser