import os import tarfile import h5py import numpy as np import torch import torch.utils.data as data from PIL import Image from scipy import io from torchvision.datasets.utils import download_url DATASET_YEAR_DICT = { "2012": { "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar", "filename": "VOCtrainval_11-May-2012.tar", "md5": "6cd6e144f989b92b3379bac3b3de84fd", "base_dir": "VOCdevkit/VOC2012", }, "2011": { "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar", "filename": "VOCtrainval_25-May-2011.tar", "md5": "6c3384ef61512963050cb5d687e5bf1e", "base_dir": "TrainVal/VOCdevkit/VOC2011", }, "2010": { "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar", "filename": "VOCtrainval_03-May-2010.tar", "md5": "da459979d0c395079b5c75ee67908abb", "base_dir": "VOCdevkit/VOC2010", }, "2009": { "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar", "filename": "VOCtrainval_11-May-2009.tar", "md5": "59065e4b188729180974ef6572f6a212", "base_dir": "VOCdevkit/VOC2009", }, "2008": { "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar", "filename": "VOCtrainval_11-May-2012.tar", "md5": "2629fa636546599198acfcfbfcf1904a", "base_dir": "VOCdevkit/VOC2008", }, "2007": { "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar", "filename": "VOCtrainval_06-Nov-2007.tar", "md5": "c52e279531787c972589f7e41ab4ae64", "base_dir": "VOCdevkit/VOC2007", }, } class VOCSegmentation(data.Dataset): """`Pascal VOC `_ Segmentation Dataset. Args: root (string): Root directory of the VOC Dataset. year (string, optional): The dataset year, supports years 2007 to 2012. image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val`` download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. """ CLASSES = 20 CLASSES_NAMES = [ "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "potted-plant", "sheep", "sofa", "train", "tvmonitor", "ambigious", ] def __init__( self, root, year="2012", image_set="train", download=False, transform=None, target_transform=None, ): self.root = os.path.expanduser(root) self.year = year self.url = DATASET_YEAR_DICT[year]["url"] self.filename = DATASET_YEAR_DICT[year]["filename"] self.md5 = DATASET_YEAR_DICT[year]["md5"] self.transform = transform self.target_transform = target_transform self.image_set = image_set base_dir = DATASET_YEAR_DICT[year]["base_dir"] voc_root = os.path.join(self.root, base_dir) image_dir = os.path.join(voc_root, "JPEGImages") mask_dir = os.path.join(voc_root, "SegmentationClass") if download: download_extract(self.url, self.root, self.filename, self.md5) if not os.path.isdir(voc_root): raise RuntimeError( "Dataset not found or corrupted." + " You can use download=True to download it" ) splits_dir = os.path.join(voc_root, "ImageSets/Segmentation") split_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".txt") if not os.path.exists(split_f): raise ValueError( 'Wrong image_set entered! Please use image_set="train" ' 'or image_set="trainval" or image_set="val"' ) with open(os.path.join(split_f), "r") as f: file_names = [x.strip() for x in f.readlines()] self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names] self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names] assert len(self.images) == len(self.masks) def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target) where target is the image segmentation. """ img = Image.open(self.images[index]).convert("RGB") target = Image.open(self.masks[index]) if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = np.array(self.target_transform(target)).astype("int32") target[target == 255] = -1 target = torch.from_numpy(target).long() return img, target @staticmethod def _mask_transform(mask): target = np.array(mask).astype("int32") target[target == 255] = -1 return torch.from_numpy(target).long() def __len__(self): return len(self.images) @property def pred_offset(self): return 0 class VOCClassification(data.Dataset): """`Pascal VOC `_ Segmentation Dataset. Args: root (string): Root directory of the VOC Dataset. year (string, optional): The dataset year, supports years 2007 to 2012. image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val`` download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` """ CLASSES = 20 def __init__( self, root, year="2012", image_set="train", download=False, transform=None ): self.root = os.path.expanduser(root) self.year = year self.url = DATASET_YEAR_DICT[year]["url"] self.filename = DATASET_YEAR_DICT[year]["filename"] self.md5 = DATASET_YEAR_DICT[year]["md5"] self.transform = transform self.image_set = image_set base_dir = DATASET_YEAR_DICT[year]["base_dir"] voc_root = os.path.join(self.root, base_dir) image_dir = os.path.join(voc_root, "JPEGImages") mask_dir = os.path.join(voc_root, "SegmentationClass") if download: download_extract(self.url, self.root, self.filename, self.md5) if not os.path.isdir(voc_root): raise RuntimeError( "Dataset not found or corrupted." + " You can use download=True to download it" ) splits_dir = os.path.join(voc_root, "ImageSets/Segmentation") split_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".txt") if not os.path.exists(split_f): raise ValueError( 'Wrong image_set entered! Please use image_set="train" ' 'or image_set="trainval" or image_set="val"' ) with open(os.path.join(split_f), "r") as f: file_names = [x.strip() for x in f.readlines()] self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names] self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names] assert len(self.images) == len(self.masks) def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target) where target is the image segmentation. """ img = Image.open(self.images[index]).convert("RGB") target = Image.open(self.masks[index]) # if self.transform is not None: # img = self.transform(img) if self.transform is not None: img, target = self.transform(img, target) visible_classes = np.unique(target) labels = torch.zeros(self.CLASSES) for id in visible_classes: if id not in (0, 255): labels[id - 1].fill_(1) return img, labels def __len__(self): return len(self.images) class VOCSBDClassification(data.Dataset): """`Pascal VOC `_ Segmentation Dataset. Args: root (string): Root directory of the VOC Dataset. year (string, optional): The dataset year, supports years 2007 to 2012. image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val`` download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` """ CLASSES = 20 def __init__( self, root, sbd_root, year="2012", image_set="train", download=False, transform=None, ): self.root = os.path.expanduser(root) self.sbd_root = os.path.expanduser(sbd_root) self.year = year self.url = DATASET_YEAR_DICT[year]["url"] self.filename = DATASET_YEAR_DICT[year]["filename"] self.md5 = DATASET_YEAR_DICT[year]["md5"] self.transform = transform self.image_set = image_set base_dir = DATASET_YEAR_DICT[year]["base_dir"] voc_root = os.path.join(self.root, base_dir) image_dir = os.path.join(voc_root, "JPEGImages") mask_dir = os.path.join(voc_root, "SegmentationClass") sbd_image_dir = os.path.join(sbd_root, "img") sbd_mask_dir = os.path.join(sbd_root, "cls") if download: download_extract(self.url, self.root, self.filename, self.md5) if not os.path.isdir(voc_root): raise RuntimeError( "Dataset not found or corrupted." + " You can use download=True to download it" ) splits_dir = os.path.join(voc_root, "ImageSets/Segmentation") split_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".txt") sbd_split = os.path.join(sbd_root, "train.txt") if not os.path.exists(split_f): raise ValueError( 'Wrong image_set entered! Please use image_set="train" ' 'or image_set="trainval" or image_set="val"' ) with open(os.path.join(split_f), "r") as f: voc_file_names = [x.strip() for x in f.readlines()] with open(os.path.join(sbd_split), "r") as f: sbd_file_names = [x.strip() for x in f.readlines()] self.images = [os.path.join(image_dir, x + ".jpg") for x in voc_file_names] self.images += [os.path.join(sbd_image_dir, x + ".jpg") for x in sbd_file_names] self.masks = [os.path.join(mask_dir, x + ".png") for x in voc_file_names] self.masks += [os.path.join(sbd_mask_dir, x + ".mat") for x in sbd_file_names] assert len(self.images) == len(self.masks) def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target) where target is the image segmentation. """ img = Image.open(self.images[index]).convert("RGB") mask_path = self.masks[index] if mask_path[-3:] == "mat": target = io.loadmat(mask_path, struct_as_record=False, squeeze_me=True)[ "GTcls" ].Segmentation target = Image.fromarray(target, mode="P") else: target = Image.open(self.masks[index]) if self.transform is not None: img, target = self.transform(img, target) visible_classes = np.unique(target) labels = torch.zeros(self.CLASSES) for id in visible_classes: if id not in (0, 255): labels[id - 1].fill_(1) return img, labels def __len__(self): return len(self.images) def download_extract(url, root, filename, md5): download_url(url, root, filename, md5) with tarfile.open(os.path.join(root, filename), "r") as tar: tar.extractall(path=root) class VOCResults(data.Dataset): CLASSES = 20 CLASSES_NAMES = [ "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "potted-plant", "sheep", "sofa", "train", "tvmonitor", "ambigious", ] def __init__(self, path): super(VOCResults, self).__init__() self.path = os.path.join(path, "results.hdf5") self.data = None print("Reading dataset length...") with h5py.File(self.path, "r") as f: self.data_length = len(f["/image"]) def __len__(self): return self.data_length def __getitem__(self, item): if self.data is None: self.data = h5py.File(self.path, "r") image = torch.tensor(self.data["image"][item]) vis = torch.tensor(self.data["vis"][item]) target = torch.tensor(self.data["target"][item]) class_pred = torch.tensor(self.data["class_pred"][item]) return image, vis, target, class_pred