|
import os
|
|
|
|
import h5py
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
class ImagenetResults(Dataset):
|
|
def __init__(self, path):
|
|
super(ImagenetResults, self).__init__()
|
|
|
|
self.path = os.path.join(path, "results.hdf5")
|
|
self.data = None
|
|
|
|
print("Reading dataset length...")
|
|
with h5py.File(self.path, "r") as f:
|
|
|
|
self.data_length = len(f["/image"])
|
|
|
|
def __len__(self):
|
|
return self.data_length
|
|
|
|
def __getitem__(self, item):
|
|
if self.data is None:
|
|
self.data = h5py.File(self.path, "r")
|
|
|
|
image = torch.tensor(self.data["image"][item])
|
|
vis = torch.tensor(self.data["vis"][item])
|
|
target = torch.tensor(self.data["target"][item]).long()
|
|
|
|
return image, vis, target
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import imageio
|
|
import numpy as np
|
|
from utils import render
|
|
|
|
ds = ImagenetResults("../visualizations/fullgrad")
|
|
sample_loader = torch.utils.data.DataLoader(ds, batch_size=5, shuffle=False)
|
|
|
|
iterator = iter(sample_loader)
|
|
image, vis, target = next(iterator)
|
|
|
|
maps = (
|
|
render.hm_to_rgb(vis[0].data.cpu().numpy(), scaling=3, sigma=1, cmap="seismic")
|
|
* 255
|
|
).astype(np.uint8)
|
|
|
|
|
|
|
|
print(len(ds))
|
|
|