|
|
|
|
|
|
|
|
|
import os |
|
import json |
|
import pdb |
|
import numpy as np |
|
|
|
|
|
class Dataset(object): |
|
"""Base class for a dataset. To be overloaded.""" |
|
|
|
root = "" |
|
img_dir = "" |
|
nimg = 0 |
|
|
|
def __len__(self): |
|
return self.nimg |
|
|
|
def get_key(self, img_idx): |
|
raise NotImplementedError() |
|
|
|
def get_filename(self, img_idx, root=None): |
|
return os.path.join(root or self.root, self.img_dir, self.get_key(img_idx)) |
|
|
|
def get_image(self, img_idx): |
|
from PIL import Image |
|
|
|
fname = self.get_filename(img_idx) |
|
try: |
|
return Image.open(fname).convert("RGB") |
|
except Exception as e: |
|
raise IOError("Could not load image %s (reason: %s)" % (fname, str(e))) |
|
|
|
def __repr__(self): |
|
res = "Dataset: %s\n" % self.__class__.__name__ |
|
res += " %d images" % self.nimg |
|
res += "\n root: %s...\n" % self.root |
|
return res |
|
|
|
|
|
class CatDataset(Dataset): |
|
"""Concatenation of several datasets.""" |
|
|
|
def __init__(self, *datasets): |
|
assert len(datasets) >= 1 |
|
self.datasets = datasets |
|
offsets = [0] |
|
for db in datasets: |
|
offsets.append(db.nimg) |
|
self.offsets = np.cumsum(offsets) |
|
self.nimg = self.offsets[-1] |
|
self.root = None |
|
|
|
def which(self, i): |
|
pos = np.searchsorted(self.offsets, i, side="right") - 1 |
|
assert pos < self.nimg, "Bad image index %d >= %d" % (i, self.nimg) |
|
return pos, i - self.offsets[pos] |
|
|
|
def get_key(self, i): |
|
b, i = self.which(i) |
|
return self.datasets[b].get_key(i) |
|
|
|
def get_filename(self, i): |
|
b, i = self.which(i) |
|
return self.datasets[b].get_filename(i) |
|
|
|
def __repr__(self): |
|
fmt_str = "CatDataset(" |
|
for db in self.datasets: |
|
fmt_str += str(db).replace("\n", " ") + ", " |
|
return fmt_str[:-2] + ")" |
|
|