import os from glob import glob from collections import defaultdict import numpy as np from PIL import Image class DAVIS(object): SUBSET_OPTIONS = ['train', 'val', 'test-dev', 'test-challenge'] TASKS = ['semi-supervised', 'unsupervised'] DATASET_WEB = 'https://davischallenge.org/davis2017/code.html' VOID_LABEL = 255 def __init__(self, root, task='unsupervised', subset='val', sequences='all', resolution='480p', codalab=False): """ Class to read the DAVIS dataset :param root: Path to the DAVIS folder that contains JPEGImages, Annotations, etc. folders. :param task: Task to load the annotations, choose between semi-supervised or unsupervised. :param subset: Set to load the annotations :param sequences: Sequences to consider, 'all' to use all the sequences in a set. :param resolution: Specify the resolution to use the dataset, choose between '480' and 'Full-Resolution' """ if subset not in self.SUBSET_OPTIONS: raise ValueError(f'Subset should be in {self.SUBSET_OPTIONS}') if task not in self.TASKS: raise ValueError(f'The only tasks that are supported are {self.TASKS}') self.task = task self.subset = subset self.root = root self.img_path = os.path.join(self.root, 'JPEGImages', resolution) annotations_folder = 'Annotations' if task == 'semi-supervised' else 'Annotations_unsupervised' self.mask_path = os.path.join(self.root, annotations_folder, resolution) year = '2019' if task == 'unsupervised' and (subset == 'test-dev' or subset == 'test-challenge') else '2017' self.imagesets_path = os.path.join(self.root, 'ImageSets', year) self._check_directories() if sequences == 'all': with open(os.path.join(self.imagesets_path, f'{self.subset}.txt'), 'r') as f: tmp = f.readlines() sequences_names = [x.strip() for x in tmp] else: sequences_names = sequences if isinstance(sequences, list) else [sequences] self.sequences = defaultdict(dict) for seq in sequences_names: images = np.sort(glob(os.path.join(self.img_path, seq, '*.jpg'))).tolist() if len(images) == 0 and not codalab: raise FileNotFoundError(f'Images for sequence {seq} not found.') self.sequences[seq]['images'] = images masks = np.sort(glob(os.path.join(self.mask_path, seq, '*.png'))).tolist() masks.extend([-1] * (len(images) - len(masks))) self.sequences[seq]['masks'] = masks def _check_directories(self): if not os.path.exists(self.root): raise FileNotFoundError(f'DAVIS not found in the specified directory, download it from {self.DATASET_WEB}') if not os.path.exists(os.path.join(self.imagesets_path, f'{self.subset}.txt')): raise FileNotFoundError(f'Subset sequences list for {self.subset} not found, download the missing subset ' f'for the {self.task} task from {self.DATASET_WEB}') if self.subset in ['train', 'val'] and not os.path.exists(self.mask_path): raise FileNotFoundError(f'Annotations folder for the {self.task} task not found, download it from {self.DATASET_WEB}') def get_frames(self, sequence): for img, msk in zip(self.sequences[sequence]['images'], self.sequences[sequence]['masks']): image = np.array(Image.open(img)) mask = None if msk is None else np.array(Image.open(msk)) yield image, mask def _get_all_elements(self, sequence, obj_type): obj = np.array(Image.open(self.sequences[sequence][obj_type][0])) all_objs = np.zeros((len(self.sequences[sequence][obj_type]), *obj.shape)) obj_id = [] for i, obj in enumerate(self.sequences[sequence][obj_type]): all_objs[i, ...] = np.array(Image.open(obj)) obj_id.append(''.join(obj.split('/')[-1].split('.')[:-1])) return all_objs, obj_id def get_all_images(self, sequence): return self._get_all_elements(sequence, 'images') def get_all_masks(self, sequence, separate_objects_masks=False): masks, masks_id = self._get_all_elements(sequence, 'masks') masks_void = np.zeros_like(masks) # Separate void and object masks for i in range(masks.shape[0]): masks_void[i, ...] = masks[i, ...] == 255 masks[i, masks[i, ...] == 255] = 0 if separate_objects_masks: num_objects = int(np.max(masks[0, ...])) tmp = np.ones((num_objects, *masks.shape)) tmp = tmp * np.arange(1, num_objects + 1)[:, None, None, None] masks = (tmp == masks[None, ...]) masks = masks > 0 return masks, masks_void, masks_id def get_sequences(self): for seq in self.sequences: yield seq if __name__ == '__main__': from matplotlib import pyplot as plt only_first_frame = True subsets = ['train', 'val'] for s in subsets: dataset = DAVIS(root='/home/csergi/scratch2/Databases/DAVIS2017_private', subset=s) for seq in dataset.get_sequences(): g = dataset.get_frames(seq) img, mask = next(g) plt.subplot(2, 1, 1) plt.title(seq) plt.imshow(img) plt.subplot(2, 1, 2) plt.imshow(mask) plt.show(block=True)