File size: 1,439 Bytes
71268b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import os
from PIL import Image
from torch.utils.data import Dataset


class PrecomputedDataset(Dataset):
    def __init__(self, 
                 root, 
                 transforms,
                 student_augs, 
                 ):
        super(PrecomputedDataset, self).__init__()
        self.root = root
        self.transforms = transforms
        self.student_augs = student_augs

        self.image_files = []
        self.label_files = []
        self.pseudo_files = []
        for file in os.listdir(os.path.join(self.root, 'imgs')):
            self.image_files.append(os.path.join(self.root, 'imgs', file))
            self.label_files.append(os.path.join(self.root, 'gts', file))
            self.pseudo_files.append(os.path.join(self.root, 'pseudos', file))


    def __getitem__(self, index):
        image_path = self.image_files[index]
        label_path = self.label_files[index]
        pseudo_path = self.pseudo_files[index]

        img = Image.open(image_path).convert("RGB")
        label = Image.open(label_path)
        pseudo = Image.open(pseudo_path)

        if self.student_augs:
            img, label, aimg, pseudo = self.transforms(img, label, pseudo)
            return img, label.long(), aimg, pseudo.long()
        else:
            img, label, pseudo = self.transforms(img, label, pseudo)
            return img, label.long(), pseudo.long()

    def __len__(self):
        return len(self.image_files)