|
import os, numpy, torch, json |
|
from .parallelfolder import ParallelImageFolders |
|
from torchvision import transforms |
|
from torchvision.transforms.functional import to_tensor, normalize |
|
|
|
class FieldDef(object): |
|
def __init__(self, field, index, bitshift, bitmask, labels): |
|
self.field = field |
|
self.index = index |
|
self.bitshift = bitshift |
|
self.bitmask = bitmask |
|
self.labels = labels |
|
|
|
class MultiSegmentDataset(object): |
|
''' |
|
Just like ClevrMulticlassDataset, but the second stream is a one-hot |
|
segmentation tensor rather than a flat one-hot presence vector. |
|
|
|
MultiSegmentDataset('datasets/clevrseg', |
|
imgdir='images/train/positive', |
|
segdir='images/train/segmentation') |
|
''' |
|
def __init__(self, directory, transform=None, |
|
imgdir='img', segdir='seg', val=False, size=None): |
|
self.segdataset = ParallelImageFolders( |
|
[os.path.join(directory, imgdir), |
|
os.path.join(directory, segdir)], |
|
transform=transform) |
|
self.fields = [] |
|
with open(os.path.join(directory, 'labelnames.json'), 'r') as f: |
|
for defn in json.load(f): |
|
self.fields.append(FieldDef( |
|
defn['field'], defn['index'], defn['bitshift'], |
|
defn['bitmask'], defn['label'])) |
|
self.labels = ['-'] |
|
self.categories = [] |
|
self.label_category = [0] |
|
for fieldnum, f in enumerate(self.fields): |
|
self.categories.append(f.field) |
|
f.firstchannel = len(self.labels) |
|
f.channels = len(f.labels) - 1 |
|
for lab in f.labels[1:]: |
|
self.labels.append(lab) |
|
self.label_category.append(fieldnum) |
|
|
|
first_val = int(len(self.segdataset) * 0.75) |
|
self.val = val |
|
self.first = first_val if val else 0 |
|
self.length = len(self.segdataset) - first_val if val else first_val |
|
|
|
if size: |
|
self.length = min(size, self.length) |
|
|
|
def __len__(self): |
|
return self.length |
|
|
|
def __getitem__(self, index): |
|
img, segimg = self.segdataset[index + self.first] |
|
segin = numpy.array(segimg, numpy.uint8, copy=False) |
|
segout = torch.zeros(len(self.categories), |
|
segin.shape[0], segin.shape[1], dtype=torch.int64) |
|
for i, field in enumerate(self.fields): |
|
fielddata = ((torch.from_numpy(segin[:, :, field.index]) |
|
>> field.bitshift) & field.bitmask) |
|
segout[i] = field.firstchannel + fielddata - 1 |
|
bincount = numpy.bincount(segout.flatten(), |
|
minlength=len(self.labels)) |
|
return img, segout, bincount |
|
|
|
if __name__ == '__main__': |
|
ds = MultiSegmentDataset('datasets/clevrseg') |
|
print(ds[0]) |
|
import pdb; pdb.set_trace() |
|
|
|
|