import os
import tarfile
import h5py
import numpy as np
import torch
import torch.utils.data as data
from PIL import Image
from scipy import io
from torchvision.datasets.utils import download_url
DATASET_YEAR_DICT = {
"2012": {
"url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar",
"filename": "VOCtrainval_11-May-2012.tar",
"md5": "6cd6e144f989b92b3379bac3b3de84fd",
"base_dir": "VOCdevkit/VOC2012",
},
"2011": {
"url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar",
"filename": "VOCtrainval_25-May-2011.tar",
"md5": "6c3384ef61512963050cb5d687e5bf1e",
"base_dir": "TrainVal/VOCdevkit/VOC2011",
},
"2010": {
"url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar",
"filename": "VOCtrainval_03-May-2010.tar",
"md5": "da459979d0c395079b5c75ee67908abb",
"base_dir": "VOCdevkit/VOC2010",
},
"2009": {
"url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar",
"filename": "VOCtrainval_11-May-2009.tar",
"md5": "59065e4b188729180974ef6572f6a212",
"base_dir": "VOCdevkit/VOC2009",
},
"2008": {
"url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar",
"filename": "VOCtrainval_11-May-2012.tar",
"md5": "2629fa636546599198acfcfbfcf1904a",
"base_dir": "VOCdevkit/VOC2008",
},
"2007": {
"url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar",
"filename": "VOCtrainval_06-Nov-2007.tar",
"md5": "c52e279531787c972589f7e41ab4ae64",
"base_dir": "VOCdevkit/VOC2007",
},
}
class VOCSegmentation(data.Dataset):
"""`Pascal VOC `_ Segmentation Dataset.
Args:
root (string): Root directory of the VOC Dataset.
year (string, optional): The dataset year, supports years 2007 to 2012.
image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
CLASSES = 20
CLASSES_NAMES = [
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"potted-plant",
"sheep",
"sofa",
"train",
"tvmonitor",
"ambigious",
]
def __init__(
self,
root,
year="2012",
image_set="train",
download=False,
transform=None,
target_transform=None,
):
self.root = os.path.expanduser(root)
self.year = year
self.url = DATASET_YEAR_DICT[year]["url"]
self.filename = DATASET_YEAR_DICT[year]["filename"]
self.md5 = DATASET_YEAR_DICT[year]["md5"]
self.transform = transform
self.target_transform = target_transform
self.image_set = image_set
base_dir = DATASET_YEAR_DICT[year]["base_dir"]
voc_root = os.path.join(self.root, base_dir)
image_dir = os.path.join(voc_root, "JPEGImages")
mask_dir = os.path.join(voc_root, "SegmentationClass")
if download:
download_extract(self.url, self.root, self.filename, self.md5)
if not os.path.isdir(voc_root):
raise RuntimeError(
"Dataset not found or corrupted."
+ " You can use download=True to download it"
)
splits_dir = os.path.join(voc_root, "ImageSets/Segmentation")
split_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".txt")
if not os.path.exists(split_f):
raise ValueError(
'Wrong image_set entered! Please use image_set="train" '
'or image_set="trainval" or image_set="val"'
)
with open(os.path.join(split_f), "r") as f:
file_names = [x.strip() for x in f.readlines()]
self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
assert len(self.images) == len(self.masks)
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is the image segmentation.
"""
img = Image.open(self.images[index]).convert("RGB")
target = Image.open(self.masks[index])
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = np.array(self.target_transform(target)).astype("int32")
target[target == 255] = -1
target = torch.from_numpy(target).long()
return img, target
@staticmethod
def _mask_transform(mask):
target = np.array(mask).astype("int32")
target[target == 255] = -1
return torch.from_numpy(target).long()
def __len__(self):
return len(self.images)
@property
def pred_offset(self):
return 0
class VOCClassification(data.Dataset):
"""`Pascal VOC `_ Segmentation Dataset.
Args:
root (string): Root directory of the VOC Dataset.
year (string, optional): The dataset year, supports years 2007 to 2012.
image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
"""
CLASSES = 20
def __init__(
self, root, year="2012", image_set="train", download=False, transform=None
):
self.root = os.path.expanduser(root)
self.year = year
self.url = DATASET_YEAR_DICT[year]["url"]
self.filename = DATASET_YEAR_DICT[year]["filename"]
self.md5 = DATASET_YEAR_DICT[year]["md5"]
self.transform = transform
self.image_set = image_set
base_dir = DATASET_YEAR_DICT[year]["base_dir"]
voc_root = os.path.join(self.root, base_dir)
image_dir = os.path.join(voc_root, "JPEGImages")
mask_dir = os.path.join(voc_root, "SegmentationClass")
if download:
download_extract(self.url, self.root, self.filename, self.md5)
if not os.path.isdir(voc_root):
raise RuntimeError(
"Dataset not found or corrupted."
+ " You can use download=True to download it"
)
splits_dir = os.path.join(voc_root, "ImageSets/Segmentation")
split_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".txt")
if not os.path.exists(split_f):
raise ValueError(
'Wrong image_set entered! Please use image_set="train" '
'or image_set="trainval" or image_set="val"'
)
with open(os.path.join(split_f), "r") as f:
file_names = [x.strip() for x in f.readlines()]
self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
assert len(self.images) == len(self.masks)
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is the image segmentation.
"""
img = Image.open(self.images[index]).convert("RGB")
target = Image.open(self.masks[index])
# if self.transform is not None:
# img = self.transform(img)
if self.transform is not None:
img, target = self.transform(img, target)
visible_classes = np.unique(target)
labels = torch.zeros(self.CLASSES)
for id in visible_classes:
if id not in (0, 255):
labels[id - 1].fill_(1)
return img, labels
def __len__(self):
return len(self.images)
class VOCSBDClassification(data.Dataset):
"""`Pascal VOC `_ Segmentation Dataset.
Args:
root (string): Root directory of the VOC Dataset.
year (string, optional): The dataset year, supports years 2007 to 2012.
image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
"""
CLASSES = 20
def __init__(
self,
root,
sbd_root,
year="2012",
image_set="train",
download=False,
transform=None,
):
self.root = os.path.expanduser(root)
self.sbd_root = os.path.expanduser(sbd_root)
self.year = year
self.url = DATASET_YEAR_DICT[year]["url"]
self.filename = DATASET_YEAR_DICT[year]["filename"]
self.md5 = DATASET_YEAR_DICT[year]["md5"]
self.transform = transform
self.image_set = image_set
base_dir = DATASET_YEAR_DICT[year]["base_dir"]
voc_root = os.path.join(self.root, base_dir)
image_dir = os.path.join(voc_root, "JPEGImages")
mask_dir = os.path.join(voc_root, "SegmentationClass")
sbd_image_dir = os.path.join(sbd_root, "img")
sbd_mask_dir = os.path.join(sbd_root, "cls")
if download:
download_extract(self.url, self.root, self.filename, self.md5)
if not os.path.isdir(voc_root):
raise RuntimeError(
"Dataset not found or corrupted."
+ " You can use download=True to download it"
)
splits_dir = os.path.join(voc_root, "ImageSets/Segmentation")
split_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".txt")
sbd_split = os.path.join(sbd_root, "train.txt")
if not os.path.exists(split_f):
raise ValueError(
'Wrong image_set entered! Please use image_set="train" '
'or image_set="trainval" or image_set="val"'
)
with open(os.path.join(split_f), "r") as f:
voc_file_names = [x.strip() for x in f.readlines()]
with open(os.path.join(sbd_split), "r") as f:
sbd_file_names = [x.strip() for x in f.readlines()]
self.images = [os.path.join(image_dir, x + ".jpg") for x in voc_file_names]
self.images += [os.path.join(sbd_image_dir, x + ".jpg") for x in sbd_file_names]
self.masks = [os.path.join(mask_dir, x + ".png") for x in voc_file_names]
self.masks += [os.path.join(sbd_mask_dir, x + ".mat") for x in sbd_file_names]
assert len(self.images) == len(self.masks)
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is the image segmentation.
"""
img = Image.open(self.images[index]).convert("RGB")
mask_path = self.masks[index]
if mask_path[-3:] == "mat":
target = io.loadmat(mask_path, struct_as_record=False, squeeze_me=True)[
"GTcls"
].Segmentation
target = Image.fromarray(target, mode="P")
else:
target = Image.open(self.masks[index])
if self.transform is not None:
img, target = self.transform(img, target)
visible_classes = np.unique(target)
labels = torch.zeros(self.CLASSES)
for id in visible_classes:
if id not in (0, 255):
labels[id - 1].fill_(1)
return img, labels
def __len__(self):
return len(self.images)
def download_extract(url, root, filename, md5):
download_url(url, root, filename, md5)
with tarfile.open(os.path.join(root, filename), "r") as tar:
tar.extractall(path=root)
class VOCResults(data.Dataset):
CLASSES = 20
CLASSES_NAMES = [
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"potted-plant",
"sheep",
"sofa",
"train",
"tvmonitor",
"ambigious",
]
def __init__(self, path):
super(VOCResults, self).__init__()
self.path = os.path.join(path, "results.hdf5")
self.data = None
print("Reading dataset length...")
with h5py.File(self.path, "r") as f:
self.data_length = len(f["/image"])
def __len__(self):
return self.data_length
def __getitem__(self, item):
if self.data is None:
self.data = h5py.File(self.path, "r")
image = torch.tensor(self.data["image"][item])
vis = torch.tensor(self.data["vis"][item])
target = torch.tensor(self.data["target"][item])
class_pred = torch.tensor(self.data["class_pred"][item])
return image, vis, target, class_pred