|
import os |
|
import tarfile |
|
import torch |
|
import torch.utils.data as data |
|
import numpy as np |
|
import h5py |
|
|
|
from PIL import Image |
|
from scipy import io |
|
from torchvision.datasets.utils import download_url |
|
|
|
DATASET_YEAR_DICT = { |
|
'2012': { |
|
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar', |
|
'filename': 'VOCtrainval_11-May-2012.tar', |
|
'md5': '6cd6e144f989b92b3379bac3b3de84fd', |
|
'base_dir': 'VOCdevkit/VOC2012' |
|
}, |
|
'2011': { |
|
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar', |
|
'filename': 'VOCtrainval_25-May-2011.tar', |
|
'md5': '6c3384ef61512963050cb5d687e5bf1e', |
|
'base_dir': 'TrainVal/VOCdevkit/VOC2011' |
|
}, |
|
'2010': { |
|
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar', |
|
'filename': 'VOCtrainval_03-May-2010.tar', |
|
'md5': 'da459979d0c395079b5c75ee67908abb', |
|
'base_dir': 'VOCdevkit/VOC2010' |
|
}, |
|
'2009': { |
|
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar', |
|
'filename': 'VOCtrainval_11-May-2009.tar', |
|
'md5': '59065e4b188729180974ef6572f6a212', |
|
'base_dir': 'VOCdevkit/VOC2009' |
|
}, |
|
'2008': { |
|
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar', |
|
'filename': 'VOCtrainval_11-May-2012.tar', |
|
'md5': '2629fa636546599198acfcfbfcf1904a', |
|
'base_dir': 'VOCdevkit/VOC2008' |
|
}, |
|
'2007': { |
|
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar', |
|
'filename': 'VOCtrainval_06-Nov-2007.tar', |
|
'md5': 'c52e279531787c972589f7e41ab4ae64', |
|
'base_dir': 'VOCdevkit/VOC2007' |
|
} |
|
} |
|
|
|
|
|
class VOCSegmentation(data.Dataset): |
|
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset. |
|
|
|
Args: |
|
root (string): Root directory of the VOC Dataset. |
|
year (string, optional): The dataset year, supports years 2007 to 2012. |
|
image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val`` |
|
download (bool, optional): If true, downloads the dataset from the internet and |
|
puts it in root directory. If dataset is already downloaded, it is not |
|
downloaded again. |
|
transform (callable, optional): A function/transform that takes in an PIL image |
|
and returns a transformed version. E.g, ``transforms.RandomCrop`` |
|
target_transform (callable, optional): A function/transform that takes in the |
|
target and transforms it. |
|
""" |
|
|
|
CLASSES = 20 |
|
CLASSES_NAMES = [ |
|
'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', |
|
'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', |
|
'motorbike', 'person', 'potted-plant', 'sheep', 'sofa', 'train', |
|
'tvmonitor', 'ambigious' |
|
] |
|
|
|
def __init__(self, |
|
root, |
|
year='2012', |
|
image_set='train', |
|
download=False, |
|
transform=None, |
|
target_transform=None): |
|
self.root = os.path.expanduser(root) |
|
self.year = year |
|
self.url = DATASET_YEAR_DICT[year]['url'] |
|
self.filename = DATASET_YEAR_DICT[year]['filename'] |
|
self.md5 = DATASET_YEAR_DICT[year]['md5'] |
|
self.transform = transform |
|
self.target_transform = target_transform |
|
self.image_set = image_set |
|
base_dir = DATASET_YEAR_DICT[year]['base_dir'] |
|
voc_root = os.path.join(self.root, base_dir) |
|
image_dir = os.path.join(voc_root, 'JPEGImages') |
|
mask_dir = os.path.join(voc_root, 'SegmentationClass') |
|
|
|
if download: |
|
download_extract(self.url, self.root, self.filename, self.md5) |
|
|
|
if not os.path.isdir(voc_root): |
|
raise RuntimeError('Dataset not found or corrupted.' + |
|
' You can use download=True to download it') |
|
|
|
splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation') |
|
|
|
split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt') |
|
|
|
if not os.path.exists(split_f): |
|
raise ValueError( |
|
'Wrong image_set entered! Please use image_set="train" ' |
|
'or image_set="trainval" or image_set="val"') |
|
|
|
with open(os.path.join(split_f), "r") as f: |
|
file_names = [x.strip() for x in f.readlines()] |
|
|
|
self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names] |
|
self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names] |
|
assert (len(self.images) == len(self.masks)) |
|
|
|
def __getitem__(self, index): |
|
""" |
|
Args: |
|
index (int): Index |
|
|
|
Returns: |
|
tuple: (image, target) where target is the image segmentation. |
|
""" |
|
img = Image.open(self.images[index]).convert('RGB') |
|
target = Image.open(self.masks[index]) |
|
|
|
if self.transform is not None: |
|
img = self.transform(img) |
|
|
|
if self.target_transform is not None: |
|
target = np.array(self.target_transform(target)).astype('int32') |
|
target[target == 255] = -1 |
|
target = torch.from_numpy(target).long() |
|
|
|
return img, target |
|
|
|
@staticmethod |
|
def _mask_transform(mask): |
|
target = np.array(mask).astype('int32') |
|
target[target == 255] = -1 |
|
return torch.from_numpy(target).long() |
|
|
|
def __len__(self): |
|
return len(self.images) |
|
|
|
@property |
|
def pred_offset(self): |
|
return 0 |
|
|
|
|
|
class VOCClassification(data.Dataset): |
|
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset. |
|
|
|
Args: |
|
root (string): Root directory of the VOC Dataset. |
|
year (string, optional): The dataset year, supports years 2007 to 2012. |
|
image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val`` |
|
download (bool, optional): If true, downloads the dataset from the internet and |
|
puts it in root directory. If dataset is already downloaded, it is not |
|
downloaded again. |
|
transform (callable, optional): A function/transform that takes in an PIL image |
|
and returns a transformed version. E.g, ``transforms.RandomCrop`` |
|
""" |
|
CLASSES = 20 |
|
|
|
def __init__(self, |
|
root, |
|
year='2012', |
|
image_set='train', |
|
download=False, |
|
transform=None): |
|
self.root = os.path.expanduser(root) |
|
self.year = year |
|
self.url = DATASET_YEAR_DICT[year]['url'] |
|
self.filename = DATASET_YEAR_DICT[year]['filename'] |
|
self.md5 = DATASET_YEAR_DICT[year]['md5'] |
|
self.transform = transform |
|
self.image_set = image_set |
|
base_dir = DATASET_YEAR_DICT[year]['base_dir'] |
|
voc_root = os.path.join(self.root, base_dir) |
|
image_dir = os.path.join(voc_root, 'JPEGImages') |
|
mask_dir = os.path.join(voc_root, 'SegmentationClass') |
|
|
|
if download: |
|
download_extract(self.url, self.root, self.filename, self.md5) |
|
|
|
if not os.path.isdir(voc_root): |
|
raise RuntimeError('Dataset not found or corrupted.' + |
|
' You can use download=True to download it') |
|
|
|
splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation') |
|
|
|
split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt') |
|
|
|
if not os.path.exists(split_f): |
|
raise ValueError( |
|
'Wrong image_set entered! Please use image_set="train" ' |
|
'or image_set="trainval" or image_set="val"') |
|
|
|
with open(os.path.join(split_f), "r") as f: |
|
file_names = [x.strip() for x in f.readlines()] |
|
|
|
self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names] |
|
self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names] |
|
assert (len(self.images) == len(self.masks)) |
|
|
|
def __getitem__(self, index): |
|
""" |
|
Args: |
|
index (int): Index |
|
|
|
Returns: |
|
tuple: (image, target) where target is the image segmentation. |
|
""" |
|
img = Image.open(self.images[index]).convert('RGB') |
|
target = Image.open(self.masks[index]) |
|
|
|
|
|
|
|
if self.transform is not None: |
|
img, target = self.transform(img, target) |
|
|
|
visible_classes = np.unique(target) |
|
labels = torch.zeros(self.CLASSES) |
|
for id in visible_classes: |
|
if id not in (0, 255): |
|
labels[id - 1].fill_(1) |
|
|
|
return img, labels |
|
|
|
def __len__(self): |
|
return len(self.images) |
|
|
|
|
|
class VOCSBDClassification(data.Dataset): |
|
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset. |
|
|
|
Args: |
|
root (string): Root directory of the VOC Dataset. |
|
year (string, optional): The dataset year, supports years 2007 to 2012. |
|
image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val`` |
|
download (bool, optional): If true, downloads the dataset from the internet and |
|
puts it in root directory. If dataset is already downloaded, it is not |
|
downloaded again. |
|
transform (callable, optional): A function/transform that takes in an PIL image |
|
and returns a transformed version. E.g, ``transforms.RandomCrop`` |
|
""" |
|
CLASSES = 20 |
|
|
|
def __init__(self, |
|
root, |
|
sbd_root, |
|
year='2012', |
|
image_set='train', |
|
download=False, |
|
transform=None): |
|
self.root = os.path.expanduser(root) |
|
self.sbd_root = os.path.expanduser(sbd_root) |
|
self.year = year |
|
self.url = DATASET_YEAR_DICT[year]['url'] |
|
self.filename = DATASET_YEAR_DICT[year]['filename'] |
|
self.md5 = DATASET_YEAR_DICT[year]['md5'] |
|
self.transform = transform |
|
self.image_set = image_set |
|
base_dir = DATASET_YEAR_DICT[year]['base_dir'] |
|
voc_root = os.path.join(self.root, base_dir) |
|
image_dir = os.path.join(voc_root, 'JPEGImages') |
|
mask_dir = os.path.join(voc_root, 'SegmentationClass') |
|
sbd_image_dir = os.path.join(sbd_root, 'img') |
|
sbd_mask_dir = os.path.join(sbd_root, 'cls') |
|
|
|
if download: |
|
download_extract(self.url, self.root, self.filename, self.md5) |
|
|
|
if not os.path.isdir(voc_root): |
|
raise RuntimeError('Dataset not found or corrupted.' + |
|
' You can use download=True to download it') |
|
|
|
splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation') |
|
|
|
split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt') |
|
sbd_split = os.path.join(sbd_root, 'train.txt') |
|
|
|
if not os.path.exists(split_f): |
|
raise ValueError( |
|
'Wrong image_set entered! Please use image_set="train" ' |
|
'or image_set="trainval" or image_set="val"') |
|
|
|
with open(os.path.join(split_f), "r") as f: |
|
voc_file_names = [x.strip() for x in f.readlines()] |
|
|
|
with open(os.path.join(sbd_split), "r") as f: |
|
sbd_file_names = [x.strip() for x in f.readlines()] |
|
|
|
self.images = [os.path.join(image_dir, x + ".jpg") for x in voc_file_names] |
|
self.images += [os.path.join(sbd_image_dir, x + ".jpg") for x in sbd_file_names] |
|
self.masks = [os.path.join(mask_dir, x + ".png") for x in voc_file_names] |
|
self.masks += [os.path.join(sbd_mask_dir, x + ".mat") for x in sbd_file_names] |
|
assert (len(self.images) == len(self.masks)) |
|
|
|
def __getitem__(self, index): |
|
""" |
|
Args: |
|
index (int): Index |
|
|
|
Returns: |
|
tuple: (image, target) where target is the image segmentation. |
|
""" |
|
img = Image.open(self.images[index]).convert('RGB') |
|
mask_path = self.masks[index] |
|
if mask_path[-3:] == 'mat': |
|
target = io.loadmat(mask_path, struct_as_record=False, squeeze_me=True)['GTcls'].Segmentation |
|
target = Image.fromarray(target, mode='P') |
|
else: |
|
target = Image.open(self.masks[index]) |
|
|
|
if self.transform is not None: |
|
img, target = self.transform(img, target) |
|
|
|
visible_classes = np.unique(target) |
|
labels = torch.zeros(self.CLASSES) |
|
for id in visible_classes: |
|
if id not in (0, 255): |
|
labels[id - 1].fill_(1) |
|
|
|
return img, labels |
|
|
|
def __len__(self): |
|
return len(self.images) |
|
|
|
|
|
def download_extract(url, root, filename, md5): |
|
download_url(url, root, filename, md5) |
|
with tarfile.open(os.path.join(root, filename), "r") as tar: |
|
tar.extractall(path=root) |
|
|
|
|
|
class VOCResults(data.Dataset): |
|
CLASSES = 20 |
|
CLASSES_NAMES = [ |
|
'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', |
|
'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', |
|
'motorbike', 'person', 'potted-plant', 'sheep', 'sofa', 'train', |
|
'tvmonitor', 'ambigious' |
|
] |
|
|
|
def __init__(self, path): |
|
super(VOCResults, self).__init__() |
|
|
|
self.path = os.path.join(path, 'results.hdf5') |
|
self.data = None |
|
|
|
print('Reading dataset length...') |
|
with h5py.File(self.path , 'r') as f: |
|
self.data_length = len(f['/image']) |
|
|
|
def __len__(self): |
|
return self.data_length |
|
|
|
def __getitem__(self, item): |
|
if self.data is None: |
|
self.data = h5py.File(self.path, 'r') |
|
|
|
image = torch.tensor(self.data['image'][item]) |
|
vis = torch.tensor(self.data['vis'][item]) |
|
target = torch.tensor(self.data['target'][item]) |
|
class_pred = torch.tensor(self.data['class_pred'][item]) |
|
|
|
return image, vis, target, class_pred |
|
|