Spaces:
Runtime error
Runtime error
| from os.path import expanduser | |
| import torch | |
| import json | |
| from general_utils import get_from_repository | |
| from datasets.lvis_oneshot3 import blend_image_segmentation | |
| from general_utils import log | |
| PASCAL_CLASSES = {a['id']: a['synonyms'] for a in json.load(open('datasets/pascal_classes.json'))} | |
| class PFEPascalWrapper(object): | |
| def __init__(self, mode, split, mask='separate', image_size=473, label_support=None, size=None, p_negative=0, aug=None): | |
| import sys | |
| # sys.path.append(expanduser('~/projects/new_one_shot')) | |
| from third_party.PFENet.util.dataset import SemData | |
| get_from_repository('PascalVOC2012', ['Pascal5i.tar']) | |
| self.p_negative = p_negative | |
| self.size = size | |
| self.mode = mode | |
| self.image_size = image_size | |
| if label_support in {True, False}: | |
| log.warning('label_support argument is deprecated. Use mask instead.') | |
| #raise ValueError() | |
| self.mask = mask | |
| value_scale = 255 | |
| mean = [0.485, 0.456, 0.406] | |
| mean = [item * value_scale for item in mean] | |
| std = [0.229, 0.224, 0.225] | |
| std = [item * value_scale for item in std] | |
| import third_party.PFENet.util.transform as transform | |
| if mode == 'val': | |
| data_list = expanduser('~/projects/old_one_shot/PFENet/lists/pascal/val.txt') | |
| data_transform = [transform.test_Resize(size=image_size)] if image_size != 'original' else [] | |
| data_transform += [ | |
| transform.ToTensor(), | |
| transform.Normalize(mean=mean, std=std) | |
| ] | |
| elif mode == 'train': | |
| data_list = expanduser('~/projects/old_one_shot/PFENet/lists/pascal/voc_sbd_merge_noduplicate.txt') | |
| assert image_size != 'original' | |
| data_transform = [ | |
| transform.RandScale([0.9, 1.1]), | |
| transform.RandRotate([-10, 10], padding=mean, ignore_label=255), | |
| transform.RandomGaussianBlur(), | |
| transform.RandomHorizontalFlip(), | |
| transform.Crop((image_size, image_size), crop_type='rand', padding=mean, ignore_label=255), | |
| transform.ToTensor(), | |
| transform.Normalize(mean=mean, std=std) | |
| ] | |
| data_transform = transform.Compose(data_transform) | |
| self.dataset = SemData(split=split, mode=mode, data_root=expanduser('~/datasets/PascalVOC2012/VOC2012'), | |
| data_list=data_list, shot=1, transform=data_transform, use_coco=False, use_split_coco=False) | |
| self.class_list = self.dataset.sub_val_list if mode == 'val' else self.dataset.sub_list | |
| # verify that subcls_list always has length 1 | |
| # assert len(set([len(d[4]) for d in self.dataset])) == 1 | |
| print('actual length', len(self.dataset.data_list)) | |
| def __len__(self): | |
| if self.mode == 'val': | |
| return len(self.dataset.data_list) | |
| else: | |
| return len(self.dataset.data_list) | |
| def __getitem__(self, index): | |
| if self.dataset.mode == 'train': | |
| image, label, s_x, s_y, subcls_list = self.dataset[index % len(self.dataset.data_list)] | |
| elif self.dataset.mode == 'val': | |
| image, label, s_x, s_y, subcls_list, ori_label = self.dataset[index % len(self.dataset.data_list)] | |
| ori_label = torch.from_numpy(ori_label).unsqueeze(0) | |
| if self.image_size != 'original': | |
| longerside = max(ori_label.size(1), ori_label.size(2)) | |
| backmask = torch.ones(ori_label.size(0), longerside, longerside).cuda()*255 | |
| backmask[0, :ori_label.size(1), :ori_label.size(2)] = ori_label | |
| label = backmask.clone().long() | |
| else: | |
| label = label.unsqueeze(0) | |
| # assert label.shape == (473, 473) | |
| if self.p_negative > 0: | |
| if torch.rand(1).item() < self.p_negative: | |
| while True: | |
| idx = torch.randint(0, len(self.dataset.data_list), (1,)).item() | |
| _, _, s_x, s_y, subcls_list_tmp, _ = self.dataset[idx] | |
| if subcls_list[0] != subcls_list_tmp[0]: | |
| break | |
| s_x = s_x[0] | |
| s_y = (s_y == 1)[0] | |
| label_fg = (label == 1).float() | |
| val_mask = (label != 255).float() | |
| class_id = self.class_list[subcls_list[0]] | |
| label_name = PASCAL_CLASSES[class_id][0] | |
| label_add = () | |
| mask = self.mask | |
| if mask == 'text': | |
| support = ('a photo of a ' + label_name + '.',) | |
| elif mask == 'separate': | |
| support = (s_x, s_y) | |
| else: | |
| if mask.startswith('text_and_'): | |
| label_add = (label_name,) | |
| mask = mask[9:] | |
| support = (blend_image_segmentation(s_x, s_y.float(), mask)[0],) | |
| return (image,) + label_add + support, (label_fg.unsqueeze(0), val_mask.unsqueeze(0), subcls_list[0]) | |