Fabrice-TIERCELIN commited on
Commit
d6c8fe2
1 Parent(s): bc69feb

Delete clipseg/datasets/pfe_dataset.py

Browse files
Files changed (1) hide show
  1. clipseg/datasets/pfe_dataset.py +0 -129
clipseg/datasets/pfe_dataset.py DELETED
@@ -1,129 +0,0 @@
1
- from os.path import expanduser
2
- import torch
3
- import json
4
- from general_utils import get_from_repository
5
- from datasets.lvis_oneshot3 import blend_image_segmentation
6
- from general_utils import log
7
-
8
- PASCAL_CLASSES = {a['id']: a['synonyms'] for a in json.load(open('datasets/pascal_classes.json'))}
9
-
10
-
11
- class PFEPascalWrapper(object):
12
-
13
- def __init__(self, mode, split, mask='separate', image_size=473, label_support=None, size=None, p_negative=0, aug=None):
14
- import sys
15
- # sys.path.append(expanduser('~/projects/new_one_shot'))
16
- from third_party.PFENet.util.dataset import SemData
17
-
18
- get_from_repository('PascalVOC2012', ['Pascal5i.tar'])
19
-
20
- self.p_negative = p_negative
21
- self.size = size
22
- self.mode = mode
23
- self.image_size = image_size
24
-
25
- if label_support in {True, False}:
26
- log.warning('label_support argument is deprecated. Use mask instead.')
27
- #raise ValueError()
28
-
29
- self.mask = mask
30
-
31
- value_scale = 255
32
- mean = [0.485, 0.456, 0.406]
33
- mean = [item * value_scale for item in mean]
34
- std = [0.229, 0.224, 0.225]
35
- std = [item * value_scale for item in std]
36
-
37
- import third_party.PFENet.util.transform as transform
38
-
39
- if mode == 'val':
40
- data_list = expanduser('~/projects/old_one_shot/PFENet/lists/pascal/val.txt')
41
-
42
- data_transform = [transform.test_Resize(size=image_size)] if image_size != 'original' else []
43
- data_transform += [
44
- transform.ToTensor(),
45
- transform.Normalize(mean=mean, std=std)
46
- ]
47
-
48
-
49
- elif mode == 'train':
50
- data_list = expanduser('~/projects/old_one_shot/PFENet/lists/pascal/voc_sbd_merge_noduplicate.txt')
51
-
52
- assert image_size != 'original'
53
-
54
- data_transform = [
55
- transform.RandScale([0.9, 1.1]),
56
- transform.RandRotate([-10, 10], padding=mean, ignore_label=255),
57
- transform.RandomGaussianBlur(),
58
- transform.RandomHorizontalFlip(),
59
- transform.Crop((image_size, image_size), crop_type='rand', padding=mean, ignore_label=255),
60
- transform.ToTensor(),
61
- transform.Normalize(mean=mean, std=std)
62
- ]
63
-
64
- data_transform = transform.Compose(data_transform)
65
-
66
- self.dataset = SemData(split=split, mode=mode, data_root=expanduser('~/datasets/PascalVOC2012/VOC2012'),
67
- data_list=data_list, shot=1, transform=data_transform, use_coco=False, use_split_coco=False)
68
-
69
- self.class_list = self.dataset.sub_val_list if mode == 'val' else self.dataset.sub_list
70
-
71
- # verify that subcls_list always has length 1
72
- # assert len(set([len(d[4]) for d in self.dataset])) == 1
73
-
74
- print('actual length', len(self.dataset.data_list))
75
-
76
- def __len__(self):
77
- if self.mode == 'val':
78
- return len(self.dataset.data_list)
79
- else:
80
- return len(self.dataset.data_list)
81
-
82
- def __getitem__(self, index):
83
- if self.dataset.mode == 'train':
84
- image, label, s_x, s_y, subcls_list = self.dataset[index % len(self.dataset.data_list)]
85
- elif self.dataset.mode == 'val':
86
- image, label, s_x, s_y, subcls_list, ori_label = self.dataset[index % len(self.dataset.data_list)]
87
- ori_label = torch.from_numpy(ori_label).unsqueeze(0)
88
-
89
- if self.image_size != 'original':
90
- longerside = max(ori_label.size(1), ori_label.size(2))
91
- backmask = torch.ones(ori_label.size(0), longerside, longerside).cuda()*255
92
- backmask[0, :ori_label.size(1), :ori_label.size(2)] = ori_label
93
- label = backmask.clone().long()
94
- else:
95
- label = label.unsqueeze(0)
96
-
97
- # assert label.shape == (473, 473)
98
-
99
- if self.p_negative > 0:
100
- if torch.rand(1).item() < self.p_negative:
101
- while True:
102
- idx = torch.randint(0, len(self.dataset.data_list), (1,)).item()
103
- _, _, s_x, s_y, subcls_list_tmp, _ = self.dataset[idx]
104
- if subcls_list[0] != subcls_list_tmp[0]:
105
- break
106
-
107
- s_x = s_x[0]
108
- s_y = (s_y == 1)[0]
109
- label_fg = (label == 1).float()
110
- val_mask = (label != 255).float()
111
-
112
- class_id = self.class_list[subcls_list[0]]
113
-
114
- label_name = PASCAL_CLASSES[class_id][0]
115
- label_add = ()
116
- mask = self.mask
117
-
118
- if mask == 'text':
119
- support = ('a photo of a ' + label_name + '.',)
120
- elif mask == 'separate':
121
- support = (s_x, s_y)
122
- else:
123
- if mask.startswith('text_and_'):
124
- label_add = (label_name,)
125
- mask = mask[9:]
126
-
127
- support = (blend_image_segmentation(s_x, s_y.float(), mask)[0],)
128
-
129
- return (image,) + label_add + support, (label_fg.unsqueeze(0), val_mask.unsqueeze(0), subcls_list[0])