Spaces:
Runtime error
Runtime error
Fabrice-TIERCELIN
commited on
Commit
•
d53d32d
1
Parent(s):
d6c8fe2
Delete clipseg/datasets/phrasecut.py
Browse files- clipseg/datasets/phrasecut.py +0 -335
clipseg/datasets/phrasecut.py
DELETED
@@ -1,335 +0,0 @@
|
|
1 |
-
|
2 |
-
import torch
|
3 |
-
import numpy as np
|
4 |
-
import os
|
5 |
-
|
6 |
-
from os.path import join, isdir, isfile, expanduser
|
7 |
-
from PIL import Image
|
8 |
-
|
9 |
-
from torchvision import transforms
|
10 |
-
from torchvision.transforms.transforms import Resize
|
11 |
-
|
12 |
-
from torch.nn import functional as nnf
|
13 |
-
from general_utils import get_from_repository
|
14 |
-
|
15 |
-
from skimage.draw import polygon2mask
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
def random_crop_slices(origin_size, target_size):
|
20 |
-
"""Gets slices of a random crop. """
|
21 |
-
assert origin_size[0] >= target_size[0] and origin_size[1] >= target_size[1], f'actual size: {origin_size}, target size: {target_size}'
|
22 |
-
|
23 |
-
offset_y = torch.randint(0, origin_size[0] - target_size[0] + 1, (1,)).item() # range: 0 <= value < high
|
24 |
-
offset_x = torch.randint(0, origin_size[1] - target_size[1] + 1, (1,)).item()
|
25 |
-
|
26 |
-
return slice(offset_y, offset_y + target_size[0]), slice(offset_x, offset_x + target_size[1])
|
27 |
-
|
28 |
-
|
29 |
-
def find_crop(seg, image_size, iterations=1000, min_frac=None, best_of=None):
|
30 |
-
|
31 |
-
|
32 |
-
best_crops = []
|
33 |
-
best_crop_not_ok = float('-inf'), None, None
|
34 |
-
min_sum = 0
|
35 |
-
|
36 |
-
seg = seg.astype('bool')
|
37 |
-
|
38 |
-
if min_frac is not None:
|
39 |
-
#min_sum = seg.sum() * min_frac
|
40 |
-
min_sum = seg.shape[0] * seg.shape[1] * min_frac
|
41 |
-
|
42 |
-
for iteration in range(iterations):
|
43 |
-
sl_y, sl_x = random_crop_slices(seg.shape, image_size)
|
44 |
-
seg_ = seg[sl_y, sl_x]
|
45 |
-
sum_seg_ = seg_.sum()
|
46 |
-
|
47 |
-
if sum_seg_ > min_sum:
|
48 |
-
|
49 |
-
if best_of is None:
|
50 |
-
return sl_y, sl_x, False
|
51 |
-
else:
|
52 |
-
best_crops += [(sum_seg_, sl_y, sl_x)]
|
53 |
-
if len(best_crops) >= best_of:
|
54 |
-
best_crops.sort(key=lambda x:x[0], reverse=True)
|
55 |
-
sl_y, sl_x = best_crops[0][1:]
|
56 |
-
|
57 |
-
return sl_y, sl_x, False
|
58 |
-
|
59 |
-
else:
|
60 |
-
if sum_seg_ > best_crop_not_ok[0]:
|
61 |
-
best_crop_not_ok = sum_seg_, sl_y, sl_x
|
62 |
-
|
63 |
-
else:
|
64 |
-
# return best segmentation found
|
65 |
-
return best_crop_not_ok[1:] + (best_crop_not_ok[0] <= min_sum,)
|
66 |
-
|
67 |
-
|
68 |
-
class PhraseCut(object):
|
69 |
-
|
70 |
-
def __init__(self, split, image_size=400, negative_prob=0, aug=None, aug_color=False, aug_crop=True,
|
71 |
-
min_size=0, remove_classes=None, with_visual=False, only_visual=False, mask=None):
|
72 |
-
super().__init__()
|
73 |
-
|
74 |
-
self.negative_prob = negative_prob
|
75 |
-
self.image_size = image_size
|
76 |
-
self.with_visual = with_visual
|
77 |
-
self.only_visual = only_visual
|
78 |
-
self.phrase_form = '{}'
|
79 |
-
self.mask = mask
|
80 |
-
self.aug_crop = aug_crop
|
81 |
-
|
82 |
-
if aug_color:
|
83 |
-
self.aug_color = transforms.Compose([
|
84 |
-
transforms.ColorJitter(0.5, 0.5, 0.2, 0.05),
|
85 |
-
])
|
86 |
-
else:
|
87 |
-
self.aug_color = None
|
88 |
-
|
89 |
-
get_from_repository('PhraseCut', ['PhraseCut.tar'], integrity_check=lambda local_dir: all([
|
90 |
-
isdir(join(local_dir, 'VGPhraseCut_v0')),
|
91 |
-
isdir(join(local_dir, 'VGPhraseCut_v0', 'images')),
|
92 |
-
isfile(join(local_dir, 'VGPhraseCut_v0', 'refer_train.json')),
|
93 |
-
len(os.listdir(join(local_dir, 'VGPhraseCut_v0', 'images'))) in {108250, 108249}
|
94 |
-
]))
|
95 |
-
|
96 |
-
from third_party.PhraseCutDataset.utils.refvg_loader import RefVGLoader
|
97 |
-
self.refvg_loader = RefVGLoader(split=split)
|
98 |
-
|
99 |
-
# img_ids where the size in the annotations does not match actual size
|
100 |
-
invalid_img_ids = set([150417, 285665, 498246, 61564, 285743, 498269, 498010, 150516, 150344, 286093, 61530,
|
101 |
-
150333, 286065, 285814, 498187, 285761, 498042])
|
102 |
-
|
103 |
-
mean = [0.485, 0.456, 0.406]
|
104 |
-
std = [0.229, 0.224, 0.225]
|
105 |
-
self.normalize = transforms.Normalize(mean, std)
|
106 |
-
|
107 |
-
self.sample_ids = [(i, j)
|
108 |
-
for i in self.refvg_loader.img_ids
|
109 |
-
for j in range(len(self.refvg_loader.get_img_ref_data(i)['phrases']))
|
110 |
-
if i not in invalid_img_ids]
|
111 |
-
|
112 |
-
|
113 |
-
# self.all_phrases = list(set([p for i in self.refvg_loader.img_ids for p in self.refvg_loader.get_img_ref_data(i)['phrases']]))
|
114 |
-
|
115 |
-
from nltk.stem import WordNetLemmatizer
|
116 |
-
wnl = WordNetLemmatizer()
|
117 |
-
|
118 |
-
# Filter by class (if remove_classes is set)
|
119 |
-
if remove_classes is None:
|
120 |
-
pass
|
121 |
-
else:
|
122 |
-
from datasets.generate_lvis_oneshot import PASCAL_SYNSETS, traverse_lemmas, traverse_lemmas_hypo
|
123 |
-
from nltk.corpus import wordnet
|
124 |
-
|
125 |
-
print('remove pascal classes...')
|
126 |
-
|
127 |
-
get_data = self.refvg_loader.get_img_ref_data # shortcut
|
128 |
-
keep_sids = None
|
129 |
-
|
130 |
-
if remove_classes[0] == 'pas5i':
|
131 |
-
subset_id = remove_classes[1]
|
132 |
-
from datasets.generate_lvis_oneshot import PASCAL_5I_SYNSETS_ORDERED, PASCAL_5I_CLASS_IDS
|
133 |
-
avoid = [PASCAL_5I_SYNSETS_ORDERED[i] for i in range(20) if i+1 not in PASCAL_5I_CLASS_IDS[subset_id]]
|
134 |
-
|
135 |
-
|
136 |
-
elif remove_classes[0] == 'zs':
|
137 |
-
stop = remove_classes[1]
|
138 |
-
|
139 |
-
from datasets.pascal_zeroshot import PASCAL_VOC_CLASSES_ZS
|
140 |
-
|
141 |
-
avoid = [c for class_set in PASCAL_VOC_CLASSES_ZS[:stop] for c in class_set]
|
142 |
-
print(avoid)
|
143 |
-
|
144 |
-
elif remove_classes[0] == 'aff':
|
145 |
-
# avoid = ['drink.v.01', 'sit.v.01', 'ride.v.02']
|
146 |
-
# all_lemmas = set(['drink', 'sit', 'ride'])
|
147 |
-
avoid = ['drink', 'drinks', 'drinking', 'sit', 'sits', 'sitting',
|
148 |
-
'ride', 'rides', 'riding',
|
149 |
-
'fly', 'flies', 'flying', 'drive', 'drives', 'driving', 'driven',
|
150 |
-
'swim', 'swims', 'swimming',
|
151 |
-
'wheels', 'wheel', 'legs', 'leg', 'ear', 'ears']
|
152 |
-
keep_sids = [(i, j) for i, j in self.sample_ids if
|
153 |
-
all(x not in avoid for x in get_data(i)['phrases'][j].split(' '))]
|
154 |
-
|
155 |
-
print('avoid classes:', avoid)
|
156 |
-
|
157 |
-
|
158 |
-
if keep_sids is None:
|
159 |
-
all_lemmas = [s for ps in avoid for s in traverse_lemmas_hypo(wordnet.synset(ps), max_depth=None)]
|
160 |
-
all_lemmas = list(set(all_lemmas))
|
161 |
-
all_lemmas = [h.replace('_', ' ').lower() for h in all_lemmas]
|
162 |
-
all_lemmas = set(all_lemmas)
|
163 |
-
|
164 |
-
# divide into multi word and single word
|
165 |
-
all_lemmas_s = set(l for l in all_lemmas if ' ' not in l)
|
166 |
-
all_lemmas_m = set(l for l in all_lemmas if l not in all_lemmas_s)
|
167 |
-
|
168 |
-
# new3
|
169 |
-
phrases = [get_data(i)['phrases'][j] for i, j in self.sample_ids]
|
170 |
-
remove_sids = set((i,j) for (i,j), phrase in zip(self.sample_ids, phrases)
|
171 |
-
if any(l in phrase for l in all_lemmas_m) or
|
172 |
-
len(set(wnl.lemmatize(w) for w in phrase.split(' ')).intersection(all_lemmas_s)) > 0
|
173 |
-
)
|
174 |
-
keep_sids = [(i, j) for i, j in self.sample_ids if (i,j) not in remove_sids]
|
175 |
-
|
176 |
-
print(f'Reduced to {len(keep_sids) / len(self.sample_ids):.3f}')
|
177 |
-
removed_ids = set(self.sample_ids) - set(keep_sids)
|
178 |
-
|
179 |
-
print('Examples of removed', len(removed_ids))
|
180 |
-
for i, j in list(removed_ids)[:20]:
|
181 |
-
print(i, get_data(i)['phrases'][j])
|
182 |
-
|
183 |
-
self.sample_ids = keep_sids
|
184 |
-
|
185 |
-
from itertools import groupby
|
186 |
-
samples_by_phrase = [(self.refvg_loader.get_img_ref_data(i)['phrases'][j], (i, j))
|
187 |
-
for i, j in self.sample_ids]
|
188 |
-
samples_by_phrase = sorted(samples_by_phrase)
|
189 |
-
samples_by_phrase = groupby(samples_by_phrase, key=lambda x: x[0])
|
190 |
-
|
191 |
-
self.samples_by_phrase = {prompt: [s[1] for s in prompt_sample_ids] for prompt, prompt_sample_ids in samples_by_phrase}
|
192 |
-
|
193 |
-
self.all_phrases = list(set(self.samples_by_phrase.keys()))
|
194 |
-
|
195 |
-
|
196 |
-
if self.only_visual:
|
197 |
-
assert self.with_visual
|
198 |
-
self.sample_ids = [(i, j) for i, j in self.sample_ids
|
199 |
-
if len(self.samples_by_phrase[self.refvg_loader.get_img_ref_data(i)['phrases'][j]]) > 1]
|
200 |
-
|
201 |
-
# Filter by size (if min_size is set)
|
202 |
-
sizes = [self.refvg_loader.get_img_ref_data(i)['gt_boxes'][j] for i, j in self.sample_ids]
|
203 |
-
image_sizes = [self.refvg_loader.get_img_ref_data(i)['width'] * self.refvg_loader.get_img_ref_data(i)['height'] for i, j in self.sample_ids]
|
204 |
-
#self.sizes = [sum([(s[2] - s[0]) * (s[3] - s[1]) for s in size]) for size in sizes]
|
205 |
-
self.sizes = [sum([s[2] * s[3] for s in size]) / img_size for size, img_size in zip(sizes, image_sizes)]
|
206 |
-
|
207 |
-
if min_size:
|
208 |
-
print('filter by size')
|
209 |
-
|
210 |
-
self.sample_ids = [self.sample_ids[i] for i in range(len(self.sample_ids)) if self.sizes[i] > min_size]
|
211 |
-
|
212 |
-
self.base_path = join(expanduser('~/datasets/PhraseCut/VGPhraseCut_v0/images/'))
|
213 |
-
|
214 |
-
def __len__(self):
|
215 |
-
return len(self.sample_ids)
|
216 |
-
|
217 |
-
|
218 |
-
def load_sample(self, sample_i, j):
|
219 |
-
|
220 |
-
img_ref_data = self.refvg_loader.get_img_ref_data(sample_i)
|
221 |
-
|
222 |
-
polys_phrase0 = img_ref_data['gt_Polygons'][j]
|
223 |
-
phrase = img_ref_data['phrases'][j]
|
224 |
-
phrase = self.phrase_form.format(phrase)
|
225 |
-
|
226 |
-
masks = []
|
227 |
-
for polys in polys_phrase0:
|
228 |
-
for poly in polys:
|
229 |
-
poly = [p[::-1] for p in poly] # swap x,y
|
230 |
-
masks += [polygon2mask((img_ref_data['height'], img_ref_data['width']), poly)]
|
231 |
-
|
232 |
-
seg = np.stack(masks).max(0)
|
233 |
-
img = np.array(Image.open(join(self.base_path, str(img_ref_data['image_id']) + '.jpg')))
|
234 |
-
|
235 |
-
min_shape = min(img.shape[:2])
|
236 |
-
|
237 |
-
if self.aug_crop:
|
238 |
-
sly, slx, exceed = find_crop(seg, (min_shape, min_shape), iterations=50, min_frac=0.05)
|
239 |
-
else:
|
240 |
-
sly, slx = slice(0, None), slice(0, None)
|
241 |
-
|
242 |
-
seg = seg[sly, slx]
|
243 |
-
img = img[sly, slx]
|
244 |
-
|
245 |
-
seg = seg.astype('uint8')
|
246 |
-
seg = torch.from_numpy(seg).view(1, 1, *seg.shape)
|
247 |
-
|
248 |
-
if img.ndim == 2:
|
249 |
-
img = np.dstack([img] * 3)
|
250 |
-
|
251 |
-
img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0).float()
|
252 |
-
|
253 |
-
seg = nnf.interpolate(seg, (self.image_size, self.image_size), mode='nearest')[0,0]
|
254 |
-
img = nnf.interpolate(img, (self.image_size, self.image_size), mode='bilinear', align_corners=True)[0]
|
255 |
-
|
256 |
-
# img = img.permute([2,0, 1])
|
257 |
-
img = img / 255.0
|
258 |
-
|
259 |
-
if self.aug_color is not None:
|
260 |
-
img = self.aug_color(img)
|
261 |
-
|
262 |
-
img = self.normalize(img)
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
return img, seg, phrase
|
267 |
-
|
268 |
-
def __getitem__(self, i):
|
269 |
-
|
270 |
-
sample_i, j = self.sample_ids[i]
|
271 |
-
|
272 |
-
img, seg, phrase = self.load_sample(sample_i, j)
|
273 |
-
|
274 |
-
if self.negative_prob > 0:
|
275 |
-
if torch.rand((1,)).item() < self.negative_prob:
|
276 |
-
|
277 |
-
new_phrase = None
|
278 |
-
while new_phrase is None or new_phrase == phrase:
|
279 |
-
idx = torch.randint(0, len(self.all_phrases), (1,)).item()
|
280 |
-
new_phrase = self.all_phrases[idx]
|
281 |
-
phrase = new_phrase
|
282 |
-
seg = torch.zeros_like(seg)
|
283 |
-
|
284 |
-
if self.with_visual:
|
285 |
-
# find a corresponding visual image
|
286 |
-
if phrase in self.samples_by_phrase and len(self.samples_by_phrase[phrase]) > 1:
|
287 |
-
idx = torch.randint(0, len(self.samples_by_phrase[phrase]), (1,)).item()
|
288 |
-
other_sample = self.samples_by_phrase[phrase][idx]
|
289 |
-
#print(other_sample)
|
290 |
-
img_s, seg_s, _ = self.load_sample(*other_sample)
|
291 |
-
|
292 |
-
from datasets.utils import blend_image_segmentation
|
293 |
-
|
294 |
-
if self.mask in {'separate', 'text_and_separate'}:
|
295 |
-
# assert img.shape[1:] == img_s.shape[1:] == seg_s.shape == seg.shape[1:]
|
296 |
-
add_phrase = [phrase] if self.mask == 'text_and_separate' else []
|
297 |
-
vis_s = add_phrase + [img_s, seg_s, True]
|
298 |
-
else:
|
299 |
-
if self.mask.startswith('text_and_'):
|
300 |
-
mask_mode = self.mask[9:]
|
301 |
-
label_add = [phrase]
|
302 |
-
else:
|
303 |
-
mask_mode = self.mask
|
304 |
-
label_add = []
|
305 |
-
|
306 |
-
masked_img_s = torch.from_numpy(blend_image_segmentation(img_s, seg_s, mode=mask_mode, image_size=self.image_size)[0])
|
307 |
-
vis_s = label_add + [masked_img_s, True]
|
308 |
-
|
309 |
-
else:
|
310 |
-
# phrase is unique
|
311 |
-
vis_s = torch.zeros_like(img)
|
312 |
-
|
313 |
-
if self.mask in {'separate', 'text_and_separate'}:
|
314 |
-
add_phrase = [phrase] if self.mask == 'text_and_separate' else []
|
315 |
-
vis_s = add_phrase + [vis_s, torch.zeros(*vis_s.shape[1:], dtype=torch.uint8), False]
|
316 |
-
elif self.mask.startswith('text_and_'):
|
317 |
-
vis_s = [phrase, vis_s, False]
|
318 |
-
else:
|
319 |
-
vis_s = [vis_s, False]
|
320 |
-
else:
|
321 |
-
assert self.mask == 'text'
|
322 |
-
vis_s = [phrase]
|
323 |
-
|
324 |
-
seg = seg.unsqueeze(0).float()
|
325 |
-
|
326 |
-
data_x = (img,) + tuple(vis_s)
|
327 |
-
|
328 |
-
return data_x, (seg, torch.zeros(0), i)
|
329 |
-
|
330 |
-
|
331 |
-
class PhraseCutPlus(PhraseCut):
|
332 |
-
|
333 |
-
def __init__(self, split, image_size=400, aug=None, aug_color=False, aug_crop=True, min_size=0, remove_classes=None, only_visual=False, mask=None):
|
334 |
-
super().__init__(split, image_size=image_size, negative_prob=0.2, aug=aug, aug_color=aug_color, aug_crop=aug_crop, min_size=min_size,
|
335 |
-
remove_classes=remove_classes, with_visual=True, only_visual=only_visual, mask=mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|