|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
from enum import Enum |
|
|
|
import PIL |
|
import torch |
|
from torchvision import transforms |
|
|
|
IMAGENET_MEAN = [0.485, 0.456, 0.406] |
|
IMAGENET_STD = [0.229, 0.224, 0.225] |
|
|
|
|
|
class DatasetSplit(Enum): |
|
TRAIN = "train" |
|
VAL = "val" |
|
TEST = "test" |
|
|
|
|
|
class RayanDataset(torch.utils.data.Dataset): |
|
def __init__( |
|
self, |
|
source, |
|
classname, |
|
input_size=518, |
|
output_size=224, |
|
split=DatasetSplit.TEST, |
|
external_transform=None, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
self.source = source |
|
self.split = split |
|
self.classnames_to_use = [classname] |
|
self.imgpaths_per_class, self.data_to_iterate = self.get_image_data() |
|
|
|
if external_transform is None: |
|
self.transform_img = [ |
|
transforms.Resize((input_size, input_size)), |
|
transforms.CenterCrop(input_size), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), |
|
] |
|
self.transform_img = transforms.Compose(self.transform_img) |
|
else: |
|
self.transform_img = external_transform |
|
|
|
|
|
self.transform_mask = [ |
|
transforms.Resize((output_size, output_size)), |
|
transforms.CenterCrop(output_size), |
|
transforms.ToTensor(), |
|
] |
|
self.transform_mask = transforms.Compose(self.transform_mask) |
|
self.output_shape = (1, output_size, output_size) |
|
|
|
def __getitem__(self, idx): |
|
classname, anomaly, image_path, mask_path = self.data_to_iterate[idx] |
|
image = PIL.Image.open(image_path).convert("RGB") |
|
image = self.transform_img(image) |
|
|
|
if self.split == DatasetSplit.TEST and mask_path is not None: |
|
mask = PIL.Image.open(mask_path).convert("L") |
|
mask = self.transform_mask(mask) > 0 |
|
else: |
|
mask = torch.zeros([*self.output_shape]) |
|
|
|
return { |
|
"image": image, |
|
"mask": mask, |
|
"is_anomaly": int(anomaly != "good"), |
|
"image_path": image_path, |
|
} |
|
|
|
def __len__(self): |
|
return len(self.data_to_iterate) |
|
|
|
def get_image_data(self): |
|
imgpaths_per_class = {} |
|
maskpaths_per_class = {} |
|
|
|
for classname in self.classnames_to_use: |
|
classpath = os.path.join(self.source, classname, self.split.value) |
|
maskpath = os.path.join(self.source, classname, "ground_truth") |
|
anomaly_types = os.listdir(classpath) |
|
|
|
imgpaths_per_class[classname] = {} |
|
maskpaths_per_class[classname] = {} |
|
|
|
for anomaly in anomaly_types: |
|
anomaly_path = os.path.join(classpath, anomaly) |
|
anomaly_files = sorted(os.listdir(anomaly_path)) |
|
imgpaths_per_class[classname][anomaly] = [ |
|
os.path.join(anomaly_path, x) for x in anomaly_files |
|
] |
|
|
|
if self.split == DatasetSplit.TEST and anomaly != "good": |
|
anomaly_mask_path = os.path.join(maskpath, anomaly) |
|
anomaly_mask_files = sorted(os.listdir(anomaly_mask_path)) |
|
maskpaths_per_class[classname][anomaly] = [ |
|
os.path.join(anomaly_mask_path, x) for x in anomaly_mask_files |
|
] |
|
else: |
|
maskpaths_per_class[classname]["good"] = None |
|
|
|
data_to_iterate = [] |
|
for classname in sorted(imgpaths_per_class.keys()): |
|
for anomaly in sorted(imgpaths_per_class[classname].keys()): |
|
for i, image_path in enumerate(imgpaths_per_class[classname][anomaly]): |
|
data_tuple = [classname, anomaly, image_path] |
|
if self.split == DatasetSplit.TEST and anomaly != "good": |
|
data_tuple.append(maskpaths_per_class[classname][anomaly][i]) |
|
else: |
|
data_tuple.append(None) |
|
data_to_iterate.append(data_tuple) |
|
|
|
return imgpaths_per_class, data_to_iterate |
|
|