File size: 5,011 Bytes
e0b74e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from os.path import expanduser
import torch
import json
from general_utils import get_from_repository
from datasets.lvis_oneshot3 import blend_image_segmentation
from general_utils import log

PASCAL_CLASSES = {a['id']: a['synonyms'] for a in json.load(open('datasets/pascal_classes.json'))}


class PFEPascalWrapper(object):

    def __init__(self, mode, split, mask='separate', image_size=473, label_support=None, size=None, p_negative=0, aug=None):
        import sys
        # sys.path.append(expanduser('~/projects/new_one_shot'))
        from third_party.PFENet.util.dataset import SemData

        get_from_repository('PascalVOC2012', ['Pascal5i.tar'])

        self.p_negative = p_negative
        self.size = size
        self.mode = mode
        self.image_size = image_size
        
        if label_support in {True, False}:
            log.warning('label_support argument is deprecated. Use mask instead.')
            #raise ValueError()

        self.mask = mask

        value_scale = 255
        mean = [0.485, 0.456, 0.406]
        mean = [item * value_scale for item in mean]
        std = [0.229, 0.224, 0.225]
        std = [item * value_scale for item in std]

        import third_party.PFENet.util.transform as transform

        if mode == 'val':
            data_list = expanduser('~/projects/old_one_shot/PFENet/lists/pascal/val.txt')

            data_transform = [transform.test_Resize(size=image_size)] if image_size != 'original' else []
            data_transform += [
                transform.ToTensor(),
                transform.Normalize(mean=mean, std=std)
            ]   


        elif mode == 'train':
            data_list =  expanduser('~/projects/old_one_shot/PFENet/lists/pascal/voc_sbd_merge_noduplicate.txt')

            assert image_size != 'original'

            data_transform = [
                transform.RandScale([0.9, 1.1]),
                transform.RandRotate([-10, 10], padding=mean, ignore_label=255),
                transform.RandomGaussianBlur(),
                transform.RandomHorizontalFlip(),
                transform.Crop((image_size, image_size), crop_type='rand', padding=mean, ignore_label=255),
                transform.ToTensor(),
                transform.Normalize(mean=mean, std=std)
            ]

        data_transform = transform.Compose(data_transform)

        self.dataset = SemData(split=split, mode=mode, data_root=expanduser('~/datasets/PascalVOC2012/VOC2012'), 
                               data_list=data_list, shot=1, transform=data_transform, use_coco=False, use_split_coco=False)

        self.class_list = self.dataset.sub_val_list if mode == 'val' else self.dataset.sub_list

        # verify that subcls_list always has length 1
        # assert len(set([len(d[4]) for d in self.dataset])) == 1

        print('actual length', len(self.dataset.data_list))

    def __len__(self):
        if self.mode == 'val':
            return len(self.dataset.data_list)
        else:
            return len(self.dataset.data_list)

    def __getitem__(self, index):
        if self.dataset.mode == 'train':
            image, label, s_x, s_y, subcls_list = self.dataset[index % len(self.dataset.data_list)]
        elif self.dataset.mode == 'val':
            image, label, s_x, s_y, subcls_list, ori_label = self.dataset[index % len(self.dataset.data_list)]
            ori_label = torch.from_numpy(ori_label).unsqueeze(0)
            
            if self.image_size != 'original':
                longerside = max(ori_label.size(1), ori_label.size(2))
                backmask = torch.ones(ori_label.size(0), longerside, longerside).cuda()*255
                backmask[0, :ori_label.size(1), :ori_label.size(2)] = ori_label
                label = backmask.clone().long()      
            else:
                label = label.unsqueeze(0) 

            # assert label.shape == (473, 473)

        if self.p_negative > 0:
            if torch.rand(1).item() < self.p_negative:
                while True:
                    idx = torch.randint(0, len(self.dataset.data_list), (1,)).item()
                    _, _, s_x, s_y, subcls_list_tmp, _ = self.dataset[idx]
                    if subcls_list[0] != subcls_list_tmp[0]:
                        break

        s_x = s_x[0]
        s_y = (s_y == 1)[0]
        label_fg = (label == 1).float()
        val_mask = (label != 255).float()

        class_id = self.class_list[subcls_list[0]]

        label_name = PASCAL_CLASSES[class_id][0]
        label_add = ()
        mask = self.mask

        if mask == 'text':
            support = ('a photo of a ' + label_name + '.',)
        elif mask == 'separate':
            support = (s_x, s_y)
        else:
            if mask.startswith('text_and_'):
                label_add = (label_name,)
                mask = mask[9:]

            support = (blend_image_segmentation(s_x, s_y.float(), mask)[0],)

        return (image,) + label_add + support, (label_fg.unsqueeze(0), val_mask.unsqueeze(0), subcls_list[0])