import os, numpy, torch, json from .parallelfolder import ParallelImageFolders from torchvision import transforms from torchvision.transforms.functional import to_tensor, normalize class FieldDef(object): def __init__(self, field, index, bitshift, bitmask, labels): self.field = field self.index = index self.bitshift = bitshift self.bitmask = bitmask self.labels = labels class MultiSegmentDataset(object): ''' Just like ClevrMulticlassDataset, but the second stream is a one-hot segmentation tensor rather than a flat one-hot presence vector. MultiSegmentDataset('dataset/clevrseg', imgdir='images/train/positive', segdir='images/train/segmentation') ''' def __init__(self, directory, transform=None, imgdir='img', segdir='seg', val=False, size=None): self.segdataset = ParallelImageFolders( [os.path.join(directory, imgdir), os.path.join(directory, segdir)], transform=transform) self.fields = [] with open(os.path.join(directory, 'labelnames.json'), 'r') as f: for defn in json.load(f): self.fields.append(FieldDef( defn['field'], defn['index'], defn['bitshift'], defn['bitmask'], defn['label'])) self.labels = ['-'] # Reserve label 0 to mean "no label" self.categories = [] self.label_category = [0] for fieldnum, f in enumerate(self.fields): self.categories.append(f.field) f.firstchannel = len(self.labels) f.channels = len(f.labels) - 1 for lab in f.labels[1:]: self.labels.append(lab) self.label_category.append(fieldnum) # Reserve 25% of the dataset for validation. first_val = int(len(self.segdataset) * 0.75) self.val = val self.first = first_val if val else 0 self.length = len(self.segdataset) - first_val if val else first_val # Truncate the dataset if requested. if size: self.length = min(size, self.length) def __len__(self): return self.length def __getitem__(self, index): img, segimg = self.segdataset[index + self.first] segin = numpy.array(segimg, numpy.uint8, copy=False) segout = torch.zeros(len(self.categories), segin.shape[0], segin.shape[1], dtype=torch.int64) for i, field in enumerate(self.fields): fielddata = ((torch.from_numpy(segin[:, :, field.index]) >> field.bitshift) & field.bitmask) segout[i] = field.firstchannel + fielddata - 1 bincount = numpy.bincount(segout.flatten(), minlength=len(self.labels)) return img, segout, bincount if __name__ == '__main__': ds = MultiSegmentDataset('dataset/clevrseg') print(ds[0]) import pdb; pdb.set_trace()