Fabrice-TIERCELIN commited on
Commit
d53d32d
1 Parent(s): d6c8fe2

Delete clipseg/datasets/phrasecut.py

Browse files
Files changed (1) hide show
  1. clipseg/datasets/phrasecut.py +0 -335
clipseg/datasets/phrasecut.py DELETED
@@ -1,335 +0,0 @@
1
-
2
- import torch
3
- import numpy as np
4
- import os
5
-
6
- from os.path import join, isdir, isfile, expanduser
7
- from PIL import Image
8
-
9
- from torchvision import transforms
10
- from torchvision.transforms.transforms import Resize
11
-
12
- from torch.nn import functional as nnf
13
- from general_utils import get_from_repository
14
-
15
- from skimage.draw import polygon2mask
16
-
17
-
18
-
19
- def random_crop_slices(origin_size, target_size):
20
- """Gets slices of a random crop. """
21
- assert origin_size[0] >= target_size[0] and origin_size[1] >= target_size[1], f'actual size: {origin_size}, target size: {target_size}'
22
-
23
- offset_y = torch.randint(0, origin_size[0] - target_size[0] + 1, (1,)).item() # range: 0 <= value < high
24
- offset_x = torch.randint(0, origin_size[1] - target_size[1] + 1, (1,)).item()
25
-
26
- return slice(offset_y, offset_y + target_size[0]), slice(offset_x, offset_x + target_size[1])
27
-
28
-
29
- def find_crop(seg, image_size, iterations=1000, min_frac=None, best_of=None):
30
-
31
-
32
- best_crops = []
33
- best_crop_not_ok = float('-inf'), None, None
34
- min_sum = 0
35
-
36
- seg = seg.astype('bool')
37
-
38
- if min_frac is not None:
39
- #min_sum = seg.sum() * min_frac
40
- min_sum = seg.shape[0] * seg.shape[1] * min_frac
41
-
42
- for iteration in range(iterations):
43
- sl_y, sl_x = random_crop_slices(seg.shape, image_size)
44
- seg_ = seg[sl_y, sl_x]
45
- sum_seg_ = seg_.sum()
46
-
47
- if sum_seg_ > min_sum:
48
-
49
- if best_of is None:
50
- return sl_y, sl_x, False
51
- else:
52
- best_crops += [(sum_seg_, sl_y, sl_x)]
53
- if len(best_crops) >= best_of:
54
- best_crops.sort(key=lambda x:x[0], reverse=True)
55
- sl_y, sl_x = best_crops[0][1:]
56
-
57
- return sl_y, sl_x, False
58
-
59
- else:
60
- if sum_seg_ > best_crop_not_ok[0]:
61
- best_crop_not_ok = sum_seg_, sl_y, sl_x
62
-
63
- else:
64
- # return best segmentation found
65
- return best_crop_not_ok[1:] + (best_crop_not_ok[0] <= min_sum,)
66
-
67
-
68
- class PhraseCut(object):
69
-
70
- def __init__(self, split, image_size=400, negative_prob=0, aug=None, aug_color=False, aug_crop=True,
71
- min_size=0, remove_classes=None, with_visual=False, only_visual=False, mask=None):
72
- super().__init__()
73
-
74
- self.negative_prob = negative_prob
75
- self.image_size = image_size
76
- self.with_visual = with_visual
77
- self.only_visual = only_visual
78
- self.phrase_form = '{}'
79
- self.mask = mask
80
- self.aug_crop = aug_crop
81
-
82
- if aug_color:
83
- self.aug_color = transforms.Compose([
84
- transforms.ColorJitter(0.5, 0.5, 0.2, 0.05),
85
- ])
86
- else:
87
- self.aug_color = None
88
-
89
- get_from_repository('PhraseCut', ['PhraseCut.tar'], integrity_check=lambda local_dir: all([
90
- isdir(join(local_dir, 'VGPhraseCut_v0')),
91
- isdir(join(local_dir, 'VGPhraseCut_v0', 'images')),
92
- isfile(join(local_dir, 'VGPhraseCut_v0', 'refer_train.json')),
93
- len(os.listdir(join(local_dir, 'VGPhraseCut_v0', 'images'))) in {108250, 108249}
94
- ]))
95
-
96
- from third_party.PhraseCutDataset.utils.refvg_loader import RefVGLoader
97
- self.refvg_loader = RefVGLoader(split=split)
98
-
99
- # img_ids where the size in the annotations does not match actual size
100
- invalid_img_ids = set([150417, 285665, 498246, 61564, 285743, 498269, 498010, 150516, 150344, 286093, 61530,
101
- 150333, 286065, 285814, 498187, 285761, 498042])
102
-
103
- mean = [0.485, 0.456, 0.406]
104
- std = [0.229, 0.224, 0.225]
105
- self.normalize = transforms.Normalize(mean, std)
106
-
107
- self.sample_ids = [(i, j)
108
- for i in self.refvg_loader.img_ids
109
- for j in range(len(self.refvg_loader.get_img_ref_data(i)['phrases']))
110
- if i not in invalid_img_ids]
111
-
112
-
113
- # self.all_phrases = list(set([p for i in self.refvg_loader.img_ids for p in self.refvg_loader.get_img_ref_data(i)['phrases']]))
114
-
115
- from nltk.stem import WordNetLemmatizer
116
- wnl = WordNetLemmatizer()
117
-
118
- # Filter by class (if remove_classes is set)
119
- if remove_classes is None:
120
- pass
121
- else:
122
- from datasets.generate_lvis_oneshot import PASCAL_SYNSETS, traverse_lemmas, traverse_lemmas_hypo
123
- from nltk.corpus import wordnet
124
-
125
- print('remove pascal classes...')
126
-
127
- get_data = self.refvg_loader.get_img_ref_data # shortcut
128
- keep_sids = None
129
-
130
- if remove_classes[0] == 'pas5i':
131
- subset_id = remove_classes[1]
132
- from datasets.generate_lvis_oneshot import PASCAL_5I_SYNSETS_ORDERED, PASCAL_5I_CLASS_IDS
133
- avoid = [PASCAL_5I_SYNSETS_ORDERED[i] for i in range(20) if i+1 not in PASCAL_5I_CLASS_IDS[subset_id]]
134
-
135
-
136
- elif remove_classes[0] == 'zs':
137
- stop = remove_classes[1]
138
-
139
- from datasets.pascal_zeroshot import PASCAL_VOC_CLASSES_ZS
140
-
141
- avoid = [c for class_set in PASCAL_VOC_CLASSES_ZS[:stop] for c in class_set]
142
- print(avoid)
143
-
144
- elif remove_classes[0] == 'aff':
145
- # avoid = ['drink.v.01', 'sit.v.01', 'ride.v.02']
146
- # all_lemmas = set(['drink', 'sit', 'ride'])
147
- avoid = ['drink', 'drinks', 'drinking', 'sit', 'sits', 'sitting',
148
- 'ride', 'rides', 'riding',
149
- 'fly', 'flies', 'flying', 'drive', 'drives', 'driving', 'driven',
150
- 'swim', 'swims', 'swimming',
151
- 'wheels', 'wheel', 'legs', 'leg', 'ear', 'ears']
152
- keep_sids = [(i, j) for i, j in self.sample_ids if
153
- all(x not in avoid for x in get_data(i)['phrases'][j].split(' '))]
154
-
155
- print('avoid classes:', avoid)
156
-
157
-
158
- if keep_sids is None:
159
- all_lemmas = [s for ps in avoid for s in traverse_lemmas_hypo(wordnet.synset(ps), max_depth=None)]
160
- all_lemmas = list(set(all_lemmas))
161
- all_lemmas = [h.replace('_', ' ').lower() for h in all_lemmas]
162
- all_lemmas = set(all_lemmas)
163
-
164
- # divide into multi word and single word
165
- all_lemmas_s = set(l for l in all_lemmas if ' ' not in l)
166
- all_lemmas_m = set(l for l in all_lemmas if l not in all_lemmas_s)
167
-
168
- # new3
169
- phrases = [get_data(i)['phrases'][j] for i, j in self.sample_ids]
170
- remove_sids = set((i,j) for (i,j), phrase in zip(self.sample_ids, phrases)
171
- if any(l in phrase for l in all_lemmas_m) or
172
- len(set(wnl.lemmatize(w) for w in phrase.split(' ')).intersection(all_lemmas_s)) > 0
173
- )
174
- keep_sids = [(i, j) for i, j in self.sample_ids if (i,j) not in remove_sids]
175
-
176
- print(f'Reduced to {len(keep_sids) / len(self.sample_ids):.3f}')
177
- removed_ids = set(self.sample_ids) - set(keep_sids)
178
-
179
- print('Examples of removed', len(removed_ids))
180
- for i, j in list(removed_ids)[:20]:
181
- print(i, get_data(i)['phrases'][j])
182
-
183
- self.sample_ids = keep_sids
184
-
185
- from itertools import groupby
186
- samples_by_phrase = [(self.refvg_loader.get_img_ref_data(i)['phrases'][j], (i, j))
187
- for i, j in self.sample_ids]
188
- samples_by_phrase = sorted(samples_by_phrase)
189
- samples_by_phrase = groupby(samples_by_phrase, key=lambda x: x[0])
190
-
191
- self.samples_by_phrase = {prompt: [s[1] for s in prompt_sample_ids] for prompt, prompt_sample_ids in samples_by_phrase}
192
-
193
- self.all_phrases = list(set(self.samples_by_phrase.keys()))
194
-
195
-
196
- if self.only_visual:
197
- assert self.with_visual
198
- self.sample_ids = [(i, j) for i, j in self.sample_ids
199
- if len(self.samples_by_phrase[self.refvg_loader.get_img_ref_data(i)['phrases'][j]]) > 1]
200
-
201
- # Filter by size (if min_size is set)
202
- sizes = [self.refvg_loader.get_img_ref_data(i)['gt_boxes'][j] for i, j in self.sample_ids]
203
- image_sizes = [self.refvg_loader.get_img_ref_data(i)['width'] * self.refvg_loader.get_img_ref_data(i)['height'] for i, j in self.sample_ids]
204
- #self.sizes = [sum([(s[2] - s[0]) * (s[3] - s[1]) for s in size]) for size in sizes]
205
- self.sizes = [sum([s[2] * s[3] for s in size]) / img_size for size, img_size in zip(sizes, image_sizes)]
206
-
207
- if min_size:
208
- print('filter by size')
209
-
210
- self.sample_ids = [self.sample_ids[i] for i in range(len(self.sample_ids)) if self.sizes[i] > min_size]
211
-
212
- self.base_path = join(expanduser('~/datasets/PhraseCut/VGPhraseCut_v0/images/'))
213
-
214
- def __len__(self):
215
- return len(self.sample_ids)
216
-
217
-
218
- def load_sample(self, sample_i, j):
219
-
220
- img_ref_data = self.refvg_loader.get_img_ref_data(sample_i)
221
-
222
- polys_phrase0 = img_ref_data['gt_Polygons'][j]
223
- phrase = img_ref_data['phrases'][j]
224
- phrase = self.phrase_form.format(phrase)
225
-
226
- masks = []
227
- for polys in polys_phrase0:
228
- for poly in polys:
229
- poly = [p[::-1] for p in poly] # swap x,y
230
- masks += [polygon2mask((img_ref_data['height'], img_ref_data['width']), poly)]
231
-
232
- seg = np.stack(masks).max(0)
233
- img = np.array(Image.open(join(self.base_path, str(img_ref_data['image_id']) + '.jpg')))
234
-
235
- min_shape = min(img.shape[:2])
236
-
237
- if self.aug_crop:
238
- sly, slx, exceed = find_crop(seg, (min_shape, min_shape), iterations=50, min_frac=0.05)
239
- else:
240
- sly, slx = slice(0, None), slice(0, None)
241
-
242
- seg = seg[sly, slx]
243
- img = img[sly, slx]
244
-
245
- seg = seg.astype('uint8')
246
- seg = torch.from_numpy(seg).view(1, 1, *seg.shape)
247
-
248
- if img.ndim == 2:
249
- img = np.dstack([img] * 3)
250
-
251
- img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0).float()
252
-
253
- seg = nnf.interpolate(seg, (self.image_size, self.image_size), mode='nearest')[0,0]
254
- img = nnf.interpolate(img, (self.image_size, self.image_size), mode='bilinear', align_corners=True)[0]
255
-
256
- # img = img.permute([2,0, 1])
257
- img = img / 255.0
258
-
259
- if self.aug_color is not None:
260
- img = self.aug_color(img)
261
-
262
- img = self.normalize(img)
263
-
264
-
265
-
266
- return img, seg, phrase
267
-
268
- def __getitem__(self, i):
269
-
270
- sample_i, j = self.sample_ids[i]
271
-
272
- img, seg, phrase = self.load_sample(sample_i, j)
273
-
274
- if self.negative_prob > 0:
275
- if torch.rand((1,)).item() < self.negative_prob:
276
-
277
- new_phrase = None
278
- while new_phrase is None or new_phrase == phrase:
279
- idx = torch.randint(0, len(self.all_phrases), (1,)).item()
280
- new_phrase = self.all_phrases[idx]
281
- phrase = new_phrase
282
- seg = torch.zeros_like(seg)
283
-
284
- if self.with_visual:
285
- # find a corresponding visual image
286
- if phrase in self.samples_by_phrase and len(self.samples_by_phrase[phrase]) > 1:
287
- idx = torch.randint(0, len(self.samples_by_phrase[phrase]), (1,)).item()
288
- other_sample = self.samples_by_phrase[phrase][idx]
289
- #print(other_sample)
290
- img_s, seg_s, _ = self.load_sample(*other_sample)
291
-
292
- from datasets.utils import blend_image_segmentation
293
-
294
- if self.mask in {'separate', 'text_and_separate'}:
295
- # assert img.shape[1:] == img_s.shape[1:] == seg_s.shape == seg.shape[1:]
296
- add_phrase = [phrase] if self.mask == 'text_and_separate' else []
297
- vis_s = add_phrase + [img_s, seg_s, True]
298
- else:
299
- if self.mask.startswith('text_and_'):
300
- mask_mode = self.mask[9:]
301
- label_add = [phrase]
302
- else:
303
- mask_mode = self.mask
304
- label_add = []
305
-
306
- masked_img_s = torch.from_numpy(blend_image_segmentation(img_s, seg_s, mode=mask_mode, image_size=self.image_size)[0])
307
- vis_s = label_add + [masked_img_s, True]
308
-
309
- else:
310
- # phrase is unique
311
- vis_s = torch.zeros_like(img)
312
-
313
- if self.mask in {'separate', 'text_and_separate'}:
314
- add_phrase = [phrase] if self.mask == 'text_and_separate' else []
315
- vis_s = add_phrase + [vis_s, torch.zeros(*vis_s.shape[1:], dtype=torch.uint8), False]
316
- elif self.mask.startswith('text_and_'):
317
- vis_s = [phrase, vis_s, False]
318
- else:
319
- vis_s = [vis_s, False]
320
- else:
321
- assert self.mask == 'text'
322
- vis_s = [phrase]
323
-
324
- seg = seg.unsqueeze(0).float()
325
-
326
- data_x = (img,) + tuple(vis_s)
327
-
328
- return data_x, (seg, torch.zeros(0), i)
329
-
330
-
331
- class PhraseCutPlus(PhraseCut):
332
-
333
- def __init__(self, split, image_size=400, aug=None, aug_color=False, aug_crop=True, min_size=0, remove_classes=None, only_visual=False, mask=None):
334
- super().__init__(split, image_size=image_size, negative_prob=0.2, aug=aug, aug_color=aug_color, aug_crop=aug_crop, min_size=min_size,
335
- remove_classes=remove_classes, with_visual=True, only_visual=only_visual, mask=mask)