sayakpaul's picture
sayakpaul HF staff
add files
c4b2b37
raw
history blame
1.4 kB
import os
import h5py
import torch
from torch.utils.data import Dataset
class ImagenetResults(Dataset):
def __init__(self, path):
super(ImagenetResults, self).__init__()
self.path = os.path.join(path, "results.hdf5")
self.data = None
print("Reading dataset length...")
with h5py.File(self.path, "r") as f:
# tmp = h5py.File(self.path , 'r')
self.data_length = len(f["/image"])
def __len__(self):
return self.data_length
def __getitem__(self, item):
if self.data is None:
self.data = h5py.File(self.path, "r")
image = torch.tensor(self.data["image"][item])
vis = torch.tensor(self.data["vis"][item])
target = torch.tensor(self.data["target"][item]).long()
return image, vis, target
if __name__ == "__main__":
import imageio
import numpy as np
from utils import render
ds = ImagenetResults("../visualizations/fullgrad")
sample_loader = torch.utils.data.DataLoader(ds, batch_size=5, shuffle=False)
iterator = iter(sample_loader)
image, vis, target = next(iterator)
maps = (
render.hm_to_rgb(vis[0].data.cpu().numpy(), scaling=3, sigma=1, cmap="seismic")
* 255
).astype(np.uint8)
# imageio.imsave('../delete_hm.jpg', maps)
print(len(ds))