Spaces:
Running
Running
File size: 7,899 Bytes
8a6df40 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
from __future__ import print_function, division
import os
from PIL import Image
import numpy as np
from torch.utils.data import Dataset
from .mypath_cihp import Path
from .mypath_pascal import Path as PP
from .mypath_atr import Path as PA
import random
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
class VOCSegmentation(Dataset):
"""
Pascal dataset
"""
def __init__(self,
cihp_dir=Path.db_root_dir('cihp'),
split='train',
transform=None,
flip=False,
pascal_dir = PP.db_root_dir('pascal'),
atr_dir = PA.db_root_dir('atr'),
):
"""
:param cihp_dir: path to CIHP dataset directory
:param pascal_dir: path to PASCAL dataset directory
:param atr_dir: path to ATR dataset directory
:param split: train/val
:param transform: transform to apply
"""
super(VOCSegmentation).__init__()
## for cihp
self._flip_flag = flip
self._base_dir = cihp_dir
self._image_dir = os.path.join(self._base_dir, 'Images')
self._cat_dir = os.path.join(self._base_dir, 'Category_ids')
self._flip_dir = os.path.join(self._base_dir,'Category_rev_ids')
## for Pascal
self._base_dir_pascal = pascal_dir
self._image_dir_pascal = os.path.join(self._base_dir_pascal, 'JPEGImages')
self._cat_dir_pascal = os.path.join(self._base_dir_pascal, 'SegmentationPart')
# self._flip_dir_pascal = os.path.join(self._base_dir_pascal, 'Category_rev_ids')
## for atr
self._base_dir_atr = atr_dir
self._image_dir_atr = os.path.join(self._base_dir_atr, 'JPEGImages')
self._cat_dir_atr = os.path.join(self._base_dir_atr, 'SegmentationClassAug')
self._flip_dir_atr = os.path.join(self._base_dir_atr, 'SegmentationClassAug_rev')
if isinstance(split, str):
self.split = [split]
else:
split.sort()
self.split = split
self.transform = transform
_splits_dir = os.path.join(self._base_dir, 'lists')
_splits_dir_pascal = os.path.join(self._base_dir_pascal, 'list')
_splits_dir_atr = os.path.join(self._base_dir_atr, 'list')
self.im_ids = []
self.images = []
self.categories = []
self.flip_categories = []
self.datasets_lbl = []
# num
self.num_cihp = 0
self.num_pascal = 0
self.num_atr = 0
# for cihp is 0
for splt in self.split:
with open(os.path.join(os.path.join(_splits_dir, splt + '_id.txt')), "r") as f:
lines = f.read().splitlines()
self.num_cihp += len(lines)
for ii, line in enumerate(lines):
_image = os.path.join(self._image_dir, line+'.jpg' )
_cat = os.path.join(self._cat_dir, line +'.png')
_flip = os.path.join(self._flip_dir,line + '.png')
# print(self._image_dir,_image)
assert os.path.isfile(_image)
# print(_cat)
assert os.path.isfile(_cat)
assert os.path.isfile(_flip)
self.im_ids.append(line)
self.images.append(_image)
self.categories.append(_cat)
self.flip_categories.append(_flip)
self.datasets_lbl.append(0)
# for pascal is 1
for splt in self.split:
if splt == 'test':
splt='val'
with open(os.path.join(os.path.join(_splits_dir_pascal, splt + '_id.txt')), "r") as f:
lines = f.read().splitlines()
self.num_pascal += len(lines)
for ii, line in enumerate(lines):
_image = os.path.join(self._image_dir_pascal, line+'.jpg' )
_cat = os.path.join(self._cat_dir_pascal, line +'.png')
# _flip = os.path.join(self._flip_dir,line + '.png')
# print(self._image_dir,_image)
assert os.path.isfile(_image)
# print(_cat)
assert os.path.isfile(_cat)
# assert os.path.isfile(_flip)
self.im_ids.append(line)
self.images.append(_image)
self.categories.append(_cat)
self.flip_categories.append([])
self.datasets_lbl.append(1)
# for atr is 2
for splt in self.split:
with open(os.path.join(os.path.join(_splits_dir_atr, splt + '_id.txt')), "r") as f:
lines = f.read().splitlines()
self.num_atr += len(lines)
for ii, line in enumerate(lines):
_image = os.path.join(self._image_dir_atr, line + '.jpg')
_cat = os.path.join(self._cat_dir_atr, line + '.png')
_flip = os.path.join(self._flip_dir_atr, line + '.png')
# print(self._image_dir,_image)
assert os.path.isfile(_image)
# print(_cat)
assert os.path.isfile(_cat)
assert os.path.isfile(_flip)
self.im_ids.append(line)
self.images.append(_image)
self.categories.append(_cat)
self.flip_categories.append(_flip)
self.datasets_lbl.append(2)
assert (len(self.images) == len(self.categories))
# assert len(self.flip_categories) == len(self.categories)
# Display stats
print('Number of images in {}: {:d}'.format(split, len(self.images)))
def __len__(self):
return len(self.images)
def get_class_num(self):
return self.num_cihp,self.num_pascal,self.num_atr
def __getitem__(self, index):
_img, _target,_lbl= self._make_img_gt_point_pair(index)
sample = {'image': _img, 'label': _target,}
if self.transform is not None:
sample = self.transform(sample)
sample['pascal'] = _lbl
return sample
def _make_img_gt_point_pair(self, index):
# Read Image and Target
# _img = np.array(Image.open(self.images[index]).convert('RGB')).astype(np.float32)
# _target = np.array(Image.open(self.categories[index])).astype(np.float32)
_img = Image.open(self.images[index]).convert('RGB') # return is RGB pic
type_lbl = self.datasets_lbl[index]
if self._flip_flag:
if random.random() < 0.5 :
# _target = Image.open(self.flip_categories[index])
_img = _img.transpose(Image.FLIP_LEFT_RIGHT)
if type_lbl == 0 or type_lbl == 2:
_target = Image.open(self.flip_categories[index])
else:
_target = Image.open(self.categories[index])
_target = _target.transpose(Image.FLIP_LEFT_RIGHT)
else:
_target = Image.open(self.categories[index])
else:
_target = Image.open(self.categories[index])
return _img, _target,type_lbl
def __str__(self):
return 'datasets(split=' + str(self.split) + ')'
if __name__ == '__main__':
from dataloaders import custom_transforms as tr
from dataloaders.utils import decode_segmap
from torch.utils.data import DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
composed_transforms_tr = transforms.Compose([
# tr.RandomHorizontalFlip(),
tr.RandomSized_new(512),
tr.RandomRotate(15),
tr.ToTensor_()])
voc_train = VOCSegmentation(split='train',
transform=composed_transforms_tr)
dataloader = DataLoader(voc_train, batch_size=5, shuffle=True, num_workers=1)
for ii, sample in enumerate(dataloader):
if ii >10:
break |