import os import torch import torch.utils.data as data import numpy as np from PIL import Image import h5py __all__ = ['ImagenetResults'] class Imagenet_Segmentation(data.Dataset): CLASSES = 2 def __init__(self, path, transform=None, target_transform=None): self.path = path self.transform = transform self.target_transform = target_transform self.h5py = None tmp = h5py.File(path, 'r') self.data_length = len(tmp['/value/img']) tmp.close() del tmp def __getitem__(self, index): if self.h5py is None: self.h5py = h5py.File(self.path, 'r') img = np.array(self.h5py[self.h5py['/value/img'][index, 0]]).transpose((2, 1, 0)) target = np.array(self.h5py[self.h5py[self.h5py['/value/gt'][index, 0]][0, 0]]).transpose((1, 0)) img = Image.fromarray(img).convert('RGB') target = Image.fromarray(target) if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = np.array(self.target_transform(target)).astype('int32') target = torch.from_numpy(target).long() return img, target def __len__(self): return self.data_length class ImagenetResults(data.Dataset): def __init__(self, path): super(ImagenetResults, self).__init__() self.path = os.path.join(path, 'results.hdf5') self.data = None print('Reading dataset length...') with h5py.File(self.path, 'r') as f: self.data_length = len(f['/image']) def __len__(self): return self.data_length def __getitem__(self, item): if self.data is None: self.data = h5py.File(self.path, 'r') image = torch.tensor(self.data['image'][item]) vis = torch.tensor(self.data['vis'][item]) target = torch.tensor(self.data['target'][item]).long() return image, vis, target