| |
|
| | import torch |
| | import numpy as np |
| | import os |
| |
|
| | from os.path import join, isdir, isfile, expanduser |
| | from PIL import Image |
| |
|
| | from torchvision import transforms |
| | from torchvision.transforms.transforms import Resize |
| |
|
| | from torch.nn import functional as nnf |
| | from general_utils import get_from_repository |
| |
|
| | from skimage.draw import polygon2mask |
| |
|
| |
|
| |
|
| | def random_crop_slices(origin_size, target_size): |
| | """Gets slices of a random crop. """ |
| | assert origin_size[0] >= target_size[0] and origin_size[1] >= target_size[1], f'actual size: {origin_size}, target size: {target_size}' |
| |
|
| | offset_y = torch.randint(0, origin_size[0] - target_size[0] + 1, (1,)).item() |
| | offset_x = torch.randint(0, origin_size[1] - target_size[1] + 1, (1,)).item() |
| |
|
| | return slice(offset_y, offset_y + target_size[0]), slice(offset_x, offset_x + target_size[1]) |
| |
|
| |
|
| | def find_crop(seg, image_size, iterations=1000, min_frac=None, best_of=None): |
| |
|
| |
|
| | best_crops = [] |
| | best_crop_not_ok = float('-inf'), None, None |
| | min_sum = 0 |
| |
|
| | seg = seg.astype('bool') |
| | |
| | if min_frac is not None: |
| | |
| | min_sum = seg.shape[0] * seg.shape[1] * min_frac |
| | |
| | for iteration in range(iterations): |
| | sl_y, sl_x = random_crop_slices(seg.shape, image_size) |
| | seg_ = seg[sl_y, sl_x] |
| | sum_seg_ = seg_.sum() |
| |
|
| | if sum_seg_ > min_sum: |
| |
|
| | if best_of is None: |
| | return sl_y, sl_x, False |
| | else: |
| | best_crops += [(sum_seg_, sl_y, sl_x)] |
| | if len(best_crops) >= best_of: |
| | best_crops.sort(key=lambda x:x[0], reverse=True) |
| | sl_y, sl_x = best_crops[0][1:] |
| | |
| | return sl_y, sl_x, False |
| |
|
| | else: |
| | if sum_seg_ > best_crop_not_ok[0]: |
| | best_crop_not_ok = sum_seg_, sl_y, sl_x |
| | |
| | else: |
| | |
| | return best_crop_not_ok[1:] + (best_crop_not_ok[0] <= min_sum,) |
| |
|
| |
|
| | class PhraseCut(object): |
| |
|
| | def __init__(self, split, image_size=400, negative_prob=0, aug=None, aug_color=False, aug_crop=True, |
| | min_size=0, remove_classes=None, with_visual=False, only_visual=False, mask=None): |
| | super().__init__() |
| |
|
| | self.negative_prob = negative_prob |
| | self.image_size = image_size |
| | self.with_visual = with_visual |
| | self.only_visual = only_visual |
| | self.phrase_form = '{}' |
| | self.mask = mask |
| | self.aug_crop = aug_crop |
| | |
| | if aug_color: |
| | self.aug_color = transforms.Compose([ |
| | transforms.ColorJitter(0.5, 0.5, 0.2, 0.05), |
| | ]) |
| | else: |
| | self.aug_color = None |
| |
|
| | get_from_repository('PhraseCut', ['PhraseCut.tar'], integrity_check=lambda local_dir: all([ |
| | isdir(join(local_dir, 'VGPhraseCut_v0')), |
| | isdir(join(local_dir, 'VGPhraseCut_v0', 'images')), |
| | isfile(join(local_dir, 'VGPhraseCut_v0', 'refer_train.json')), |
| | len(os.listdir(join(local_dir, 'VGPhraseCut_v0', 'images'))) in {108250, 108249} |
| | ])) |
| |
|
| | from third_party.PhraseCutDataset.utils.refvg_loader import RefVGLoader |
| | self.refvg_loader = RefVGLoader(split=split) |
| |
|
| | |
| | invalid_img_ids = set([150417, 285665, 498246, 61564, 285743, 498269, 498010, 150516, 150344, 286093, 61530, |
| | 150333, 286065, 285814, 498187, 285761, 498042]) |
| | |
| | mean = [0.485, 0.456, 0.406] |
| | std = [0.229, 0.224, 0.225] |
| | self.normalize = transforms.Normalize(mean, std) |
| |
|
| | self.sample_ids = [(i, j) |
| | for i in self.refvg_loader.img_ids |
| | for j in range(len(self.refvg_loader.get_img_ref_data(i)['phrases'])) |
| | if i not in invalid_img_ids] |
| | |
| |
|
| | |
| |
|
| | from nltk.stem import WordNetLemmatizer |
| | wnl = WordNetLemmatizer() |
| |
|
| | |
| | if remove_classes is None: |
| | pass |
| | else: |
| | from datasets.generate_lvis_oneshot import PASCAL_SYNSETS, traverse_lemmas, traverse_lemmas_hypo |
| | from nltk.corpus import wordnet |
| |
|
| | print('remove pascal classes...') |
| |
|
| | get_data = self.refvg_loader.get_img_ref_data |
| | keep_sids = None |
| |
|
| | if remove_classes[0] == 'pas5i': |
| | subset_id = remove_classes[1] |
| | from datasets.generate_lvis_oneshot import PASCAL_5I_SYNSETS_ORDERED, PASCAL_5I_CLASS_IDS |
| | avoid = [PASCAL_5I_SYNSETS_ORDERED[i] for i in range(20) if i+1 not in PASCAL_5I_CLASS_IDS[subset_id]] |
| | |
| |
|
| | elif remove_classes[0] == 'zs': |
| | stop = remove_classes[1] |
| | |
| | from datasets.pascal_zeroshot import PASCAL_VOC_CLASSES_ZS |
| |
|
| | avoid = [c for class_set in PASCAL_VOC_CLASSES_ZS[:stop] for c in class_set] |
| | print(avoid) |
| |
|
| | elif remove_classes[0] == 'aff': |
| | |
| | |
| | avoid = ['drink', 'drinks', 'drinking', 'sit', 'sits', 'sitting', |
| | 'ride', 'rides', 'riding', |
| | 'fly', 'flies', 'flying', 'drive', 'drives', 'driving', 'driven', |
| | 'swim', 'swims', 'swimming', |
| | 'wheels', 'wheel', 'legs', 'leg', 'ear', 'ears'] |
| | keep_sids = [(i, j) for i, j in self.sample_ids if |
| | all(x not in avoid for x in get_data(i)['phrases'][j].split(' '))] |
| |
|
| | print('avoid classes:', avoid) |
| |
|
| |
|
| | if keep_sids is None: |
| | all_lemmas = [s for ps in avoid for s in traverse_lemmas_hypo(wordnet.synset(ps), max_depth=None)] |
| | all_lemmas = list(set(all_lemmas)) |
| | all_lemmas = [h.replace('_', ' ').lower() for h in all_lemmas] |
| | all_lemmas = set(all_lemmas) |
| |
|
| | |
| | all_lemmas_s = set(l for l in all_lemmas if ' ' not in l) |
| | all_lemmas_m = set(l for l in all_lemmas if l not in all_lemmas_s) |
| |
|
| | |
| | phrases = [get_data(i)['phrases'][j] for i, j in self.sample_ids] |
| | remove_sids = set((i,j) for (i,j), phrase in zip(self.sample_ids, phrases) |
| | if any(l in phrase for l in all_lemmas_m) or |
| | len(set(wnl.lemmatize(w) for w in phrase.split(' ')).intersection(all_lemmas_s)) > 0 |
| | ) |
| | keep_sids = [(i, j) for i, j in self.sample_ids if (i,j) not in remove_sids] |
| |
|
| | print(f'Reduced to {len(keep_sids) / len(self.sample_ids):.3f}') |
| | removed_ids = set(self.sample_ids) - set(keep_sids) |
| |
|
| | print('Examples of removed', len(removed_ids)) |
| | for i, j in list(removed_ids)[:20]: |
| | print(i, get_data(i)['phrases'][j]) |
| |
|
| | self.sample_ids = keep_sids |
| |
|
| | from itertools import groupby |
| | samples_by_phrase = [(self.refvg_loader.get_img_ref_data(i)['phrases'][j], (i, j)) |
| | for i, j in self.sample_ids] |
| | samples_by_phrase = sorted(samples_by_phrase) |
| | samples_by_phrase = groupby(samples_by_phrase, key=lambda x: x[0]) |
| | |
| | self.samples_by_phrase = {prompt: [s[1] for s in prompt_sample_ids] for prompt, prompt_sample_ids in samples_by_phrase} |
| |
|
| | self.all_phrases = list(set(self.samples_by_phrase.keys())) |
| |
|
| |
|
| | if self.only_visual: |
| | assert self.with_visual |
| | self.sample_ids = [(i, j) for i, j in self.sample_ids |
| | if len(self.samples_by_phrase[self.refvg_loader.get_img_ref_data(i)['phrases'][j]]) > 1] |
| |
|
| | |
| | sizes = [self.refvg_loader.get_img_ref_data(i)['gt_boxes'][j] for i, j in self.sample_ids] |
| | image_sizes = [self.refvg_loader.get_img_ref_data(i)['width'] * self.refvg_loader.get_img_ref_data(i)['height'] for i, j in self.sample_ids] |
| | |
| | self.sizes = [sum([s[2] * s[3] for s in size]) / img_size for size, img_size in zip(sizes, image_sizes)] |
| |
|
| | if min_size: |
| | print('filter by size') |
| |
|
| | self.sample_ids = [self.sample_ids[i] for i in range(len(self.sample_ids)) if self.sizes[i] > min_size] |
| |
|
| | self.base_path = join(expanduser('~/datasets/PhraseCut/VGPhraseCut_v0/images/')) |
| |
|
| | def __len__(self): |
| | return len(self.sample_ids) |
| |
|
| |
|
| | def load_sample(self, sample_i, j): |
| |
|
| | img_ref_data = self.refvg_loader.get_img_ref_data(sample_i) |
| |
|
| | polys_phrase0 = img_ref_data['gt_Polygons'][j] |
| | phrase = img_ref_data['phrases'][j] |
| | phrase = self.phrase_form.format(phrase) |
| |
|
| | masks = [] |
| | for polys in polys_phrase0: |
| | for poly in polys: |
| | poly = [p[::-1] for p in poly] |
| | masks += [polygon2mask((img_ref_data['height'], img_ref_data['width']), poly)] |
| |
|
| | seg = np.stack(masks).max(0) |
| | img = np.array(Image.open(join(self.base_path, str(img_ref_data['image_id']) + '.jpg'))) |
| |
|
| | min_shape = min(img.shape[:2]) |
| |
|
| | if self.aug_crop: |
| | sly, slx, exceed = find_crop(seg, (min_shape, min_shape), iterations=50, min_frac=0.05) |
| | else: |
| | sly, slx = slice(0, None), slice(0, None) |
| | |
| | seg = seg[sly, slx] |
| | img = img[sly, slx] |
| |
|
| | seg = seg.astype('uint8') |
| | seg = torch.from_numpy(seg).view(1, 1, *seg.shape) |
| |
|
| | if img.ndim == 2: |
| | img = np.dstack([img] * 3) |
| |
|
| | img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0).float() |
| |
|
| | seg = nnf.interpolate(seg, (self.image_size, self.image_size), mode='nearest')[0,0] |
| | img = nnf.interpolate(img, (self.image_size, self.image_size), mode='bilinear', align_corners=True)[0] |
| |
|
| | |
| | img = img / 255.0 |
| |
|
| | if self.aug_color is not None: |
| | img = self.aug_color(img) |
| |
|
| | img = self.normalize(img) |
| |
|
| |
|
| |
|
| | return img, seg, phrase |
| |
|
| | def __getitem__(self, i): |
| | |
| | sample_i, j = self.sample_ids[i] |
| |
|
| | img, seg, phrase = self.load_sample(sample_i, j) |
| |
|
| | if self.negative_prob > 0: |
| | if torch.rand((1,)).item() < self.negative_prob: |
| |
|
| | new_phrase = None |
| | while new_phrase is None or new_phrase == phrase: |
| | idx = torch.randint(0, len(self.all_phrases), (1,)).item() |
| | new_phrase = self.all_phrases[idx] |
| | phrase = new_phrase |
| | seg = torch.zeros_like(seg) |
| |
|
| | if self.with_visual: |
| | |
| | if phrase in self.samples_by_phrase and len(self.samples_by_phrase[phrase]) > 1: |
| | idx = torch.randint(0, len(self.samples_by_phrase[phrase]), (1,)).item() |
| | other_sample = self.samples_by_phrase[phrase][idx] |
| | |
| | img_s, seg_s, _ = self.load_sample(*other_sample) |
| |
|
| | from datasets.utils import blend_image_segmentation |
| |
|
| | if self.mask in {'separate', 'text_and_separate'}: |
| | |
| | add_phrase = [phrase] if self.mask == 'text_and_separate' else [] |
| | vis_s = add_phrase + [img_s, seg_s, True] |
| | else: |
| | if self.mask.startswith('text_and_'): |
| | mask_mode = self.mask[9:] |
| | label_add = [phrase] |
| | else: |
| | mask_mode = self.mask |
| | label_add = [] |
| |
|
| | masked_img_s = torch.from_numpy(blend_image_segmentation(img_s, seg_s, mode=mask_mode, image_size=self.image_size)[0]) |
| | vis_s = label_add + [masked_img_s, True] |
| | |
| | else: |
| | |
| | vis_s = torch.zeros_like(img) |
| |
|
| | if self.mask in {'separate', 'text_and_separate'}: |
| | add_phrase = [phrase] if self.mask == 'text_and_separate' else [] |
| | vis_s = add_phrase + [vis_s, torch.zeros(*vis_s.shape[1:], dtype=torch.uint8), False] |
| | elif self.mask.startswith('text_and_'): |
| | vis_s = [phrase, vis_s, False] |
| | else: |
| | vis_s = [vis_s, False] |
| | else: |
| | assert self.mask == 'text' |
| | vis_s = [phrase] |
| | |
| | seg = seg.unsqueeze(0).float() |
| |
|
| | data_x = (img,) + tuple(vis_s) |
| |
|
| | return data_x, (seg, torch.zeros(0), i) |
| |
|
| |
|
| | class PhraseCutPlus(PhraseCut): |
| |
|
| | def __init__(self, split, image_size=400, aug=None, aug_color=False, aug_crop=True, min_size=0, remove_classes=None, only_visual=False, mask=None): |
| | super().__init__(split, image_size=image_size, negative_prob=0.2, aug=aug, aug_color=aug_color, aug_crop=aug_crop, min_size=min_size, |
| | remove_classes=remove_classes, with_visual=True, only_visual=only_visual, mask=mask) |