Spaces:
Runtime error
Runtime error
File size: 5,011 Bytes
e0b74e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
from os.path import expanduser
import torch
import json
from general_utils import get_from_repository
from datasets.lvis_oneshot3 import blend_image_segmentation
from general_utils import log
PASCAL_CLASSES = {a['id']: a['synonyms'] for a in json.load(open('datasets/pascal_classes.json'))}
class PFEPascalWrapper(object):
def __init__(self, mode, split, mask='separate', image_size=473, label_support=None, size=None, p_negative=0, aug=None):
import sys
# sys.path.append(expanduser('~/projects/new_one_shot'))
from third_party.PFENet.util.dataset import SemData
get_from_repository('PascalVOC2012', ['Pascal5i.tar'])
self.p_negative = p_negative
self.size = size
self.mode = mode
self.image_size = image_size
if label_support in {True, False}:
log.warning('label_support argument is deprecated. Use mask instead.')
#raise ValueError()
self.mask = mask
value_scale = 255
mean = [0.485, 0.456, 0.406]
mean = [item * value_scale for item in mean]
std = [0.229, 0.224, 0.225]
std = [item * value_scale for item in std]
import third_party.PFENet.util.transform as transform
if mode == 'val':
data_list = expanduser('~/projects/old_one_shot/PFENet/lists/pascal/val.txt')
data_transform = [transform.test_Resize(size=image_size)] if image_size != 'original' else []
data_transform += [
transform.ToTensor(),
transform.Normalize(mean=mean, std=std)
]
elif mode == 'train':
data_list = expanduser('~/projects/old_one_shot/PFENet/lists/pascal/voc_sbd_merge_noduplicate.txt')
assert image_size != 'original'
data_transform = [
transform.RandScale([0.9, 1.1]),
transform.RandRotate([-10, 10], padding=mean, ignore_label=255),
transform.RandomGaussianBlur(),
transform.RandomHorizontalFlip(),
transform.Crop((image_size, image_size), crop_type='rand', padding=mean, ignore_label=255),
transform.ToTensor(),
transform.Normalize(mean=mean, std=std)
]
data_transform = transform.Compose(data_transform)
self.dataset = SemData(split=split, mode=mode, data_root=expanduser('~/datasets/PascalVOC2012/VOC2012'),
data_list=data_list, shot=1, transform=data_transform, use_coco=False, use_split_coco=False)
self.class_list = self.dataset.sub_val_list if mode == 'val' else self.dataset.sub_list
# verify that subcls_list always has length 1
# assert len(set([len(d[4]) for d in self.dataset])) == 1
print('actual length', len(self.dataset.data_list))
def __len__(self):
if self.mode == 'val':
return len(self.dataset.data_list)
else:
return len(self.dataset.data_list)
def __getitem__(self, index):
if self.dataset.mode == 'train':
image, label, s_x, s_y, subcls_list = self.dataset[index % len(self.dataset.data_list)]
elif self.dataset.mode == 'val':
image, label, s_x, s_y, subcls_list, ori_label = self.dataset[index % len(self.dataset.data_list)]
ori_label = torch.from_numpy(ori_label).unsqueeze(0)
if self.image_size != 'original':
longerside = max(ori_label.size(1), ori_label.size(2))
backmask = torch.ones(ori_label.size(0), longerside, longerside).cuda()*255
backmask[0, :ori_label.size(1), :ori_label.size(2)] = ori_label
label = backmask.clone().long()
else:
label = label.unsqueeze(0)
# assert label.shape == (473, 473)
if self.p_negative > 0:
if torch.rand(1).item() < self.p_negative:
while True:
idx = torch.randint(0, len(self.dataset.data_list), (1,)).item()
_, _, s_x, s_y, subcls_list_tmp, _ = self.dataset[idx]
if subcls_list[0] != subcls_list_tmp[0]:
break
s_x = s_x[0]
s_y = (s_y == 1)[0]
label_fg = (label == 1).float()
val_mask = (label != 255).float()
class_id = self.class_list[subcls_list[0]]
label_name = PASCAL_CLASSES[class_id][0]
label_add = ()
mask = self.mask
if mask == 'text':
support = ('a photo of a ' + label_name + '.',)
elif mask == 'separate':
support = (s_x, s_y)
else:
if mask.startswith('text_and_'):
label_add = (label_name,)
mask = mask[9:]
support = (blend_image_segmentation(s_x, s_y.float(), mask)[0],)
return (image,) + label_add + support, (label_fg.unsqueeze(0), val_mask.unsqueeze(0), subcls_list[0])
|