WwYc commited on
Commit
b18b78b
1 Parent(s): 2e6234b

Delete Transformer-Explainability/dataset

Browse files
Transformer-Explainability/dataset/expl_hdf5.py DELETED
@@ -1,52 +0,0 @@
1
- import os
2
-
3
- import h5py
4
- import torch
5
- from torch.utils.data import Dataset
6
-
7
-
8
- class ImagenetResults(Dataset):
9
- def __init__(self, path):
10
- super(ImagenetResults, self).__init__()
11
-
12
- self.path = os.path.join(path, "results.hdf5")
13
- self.data = None
14
-
15
- print("Reading dataset length...")
16
- with h5py.File(self.path, "r") as f:
17
- # tmp = h5py.File(self.path , 'r')
18
- self.data_length = len(f["/image"])
19
-
20
- def __len__(self):
21
- return self.data_length
22
-
23
- def __getitem__(self, item):
24
- if self.data is None:
25
- self.data = h5py.File(self.path, "r")
26
-
27
- image = torch.tensor(self.data["image"][item])
28
- vis = torch.tensor(self.data["vis"][item])
29
- target = torch.tensor(self.data["target"][item]).long()
30
-
31
- return image, vis, target
32
-
33
-
34
- if __name__ == "__main__":
35
- import imageio
36
- import numpy as np
37
- from utils import render
38
-
39
- ds = ImagenetResults("../visualizations/fullgrad")
40
- sample_loader = torch.utils.data.DataLoader(ds, batch_size=5, shuffle=False)
41
-
42
- iterator = iter(sample_loader)
43
- image, vis, target = next(iterator)
44
-
45
- maps = (
46
- render.hm_to_rgb(vis[0].data.cpu().numpy(), scaling=3, sigma=1, cmap="seismic")
47
- * 255
48
- ).astype(np.uint8)
49
-
50
- # imageio.imsave('../delete_hm.jpg', maps)
51
-
52
- print(len(ds))