|
import os |
|
from PIL import Image |
|
from torch.utils.data import Dataset |
|
|
|
|
|
class PrecomputedDataset(Dataset): |
|
def __init__(self, |
|
root, |
|
transforms, |
|
student_augs, |
|
): |
|
super(PrecomputedDataset, self).__init__() |
|
self.root = root |
|
self.transforms = transforms |
|
self.student_augs = student_augs |
|
|
|
self.image_files = [] |
|
self.label_files = [] |
|
self.pseudo_files = [] |
|
for file in os.listdir(os.path.join(self.root, 'imgs')): |
|
self.image_files.append(os.path.join(self.root, 'imgs', file)) |
|
self.label_files.append(os.path.join(self.root, 'gts', file)) |
|
self.pseudo_files.append(os.path.join(self.root, 'pseudos', file)) |
|
|
|
|
|
def __getitem__(self, index): |
|
image_path = self.image_files[index] |
|
label_path = self.label_files[index] |
|
pseudo_path = self.pseudo_files[index] |
|
|
|
img = Image.open(image_path).convert("RGB") |
|
label = Image.open(label_path) |
|
pseudo = Image.open(pseudo_path) |
|
|
|
if self.student_augs: |
|
img, label, aimg, pseudo = self.transforms(img, label, pseudo) |
|
return img, label.long(), aimg, pseudo.long() |
|
else: |
|
img, label, pseudo = self.transforms(img, label, pseudo) |
|
return img, label.long(), pseudo.long() |
|
|
|
def __len__(self): |
|
return len(self.image_files) |