asadi / datasets /rayan_dataset.py
smjfas's picture
initial commit
16a0f31
# -----------------------------------------------------------------------------
# Do Not Alter This File!
# -----------------------------------------------------------------------------
# The following code is part of the logic used for loading and evaluating your
# output scores. Please DO NOT modify this section, as upon your submission,
# the whole evaluation logic will be overwritten by the original code.
# -----------------------------------------------------------------------------
# If you'd like to make modifications, you can create a completely new Dataset
# class or a child class that inherits from this one and use that with your
# data loader.
# -----------------------------------------------------------------------------
import os
from enum import Enum
import PIL
import torch
from torchvision import transforms
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
class DatasetSplit(Enum):
TRAIN = "train"
VAL = "val"
TEST = "test"
class RayanDataset(torch.utils.data.Dataset):
def __init__(
self,
source,
classname,
input_size=518,
output_size=224,
split=DatasetSplit.TEST,
external_transform=None,
**kwargs,
):
super().__init__()
self.source = source
self.split = split
self.classnames_to_use = [classname]
self.imgpaths_per_class, self.data_to_iterate = self.get_image_data()
if external_transform is None:
self.transform_img = [
transforms.Resize((input_size, input_size)),
transforms.CenterCrop(input_size),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
]
self.transform_img = transforms.Compose(self.transform_img)
else:
self.transform_img = external_transform
# Output size of the mask has to be of shape: 1×224×224
self.transform_mask = [
transforms.Resize((output_size, output_size)),
transforms.CenterCrop(output_size),
transforms.ToTensor(),
]
self.transform_mask = transforms.Compose(self.transform_mask)
self.output_shape = (1, output_size, output_size)
def __getitem__(self, idx):
classname, anomaly, image_path, mask_path = self.data_to_iterate[idx]
image = PIL.Image.open(image_path).convert("RGB")
image = self.transform_img(image)
if self.split == DatasetSplit.TEST and mask_path is not None:
mask = PIL.Image.open(mask_path).convert("L")
mask = self.transform_mask(mask) > 0
else:
mask = torch.zeros([*self.output_shape])
return {
"image": image,
"mask": mask,
"is_anomaly": int(anomaly != "good"),
"image_path": image_path,
}
def __len__(self):
return len(self.data_to_iterate)
def get_image_data(self):
imgpaths_per_class = {}
maskpaths_per_class = {}
for classname in self.classnames_to_use:
classpath = os.path.join(self.source, classname, self.split.value)
maskpath = os.path.join(self.source, classname, "ground_truth")
anomaly_types = os.listdir(classpath)
imgpaths_per_class[classname] = {}
maskpaths_per_class[classname] = {}
for anomaly in anomaly_types:
anomaly_path = os.path.join(classpath, anomaly)
anomaly_files = sorted(os.listdir(anomaly_path))
imgpaths_per_class[classname][anomaly] = [
os.path.join(anomaly_path, x) for x in anomaly_files
]
if self.split == DatasetSplit.TEST and anomaly != "good":
anomaly_mask_path = os.path.join(maskpath, anomaly)
anomaly_mask_files = sorted(os.listdir(anomaly_mask_path))
maskpaths_per_class[classname][anomaly] = [
os.path.join(anomaly_mask_path, x) for x in anomaly_mask_files
]
else:
maskpaths_per_class[classname]["good"] = None
data_to_iterate = []
for classname in sorted(imgpaths_per_class.keys()):
for anomaly in sorted(imgpaths_per_class[classname].keys()):
for i, image_path in enumerate(imgpaths_per_class[classname][anomaly]):
data_tuple = [classname, anomaly, image_path]
if self.split == DatasetSplit.TEST and anomaly != "good":
data_tuple.append(maskpaths_per_class[classname][anomaly][i])
else:
data_tuple.append(None)
data_to_iterate.append(data_tuple)
return imgpaths_per_class, data_to_iterate