File size: 911 Bytes
1865436
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from PIL import Image
from torch.utils.data import Dataset


class Labelizer():
    def __init__(self):
        super().__init__()
        self.labels = {'background': 0, 'bien': 1}
        self.inv_labels = {0: 'background', 1: 'bien'}

    def transform(self, label):
        return self.labels[label]

    def inverse_transform(self, ys):
        return self.inv_labels(ys)

    def num_classes(self):
        return len(self.labels)


class PoIDataset(Dataset):
    def __init__(self,
                 data_path,
                 transforms=None):
        self.data_path = data_path
        self.transforms = transforms

    def __len__(self):
        return len(self.data_path)

    def __getitem__(self, idx):
        image = Image.open(self.data_path[idx]).convert('RGB')
        target = {}
        if self.transforms is not None:
            image = self.transforms(image)
        return image, target