peekaboo-demo / datasets /datasets.py
hasibzunair's picture
add files
1803579
raw
history blame
16.2 kB
# Code for Peekaboo
# Author: Hasib Zunair
# Modified from https://github.com/NoelShin/selfmask
"""
Dataset functions for applying Normalized Cut.
"""
import os
import glob
import random
from typing import Optional, Tuple, Union
from pycocotools.coco import COCO
import numpy as np
import torch
import torchvision
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms as T
try:
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC
from datasets.utils import unnormalize
from datasets.geometric_transforms import resize
from datasets.VOC import get_voc_detection_gt, create_gt_masks_if_voc, create_VOC_loader
from datasets.augmentations import geometric_augmentations, photometric_augmentations
from datasets.uod_datasets import UODDataset
NORMALIZE = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
def set_dataset_dir(dataset_name, root_dir):
if dataset_name == "ECSSD":
dataset_dir = os.path.join(root_dir, "ECSSD")
img_dir = os.path.join(dataset_dir, "images")
gt_dir = os.path.join(dataset_dir, "ground_truth_mask")
scribbles_dir = os.path.join(root_dir, "SCRIBBLES")
elif dataset_name == "DUTS-TEST":
dataset_dir = os.path.join(root_dir, "DUTS-TE")
img_dir = os.path.join(dataset_dir, "DUTS-TE-Image")
gt_dir = os.path.join(dataset_dir, "DUTS-TE-Mask")
scribbles_dir = os.path.join(root_dir, "SCRIBBLES")
elif dataset_name == "DUTS-TR":
dataset_dir = os.path.join(root_dir, "DUTS-TR")
img_dir = os.path.join(dataset_dir, "DUTS-TR-Image")
gt_dir = os.path.join(dataset_dir, "DUTS-TR-Mask")
scribbles_dir = os.path.join(root_dir, "SCRIBBLES")
elif dataset_name == "DUT-OMRON":
dataset_dir = os.path.join(root_dir, "DUT-OMRON")
img_dir = os.path.join(dataset_dir, "DUT-OMRON-image")
gt_dir = os.path.join(dataset_dir, "pixelwiseGT-new-PNG")
scribbles_dir = os.path.join(root_dir, "SCRIBBLES")
elif dataset_name == "VOC07":
dataset_dir = os.path.join(root_dir, "VOC2007")
img_dir = dataset_dir
gt_dir = dataset_dir
scribbles_dir = os.path.join(root_dir, "SCRIBBLES")
elif dataset_name == "VOC12":
dataset_dir = os.path.join(root_dir, "VOC2012")
img_dir = dataset_dir
gt_dir = dataset_dir
scribbles_dir = os.path.join(root_dir, "SCRIBBLES")
elif dataset_name == "COCO17":
dataset_dir = os.path.join(root_dir, "COCO")
img_dir = dataset_dir
gt_dir = dataset_dir
scribbles_dir = os.path.join(root_dir, "SCRIBBLES")
elif dataset_name == "ImageNet":
dataset_dir = os.path.join(root_dir, "ImageNet")
img_dir = dataset_dir
gt_dir = dataset_dir
else:
raise ValueError(f"Unknown dataset {dataset_name}")
return img_dir, gt_dir, scribbles_dir
def build_dataset(
root_dir: str,
dataset_name: str,
dataset_set: Optional[str] = None,
for_eval: bool = False,
config=None,
evaluation_type="saliency", # uod,
):
"""
Build dataset
"""
if evaluation_type == "saliency":
# training data loaded from here
img_dir, gt_dir, scribbles_dir = set_dataset_dir(dataset_name, root_dir)
dataset = PeekabooDataset(
name=dataset_name,
img_dir=img_dir,
gt_dir=gt_dir,
scribbles_dir=scribbles_dir,
dataset_set=dataset_set,
config=config,
for_eval=for_eval,
evaluation_type=evaluation_type,
)
elif evaluation_type == "uod":
assert dataset_name in ["VOC07", "VOC12", "COCO20k"]
dataset_set = "trainval" if dataset_name in ["VOC07", "VOC12"] else "train"
no_hards = False
dataset = UODDataset(
dataset_name,
dataset_set,
root_dir=root_dir,
remove_hards=no_hards,
)
return dataset
class PeekabooDataset(Dataset):
def __init__(
self,
name: str,
img_dir: str,
gt_dir: str,
scribbles_dir: str,
dataset_set: Optional[str] = None,
config=None,
for_eval: bool = False,
evaluation_type: str = "saliency",
) -> None:
"""
Args:
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.for_eval = for_eval
self.use_aug = not for_eval
self.evaluation_type = evaluation_type
assert evaluation_type in ["saliency"]
self.name = name
self.dataset_set = dataset_set
self.img_dir = img_dir
self.gt_dir = gt_dir
self.scribbles_dir = scribbles_dir
# if VOC dataset
self.loader = None
self.cocoGt = None
self.config = config
if "VOC" in self.name:
self.loader = create_VOC_loader(self.img_dir, dataset_set, evaluation_type)
# if ImageNet dataset
elif "ImageNet" in self.name:
self.loader = torchvision.datasets.ImageNet(
self.img_dir,
split=dataset_set,
transform=None,
target_transform=None,
)
elif "COCO" in self.name:
year = int("20" + self.name[-2:])
annFile = f"/datasets_local/COCO/annotations/instances_{dataset_set}{str(year)}.json"
self.cocoGt = COCO(annFile)
self.img_ids = list(sorted(self.cocoGt.getImgIds()))
self.img_dir = f"/datasets_local/COCO/images/{dataset_set}{str(year)}/"
# Transformations
if self.for_eval:
(
full_img_transform,
no_norm_full_img_transform,
) = self.get_init_transformation(isVOC="VOC" in name)
self.full_img_transform = full_img_transform
self.no_norm_full_img_transform = no_norm_full_img_transform
# Images
self.list_images = None
self.list_scribbles = None
if not "VOC" in self.name and not "COCO" in self.name:
self.list_images = [
os.path.join(img_dir, i) for i in sorted(os.listdir(img_dir))
]
# get path to scribbles, high masks are used, see https://github.com/hasibzunair/msl-recognition
self.list_scribbles = sorted(glob.glob(scribbles_dir + "/*.png"))[::-1][
:1000
] # For heavy masking [::-1]
self.ignore_index = -1
self.mean = NORMALIZE.mean
self.std = NORMALIZE.std
self.to_tensor_and_normalize = T.Compose([T.ToTensor(), NORMALIZE])
self.normalize = NORMALIZE
if config is not None and self.use_aug:
self._set_aug(config)
def get_init_transformation(self, isVOC: bool = False):
if isVOC:
t = T.Compose(
[T.PILToTensor(), T.ConvertImageDtype(torch.float), NORMALIZE]
)
t_nonorm = T.Compose([T.PILToTensor(), T.ConvertImageDtype(torch.float)])
return t, t_nonorm
else:
t = T.Compose([T.ToTensor(), NORMALIZE])
t_nonorm = T.Compose([T.ToTensor()])
return t, t_nonorm
def _set_aug(self, config):
"""
Set augmentation based on config.
"""
photometric_aug = config.training["photometric_aug"]
self.cropping_strategy = config.training["cropping_strategy"]
if self.cropping_strategy == "center_crop":
self.use_aug = False # default strategy, not considered to be a data aug
self.scale_range = config.training["scale_range"]
self.crop_size = config.training["crop_size"]
self.center_crop_transforms = T.Compose(
[
T.CenterCrop((self.crop_size, self.crop_size)),
T.ToTensor(),
]
)
self.center_crop_only_transforms = T.Compose(
[T.CenterCrop((self.crop_size, self.crop_size)), T.PILToTensor()]
)
self.proba_photometric_aug = config.training["proba_photometric_aug"]
self.random_color_jitter = False
self.random_grayscale = False
self.random_gaussian_blur = False
if photometric_aug == "color_jitter":
self.random_color_jitter = True
elif photometric_aug == "grayscale":
self.random_grayscale = True
elif photometric_aug == "gaussian_blur":
self.random_gaussian_blur = True
def _preprocess_data_aug(
self,
image: Image.Image,
mask: Image.Image,
ignore_index: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Prepare data in a proper form for either training (data augmentation) or validation."""
# resize to base size
image = resize(
image,
size=self.crop_size,
edge="shorter",
interpolation="bilinear",
)
mask = resize(
mask,
size=self.crop_size,
edge="shorter",
interpolation="bilinear",
)
if not isinstance(mask, torch.Tensor):
mask: torch.Tensor = torch.tensor(np.array(mask))
random_scale_range = None
random_crop_size = None
random_hflip_p = None
if self.cropping_strategy == "random_scale":
random_scale_range = self.scale_range
elif self.cropping_strategy == "random_crop":
random_crop_size = self.crop_size
elif self.cropping_strategy == "random_hflip":
random_hflip_p = 0.5
elif self.cropping_strategy == "random_crop_and_hflip":
random_hflip_p = 0.5
random_crop_size = self.crop_size
if random_crop_size or random_hflip_p or random_scale_range:
image, mask = geometric_augmentations(
image=image,
mask=mask,
random_scale_range=random_scale_range,
random_crop_size=random_crop_size,
ignore_index=ignore_index,
random_hflip_p=random_hflip_p,
)
if random_scale_range:
# resize to (self.crop_size, self.crop_size)
image = resize(
image,
size=self.crop_size,
interpolation="bilinear",
)
mask = resize(
mask,
size=(self.crop_size, self.crop_size),
interpolation="bilinear",
)
image = photometric_augmentations(
image,
random_color_jitter=self.random_color_jitter,
random_grayscale=self.random_grayscale,
random_gaussian_blur=self.random_gaussian_blur,
proba_photometric_aug=self.proba_photometric_aug,
)
# to tensor + normalize image
image = self.to_tensor_and_normalize(image)
return image, mask
def __len__(self) -> int:
if "VOC" in self.name:
return len(self.loader)
elif "ImageNet" in self.name:
return len(self.loader)
elif "COCO" in self.name:
return len(self.img_ids)
return len(self.list_images)
def _apply_center_crop(
self, image: Image.Image, mask: Union[Image.Image, np.ndarray, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
img_t = self.center_crop_transforms(image)
# need to normalize image
img_t = self.normalize(img_t)
mask_gt = self.center_crop_transforms(mask).squeeze()
return img_t, mask_gt
def _preprocess_scribble(self, img, img_size):
transform = T.Compose(
[
T.Resize(img_size, BICUBIC),
T.CenterCrop(img_size),
T.ToTensor(),
]
)
return transform(img)
def __getitem__(self, idx, get_mask_gt=True):
if "VOC" in self.name:
img, gt_labels = self.loader[idx]
if self.evaluation_type == "uod":
gt_labels, _ = get_voc_detection_gt(gt_labels, remove_hards=False)
elif self.evaluation_type == "saliency":
mask_gt = create_gt_masks_if_voc(gt_labels)
img_path = self.loader.images[idx]
elif "ImageNet" in self.name:
img, _ = self.loader[idx]
img_path = self.loader.imgs[idx][0]
# empty mask since no gt mask, only class label
zeros = np.zeros(np.array(img).shape[:2])
mask_gt = Image.fromarray(zeros)
elif "COCO" in self.name:
img_id = self.img_ids[idx]
path = self.cocoGt.loadImgs(img_id)[0]["file_name"]
img = Image.open(os.path.join(self.img_dir, path)).convert("RGB")
_ = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(id))
img_path = self.img_ids[idx] # What matters most is the id for eval
# empty mask since no gt mask, only class label
zeros = np.zeros(np.array(img).shape[:2])
mask_gt = Image.fromarray(zeros)
# For all others
else:
img_path = self.list_images[idx]
scribble_path = self.list_scribbles[random.randint(0, 950)]
# read image
with open(img_path, "rb") as f:
img = Image.open(f)
img = img.convert("RGB")
im_name = img_path.split("/")[-1]
mask_gt = Image.open(
os.path.join(self.gt_dir, im_name.replace(".jpg", ".png"))
).convert("L")
if self.for_eval:
img_t = self.full_img_transform(img)
img_init = self.no_norm_full_img_transform(img)
if self.evaluation_type == "saliency":
mask_gt = torch.tensor(np.array(mask_gt)).squeeze()
mask_gt = np.array(mask_gt)
mask_gt = mask_gt == 255
mask_gt = torch.tensor(mask_gt)
else:
if self.use_aug:
img_t, mask_gt = self._preprocess_data_aug(
image=img, mask=mask_gt, ignore_index=self.ignore_index
)
mask_gt = np.array(mask_gt)
mask_gt = mask_gt == 255
mask_gt = torch.tensor(mask_gt)
else:
# no data aug
img_t, mask_gt = self._apply_center_crop(image=img, mask=mask_gt)
gt_labels = self.center_crop_only_transforms(gt_labels).squeeze()
mask_gt = np.asarray(mask_gt, np.int64)
mask_gt = mask_gt == 1
mask_gt = torch.tensor(mask_gt)
img_init = unnormalize(img_t)
if not get_mask_gt:
mask_gt = None
if self.evaluation_type == "uod":
gt_labels = torch.tensor(gt_labels)
mask_gt = gt_labels
# read scribble
with open(scribble_path, "rb") as f:
scribble = Image.open(f).convert("P")
scribble = self._preprocess_scribble(scribble, img_t.shape[1])
scribble = (scribble > 0).float() # threshold to [0,1]
scribble = torch.max(scribble) - scribble # inverted scribble
# create masked input image with scribble when training
if not self.for_eval:
masked_img_t = img_t * scribble
masked_img_init = unnormalize(masked_img_t)
else:
masked_img_t = img_t
masked_img_init = img_init
# returns the
# image, masked image, scribble,
# un-normalized image, un-normalized masked image
# ground truth mask, image path
return (
img_t,
masked_img_t,
scribble,
img_init,
masked_img_init,
mask_gt,
img_path,
)
def fullimg_mode(self):
self.val_full_image = True
def training_mode(self):
self.val_full_image = False