sayakpaul's picture
sayakpaul HF staff
add files
c4b2b37
raw
history blame
No virus
14.2 kB
import os
import tarfile
import h5py
import numpy as np
import torch
import torch.utils.data as data
from PIL import Image
from scipy import io
from torchvision.datasets.utils import download_url
DATASET_YEAR_DICT = {
"2012": {
"url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar",
"filename": "VOCtrainval_11-May-2012.tar",
"md5": "6cd6e144f989b92b3379bac3b3de84fd",
"base_dir": "VOCdevkit/VOC2012",
},
"2011": {
"url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar",
"filename": "VOCtrainval_25-May-2011.tar",
"md5": "6c3384ef61512963050cb5d687e5bf1e",
"base_dir": "TrainVal/VOCdevkit/VOC2011",
},
"2010": {
"url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar",
"filename": "VOCtrainval_03-May-2010.tar",
"md5": "da459979d0c395079b5c75ee67908abb",
"base_dir": "VOCdevkit/VOC2010",
},
"2009": {
"url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar",
"filename": "VOCtrainval_11-May-2009.tar",
"md5": "59065e4b188729180974ef6572f6a212",
"base_dir": "VOCdevkit/VOC2009",
},
"2008": {
"url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar",
"filename": "VOCtrainval_11-May-2012.tar",
"md5": "2629fa636546599198acfcfbfcf1904a",
"base_dir": "VOCdevkit/VOC2008",
},
"2007": {
"url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar",
"filename": "VOCtrainval_06-Nov-2007.tar",
"md5": "c52e279531787c972589f7e41ab4ae64",
"base_dir": "VOCdevkit/VOC2007",
},
}
class VOCSegmentation(data.Dataset):
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
Args:
root (string): Root directory of the VOC Dataset.
year (string, optional): The dataset year, supports years 2007 to 2012.
image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
CLASSES = 20
CLASSES_NAMES = [
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"potted-plant",
"sheep",
"sofa",
"train",
"tvmonitor",
"ambigious",
]
def __init__(
self,
root,
year="2012",
image_set="train",
download=False,
transform=None,
target_transform=None,
):
self.root = os.path.expanduser(root)
self.year = year
self.url = DATASET_YEAR_DICT[year]["url"]
self.filename = DATASET_YEAR_DICT[year]["filename"]
self.md5 = DATASET_YEAR_DICT[year]["md5"]
self.transform = transform
self.target_transform = target_transform
self.image_set = image_set
base_dir = DATASET_YEAR_DICT[year]["base_dir"]
voc_root = os.path.join(self.root, base_dir)
image_dir = os.path.join(voc_root, "JPEGImages")
mask_dir = os.path.join(voc_root, "SegmentationClass")
if download:
download_extract(self.url, self.root, self.filename, self.md5)
if not os.path.isdir(voc_root):
raise RuntimeError(
"Dataset not found or corrupted."
+ " You can use download=True to download it"
)
splits_dir = os.path.join(voc_root, "ImageSets/Segmentation")
split_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".txt")
if not os.path.exists(split_f):
raise ValueError(
'Wrong image_set entered! Please use image_set="train" '
'or image_set="trainval" or image_set="val"'
)
with open(os.path.join(split_f), "r") as f:
file_names = [x.strip() for x in f.readlines()]
self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
assert len(self.images) == len(self.masks)
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is the image segmentation.
"""
img = Image.open(self.images[index]).convert("RGB")
target = Image.open(self.masks[index])
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = np.array(self.target_transform(target)).astype("int32")
target[target == 255] = -1
target = torch.from_numpy(target).long()
return img, target
@staticmethod
def _mask_transform(mask):
target = np.array(mask).astype("int32")
target[target == 255] = -1
return torch.from_numpy(target).long()
def __len__(self):
return len(self.images)
@property
def pred_offset(self):
return 0
class VOCClassification(data.Dataset):
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
Args:
root (string): Root directory of the VOC Dataset.
year (string, optional): The dataset year, supports years 2007 to 2012.
image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
"""
CLASSES = 20
def __init__(
self, root, year="2012", image_set="train", download=False, transform=None
):
self.root = os.path.expanduser(root)
self.year = year
self.url = DATASET_YEAR_DICT[year]["url"]
self.filename = DATASET_YEAR_DICT[year]["filename"]
self.md5 = DATASET_YEAR_DICT[year]["md5"]
self.transform = transform
self.image_set = image_set
base_dir = DATASET_YEAR_DICT[year]["base_dir"]
voc_root = os.path.join(self.root, base_dir)
image_dir = os.path.join(voc_root, "JPEGImages")
mask_dir = os.path.join(voc_root, "SegmentationClass")
if download:
download_extract(self.url, self.root, self.filename, self.md5)
if not os.path.isdir(voc_root):
raise RuntimeError(
"Dataset not found or corrupted."
+ " You can use download=True to download it"
)
splits_dir = os.path.join(voc_root, "ImageSets/Segmentation")
split_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".txt")
if not os.path.exists(split_f):
raise ValueError(
'Wrong image_set entered! Please use image_set="train" '
'or image_set="trainval" or image_set="val"'
)
with open(os.path.join(split_f), "r") as f:
file_names = [x.strip() for x in f.readlines()]
self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
assert len(self.images) == len(self.masks)
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is the image segmentation.
"""
img = Image.open(self.images[index]).convert("RGB")
target = Image.open(self.masks[index])
# if self.transform is not None:
# img = self.transform(img)
if self.transform is not None:
img, target = self.transform(img, target)
visible_classes = np.unique(target)
labels = torch.zeros(self.CLASSES)
for id in visible_classes:
if id not in (0, 255):
labels[id - 1].fill_(1)
return img, labels
def __len__(self):
return len(self.images)
class VOCSBDClassification(data.Dataset):
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
Args:
root (string): Root directory of the VOC Dataset.
year (string, optional): The dataset year, supports years 2007 to 2012.
image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
"""
CLASSES = 20
def __init__(
self,
root,
sbd_root,
year="2012",
image_set="train",
download=False,
transform=None,
):
self.root = os.path.expanduser(root)
self.sbd_root = os.path.expanduser(sbd_root)
self.year = year
self.url = DATASET_YEAR_DICT[year]["url"]
self.filename = DATASET_YEAR_DICT[year]["filename"]
self.md5 = DATASET_YEAR_DICT[year]["md5"]
self.transform = transform
self.image_set = image_set
base_dir = DATASET_YEAR_DICT[year]["base_dir"]
voc_root = os.path.join(self.root, base_dir)
image_dir = os.path.join(voc_root, "JPEGImages")
mask_dir = os.path.join(voc_root, "SegmentationClass")
sbd_image_dir = os.path.join(sbd_root, "img")
sbd_mask_dir = os.path.join(sbd_root, "cls")
if download:
download_extract(self.url, self.root, self.filename, self.md5)
if not os.path.isdir(voc_root):
raise RuntimeError(
"Dataset not found or corrupted."
+ " You can use download=True to download it"
)
splits_dir = os.path.join(voc_root, "ImageSets/Segmentation")
split_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".txt")
sbd_split = os.path.join(sbd_root, "train.txt")
if not os.path.exists(split_f):
raise ValueError(
'Wrong image_set entered! Please use image_set="train" '
'or image_set="trainval" or image_set="val"'
)
with open(os.path.join(split_f), "r") as f:
voc_file_names = [x.strip() for x in f.readlines()]
with open(os.path.join(sbd_split), "r") as f:
sbd_file_names = [x.strip() for x in f.readlines()]
self.images = [os.path.join(image_dir, x + ".jpg") for x in voc_file_names]
self.images += [os.path.join(sbd_image_dir, x + ".jpg") for x in sbd_file_names]
self.masks = [os.path.join(mask_dir, x + ".png") for x in voc_file_names]
self.masks += [os.path.join(sbd_mask_dir, x + ".mat") for x in sbd_file_names]
assert len(self.images) == len(self.masks)
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is the image segmentation.
"""
img = Image.open(self.images[index]).convert("RGB")
mask_path = self.masks[index]
if mask_path[-3:] == "mat":
target = io.loadmat(mask_path, struct_as_record=False, squeeze_me=True)[
"GTcls"
].Segmentation
target = Image.fromarray(target, mode="P")
else:
target = Image.open(self.masks[index])
if self.transform is not None:
img, target = self.transform(img, target)
visible_classes = np.unique(target)
labels = torch.zeros(self.CLASSES)
for id in visible_classes:
if id not in (0, 255):
labels[id - 1].fill_(1)
return img, labels
def __len__(self):
return len(self.images)
def download_extract(url, root, filename, md5):
download_url(url, root, filename, md5)
with tarfile.open(os.path.join(root, filename), "r") as tar:
tar.extractall(path=root)
class VOCResults(data.Dataset):
CLASSES = 20
CLASSES_NAMES = [
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"potted-plant",
"sheep",
"sofa",
"train",
"tvmonitor",
"ambigious",
]
def __init__(self, path):
super(VOCResults, self).__init__()
self.path = os.path.join(path, "results.hdf5")
self.data = None
print("Reading dataset length...")
with h5py.File(self.path, "r") as f:
self.data_length = len(f["/image"])
def __len__(self):
return self.data_length
def __getitem__(self, item):
if self.data is None:
self.data = h5py.File(self.path, "r")
image = torch.tensor(self.data["image"][item])
vis = torch.tensor(self.data["vis"][item])
target = torch.tensor(self.data["target"][item])
class_pred = torch.tensor(self.data["class_pred"][item])
return image, vis, target, class_pred