from __future__ import division import os from glob import glob import json import random import cv2 from PIL import Image import numpy as np import torch from torch.utils.data import Dataset import torchvision.transforms as TF import dataloaders.image_transforms as IT cv2.setNumThreads(0) def _get_images(sample): return [sample['ref_img'], sample['prev_img']] + sample['curr_img'] def _get_labels(sample): return [sample['ref_label'], sample['prev_label']] + sample['curr_label'] def _merge_sample(sample1, sample2, min_obj_pixels=100, max_obj_n=10): sample1_images = _get_images(sample1) sample2_images = _get_images(sample2) sample1_labels = _get_labels(sample1) sample2_labels = _get_labels(sample2) obj_idx = torch.arange(0, max_obj_n * 2 + 1).view(max_obj_n * 2 + 1, 1, 1) selected_idx = None selected_obj = None all_img = [] all_mask = [] for idx, (s1_img, s2_img, s1_label, s2_label) in enumerate( zip(sample1_images, sample2_images, sample1_labels, sample2_labels)): s2_fg = (s2_label > 0).float() s2_bg = 1 - s2_fg merged_img = s1_img * s2_bg + s2_img * s2_fg merged_mask = s1_label * s2_bg.long() + ( (s2_label + max_obj_n) * s2_fg.long()) merged_mask = (merged_mask == obj_idx).float() if idx == 0: after_merge_pixels = merged_mask.sum(dim=(1, 2), keepdim=True) selected_idx = after_merge_pixels > min_obj_pixels selected_idx[0] = True obj_num = selected_idx.sum().int().item() - 1 selected_idx = selected_idx.expand(-1, s1_label.size()[1], s1_label.size()[2]) if obj_num > max_obj_n: selected_obj = list(range(1, obj_num + 1)) random.shuffle(selected_obj) selected_obj = [0] + selected_obj[:max_obj_n] merged_mask = merged_mask[selected_idx].view(obj_num + 1, s1_label.size()[1], s1_label.size()[2]) if obj_num > max_obj_n: merged_mask = merged_mask[selected_obj] merged_mask[0] += 0.1 merged_mask = torch.argmax(merged_mask, dim=0, keepdim=True).long() all_img.append(merged_img) all_mask.append(merged_mask) sample = { 'ref_img': all_img[0], 'prev_img': all_img[1], 'curr_img': all_img[2:], 'ref_label': all_mask[0], 'prev_label': all_mask[1], 'curr_label': all_mask[2:] } sample['meta'] = sample1['meta'] sample['meta']['obj_num'] = min(obj_num, max_obj_n) return sample class StaticTrain(Dataset): def __init__(self, root, output_size, seq_len=5, max_obj_n=10, dynamic_merge=True, merge_prob=1.0, aug_type='v1'): self.root = root self.clip_n = seq_len self.output_size = output_size self.max_obj_n = max_obj_n self.dynamic_merge = dynamic_merge self.merge_prob = merge_prob self.img_list = list() self.mask_list = list() dataset_list = list() lines = ['COCO', 'ECSSD', 'MSRA10K', 'PASCAL-S', 'PASCALVOC2012'] for line in lines: dataset_name = line.strip() img_dir = os.path.join(root, 'JPEGImages', dataset_name) mask_dir = os.path.join(root, 'Annotations', dataset_name) img_list = sorted(glob(os.path.join(img_dir, '*.jpg'))) + \ sorted(glob(os.path.join(img_dir, '*.png'))) mask_list = sorted(glob(os.path.join(mask_dir, '*.png'))) if len(img_list) > 0: if len(img_list) == len(mask_list): dataset_list.append(dataset_name) self.img_list += img_list self.mask_list += mask_list print(f'\t{dataset_name}: {len(img_list)} imgs.') else: print( f'\tPreTrain dataset {dataset_name} has {len(img_list)} imgs and {len(mask_list)} annots. Not match! Skip.' ) else: print( f'\tPreTrain dataset {dataset_name} doesn\'t exist. Skip.') print( f'{len(self.img_list)} imgs are used for PreTrain. They are from {dataset_list}.' ) self.aug_type = aug_type self.pre_random_horizontal_flip = IT.RandomHorizontalFlip(0.5) self.random_horizontal_flip = IT.RandomHorizontalFlip(0.3) if self.aug_type == 'v1': self.color_jitter = TF.ColorJitter(0.1, 0.1, 0.1, 0.03) elif self.aug_type == 'v2': self.color_jitter = TF.RandomApply( [TF.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8) self.gray_scale = TF.RandomGrayscale(p=0.2) self.blur = TF.RandomApply([IT.GaussianBlur([.1, 2.])], p=0.3) else: assert NotImplementedError self.random_affine = IT.RandomAffine(degrees=20, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=10, resample=Image.BICUBIC, fillcolor=(124, 116, 104)) base_ratio = float(output_size[1]) / output_size[0] self.random_resize_crop = IT.RandomResizedCrop( output_size, (0.8, 1), ratio=(base_ratio * 3. / 4., base_ratio * 4. / 3.), interpolation=Image.BICUBIC) self.to_tensor = TF.ToTensor() self.to_onehot = IT.ToOnehot(max_obj_n, shuffle=True) self.normalize = TF.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) def __len__(self): return len(self.img_list) def load_image_in_PIL(self, path, mode='RGB'): img = Image.open(path) img.load() # Very important for loading large image return img.convert(mode) def sample_sequence(self, idx): img_pil = self.load_image_in_PIL(self.img_list[idx], 'RGB') mask_pil = self.load_image_in_PIL(self.mask_list[idx], 'P') frames = [] masks = [] img_pil, mask_pil = self.pre_random_horizontal_flip(img_pil, mask_pil) # img_pil, mask_pil = self.pre_random_vertical_flip(img_pil, mask_pil) for i in range(self.clip_n): img, mask = img_pil, mask_pil if i > 0: img, mask = self.random_horizontal_flip(img, mask) img, mask = self.random_affine(img, mask) img = self.color_jitter(img) img, mask = self.random_resize_crop(img, mask) if self.aug_type == 'v2': img = self.gray_scale(img) img = self.blur(img) mask = np.array(mask, np.uint8) if i == 0: mask, obj_list = self.to_onehot(mask) obj_num = len(obj_list) else: mask, _ = self.to_onehot(mask, obj_list) mask = torch.argmax(mask, dim=0, keepdim=True) frames.append(self.normalize(self.to_tensor(img))) masks.append(mask) sample = { 'ref_img': frames[0], 'prev_img': frames[1], 'curr_img': frames[2:], 'ref_label': masks[0], 'prev_label': masks[1], 'curr_label': masks[2:] } sample['meta'] = { 'seq_name': self.img_list[idx], 'frame_num': 1, 'obj_num': obj_num } return sample def __getitem__(self, idx): sample1 = self.sample_sequence(idx) if self.dynamic_merge and (sample1['meta']['obj_num'] == 0 or random.random() < self.merge_prob): rand_idx = np.random.randint(len(self.img_list)) while (rand_idx == idx): rand_idx = np.random.randint(len(self.img_list)) sample2 = self.sample_sequence(rand_idx) sample = self.merge_sample(sample1, sample2) else: sample = sample1 return sample def merge_sample(self, sample1, sample2, min_obj_pixels=100): return _merge_sample(sample1, sample2, min_obj_pixels, self.max_obj_n) class VOSTrain(Dataset): def __init__(self, image_root, label_root, imglistdic, transform=None, rgb=True, repeat_time=1, rand_gap=3, seq_len=5, rand_reverse=True, dynamic_merge=True, enable_prev_frame=False, merge_prob=0.3, max_obj_n=10): self.image_root = image_root self.label_root = label_root self.rand_gap = rand_gap self.seq_len = seq_len self.rand_reverse = rand_reverse self.repeat_time = repeat_time self.transform = transform self.dynamic_merge = dynamic_merge self.merge_prob = merge_prob self.enable_prev_frame = enable_prev_frame self.max_obj_n = max_obj_n self.rgb = rgb self.imglistdic = imglistdic self.seqs = list(self.imglistdic.keys()) print('Video Num: {} X {}'.format(len(self.seqs), self.repeat_time)) def __len__(self): return int(len(self.seqs) * self.repeat_time) def reverse_seq(self, imagelist, lablist): if np.random.randint(2) == 1: imagelist = imagelist[::-1] lablist = lablist[::-1] return imagelist, lablist def get_ref_index(self, seqname, lablist, objs, min_fg_pixels=200, max_try=5): bad_indices = [] for _ in range(max_try): ref_index = np.random.randint(len(lablist)) if ref_index in bad_indices: continue ref_label = Image.open( os.path.join(self.label_root, seqname, lablist[ref_index])) ref_label = np.array(ref_label, dtype=np.uint8) ref_objs = list(np.unique(ref_label)) is_consistent = True for obj in ref_objs: if obj == 0: continue if obj not in objs: is_consistent = False xs, ys = np.nonzero(ref_label) if len(xs) > min_fg_pixels and is_consistent: break bad_indices.append(ref_index) return ref_index def get_ref_index_v2(self, seqname, lablist, min_fg_pixels=200, max_try=20, total_gap=0): search_range = len(lablist) - total_gap if search_range <= 1: return 0 bad_indices = [] for _ in range(max_try): ref_index = np.random.randint(search_range) if ref_index in bad_indices: continue ref_label = Image.open( os.path.join(self.label_root, seqname, lablist[ref_index])) ref_label = np.array(ref_label, dtype=np.uint8) xs, ys = np.nonzero(ref_label) if len(xs) > min_fg_pixels: break bad_indices.append(ref_index) return ref_index def get_curr_gaps(self, seq_len, max_gap=999, max_try=10): for _ in range(max_try): curr_gaps = [] total_gap = 0 for _ in range(seq_len): gap = int(np.random.randint(self.rand_gap) + 1) total_gap += gap curr_gaps.append(gap) if total_gap <= max_gap: break return curr_gaps, total_gap def get_prev_index(self, lablist, total_gap): search_range = len(lablist) - total_gap if search_range > 1: prev_index = np.random.randint(search_range) else: prev_index = 0 return prev_index def check_index(self, total_len, index, allow_reflect=True): if total_len <= 1: return 0 if index < 0: if allow_reflect: index = -index index = self.check_index(total_len, index, True) else: index = 0 elif index >= total_len: if allow_reflect: index = 2 * (total_len - 1) - index index = self.check_index(total_len, index, True) else: index = total_len - 1 return index def get_curr_indices(self, lablist, prev_index, gaps): total_len = len(lablist) curr_indices = [] now_index = prev_index for gap in gaps: now_index += gap curr_indices.append(self.check_index(total_len, now_index)) return curr_indices def get_image_label(self, seqname, imagelist, lablist, index): image = cv2.imread( os.path.join(self.image_root, seqname, imagelist[index])) image = np.array(image, dtype=np.float32) if self.rgb: image = image[:, :, [2, 1, 0]] label = Image.open( os.path.join(self.label_root, seqname, lablist[index])) label = np.array(label, dtype=np.uint8) return image, label def sample_sequence(self, idx): idx = idx % len(self.seqs) seqname = self.seqs[idx] imagelist, lablist = self.imglistdic[seqname] frame_num = len(imagelist) if self.rand_reverse: imagelist, lablist = self.reverse_seq(imagelist, lablist) is_consistent = False max_try = 5 try_step = 0 while (is_consistent is False and try_step < max_try): try_step += 1 # generate random gaps curr_gaps, total_gap = self.get_curr_gaps(self.seq_len - 1) if self.enable_prev_frame: # prev frame is randomly sampled # get prev frame prev_index = self.get_prev_index(lablist, total_gap) prev_image, prev_label = self.get_image_label( seqname, imagelist, lablist, prev_index) prev_objs = list(np.unique(prev_label)) # get curr frames curr_indices = self.get_curr_indices(lablist, prev_index, curr_gaps) curr_images, curr_labels, curr_objs = [], [], [] for curr_index in curr_indices: curr_image, curr_label = self.get_image_label( seqname, imagelist, lablist, curr_index) c_objs = list(np.unique(curr_label)) curr_images.append(curr_image) curr_labels.append(curr_label) curr_objs.extend(c_objs) objs = list(np.unique(prev_objs + curr_objs)) start_index = prev_index end_index = max(curr_indices) # get ref frame _try_step = 0 ref_index = self.get_ref_index_v2(seqname, lablist) while (ref_index > start_index and ref_index <= end_index and _try_step < max_try): _try_step += 1 ref_index = self.get_ref_index_v2(seqname, lablist) ref_image, ref_label = self.get_image_label( seqname, imagelist, lablist, ref_index) ref_objs = list(np.unique(ref_label)) else: # prev frame is next to ref frame # get ref frame ref_index = self.get_ref_index_v2(seqname, lablist) ref_image, ref_label = self.get_image_label( seqname, imagelist, lablist, ref_index) ref_objs = list(np.unique(ref_label)) # get curr frames curr_indices = self.get_curr_indices(lablist, ref_index, curr_gaps) curr_images, curr_labels, curr_objs = [], [], [] for curr_index in curr_indices: curr_image, curr_label = self.get_image_label( seqname, imagelist, lablist, curr_index) c_objs = list(np.unique(curr_label)) curr_images.append(curr_image) curr_labels.append(curr_label) curr_objs.extend(c_objs) objs = list(np.unique(curr_objs)) prev_image, prev_label = curr_images[0], curr_labels[0] curr_images, curr_labels = curr_images[1:], curr_labels[1:] is_consistent = True for obj in objs: if obj == 0: continue if obj not in ref_objs: is_consistent = False break # get meta info obj_num = list(np.sort(ref_objs))[-1] sample = { 'ref_img': ref_image, 'prev_img': prev_image, 'curr_img': curr_images, 'ref_label': ref_label, 'prev_label': prev_label, 'curr_label': curr_labels } sample['meta'] = { 'seq_name': seqname, 'frame_num': frame_num, 'obj_num': obj_num } if self.transform is not None: sample = self.transform(sample) return sample def __getitem__(self, idx): sample1 = self.sample_sequence(idx) if self.dynamic_merge and (sample1['meta']['obj_num'] == 0 or random.random() < self.merge_prob): rand_idx = np.random.randint(len(self.seqs)) while (rand_idx == (idx % len(self.seqs))): rand_idx = np.random.randint(len(self.seqs)) sample2 = self.sample_sequence(rand_idx) sample = self.merge_sample(sample1, sample2) else: sample = sample1 return sample def merge_sample(self, sample1, sample2, min_obj_pixels=100): return _merge_sample(sample1, sample2, min_obj_pixels, self.max_obj_n) class DAVIS2017_Train(VOSTrain): def __init__(self, split=['train'], root='./DAVIS', transform=None, rgb=True, repeat_time=1, full_resolution=True, year=2017, rand_gap=3, seq_len=5, rand_reverse=True, dynamic_merge=True, enable_prev_frame=False, max_obj_n=10, merge_prob=0.3): if full_resolution: resolution = 'Full-Resolution' if not os.path.exists(os.path.join(root, 'JPEGImages', resolution)): print('No Full-Resolution, use 480p instead.') resolution = '480p' else: resolution = '480p' image_root = os.path.join(root, 'JPEGImages', resolution) label_root = os.path.join(root, 'Annotations', resolution) seq_names = [] for spt in split: with open(os.path.join(root, 'ImageSets', str(year), spt + '.txt')) as f: seqs_tmp = f.readlines() seqs_tmp = list(map(lambda elem: elem.strip(), seqs_tmp)) seq_names.extend(seqs_tmp) imglistdic = {} for seq_name in seq_names: images = list( np.sort(os.listdir(os.path.join(image_root, seq_name)))) labels = list( np.sort(os.listdir(os.path.join(label_root, seq_name)))) imglistdic[seq_name] = (images, labels) super(DAVIS2017_Train, self).__init__(image_root, label_root, imglistdic, transform, rgb, repeat_time, rand_gap, seq_len, rand_reverse, dynamic_merge, enable_prev_frame, merge_prob=merge_prob, max_obj_n=max_obj_n) class YOUTUBEVOS_Train(VOSTrain): def __init__(self, root='./datasets/YTB', year=2019, transform=None, rgb=True, rand_gap=3, seq_len=3, rand_reverse=True, dynamic_merge=True, enable_prev_frame=False, max_obj_n=10, merge_prob=0.3): root = os.path.join(root, str(year), 'train') image_root = os.path.join(root, 'JPEGImages') label_root = os.path.join(root, 'Annotations') self.seq_list_file = os.path.join(root, 'meta.json') self._check_preprocess() seq_names = list(self.ann_f.keys()) imglistdic = {} for seq_name in seq_names: data = self.ann_f[seq_name]['objects'] obj_names = list(data.keys()) images = [] labels = [] for obj_n in obj_names: if len(data[obj_n]["frames"]) < 2: print("Short object: " + seq_name + '-' + obj_n) continue images += list( map(lambda x: x + '.jpg', list(data[obj_n]["frames"]))) labels += list( map(lambda x: x + '.png', list(data[obj_n]["frames"]))) images = np.sort(np.unique(images)) labels = np.sort(np.unique(labels)) if len(images) < 2: print("Short video: " + seq_name) continue imglistdic[seq_name] = (images, labels) super(YOUTUBEVOS_Train, self).__init__(image_root, label_root, imglistdic, transform, rgb, 1, rand_gap, seq_len, rand_reverse, dynamic_merge, enable_prev_frame, merge_prob=merge_prob, max_obj_n=max_obj_n) def _check_preprocess(self): if not os.path.isfile(self.seq_list_file): print('No such file: {}.'.format(self.seq_list_file)) return False else: self.ann_f = json.load(open(self.seq_list_file, 'r'))['videos'] return True class TEST(Dataset): def __init__( self, seq_len=3, obj_num=3, transform=None, ): self.seq_len = seq_len self.obj_num = obj_num self.transform = transform def __len__(self): return 3000 def __getitem__(self, idx): img = np.zeros((800, 800, 3)).astype(np.float32) label = np.ones((800, 800)).astype(np.uint8) sample = { 'ref_img': img, 'prev_img': img, 'curr_img': [img] * (self.seq_len - 2), 'ref_label': label, 'prev_label': label, 'curr_label': [label] * (self.seq_len - 2) } sample['meta'] = { 'seq_name': 'test', 'frame_num': 100, 'obj_num': self.obj_num } if self.transform is not None: sample = self.transform(sample) return sample