# Copyright 2019-present NAVER Corp. # CC BY-NC-SA 3.0 # Available only for non-commercial use import os import json import pdb import numpy as np class Dataset(object): ''' Base class for a dataset. To be overloaded. ''' root = '' img_dir = '' nimg = 0 def __len__(self): return self.nimg def get_key(self, img_idx): raise NotImplementedError() def get_filename(self, img_idx, root=None): return os.path.join(root or self.root, self.img_dir, self.get_key(img_idx)) def get_image(self, img_idx): from PIL import Image fname = self.get_filename(img_idx) try: return Image.open(fname).convert('RGB') except Exception as e: raise IOError("Could not load image %s (reason: %s)" % (fname, str(e))) def __repr__(self): res = 'Dataset: %s\n' % self.__class__.__name__ res += ' %d images' % self.nimg res += '\n root: %s...\n' % self.root return res class CatDataset (Dataset): ''' Concatenation of several datasets. ''' def __init__(self, *datasets): assert len(datasets) >= 1 self.datasets = datasets offsets = [0] for db in datasets: offsets.append(db.nimg) self.offsets = np.cumsum(offsets) self.nimg = self.offsets[-1] self.root = None def which(self, i): pos = np.searchsorted(self.offsets, i, side='right')-1 assert pos < self.nimg, 'Bad image index %d >= %d' % (i, self.nimg) return pos, i - self.offsets[pos] def get_key(self, i): b, i = self.which(i) return self.datasets[b].get_key(i) def get_filename(self, i): b, i = self.which(i) return self.datasets[b].get_filename(i) def __repr__(self): fmt_str = "CatDataset(" for db in self.datasets: fmt_str += str(db).replace("\n"," ") + ', ' return fmt_str[:-2] + ')'