# Code for Peekaboo # Author: Hasib Zunair # Modified from https://github.com/NoelShin/selfmask """ Dataset functions for applying Normalized Cut. """ import os import glob import random from typing import Optional, Tuple, Union from pycocotools.coco import COCO import numpy as np import torch import torchvision from PIL import Image from torch.utils.data import Dataset from torchvision import transforms as T try: from torchvision.transforms import InterpolationMode BICUBIC = InterpolationMode.BICUBIC except ImportError: BICUBIC = Image.BICUBIC from datasets.utils import unnormalize from datasets.geometric_transforms import resize from datasets.VOC import get_voc_detection_gt, create_gt_masks_if_voc, create_VOC_loader from datasets.augmentations import geometric_augmentations, photometric_augmentations from datasets.uod_datasets import UODDataset NORMALIZE = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) def set_dataset_dir(dataset_name, root_dir): if dataset_name == "ECSSD": dataset_dir = os.path.join(root_dir, "ECSSD") img_dir = os.path.join(dataset_dir, "images") gt_dir = os.path.join(dataset_dir, "ground_truth_mask") scribbles_dir = os.path.join(root_dir, "SCRIBBLES") elif dataset_name == "DUTS-TEST": dataset_dir = os.path.join(root_dir, "DUTS-TE") img_dir = os.path.join(dataset_dir, "DUTS-TE-Image") gt_dir = os.path.join(dataset_dir, "DUTS-TE-Mask") scribbles_dir = os.path.join(root_dir, "SCRIBBLES") elif dataset_name == "DUTS-TR": dataset_dir = os.path.join(root_dir, "DUTS-TR") img_dir = os.path.join(dataset_dir, "DUTS-TR-Image") gt_dir = os.path.join(dataset_dir, "DUTS-TR-Mask") scribbles_dir = os.path.join(root_dir, "SCRIBBLES") elif dataset_name == "DUT-OMRON": dataset_dir = os.path.join(root_dir, "DUT-OMRON") img_dir = os.path.join(dataset_dir, "DUT-OMRON-image") gt_dir = os.path.join(dataset_dir, "pixelwiseGT-new-PNG") scribbles_dir = os.path.join(root_dir, "SCRIBBLES") elif dataset_name == "VOC07": dataset_dir = os.path.join(root_dir, "VOC2007") img_dir = dataset_dir gt_dir = dataset_dir scribbles_dir = os.path.join(root_dir, "SCRIBBLES") elif dataset_name == "VOC12": dataset_dir = os.path.join(root_dir, "VOC2012") img_dir = dataset_dir gt_dir = dataset_dir scribbles_dir = os.path.join(root_dir, "SCRIBBLES") elif dataset_name == "COCO17": dataset_dir = os.path.join(root_dir, "COCO") img_dir = dataset_dir gt_dir = dataset_dir scribbles_dir = os.path.join(root_dir, "SCRIBBLES") elif dataset_name == "ImageNet": dataset_dir = os.path.join(root_dir, "ImageNet") img_dir = dataset_dir gt_dir = dataset_dir else: raise ValueError(f"Unknown dataset {dataset_name}") return img_dir, gt_dir, scribbles_dir def build_dataset( root_dir: str, dataset_name: str, dataset_set: Optional[str] = None, for_eval: bool = False, config=None, evaluation_type="saliency", # uod, ): """ Build dataset """ if evaluation_type == "saliency": # training data loaded from here img_dir, gt_dir, scribbles_dir = set_dataset_dir(dataset_name, root_dir) dataset = PeekabooDataset( name=dataset_name, img_dir=img_dir, gt_dir=gt_dir, scribbles_dir=scribbles_dir, dataset_set=dataset_set, config=config, for_eval=for_eval, evaluation_type=evaluation_type, ) elif evaluation_type == "uod": assert dataset_name in ["VOC07", "VOC12", "COCO20k"] dataset_set = "trainval" if dataset_name in ["VOC07", "VOC12"] else "train" no_hards = False dataset = UODDataset( dataset_name, dataset_set, root_dir=root_dir, remove_hards=no_hards, ) return dataset class PeekabooDataset(Dataset): def __init__( self, name: str, img_dir: str, gt_dir: str, scribbles_dir: str, dataset_set: Optional[str] = None, config=None, for_eval: bool = False, evaluation_type: str = "saliency", ) -> None: """ Args: root_dir (string): Directory with all the images. transform (callable, optional): Optional transform to be applied on a sample. """ self.for_eval = for_eval self.use_aug = not for_eval self.evaluation_type = evaluation_type assert evaluation_type in ["saliency"] self.name = name self.dataset_set = dataset_set self.img_dir = img_dir self.gt_dir = gt_dir self.scribbles_dir = scribbles_dir # if VOC dataset self.loader = None self.cocoGt = None self.config = config if "VOC" in self.name: self.loader = create_VOC_loader(self.img_dir, dataset_set, evaluation_type) # if ImageNet dataset elif "ImageNet" in self.name: self.loader = torchvision.datasets.ImageNet( self.img_dir, split=dataset_set, transform=None, target_transform=None, ) elif "COCO" in self.name: year = int("20" + self.name[-2:]) annFile = f"/datasets_local/COCO/annotations/instances_{dataset_set}{str(year)}.json" self.cocoGt = COCO(annFile) self.img_ids = list(sorted(self.cocoGt.getImgIds())) self.img_dir = f"/datasets_local/COCO/images/{dataset_set}{str(year)}/" # Transformations if self.for_eval: ( full_img_transform, no_norm_full_img_transform, ) = self.get_init_transformation(isVOC="VOC" in name) self.full_img_transform = full_img_transform self.no_norm_full_img_transform = no_norm_full_img_transform # Images self.list_images = None self.list_scribbles = None if not "VOC" in self.name and not "COCO" in self.name: self.list_images = [ os.path.join(img_dir, i) for i in sorted(os.listdir(img_dir)) ] # get path to scribbles, high masks are used, see https://github.com/hasibzunair/msl-recognition self.list_scribbles = sorted(glob.glob(scribbles_dir + "/*.png"))[::-1][ :1000 ] # For heavy masking [::-1] self.ignore_index = -1 self.mean = NORMALIZE.mean self.std = NORMALIZE.std self.to_tensor_and_normalize = T.Compose([T.ToTensor(), NORMALIZE]) self.normalize = NORMALIZE if config is not None and self.use_aug: self._set_aug(config) def get_init_transformation(self, isVOC: bool = False): if isVOC: t = T.Compose( [T.PILToTensor(), T.ConvertImageDtype(torch.float), NORMALIZE] ) t_nonorm = T.Compose([T.PILToTensor(), T.ConvertImageDtype(torch.float)]) return t, t_nonorm else: t = T.Compose([T.ToTensor(), NORMALIZE]) t_nonorm = T.Compose([T.ToTensor()]) return t, t_nonorm def _set_aug(self, config): """ Set augmentation based on config. """ photometric_aug = config.training["photometric_aug"] self.cropping_strategy = config.training["cropping_strategy"] if self.cropping_strategy == "center_crop": self.use_aug = False # default strategy, not considered to be a data aug self.scale_range = config.training["scale_range"] self.crop_size = config.training["crop_size"] self.center_crop_transforms = T.Compose( [ T.CenterCrop((self.crop_size, self.crop_size)), T.ToTensor(), ] ) self.center_crop_only_transforms = T.Compose( [T.CenterCrop((self.crop_size, self.crop_size)), T.PILToTensor()] ) self.proba_photometric_aug = config.training["proba_photometric_aug"] self.random_color_jitter = False self.random_grayscale = False self.random_gaussian_blur = False if photometric_aug == "color_jitter": self.random_color_jitter = True elif photometric_aug == "grayscale": self.random_grayscale = True elif photometric_aug == "gaussian_blur": self.random_gaussian_blur = True def _preprocess_data_aug( self, image: Image.Image, mask: Image.Image, ignore_index: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Prepare data in a proper form for either training (data augmentation) or validation.""" # resize to base size image = resize( image, size=self.crop_size, edge="shorter", interpolation="bilinear", ) mask = resize( mask, size=self.crop_size, edge="shorter", interpolation="bilinear", ) if not isinstance(mask, torch.Tensor): mask: torch.Tensor = torch.tensor(np.array(mask)) random_scale_range = None random_crop_size = None random_hflip_p = None if self.cropping_strategy == "random_scale": random_scale_range = self.scale_range elif self.cropping_strategy == "random_crop": random_crop_size = self.crop_size elif self.cropping_strategy == "random_hflip": random_hflip_p = 0.5 elif self.cropping_strategy == "random_crop_and_hflip": random_hflip_p = 0.5 random_crop_size = self.crop_size if random_crop_size or random_hflip_p or random_scale_range: image, mask = geometric_augmentations( image=image, mask=mask, random_scale_range=random_scale_range, random_crop_size=random_crop_size, ignore_index=ignore_index, random_hflip_p=random_hflip_p, ) if random_scale_range: # resize to (self.crop_size, self.crop_size) image = resize( image, size=self.crop_size, interpolation="bilinear", ) mask = resize( mask, size=(self.crop_size, self.crop_size), interpolation="bilinear", ) image = photometric_augmentations( image, random_color_jitter=self.random_color_jitter, random_grayscale=self.random_grayscale, random_gaussian_blur=self.random_gaussian_blur, proba_photometric_aug=self.proba_photometric_aug, ) # to tensor + normalize image image = self.to_tensor_and_normalize(image) return image, mask def __len__(self) -> int: if "VOC" in self.name: return len(self.loader) elif "ImageNet" in self.name: return len(self.loader) elif "COCO" in self.name: return len(self.img_ids) return len(self.list_images) def _apply_center_crop( self, image: Image.Image, mask: Union[Image.Image, np.ndarray, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: img_t = self.center_crop_transforms(image) # need to normalize image img_t = self.normalize(img_t) mask_gt = self.center_crop_transforms(mask).squeeze() return img_t, mask_gt def _preprocess_scribble(self, img, img_size): transform = T.Compose( [ T.Resize(img_size, BICUBIC), T.CenterCrop(img_size), T.ToTensor(), ] ) return transform(img) def __getitem__(self, idx, get_mask_gt=True): if "VOC" in self.name: img, gt_labels = self.loader[idx] if self.evaluation_type == "uod": gt_labels, _ = get_voc_detection_gt(gt_labels, remove_hards=False) elif self.evaluation_type == "saliency": mask_gt = create_gt_masks_if_voc(gt_labels) img_path = self.loader.images[idx] elif "ImageNet" in self.name: img, _ = self.loader[idx] img_path = self.loader.imgs[idx][0] # empty mask since no gt mask, only class label zeros = np.zeros(np.array(img).shape[:2]) mask_gt = Image.fromarray(zeros) elif "COCO" in self.name: img_id = self.img_ids[idx] path = self.cocoGt.loadImgs(img_id)[0]["file_name"] img = Image.open(os.path.join(self.img_dir, path)).convert("RGB") _ = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(id)) img_path = self.img_ids[idx] # What matters most is the id for eval # empty mask since no gt mask, only class label zeros = np.zeros(np.array(img).shape[:2]) mask_gt = Image.fromarray(zeros) # For all others else: img_path = self.list_images[idx] scribble_path = self.list_scribbles[random.randint(0, 950)] # read image with open(img_path, "rb") as f: img = Image.open(f) img = img.convert("RGB") im_name = img_path.split("/")[-1] mask_gt = Image.open( os.path.join(self.gt_dir, im_name.replace(".jpg", ".png")) ).convert("L") if self.for_eval: img_t = self.full_img_transform(img) img_init = self.no_norm_full_img_transform(img) if self.evaluation_type == "saliency": mask_gt = torch.tensor(np.array(mask_gt)).squeeze() mask_gt = np.array(mask_gt) mask_gt = mask_gt == 255 mask_gt = torch.tensor(mask_gt) else: if self.use_aug: img_t, mask_gt = self._preprocess_data_aug( image=img, mask=mask_gt, ignore_index=self.ignore_index ) mask_gt = np.array(mask_gt) mask_gt = mask_gt == 255 mask_gt = torch.tensor(mask_gt) else: # no data aug img_t, mask_gt = self._apply_center_crop(image=img, mask=mask_gt) gt_labels = self.center_crop_only_transforms(gt_labels).squeeze() mask_gt = np.asarray(mask_gt, np.int64) mask_gt = mask_gt == 1 mask_gt = torch.tensor(mask_gt) img_init = unnormalize(img_t) if not get_mask_gt: mask_gt = None if self.evaluation_type == "uod": gt_labels = torch.tensor(gt_labels) mask_gt = gt_labels # read scribble with open(scribble_path, "rb") as f: scribble = Image.open(f).convert("P") scribble = self._preprocess_scribble(scribble, img_t.shape[1]) scribble = (scribble > 0).float() # threshold to [0,1] scribble = torch.max(scribble) - scribble # inverted scribble # create masked input image with scribble when training if not self.for_eval: masked_img_t = img_t * scribble masked_img_init = unnormalize(masked_img_t) else: masked_img_t = img_t masked_img_init = img_init # returns the # image, masked image, scribble, # un-normalized image, un-normalized masked image # ground truth mask, image path return ( img_t, masked_img_t, scribble, img_init, masked_img_init, mask_gt, img_path, ) def fullimg_mode(self): self.val_full_image = True def training_mode(self): self.val_full_image = False