Spaces:
Running
Running
import os, numpy, torch, json | |
from .parallelfolder import ParallelImageFolders | |
from torchvision import transforms | |
from torchvision.transforms.functional import to_tensor, normalize | |
class FieldDef(object): | |
def __init__(self, field, index, bitshift, bitmask, labels): | |
self.field = field | |
self.index = index | |
self.bitshift = bitshift | |
self.bitmask = bitmask | |
self.labels = labels | |
class MultiSegmentDataset(object): | |
''' | |
Just like ClevrMulticlassDataset, but the second stream is a one-hot | |
segmentation tensor rather than a flat one-hot presence vector. | |
MultiSegmentDataset('dataset/clevrseg', | |
imgdir='images/train/positive', | |
segdir='images/train/segmentation') | |
''' | |
def __init__(self, directory, transform=None, | |
imgdir='img', segdir='seg', val=False, size=None): | |
self.segdataset = ParallelImageFolders( | |
[os.path.join(directory, imgdir), | |
os.path.join(directory, segdir)], | |
transform=transform) | |
self.fields = [] | |
with open(os.path.join(directory, 'labelnames.json'), 'r') as f: | |
for defn in json.load(f): | |
self.fields.append(FieldDef( | |
defn['field'], defn['index'], defn['bitshift'], | |
defn['bitmask'], defn['label'])) | |
self.labels = ['-'] # Reserve label 0 to mean "no label" | |
self.categories = [] | |
self.label_category = [0] | |
for fieldnum, f in enumerate(self.fields): | |
self.categories.append(f.field) | |
f.firstchannel = len(self.labels) | |
f.channels = len(f.labels) - 1 | |
for lab in f.labels[1:]: | |
self.labels.append(lab) | |
self.label_category.append(fieldnum) | |
# Reserve 25% of the dataset for validation. | |
first_val = int(len(self.segdataset) * 0.75) | |
self.val = val | |
self.first = first_val if val else 0 | |
self.length = len(self.segdataset) - first_val if val else first_val | |
# Truncate the dataset if requested. | |
if size: | |
self.length = min(size, self.length) | |
def __len__(self): | |
return self.length | |
def __getitem__(self, index): | |
img, segimg = self.segdataset[index + self.first] | |
segin = numpy.array(segimg, numpy.uint8, copy=False) | |
segout = torch.zeros(len(self.categories), | |
segin.shape[0], segin.shape[1], dtype=torch.int64) | |
for i, field in enumerate(self.fields): | |
fielddata = ((torch.from_numpy(segin[:, :, field.index]) | |
>> field.bitshift) & field.bitmask) | |
segout[i] = field.firstchannel + fielddata - 1 | |
bincount = numpy.bincount(segout.flatten(), | |
minlength=len(self.labels)) | |
return img, segout, bincount | |
if __name__ == '__main__': | |
ds = MultiSegmentDataset('dataset/clevrseg') | |
print(ds[0]) | |
import pdb; pdb.set_trace() | |