File size: 1,398 Bytes
c4b2b37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import os

import h5py
import torch
from torch.utils.data import Dataset


class ImagenetResults(Dataset):
    def __init__(self, path):
        super(ImagenetResults, self).__init__()

        self.path = os.path.join(path, "results.hdf5")
        self.data = None

        print("Reading dataset length...")
        with h5py.File(self.path, "r") as f:
            # tmp = h5py.File(self.path , 'r')
            self.data_length = len(f["/image"])

    def __len__(self):
        return self.data_length

    def __getitem__(self, item):
        if self.data is None:
            self.data = h5py.File(self.path, "r")

        image = torch.tensor(self.data["image"][item])
        vis = torch.tensor(self.data["vis"][item])
        target = torch.tensor(self.data["target"][item]).long()

        return image, vis, target


if __name__ == "__main__":
    import imageio
    import numpy as np
    from utils import render

    ds = ImagenetResults("../visualizations/fullgrad")
    sample_loader = torch.utils.data.DataLoader(ds, batch_size=5, shuffle=False)

    iterator = iter(sample_loader)
    image, vis, target = next(iterator)

    maps = (
        render.hm_to_rgb(vis[0].data.cpu().numpy(), scaling=3, sigma=1, cmap="seismic")
        * 255
    ).astype(np.uint8)

    # imageio.imsave('../delete_hm.jpg', maps)

    print(len(ds))