WwYc commited on
Commit
3d27aee
1 Parent(s): 971cce4

Upload 707 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. ViT_DeiT/.gitignore +11 -0
  3. ViT_DeiT/.ipynb_checkpoints/DeiT_example-checkpoint.ipynb +0 -0
  4. ViT_DeiT/.ipynb_checkpoints/example-checkpoint.ipynb +0 -0
  5. ViT_DeiT/LICENSE +21 -0
  6. ViT_DeiT/VIT-EXPL.py +96 -0
  7. ViT_DeiT/baselines/ViT/ViT_LRP.py +437 -0
  8. ViT_DeiT/baselines/ViT/ViT_explanation_generator.py +83 -0
  9. ViT_DeiT/baselines/ViT/ViT_new.py +238 -0
  10. ViT_DeiT/baselines/ViT/ViT_orig_LRP.py +425 -0
  11. ViT_DeiT/baselines/ViT/__pycache__/ViT_LRP.cpython-38.pyc +0 -0
  12. ViT_DeiT/baselines/ViT/__pycache__/ViT_explanation_generator.cpython-38.pyc +0 -0
  13. ViT_DeiT/baselines/ViT/__pycache__/helpers.cpython-38.pyc +0 -0
  14. ViT_DeiT/baselines/ViT/__pycache__/layer_helpers.cpython-38.pyc +0 -0
  15. ViT_DeiT/baselines/ViT/__pycache__/weight_init.cpython-38.pyc +0 -0
  16. ViT_DeiT/baselines/ViT/generate_visualizations.py +208 -0
  17. ViT_DeiT/baselines/ViT/helpers.py +295 -0
  18. ViT_DeiT/baselines/ViT/imagenet_seg_eval.py +334 -0
  19. ViT_DeiT/baselines/ViT/layer_helpers.py +21 -0
  20. ViT_DeiT/baselines/ViT/misc_functions.py +68 -0
  21. ViT_DeiT/baselines/ViT/pertubation_eval_from_hdf5.py +233 -0
  22. ViT_DeiT/baselines/ViT/weight_init.py +60 -0
  23. ViT_DeiT/data/VOC.py +372 -0
  24. ViT_DeiT/data/__init__.py +0 -0
  25. ViT_DeiT/data/imagenet.py +74 -0
  26. ViT_DeiT/data/imagenet_utils.py +1002 -0
  27. ViT_DeiT/data/transforms.py +442 -0
  28. ViT_DeiT/dataset/expl_hdf5.py +51 -0
  29. ViT_DeiT/modules/__init__.py +0 -0
  30. ViT_DeiT/modules/__pycache__/__init__.cpython-38.pyc +0 -0
  31. ViT_DeiT/modules/__pycache__/layers_ours.cpython-38.pyc +0 -0
  32. ViT_DeiT/modules/layers_lrp.py +261 -0
  33. ViT_DeiT/modules/layers_ours.py +280 -0
  34. ViT_DeiT/requirements.txt +15 -0
  35. ViT_DeiT/samples/CLS2IDX.py +1000 -0
  36. ViT_DeiT/samples/__pycache__/CLS2IDX.cpython-38.pyc +0 -0
  37. ViT_DeiT/samples/catdog.png +0 -0
  38. ViT_DeiT/samples/dogbird.png +0 -0
  39. ViT_DeiT/samples/dogcat2.png +0 -0
  40. ViT_DeiT/samples/el1.png +0 -0
  41. ViT_DeiT/samples/el2.png +0 -0
  42. ViT_DeiT/samples/el3.png +0 -0
  43. ViT_DeiT/samples/el4.png +0 -0
  44. ViT_DeiT/samples/el5.png +0 -0
  45. ViT_DeiT/utils/__init__.py +0 -0
  46. ViT_DeiT/utils/__pycache__/__init__.cpython-38.pyc +0 -0
  47. ViT_DeiT/utils/confusionmatrix.py +88 -0
  48. ViT_DeiT/utils/iou.py +93 -0
  49. ViT_DeiT/utils/metric.py +12 -0
  50. ViT_DeiT/utils/metrices.py +208 -0
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ ViT_DeiT/venv/Scripts/_hashlib.pyd filter=lfs diff=lfs merge=lfs -text
37
+ ViT_DeiT/venv/Scripts/_ssl.pyd filter=lfs diff=lfs merge=lfs -text
38
+ ViT_DeiT/venv/Scripts/python36.dll filter=lfs diff=lfs merge=lfs -text
39
+ ViT_DeiT/venv/Scripts/sqlite3.dll filter=lfs diff=lfs merge=lfs -text
40
+ ViT_DeiT/venv/Scripts/tcl86t.dll filter=lfs diff=lfs merge=lfs -text
41
+ ViT_DeiT/venv/Scripts/tk86t.dll filter=lfs diff=lfs merge=lfs -text
ViT_DeiT/.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.pyc
2
+ all_good_vis/
3
+ __pycache__
4
+ *.tar
5
+ .idea
6
+ run/
7
+ baselines/ViT/experiments/
8
+ baselines/ViT/visualizations/
9
+ bert_models/
10
+ data/movies/
11
+
ViT_DeiT/.ipynb_checkpoints/DeiT_example-checkpoint.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
ViT_DeiT/.ipynb_checkpoints/example-checkpoint.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
ViT_DeiT/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 Hila Chefer
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
ViT_DeiT/VIT-EXPL.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from PIL import Image
4
+ import torchvision.transforms as transforms
5
+ import matplotlib.pyplot as plt
6
+ import pylab
7
+ import torch
8
+ import numpy as np
9
+ import cv2
10
+ from samples.CLS2IDX import CLS2IDX
11
+ from baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP
12
+ from baselines.ViT.ViT_explanation_generator import LRP
13
+
14
+ normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
15
+ transform = transforms.Compose([
16
+ transforms.Resize(256),
17
+ transforms.CenterCrop(224),
18
+ transforms.ToTensor(),
19
+ normalize,
20
+ ])
21
+ use_thresholding = False
22
+ def show_cam_on_image(img, mask):
23
+ heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
24
+ heatmap = np.float32(heatmap) / 255
25
+ cam = heatmap + np.float32(img)
26
+ cam = cam / np.max(cam)
27
+ return cam
28
+
29
+ # initialize ViT pretrained
30
+ model = vit_LRP(pretrained=True).cuda()
31
+ model.eval()
32
+ attribution_generator = LRP(model)
33
+
34
+ def generate_visualization(original_image, class_index=None):
35
+ transformer_attribution = attribution_generator.generate_LRP(original_image.unsqueeze(0).cuda(), method="transformer_attribution", index=class_index).detach()
36
+ transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
37
+ transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16, mode='bilinear')
38
+ transformer_attribution = transformer_attribution.reshape(224, 224).data.cpu().numpy()
39
+ transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())
40
+
41
+ if use_thresholding:
42
+ transformer_attribution = transformer_attribution * 255
43
+ transformer_attribution = transformer_attribution.astype(np.uint8)
44
+ ret, transformer_attribution = cv2.threshold(transformer_attribution, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
45
+ transformer_attribution[transformer_attribution == 255] = 1
46
+
47
+ image_transformer_attribution = original_image.permute(1, 2, 0).data.cpu().numpy()
48
+ image_transformer_attribution = (image_transformer_attribution - image_transformer_attribution.min()) / (image_transformer_attribution.max() - image_transformer_attribution.min())
49
+ vis = show_cam_on_image(image_transformer_attribution, transformer_attribution)
50
+ vis = np.uint8(255 * vis)
51
+ vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
52
+ return vis
53
+
54
+
55
+ def print_top_classes(predictions, **kwargs):
56
+ # Print Top-5 predictions
57
+ prob = torch.softmax(predictions, dim=1)
58
+ class_indices = predictions.data.topk(5, dim=1)[1][0].tolist()
59
+ max_str_len = 0
60
+ class_names = []
61
+ for cls_idx in class_indices:
62
+ class_names.append(CLS2IDX[cls_idx])
63
+ if len(CLS2IDX[cls_idx]) > max_str_len:
64
+ max_str_len = len(CLS2IDX[cls_idx])
65
+
66
+ print('Top 5 classes:')
67
+ for cls_idx in class_indices:
68
+ output_string = '\t{} : {}'.format(cls_idx, CLS2IDX[cls_idx])
69
+ output_string += ' ' * (max_str_len - len(CLS2IDX[cls_idx])) + '\t\t'
70
+ output_string += 'value = {:.3f}\t prob = {:.1f}%'.format(predictions[0, cls_idx], 100 * prob[0, cls_idx])
71
+ print(output_string)
72
+
73
+
74
+ image = Image.open('samples/dogcat2.png')
75
+ dog_cat_image = transform(image)
76
+
77
+ fig, axs = plt.subplots(1, 3)
78
+ axs[0].imshow(image);
79
+ axs[0].axis('off');
80
+
81
+ output = model(dog_cat_image.unsqueeze(0).cuda())
82
+ print_top_classes(output)
83
+
84
+ # cat - the predicted class
85
+ cat = generate_visualization(dog_cat_image)
86
+
87
+ # dog
88
+ # generate visualization for class 243: 'bull mastiff'
89
+ dog = generate_visualization(dog_cat_image, class_index=243)
90
+
91
+
92
+ axs[1].imshow(cat);
93
+ axs[1].axis('off');
94
+ axs[2].imshow(dog);
95
+ axs[2].axis('off');
96
+ pylab.show()
ViT_DeiT/baselines/ViT/ViT_LRP.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Vision Transformer (ViT) in PyTorch
2
+ Hacked together by / Copyright 2020 Ross Wightman
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange
7
+ from modules.layers_ours import *
8
+
9
+ from baselines.ViT.helpers import load_pretrained
10
+ from baselines.ViT.weight_init import trunc_normal_
11
+ from baselines.ViT.layer_helpers import to_2tuple
12
+
13
+
14
+ def _cfg(url='', **kwargs):
15
+ return {
16
+ 'url': url,
17
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
18
+ 'crop_pct': .9, 'interpolation': 'bicubic',
19
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head',
20
+ **kwargs
21
+ }
22
+
23
+
24
+ default_cfgs = {
25
+ # patch models
26
+ 'vit_small_patch16_224': _cfg(
27
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
28
+ ),
29
+ 'vit_base_patch16_224': _cfg(
30
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
31
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
32
+ ),
33
+ 'vit_large_patch16_224': _cfg(
34
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
35
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
36
+ }
37
+
38
+ def compute_rollout_attention(all_layer_matrices, start_layer=0):
39
+ # adding residual consideration
40
+ num_tokens = all_layer_matrices[0].shape[1]
41
+ batch_size = all_layer_matrices[0].shape[0]
42
+ eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device)
43
+ all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))]
44
+ # all_layer_matrices = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
45
+ # for i in range(len(all_layer_matrices))]
46
+ joint_attention = all_layer_matrices[start_layer]
47
+ for i in range(start_layer+1, len(all_layer_matrices)):
48
+ joint_attention = all_layer_matrices[i].bmm(joint_attention)
49
+ return joint_attention
50
+
51
+ class Mlp(nn.Module):
52
+ def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
53
+ super().__init__()
54
+ out_features = out_features or in_features
55
+ hidden_features = hidden_features or in_features
56
+ self.fc1 = Linear(in_features, hidden_features)
57
+ self.act = GELU()
58
+ self.fc2 = Linear(hidden_features, out_features)
59
+ self.drop = Dropout(drop)
60
+
61
+ def forward(self, x):
62
+ x = self.fc1(x)
63
+ x = self.act(x)
64
+ x = self.drop(x)
65
+ x = self.fc2(x)
66
+ x = self.drop(x)
67
+ return x
68
+
69
+ def relprop(self, cam, **kwargs):
70
+ cam = self.drop.relprop(cam, **kwargs)
71
+ cam = self.fc2.relprop(cam, **kwargs)
72
+ cam = self.act.relprop(cam, **kwargs)
73
+ cam = self.fc1.relprop(cam, **kwargs)
74
+ return cam
75
+
76
+
77
+ class Attention(nn.Module):
78
+ def __init__(self, dim, num_heads=8, qkv_bias=False,attn_drop=0., proj_drop=0.):
79
+ super().__init__()
80
+ self.num_heads = num_heads
81
+ head_dim = dim // num_heads
82
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
83
+ self.scale = head_dim ** -0.5
84
+
85
+ # A = Q*K^T
86
+ self.matmul1 = einsum('bhid,bhjd->bhij')
87
+ # attn = A*V
88
+ self.matmul2 = einsum('bhij,bhjd->bhid')
89
+
90
+ self.qkv = Linear(dim, dim * 3, bias=qkv_bias)
91
+ self.attn_drop = Dropout(attn_drop)
92
+ self.proj = Linear(dim, dim)
93
+ self.proj_drop = Dropout(proj_drop)
94
+ self.softmax = Softmax(dim=-1)
95
+
96
+ self.attn_cam = None
97
+ self.attn = None
98
+ self.v = None
99
+ self.v_cam = None
100
+ self.attn_gradients = None
101
+
102
+ def get_attn(self):
103
+ return self.attn
104
+
105
+ def save_attn(self, attn):
106
+ self.attn = attn
107
+
108
+ def save_attn_cam(self, cam):
109
+ self.attn_cam = cam
110
+
111
+ def get_attn_cam(self):
112
+ return self.attn_cam
113
+
114
+ def get_v(self):
115
+ return self.v
116
+
117
+ def save_v(self, v):
118
+ self.v = v
119
+
120
+ def save_v_cam(self, cam):
121
+ self.v_cam = cam
122
+
123
+ def get_v_cam(self):
124
+ return self.v_cam
125
+
126
+ def save_attn_gradients(self, attn_gradients):
127
+ self.attn_gradients = attn_gradients
128
+
129
+ def get_attn_gradients(self):
130
+ return self.attn_gradients
131
+
132
+ def forward(self, x):
133
+ b, n, _, h = *x.shape, self.num_heads
134
+ qkv = self.qkv(x)
135
+ q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv=3, h=h)
136
+
137
+ self.save_v(v)
138
+
139
+ dots = self.matmul1([q, k]) * self.scale
140
+
141
+ attn = self.softmax(dots)
142
+ attn = self.attn_drop(attn)
143
+
144
+ self.save_attn(attn)
145
+ attn.register_hook(self.save_attn_gradients)
146
+
147
+ out = self.matmul2([attn, v])
148
+ out = rearrange(out, 'b h n d -> b n (h d)')
149
+
150
+ out = self.proj(out)
151
+ out = self.proj_drop(out)
152
+ return out
153
+
154
+ def relprop(self, cam, **kwargs):
155
+ cam = self.proj_drop.relprop(cam, **kwargs)
156
+ cam = self.proj.relprop(cam, **kwargs)
157
+ cam = rearrange(cam, 'b n (h d) -> b h n d', h=self.num_heads)
158
+
159
+ # attn = A*V
160
+ (cam1, cam_v)= self.matmul2.relprop(cam, **kwargs)
161
+ cam1 /= 2
162
+ cam_v /= 2
163
+
164
+ self.save_v_cam(cam_v)
165
+ self.save_attn_cam(cam1)
166
+
167
+ cam1 = self.attn_drop.relprop(cam1, **kwargs)
168
+ cam1 = self.softmax.relprop(cam1, **kwargs)
169
+
170
+ # A = Q*K^T
171
+ (cam_q, cam_k) = self.matmul1.relprop(cam1, **kwargs)
172
+ cam_q /= 2
173
+ cam_k /= 2
174
+
175
+ cam_qkv = rearrange([cam_q, cam_k, cam_v], 'qkv b h n d -> b n (qkv h d)', qkv=3, h=self.num_heads)
176
+
177
+ return self.qkv.relprop(cam_qkv, **kwargs)
178
+
179
+
180
+ class Block(nn.Module):
181
+
182
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.):
183
+ super().__init__()
184
+ self.norm1 = LayerNorm(dim, eps=1e-6)
185
+ self.attn = Attention(
186
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
187
+ self.norm2 = LayerNorm(dim, eps=1e-6)
188
+ mlp_hidden_dim = int(dim * mlp_ratio)
189
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)
190
+
191
+ self.add1 = Add()
192
+ self.add2 = Add()
193
+ self.clone1 = Clone()
194
+ self.clone2 = Clone()
195
+
196
+ def forward(self, x):
197
+ x1, x2 = self.clone1(x, 2)
198
+ x = self.add1([x1, self.attn(self.norm1(x2))])
199
+ x1, x2 = self.clone2(x, 2)
200
+ x = self.add2([x1, self.mlp(self.norm2(x2))])
201
+ return x
202
+
203
+ def relprop(self, cam, **kwargs):
204
+ (cam1, cam2) = self.add2.relprop(cam, **kwargs)
205
+ cam2 = self.mlp.relprop(cam2, **kwargs)
206
+ cam2 = self.norm2.relprop(cam2, **kwargs)
207
+ cam = self.clone2.relprop((cam1, cam2), **kwargs)
208
+
209
+ (cam1, cam2) = self.add1.relprop(cam, **kwargs)
210
+ cam2 = self.attn.relprop(cam2, **kwargs)
211
+ cam2 = self.norm1.relprop(cam2, **kwargs)
212
+ cam = self.clone1.relprop((cam1, cam2), **kwargs)
213
+ return cam
214
+
215
+
216
+ class PatchEmbed(nn.Module):
217
+ """ Image to Patch Embedding
218
+ """
219
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
220
+ super().__init__()
221
+ img_size = to_2tuple(img_size)
222
+ patch_size = to_2tuple(patch_size)
223
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
224
+ self.img_size = img_size
225
+ self.patch_size = patch_size
226
+ self.num_patches = num_patches
227
+
228
+ self.proj = Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
229
+
230
+ def forward(self, x):
231
+ B, C, H, W = x.shape
232
+ # FIXME look at relaxing size constraints
233
+ assert H == self.img_size[0] and W == self.img_size[1], \
234
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
235
+ x = self.proj(x).flatten(2).transpose(1, 2)
236
+ return x
237
+
238
+ def relprop(self, cam, **kwargs):
239
+ cam = cam.transpose(1,2)
240
+ cam = cam.reshape(cam.shape[0], cam.shape[1],
241
+ (self.img_size[0] // self.patch_size[0]), (self.img_size[1] // self.patch_size[1]))
242
+ return self.proj.relprop(cam, **kwargs)
243
+
244
+
245
+ class VisionTransformer(nn.Module):
246
+ """ Vision Transformer with support for patch or hybrid CNN input stage
247
+ """
248
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
249
+ num_heads=12, mlp_ratio=4., qkv_bias=False, mlp_head=False, drop_rate=0., attn_drop_rate=0.):
250
+ super().__init__()
251
+ self.num_classes = num_classes
252
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
253
+ self.patch_embed = PatchEmbed(
254
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
255
+ num_patches = self.patch_embed.num_patches
256
+
257
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
258
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
259
+
260
+ self.blocks = nn.ModuleList([
261
+ Block(
262
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
263
+ drop=drop_rate, attn_drop=attn_drop_rate)
264
+ for i in range(depth)])
265
+
266
+ self.norm = LayerNorm(embed_dim)
267
+ if mlp_head:
268
+ # paper diagram suggests 'MLP head', but results in 4M extra parameters vs paper
269
+ self.head = Mlp(embed_dim, int(embed_dim * mlp_ratio), num_classes)
270
+ else:
271
+ # with a single Linear layer as head, the param count within rounding of paper
272
+ self.head = Linear(embed_dim, num_classes)
273
+
274
+ # FIXME not quite sure what the proper weight init is supposed to be,
275
+ # normal / trunc normal w/ std == .02 similar to other Bert like transformers
276
+ trunc_normal_(self.pos_embed, std=.02) # embeddings same as weights?
277
+ trunc_normal_(self.cls_token, std=.02)
278
+ self.apply(self._init_weights)
279
+
280
+ self.pool = IndexSelect()
281
+ self.add = Add()
282
+
283
+ self.inp_grad = None
284
+
285
+ def save_inp_grad(self,grad):
286
+ self.inp_grad = grad
287
+
288
+ def get_inp_grad(self):
289
+ return self.inp_grad
290
+
291
+
292
+ def _init_weights(self, m):
293
+ if isinstance(m, nn.Linear):
294
+ trunc_normal_(m.weight, std=.02)
295
+ if isinstance(m, nn.Linear) and m.bias is not None:
296
+ nn.init.constant_(m.bias, 0)
297
+ elif isinstance(m, nn.LayerNorm):
298
+ nn.init.constant_(m.bias, 0)
299
+ nn.init.constant_(m.weight, 1.0)
300
+
301
+ @property
302
+ def no_weight_decay(self):
303
+ return {'pos_embed', 'cls_token'}
304
+
305
+ def forward(self, x):
306
+ B = x.shape[0]
307
+ x = self.patch_embed(x)
308
+
309
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
310
+ x = torch.cat((cls_tokens, x), dim=1)
311
+ x = self.add([x, self.pos_embed])
312
+
313
+ x.register_hook(self.save_inp_grad)
314
+
315
+ for blk in self.blocks:
316
+ x = blk(x)
317
+
318
+ x = self.norm(x)
319
+ x = self.pool(x, dim=1, indices=torch.tensor(0, device=x.device))
320
+ x = x.squeeze(1)
321
+ x = self.head(x)
322
+ return x
323
+
324
+ def relprop(self, cam=None,method="transformer_attribution", is_ablation=False, start_layer=0, **kwargs):
325
+ # print(kwargs)
326
+ # print("conservation 1", cam.sum())
327
+ cam = self.head.relprop(cam, **kwargs)
328
+ cam = cam.unsqueeze(1)
329
+ cam = self.pool.relprop(cam, **kwargs)
330
+ cam = self.norm.relprop(cam, **kwargs)
331
+ for blk in reversed(self.blocks):
332
+ cam = blk.relprop(cam, **kwargs)
333
+
334
+ # print("conservation 2", cam.sum())
335
+ # print("min", cam.min())
336
+
337
+ if method == "full":
338
+ (cam, _) = self.add.relprop(cam, **kwargs)
339
+ cam = cam[:, 1:]
340
+ cam = self.patch_embed.relprop(cam, **kwargs)
341
+ # sum on channels
342
+ cam = cam.sum(dim=1)
343
+ return cam
344
+
345
+ elif method == "rollout":
346
+ # cam rollout
347
+ attn_cams = []
348
+ for blk in self.blocks:
349
+ attn_heads = blk.attn.get_attn_cam().clamp(min=0)
350
+ avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
351
+ attn_cams.append(avg_heads)
352
+ cam = compute_rollout_attention(attn_cams, start_layer=start_layer)
353
+ cam = cam[:, 0, 1:]
354
+ return cam
355
+
356
+ # our method, method name grad is legacy
357
+ elif method == "transformer_attribution" or method == "grad":
358
+ cams = []
359
+ for blk in self.blocks:
360
+ grad = blk.attn.get_attn_gradients()
361
+ cam = blk.attn.get_attn_cam()
362
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
363
+ grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
364
+ cam = grad * cam
365
+ cam = cam.clamp(min=0).mean(dim=0)
366
+ cams.append(cam.unsqueeze(0))
367
+ rollout = compute_rollout_attention(cams, start_layer=start_layer)
368
+ cam = rollout[:, 0, 1:]
369
+ return cam
370
+
371
+ elif method == "last_layer":
372
+ cam = self.blocks[-1].attn.get_attn_cam()
373
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
374
+ if is_ablation:
375
+ grad = self.blocks[-1].attn.get_attn_gradients()
376
+ grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
377
+ cam = grad * cam
378
+ cam = cam.clamp(min=0).mean(dim=0)
379
+ cam = cam[0, 1:]
380
+ return cam
381
+
382
+ elif method == "last_layer_attn":
383
+ cam = self.blocks[-1].attn.get_attn()
384
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
385
+ cam = cam.clamp(min=0).mean(dim=0)
386
+ cam = cam[0, 1:]
387
+ return cam
388
+
389
+ elif method == "second_layer":
390
+ cam = self.blocks[1].attn.get_attn_cam()
391
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
392
+ if is_ablation:
393
+ grad = self.blocks[1].attn.get_attn_gradients()
394
+ grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
395
+ cam = grad * cam
396
+ cam = cam.clamp(min=0).mean(dim=0)
397
+ cam = cam[0, 1:]
398
+ return cam
399
+
400
+
401
+ def _conv_filter(state_dict, patch_size=16):
402
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
403
+ out_dict = {}
404
+ for k, v in state_dict.items():
405
+ if 'patch_embed.proj.weight' in k:
406
+ v = v.reshape((v.shape[0], 3, patch_size, patch_size))
407
+ out_dict[k] = v
408
+ return out_dict
409
+
410
+ def vit_base_patch16_224(pretrained=False, **kwargs):
411
+ model = VisionTransformer(
412
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, **kwargs)
413
+ model.default_cfg = default_cfgs['vit_base_patch16_224']
414
+ if pretrained:
415
+ load_pretrained(
416
+ model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
417
+ return model
418
+
419
+ def vit_large_patch16_224(pretrained=False, **kwargs):
420
+ model = VisionTransformer(
421
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, **kwargs)
422
+ model.default_cfg = default_cfgs['vit_large_patch16_224']
423
+ if pretrained:
424
+ load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
425
+ return model
426
+
427
+ def deit_base_patch16_224(pretrained=False, **kwargs):
428
+ model = VisionTransformer(
429
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, **kwargs)
430
+ model.default_cfg = _cfg()
431
+ if pretrained:
432
+ checkpoint = torch.hub.load_state_dict_from_url(
433
+ url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
434
+ map_location="cpu", check_hash=True
435
+ )
436
+ model.load_state_dict(checkpoint["model"])
437
+ return model
ViT_DeiT/baselines/ViT/ViT_explanation_generator.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import numpy as np
4
+ from numpy import *
5
+
6
+ # compute rollout between attention layers
7
+ def compute_rollout_attention(all_layer_matrices, start_layer=0):
8
+ # adding residual consideration- code adapted from https://github.com/samiraabnar/attention_flow
9
+ num_tokens = all_layer_matrices[0].shape[1]
10
+ batch_size = all_layer_matrices[0].shape[0]
11
+ eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device)
12
+ all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))]
13
+ matrices_aug = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
14
+ for i in range(len(all_layer_matrices))]
15
+ joint_attention = matrices_aug[start_layer]
16
+ for i in range(start_layer+1, len(matrices_aug)):
17
+ joint_attention = matrices_aug[i].bmm(joint_attention)
18
+ return joint_attention
19
+
20
+ class LRP:
21
+ def __init__(self, model):
22
+ self.model = model
23
+ self.model.eval()
24
+
25
+ def generate_LRP(self, input, index=None, method="transformer_attribution", is_ablation=False, start_layer=0):
26
+ output = self.model(input)
27
+ kwargs = {"alpha": 1}
28
+ if index == None:
29
+ index = np.argmax(output.cpu().data.numpy(), axis=-1)
30
+
31
+ one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
32
+ one_hot[0, index] = 1
33
+ one_hot_vector = one_hot
34
+ one_hot = torch.from_numpy(one_hot).requires_grad_(True)
35
+ one_hot = torch.sum(one_hot.cuda() * output)
36
+
37
+ self.model.zero_grad()
38
+ one_hot.backward(retain_graph=True)
39
+
40
+ return self.model.relprop(torch.tensor(one_hot_vector).to(input.device), method=method, is_ablation=is_ablation,
41
+ start_layer=start_layer, **kwargs)
42
+
43
+
44
+
45
+ class Baselines:
46
+ def __init__(self, model):
47
+ self.model = model
48
+ self.model.eval()
49
+
50
+ def generate_cam_attn(self, input, index=None):
51
+ output = self.model(input.cuda(), register_hook=True)
52
+ if index == None:
53
+ index = np.argmax(output.cpu().data.numpy())
54
+
55
+ one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
56
+ one_hot[0][index] = 1
57
+ one_hot = torch.from_numpy(one_hot).requires_grad_(True)
58
+ one_hot = torch.sum(one_hot.cuda() * output)
59
+
60
+ self.model.zero_grad()
61
+ one_hot.backward(retain_graph=True)
62
+ #################### attn
63
+ grad = self.model.blocks[-1].attn.get_attn_gradients()
64
+ cam = self.model.blocks[-1].attn.get_attention_map()
65
+ cam = cam[0, :, 0, 1:].reshape(-1, 14, 14)
66
+ grad = grad[0, :, 0, 1:].reshape(-1, 14, 14)
67
+ grad = grad.mean(dim=[1, 2], keepdim=True)
68
+ cam = (cam * grad).mean(0).clamp(min=0)
69
+ cam = (cam - cam.min()) / (cam.max() - cam.min())
70
+
71
+ return cam
72
+ #################### attn
73
+
74
+ def generate_rollout(self, input, start_layer=0):
75
+ self.model(input)
76
+ blocks = self.model.blocks
77
+ all_layer_attentions = []
78
+ for blk in blocks:
79
+ attn_heads = blk.attn.get_attention_map()
80
+ avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
81
+ all_layer_attentions.append(avg_heads)
82
+ rollout = compute_rollout_attention(all_layer_attentions, start_layer=start_layer)
83
+ return rollout[:,0, 1:]
ViT_DeiT/baselines/ViT/ViT_new.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Vision Transformer (ViT) in PyTorch
2
+ Hacked together by / Copyright 2020 Ross Wightman
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+ from functools import partial
7
+ from einops import rearrange
8
+
9
+ from baselines.ViT.helpers import load_pretrained
10
+ from baselines.ViT.weight_init import trunc_normal_
11
+ from baselines.ViT.layer_helpers import to_2tuple
12
+
13
+
14
+ def _cfg(url='', **kwargs):
15
+ return {
16
+ 'url': url,
17
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
18
+ 'crop_pct': .9, 'interpolation': 'bicubic',
19
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head',
20
+ **kwargs
21
+ }
22
+
23
+
24
+ default_cfgs = {
25
+ # patch models
26
+ 'vit_small_patch16_224': _cfg(
27
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
28
+ ),
29
+ 'vit_base_patch16_224': _cfg(
30
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
31
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
32
+ ),
33
+ 'vit_large_patch16_224': _cfg(
34
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
35
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
36
+ }
37
+
38
+ class Mlp(nn.Module):
39
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
40
+ super().__init__()
41
+ out_features = out_features or in_features
42
+ hidden_features = hidden_features or in_features
43
+ self.fc1 = nn.Linear(in_features, hidden_features)
44
+ self.act = act_layer()
45
+ self.fc2 = nn.Linear(hidden_features, out_features)
46
+ self.drop = nn.Dropout(drop)
47
+
48
+ def forward(self, x):
49
+ x = self.fc1(x)
50
+ x = self.act(x)
51
+ x = self.drop(x)
52
+ x = self.fc2(x)
53
+ x = self.drop(x)
54
+ return x
55
+
56
+
57
+ class Attention(nn.Module):
58
+ def __init__(self, dim, num_heads=8, qkv_bias=False,attn_drop=0., proj_drop=0.):
59
+ super().__init__()
60
+ self.num_heads = num_heads
61
+ head_dim = dim // num_heads
62
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
63
+ self.scale = head_dim ** -0.5
64
+
65
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
66
+ self.attn_drop = nn.Dropout(attn_drop)
67
+ self.proj = nn.Linear(dim, dim)
68
+ self.proj_drop = nn.Dropout(proj_drop)
69
+
70
+ self.attn_gradients = None
71
+ self.attention_map = None
72
+
73
+ def save_attn_gradients(self, attn_gradients):
74
+ self.attn_gradients = attn_gradients
75
+
76
+ def get_attn_gradients(self):
77
+ return self.attn_gradients
78
+
79
+ def save_attention_map(self, attention_map):
80
+ self.attention_map = attention_map
81
+
82
+ def get_attention_map(self):
83
+ return self.attention_map
84
+
85
+ def forward(self, x, register_hook=False):
86
+ b, n, _, h = *x.shape, self.num_heads
87
+
88
+ # self.save_output(x)
89
+ # x.register_hook(self.save_output_grad)
90
+
91
+ qkv = self.qkv(x)
92
+ q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = h)
93
+
94
+ dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
95
+
96
+ attn = dots.softmax(dim=-1)
97
+ attn = self.attn_drop(attn)
98
+
99
+ out = torch.einsum('bhij,bhjd->bhid', attn, v)
100
+
101
+ self.save_attention_map(attn)
102
+ if register_hook:
103
+ attn.register_hook(self.save_attn_gradients)
104
+
105
+ out = rearrange(out, 'b h n d -> b n (h d)')
106
+ out = self.proj(out)
107
+ out = self.proj_drop(out)
108
+ return out
109
+
110
+
111
+ class Block(nn.Module):
112
+
113
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
114
+ super().__init__()
115
+ self.norm1 = norm_layer(dim)
116
+ self.attn = Attention(
117
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
118
+ self.norm2 = norm_layer(dim)
119
+ mlp_hidden_dim = int(dim * mlp_ratio)
120
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
121
+
122
+ def forward(self, x, register_hook=False):
123
+ x = x + self.attn(self.norm1(x), register_hook=register_hook)
124
+ x = x + self.mlp(self.norm2(x))
125
+ return x
126
+
127
+
128
+ class PatchEmbed(nn.Module):
129
+ """ Image to Patch Embedding
130
+ """
131
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
132
+ super().__init__()
133
+ img_size = to_2tuple(img_size)
134
+ patch_size = to_2tuple(patch_size)
135
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
136
+ self.img_size = img_size
137
+ self.patch_size = patch_size
138
+ self.num_patches = num_patches
139
+
140
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
141
+
142
+ def forward(self, x):
143
+ B, C, H, W = x.shape
144
+ # FIXME look at relaxing size constraints
145
+ assert H == self.img_size[0] and W == self.img_size[1], \
146
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
147
+ x = self.proj(x).flatten(2).transpose(1, 2)
148
+ return x
149
+
150
+ class VisionTransformer(nn.Module):
151
+ """ Vision Transformer
152
+ """
153
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
154
+ num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0., norm_layer=nn.LayerNorm):
155
+ super().__init__()
156
+ self.num_classes = num_classes
157
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
158
+ self.patch_embed = PatchEmbed(
159
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
160
+ num_patches = self.patch_embed.num_patches
161
+
162
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
163
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
164
+ self.pos_drop = nn.Dropout(p=drop_rate)
165
+
166
+ self.blocks = nn.ModuleList([
167
+ Block(
168
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
169
+ drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer)
170
+ for i in range(depth)])
171
+ self.norm = norm_layer(embed_dim)
172
+
173
+ # Classifier head
174
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
175
+
176
+ trunc_normal_(self.pos_embed, std=.02)
177
+ trunc_normal_(self.cls_token, std=.02)
178
+ self.apply(self._init_weights)
179
+
180
+ def _init_weights(self, m):
181
+ if isinstance(m, nn.Linear):
182
+ trunc_normal_(m.weight, std=.02)
183
+ if isinstance(m, nn.Linear) and m.bias is not None:
184
+ nn.init.constant_(m.bias, 0)
185
+ elif isinstance(m, nn.LayerNorm):
186
+ nn.init.constant_(m.bias, 0)
187
+ nn.init.constant_(m.weight, 1.0)
188
+
189
+ @torch.jit.ignore
190
+ def no_weight_decay(self):
191
+ return {'pos_embed', 'cls_token'}
192
+
193
+ def forward(self, x, register_hook=False):
194
+ B = x.shape[0]
195
+ x = self.patch_embed(x)
196
+
197
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
198
+ x = torch.cat((cls_tokens, x), dim=1)
199
+ x = x + self.pos_embed
200
+ x = self.pos_drop(x)
201
+
202
+ for blk in self.blocks:
203
+ x = blk(x, register_hook=register_hook)
204
+
205
+ x = self.norm(x)
206
+ x = x[:, 0]
207
+ x = self.head(x)
208
+ return x
209
+
210
+
211
+ def _conv_filter(state_dict, patch_size=16):
212
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
213
+ out_dict = {}
214
+ for k, v in state_dict.items():
215
+ if 'patch_embed.proj.weight' in k:
216
+ v = v.reshape((v.shape[0], 3, patch_size, patch_size))
217
+ out_dict[k] = v
218
+ return out_dict
219
+
220
+
221
+ def vit_base_patch16_224(pretrained=False, **kwargs):
222
+ model = VisionTransformer(
223
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
224
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
225
+ model.default_cfg = default_cfgs['vit_base_patch16_224']
226
+ if pretrained:
227
+ load_pretrained(
228
+ model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
229
+ return model
230
+
231
+ def vit_large_patch16_224(pretrained=False, **kwargs):
232
+ model = VisionTransformer(
233
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
234
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
235
+ model.default_cfg = default_cfgs['vit_large_patch16_224']
236
+ if pretrained:
237
+ load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
238
+ return model
ViT_DeiT/baselines/ViT/ViT_orig_LRP.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Vision Transformer (ViT) in PyTorch
2
+ Hacked together by / Copyright 2020 Ross Wightman
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange
7
+ from modules.layers_lrp import *
8
+
9
+ from baselines.ViT.helpers import load_pretrained
10
+ from baselines.ViT.weight_init import trunc_normal_
11
+ from baselines.ViT.layer_helpers import to_2tuple
12
+
13
+
14
+ def _cfg(url='', **kwargs):
15
+ return {
16
+ 'url': url,
17
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
18
+ 'crop_pct': .9, 'interpolation': 'bicubic',
19
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head',
20
+ **kwargs
21
+ }
22
+
23
+
24
+ default_cfgs = {
25
+ # patch models
26
+ 'vit_small_patch16_224': _cfg(
27
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
28
+ ),
29
+ 'vit_base_patch16_224': _cfg(
30
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
31
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
32
+ ),
33
+ 'vit_large_patch16_224': _cfg(
34
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
35
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
36
+ }
37
+
38
+ def compute_rollout_attention(all_layer_matrices, start_layer=0):
39
+ # adding residual consideration
40
+ num_tokens = all_layer_matrices[0].shape[1]
41
+ batch_size = all_layer_matrices[0].shape[0]
42
+ eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device)
43
+ all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))]
44
+ # all_layer_matrices = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
45
+ # for i in range(len(all_layer_matrices))]
46
+ joint_attention = all_layer_matrices[start_layer]
47
+ for i in range(start_layer+1, len(all_layer_matrices)):
48
+ joint_attention = all_layer_matrices[i].bmm(joint_attention)
49
+ return joint_attention
50
+
51
+ class Mlp(nn.Module):
52
+ def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
53
+ super().__init__()
54
+ out_features = out_features or in_features
55
+ hidden_features = hidden_features or in_features
56
+ self.fc1 = Linear(in_features, hidden_features)
57
+ self.act = GELU()
58
+ self.fc2 = Linear(hidden_features, out_features)
59
+ self.drop = Dropout(drop)
60
+
61
+ def forward(self, x):
62
+ x = self.fc1(x)
63
+ x = self.act(x)
64
+ x = self.drop(x)
65
+ x = self.fc2(x)
66
+ x = self.drop(x)
67
+ return x
68
+
69
+ def relprop(self, cam, **kwargs):
70
+ cam = self.drop.relprop(cam, **kwargs)
71
+ cam = self.fc2.relprop(cam, **kwargs)
72
+ cam = self.act.relprop(cam, **kwargs)
73
+ cam = self.fc1.relprop(cam, **kwargs)
74
+ return cam
75
+
76
+
77
+ class Attention(nn.Module):
78
+ def __init__(self, dim, num_heads=8, qkv_bias=False,attn_drop=0., proj_drop=0.):
79
+ super().__init__()
80
+ self.num_heads = num_heads
81
+ head_dim = dim // num_heads
82
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
83
+ self.scale = head_dim ** -0.5
84
+
85
+ # A = Q*K^T
86
+ self.matmul1 = einsum('bhid,bhjd->bhij')
87
+ # attn = A*V
88
+ self.matmul2 = einsum('bhij,bhjd->bhid')
89
+
90
+ self.qkv = Linear(dim, dim * 3, bias=qkv_bias)
91
+ self.attn_drop = Dropout(attn_drop)
92
+ self.proj = Linear(dim, dim)
93
+ self.proj_drop = Dropout(proj_drop)
94
+ self.softmax = Softmax(dim=-1)
95
+
96
+ self.attn_cam = None
97
+ self.attn = None
98
+ self.v = None
99
+ self.v_cam = None
100
+ self.attn_gradients = None
101
+
102
+ def get_attn(self):
103
+ return self.attn
104
+
105
+ def save_attn(self, attn):
106
+ self.attn = attn
107
+
108
+ def save_attn_cam(self, cam):
109
+ self.attn_cam = cam
110
+
111
+ def get_attn_cam(self):
112
+ return self.attn_cam
113
+
114
+ def get_v(self):
115
+ return self.v
116
+
117
+ def save_v(self, v):
118
+ self.v = v
119
+
120
+ def save_v_cam(self, cam):
121
+ self.v_cam = cam
122
+
123
+ def get_v_cam(self):
124
+ return self.v_cam
125
+
126
+ def save_attn_gradients(self, attn_gradients):
127
+ self.attn_gradients = attn_gradients
128
+
129
+ def get_attn_gradients(self):
130
+ return self.attn_gradients
131
+
132
+ def forward(self, x):
133
+ b, n, _, h = *x.shape, self.num_heads
134
+ qkv = self.qkv(x)
135
+ q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv=3, h=h)
136
+
137
+ self.save_v(v)
138
+
139
+ dots = self.matmul1([q, k]) * self.scale
140
+
141
+ attn = self.softmax(dots)
142
+ attn = self.attn_drop(attn)
143
+
144
+ self.save_attn(attn)
145
+ attn.register_hook(self.save_attn_gradients)
146
+
147
+ out = self.matmul2([attn, v])
148
+ out = rearrange(out, 'b h n d -> b n (h d)')
149
+
150
+ out = self.proj(out)
151
+ out = self.proj_drop(out)
152
+ return out
153
+
154
+ def relprop(self, cam, **kwargs):
155
+ cam = self.proj_drop.relprop(cam, **kwargs)
156
+ cam = self.proj.relprop(cam, **kwargs)
157
+ cam = rearrange(cam, 'b n (h d) -> b h n d', h=self.num_heads)
158
+
159
+ # attn = A*V
160
+ (cam1, cam_v)= self.matmul2.relprop(cam, **kwargs)
161
+ cam1 /= 2
162
+ cam_v /= 2
163
+
164
+ self.save_v_cam(cam_v)
165
+ self.save_attn_cam(cam1)
166
+
167
+ cam1 = self.attn_drop.relprop(cam1, **kwargs)
168
+ cam1 = self.softmax.relprop(cam1, **kwargs)
169
+
170
+ # A = Q*K^T
171
+ (cam_q, cam_k) = self.matmul1.relprop(cam1, **kwargs)
172
+ cam_q /= 2
173
+ cam_k /= 2
174
+
175
+ cam_qkv = rearrange([cam_q, cam_k, cam_v], 'qkv b h n d -> b n (qkv h d)', qkv=3, h=self.num_heads)
176
+
177
+ return self.qkv.relprop(cam_qkv, **kwargs)
178
+
179
+
180
+ class Block(nn.Module):
181
+
182
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.):
183
+ super().__init__()
184
+ self.norm1 = LayerNorm(dim, eps=1e-6)
185
+ self.attn = Attention(
186
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
187
+ self.norm2 = LayerNorm(dim, eps=1e-6)
188
+ mlp_hidden_dim = int(dim * mlp_ratio)
189
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)
190
+
191
+ self.add1 = Add()
192
+ self.add2 = Add()
193
+ self.clone1 = Clone()
194
+ self.clone2 = Clone()
195
+
196
+ def forward(self, x):
197
+ x1, x2 = self.clone1(x, 2)
198
+ x = self.add1([x1, self.attn(self.norm1(x2))])
199
+ x1, x2 = self.clone2(x, 2)
200
+ x = self.add2([x1, self.mlp(self.norm2(x2))])
201
+ return x
202
+
203
+ def relprop(self, cam, **kwargs):
204
+ (cam1, cam2) = self.add2.relprop(cam, **kwargs)
205
+ cam2 = self.mlp.relprop(cam2, **kwargs)
206
+ cam2 = self.norm2.relprop(cam2, **kwargs)
207
+ cam = self.clone2.relprop((cam1, cam2), **kwargs)
208
+
209
+ (cam1, cam2) = self.add1.relprop(cam, **kwargs)
210
+ cam2 = self.attn.relprop(cam2, **kwargs)
211
+ cam2 = self.norm1.relprop(cam2, **kwargs)
212
+ cam = self.clone1.relprop((cam1, cam2), **kwargs)
213
+ return cam
214
+
215
+
216
+ class PatchEmbed(nn.Module):
217
+ """ Image to Patch Embedding
218
+ """
219
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
220
+ super().__init__()
221
+ img_size = to_2tuple(img_size)
222
+ patch_size = to_2tuple(patch_size)
223
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
224
+ self.img_size = img_size
225
+ self.patch_size = patch_size
226
+ self.num_patches = num_patches
227
+
228
+ self.proj = Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
229
+
230
+ def forward(self, x):
231
+ B, C, H, W = x.shape
232
+ # FIXME look at relaxing size constraints
233
+ assert H == self.img_size[0] and W == self.img_size[1], \
234
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
235
+ x = self.proj(x).flatten(2).transpose(1, 2)
236
+ return x
237
+
238
+ def relprop(self, cam, **kwargs):
239
+ cam = cam.transpose(1,2)
240
+ cam = cam.reshape(cam.shape[0], cam.shape[1],
241
+ (self.img_size[0] // self.patch_size[0]), (self.img_size[1] // self.patch_size[1]))
242
+ return self.proj.relprop(cam, **kwargs)
243
+
244
+
245
+ class VisionTransformer(nn.Module):
246
+ """ Vision Transformer with support for patch or hybrid CNN input stage
247
+ """
248
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
249
+ num_heads=12, mlp_ratio=4., qkv_bias=False, mlp_head=False, drop_rate=0., attn_drop_rate=0.):
250
+ super().__init__()
251
+ self.num_classes = num_classes
252
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
253
+ self.patch_embed = PatchEmbed(
254
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
255
+ num_patches = self.patch_embed.num_patches
256
+
257
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
258
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
259
+
260
+ self.blocks = nn.ModuleList([
261
+ Block(
262
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
263
+ drop=drop_rate, attn_drop=attn_drop_rate)
264
+ for i in range(depth)])
265
+
266
+ self.norm = LayerNorm(embed_dim)
267
+ if mlp_head:
268
+ # paper diagram suggests 'MLP head', but results in 4M extra parameters vs paper
269
+ self.head = Mlp(embed_dim, int(embed_dim * mlp_ratio), num_classes)
270
+ else:
271
+ # with a single Linear layer as head, the param count within rounding of paper
272
+ self.head = Linear(embed_dim, num_classes)
273
+
274
+ # FIXME not quite sure what the proper weight init is supposed to be,
275
+ # normal / trunc normal w/ std == .02 similar to other Bert like transformers
276
+ trunc_normal_(self.pos_embed, std=.02) # embeddings same as weights?
277
+ trunc_normal_(self.cls_token, std=.02)
278
+ self.apply(self._init_weights)
279
+
280
+ self.pool = IndexSelect()
281
+ self.add = Add()
282
+
283
+ self.inp_grad = None
284
+
285
+ def save_inp_grad(self,grad):
286
+ self.inp_grad = grad
287
+
288
+ def get_inp_grad(self):
289
+ return self.inp_grad
290
+
291
+
292
+ def _init_weights(self, m):
293
+ if isinstance(m, nn.Linear):
294
+ trunc_normal_(m.weight, std=.02)
295
+ if isinstance(m, nn.Linear) and m.bias is not None:
296
+ nn.init.constant_(m.bias, 0)
297
+ elif isinstance(m, nn.LayerNorm):
298
+ nn.init.constant_(m.bias, 0)
299
+ nn.init.constant_(m.weight, 1.0)
300
+
301
+ @property
302
+ def no_weight_decay(self):
303
+ return {'pos_embed', 'cls_token'}
304
+
305
+ def forward(self, x):
306
+ B = x.shape[0]
307
+ x = self.patch_embed(x)
308
+
309
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
310
+ x = torch.cat((cls_tokens, x), dim=1)
311
+ x = self.add([x, self.pos_embed])
312
+
313
+ x.register_hook(self.save_inp_grad)
314
+
315
+ for blk in self.blocks:
316
+ x = blk(x)
317
+
318
+ x = self.norm(x)
319
+ x = self.pool(x, dim=1, indices=torch.tensor(0, device=x.device))
320
+ x = x.squeeze(1)
321
+ x = self.head(x)
322
+ return x
323
+
324
+ def relprop(self, cam=None,method="grad", is_ablation=False, start_layer=0, **kwargs):
325
+ # print(kwargs)
326
+ # print("conservation 1", cam.sum())
327
+ cam = self.head.relprop(cam, **kwargs)
328
+ cam = cam.unsqueeze(1)
329
+ cam = self.pool.relprop(cam, **kwargs)
330
+ cam = self.norm.relprop(cam, **kwargs)
331
+ for blk in reversed(self.blocks):
332
+ cam = blk.relprop(cam, **kwargs)
333
+
334
+ # print("conservation 2", cam.sum())
335
+ # print("min", cam.min())
336
+
337
+ if method == "full":
338
+ (cam, _) = self.add.relprop(cam, **kwargs)
339
+ cam = cam[:, 1:]
340
+ cam = self.patch_embed.relprop(cam, **kwargs)
341
+ # sum on channels
342
+ cam = cam.sum(dim=1)
343
+ return cam
344
+
345
+ elif method == "rollout":
346
+ # cam rollout
347
+ attn_cams = []
348
+ for blk in self.blocks:
349
+ attn_heads = blk.attn.get_attn_cam().clamp(min=0)
350
+ avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
351
+ attn_cams.append(avg_heads)
352
+ cam = compute_rollout_attention(attn_cams, start_layer=start_layer)
353
+ cam = cam[:, 0, 1:]
354
+ return cam
355
+
356
+ elif method == "grad":
357
+ cams = []
358
+ for blk in self.blocks:
359
+ grad = blk.attn.get_attn_gradients()
360
+ cam = blk.attn.get_attn_cam()
361
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
362
+ grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
363
+ cam = grad * cam
364
+ cam = cam.clamp(min=0).mean(dim=0)
365
+ cams.append(cam.unsqueeze(0))
366
+ rollout = compute_rollout_attention(cams, start_layer=start_layer)
367
+ cam = rollout[:, 0, 1:]
368
+ return cam
369
+
370
+ elif method == "last_layer":
371
+ cam = self.blocks[-1].attn.get_attn_cam()
372
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
373
+ if is_ablation:
374
+ grad = self.blocks[-1].attn.get_attn_gradients()
375
+ grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
376
+ cam = grad * cam
377
+ cam = cam.clamp(min=0).mean(dim=0)
378
+ cam = cam[0, 1:]
379
+ return cam
380
+
381
+ elif method == "last_layer_attn":
382
+ cam = self.blocks[-1].attn.get_attn()
383
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
384
+ cam = cam.clamp(min=0).mean(dim=0)
385
+ cam = cam[0, 1:]
386
+ return cam
387
+
388
+ elif method == "second_layer":
389
+ cam = self.blocks[1].attn.get_attn_cam()
390
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
391
+ if is_ablation:
392
+ grad = self.blocks[1].attn.get_attn_gradients()
393
+ grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
394
+ cam = grad * cam
395
+ cam = cam.clamp(min=0).mean(dim=0)
396
+ cam = cam[0, 1:]
397
+ return cam
398
+
399
+
400
+ def _conv_filter(state_dict, patch_size=16):
401
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
402
+ out_dict = {}
403
+ for k, v in state_dict.items():
404
+ if 'patch_embed.proj.weight' in k:
405
+ v = v.reshape((v.shape[0], 3, patch_size, patch_size))
406
+ out_dict[k] = v
407
+ return out_dict
408
+
409
+
410
+ def vit_base_patch16_224(pretrained=False, **kwargs):
411
+ model = VisionTransformer(
412
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, **kwargs)
413
+ model.default_cfg = default_cfgs['vit_base_patch16_224']
414
+ if pretrained:
415
+ load_pretrained(
416
+ model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
417
+ return model
418
+
419
+ def vit_large_patch16_224(pretrained=False, **kwargs):
420
+ model = VisionTransformer(
421
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, **kwargs)
422
+ model.default_cfg = default_cfgs['vit_large_patch16_224']
423
+ if pretrained:
424
+ load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
425
+ return model
ViT_DeiT/baselines/ViT/__pycache__/ViT_LRP.cpython-38.pyc ADDED
Binary file (14 kB). View file
 
ViT_DeiT/baselines/ViT/__pycache__/ViT_explanation_generator.cpython-38.pyc ADDED
Binary file (3.44 kB). View file
 
ViT_DeiT/baselines/ViT/__pycache__/helpers.cpython-38.pyc ADDED
Binary file (7.25 kB). View file
 
ViT_DeiT/baselines/ViT/__pycache__/layer_helpers.cpython-38.pyc ADDED
Binary file (766 Bytes). View file
 
ViT_DeiT/baselines/ViT/__pycache__/weight_init.cpython-38.pyc ADDED
Binary file (1.92 kB). View file
 
ViT_DeiT/baselines/ViT/generate_visualizations.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from tqdm import tqdm
3
+ import h5py
4
+
5
+ import argparse
6
+
7
+ # Import saliency methods and models
8
+ from misc_functions import *
9
+
10
+ from ViT_explanation_generator import Baselines, LRP
11
+ from ViT_new import vit_base_patch16_224
12
+ from ViT_LRP import vit_base_patch16_224 as vit_LRP
13
+ from ViT_orig_LRP import vit_base_patch16_224 as vit_orig_LRP
14
+
15
+ from torchvision.datasets import ImageNet
16
+
17
+
18
+ def normalize(tensor,
19
+ mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
20
+ dtype = tensor.dtype
21
+ mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
22
+ std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
23
+ tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None])
24
+ return tensor
25
+
26
+
27
+ def compute_saliency_and_save(args):
28
+ first = True
29
+ with h5py.File(os.path.join(args.method_dir, 'results.hdf5'), 'a') as f:
30
+ data_cam = f.create_dataset('vis',
31
+ (1, 1, 224, 224),
32
+ maxshape=(None, 1, 224, 224),
33
+ dtype=np.float32,
34
+ compression="gzip")
35
+ data_image = f.create_dataset('image',
36
+ (1, 3, 224, 224),
37
+ maxshape=(None, 3, 224, 224),
38
+ dtype=np.float32,
39
+ compression="gzip")
40
+ data_target = f.create_dataset('target',
41
+ (1,),
42
+ maxshape=(None,),
43
+ dtype=np.int32,
44
+ compression="gzip")
45
+ for batch_idx, (data, target) in enumerate(tqdm(sample_loader)):
46
+ if first:
47
+ first = False
48
+ data_cam.resize(data_cam.shape[0] + data.shape[0] - 1, axis=0)
49
+ data_image.resize(data_image.shape[0] + data.shape[0] - 1, axis=0)
50
+ data_target.resize(data_target.shape[0] + data.shape[0] - 1, axis=0)
51
+ else:
52
+ data_cam.resize(data_cam.shape[0] + data.shape[0], axis=0)
53
+ data_image.resize(data_image.shape[0] + data.shape[0], axis=0)
54
+ data_target.resize(data_target.shape[0] + data.shape[0], axis=0)
55
+
56
+ # Add data
57
+ data_image[-data.shape[0]:] = data.data.cpu().numpy()
58
+ data_target[-data.shape[0]:] = target.data.cpu().numpy()
59
+
60
+ target = target.to(device)
61
+
62
+ data = normalize(data)
63
+ data = data.to(device)
64
+ data.requires_grad_()
65
+
66
+ index = None
67
+ if args.vis_class == 'target':
68
+ index = target
69
+
70
+ if args.method == 'rollout':
71
+ Res = baselines.generate_rollout(data, start_layer=1).reshape(data.shape[0], 1, 14, 14)
72
+ # Res = Res - Res.mean()
73
+
74
+ elif args.method == 'lrp':
75
+ Res = lrp.generate_LRP(data, start_layer=1, index=index).reshape(data.shape[0], 1, 14, 14)
76
+ # Res = Res - Res.mean()
77
+
78
+ elif args.method == 'transformer_attribution':
79
+ Res = lrp.generate_LRP(data, start_layer=1, method="grad", index=index).reshape(data.shape[0], 1, 14, 14)
80
+ # Res = Res - Res.mean()
81
+
82
+ elif args.method == 'full_lrp':
83
+ Res = orig_lrp.generate_LRP(data, method="full", index=index).reshape(data.shape[0], 1, 224, 224)
84
+ # Res = Res - Res.mean()
85
+
86
+ elif args.method == 'lrp_last_layer':
87
+ Res = orig_lrp.generate_LRP(data, method="last_layer", is_ablation=args.is_ablation, index=index) \
88
+ .reshape(data.shape[0], 1, 14, 14)
89
+ # Res = Res - Res.mean()
90
+
91
+ elif args.method == 'attn_last_layer':
92
+ Res = lrp.generate_LRP(data, method="last_layer_attn", is_ablation=args.is_ablation) \
93
+ .reshape(data.shape[0], 1, 14, 14)
94
+
95
+ elif args.method == 'attn_gradcam':
96
+ Res = baselines.generate_cam_attn(data, index=index).reshape(data.shape[0], 1, 14, 14)
97
+
98
+ if args.method != 'full_lrp' and args.method != 'input_grads':
99
+ Res = torch.nn.functional.interpolate(Res, scale_factor=16, mode='bilinear').cuda()
100
+ Res = (Res - Res.min()) / (Res.max() - Res.min())
101
+
102
+ data_cam[-data.shape[0]:] = Res.data.cpu().numpy()
103
+
104
+
105
+ if __name__ == "__main__":
106
+ parser = argparse.ArgumentParser(description='Train a segmentation')
107
+ parser.add_argument('--batch-size', type=int,
108
+ default=1,
109
+ help='')
110
+ parser.add_argument('--method', type=str,
111
+ default='grad_rollout',
112
+ choices=['rollout', 'lrp', 'transformer_attribution', 'full_lrp', 'lrp_last_layer',
113
+ 'attn_last_layer', 'attn_gradcam'],
114
+ help='')
115
+ parser.add_argument('--lmd', type=float,
116
+ default=10,
117
+ help='')
118
+ parser.add_argument('--vis-class', type=str,
119
+ default='top',
120
+ choices=['top', 'target', 'index'],
121
+ help='')
122
+ parser.add_argument('--class-id', type=int,
123
+ default=0,
124
+ help='')
125
+ parser.add_argument('--cls-agn', action='store_true',
126
+ default=False,
127
+ help='')
128
+ parser.add_argument('--no-ia', action='store_true',
129
+ default=False,
130
+ help='')
131
+ parser.add_argument('--no-fx', action='store_true',
132
+ default=False,
133
+ help='')
134
+ parser.add_argument('--no-fgx', action='store_true',
135
+ default=False,
136
+ help='')
137
+ parser.add_argument('--no-m', action='store_true',
138
+ default=False,
139
+ help='')
140
+ parser.add_argument('--no-reg', action='store_true',
141
+ default=False,
142
+ help='')
143
+ parser.add_argument('--is-ablation', type=bool,
144
+ default=False,
145
+ help='')
146
+ parser.add_argument('--imagenet-validation-path', type=str,
147
+ required=True,
148
+ help='')
149
+ args = parser.parse_args()
150
+
151
+ # PATH variables
152
+ PATH = os.path.dirname(os.path.abspath(__file__)) + '/'
153
+ os.makedirs(os.path.join(PATH, 'visualizations'), exist_ok=True)
154
+
155
+ try:
156
+ os.remove(os.path.join(PATH, 'visualizations/{}/{}/results.hdf5'.format(args.method,
157
+ args.vis_class)))
158
+ except OSError:
159
+ pass
160
+
161
+
162
+ os.makedirs(os.path.join(PATH, 'visualizations/{}'.format(args.method)), exist_ok=True)
163
+ if args.vis_class == 'index':
164
+ os.makedirs(os.path.join(PATH, 'visualizations/{}/{}_{}'.format(args.method,
165
+ args.vis_class,
166
+ args.class_id)), exist_ok=True)
167
+ args.method_dir = os.path.join(PATH, 'visualizations/{}/{}_{}'.format(args.method,
168
+ args.vis_class,
169
+ args.class_id))
170
+ else:
171
+ ablation_fold = 'ablation' if args.is_ablation else 'not_ablation'
172
+ os.makedirs(os.path.join(PATH, 'visualizations/{}/{}/{}'.format(args.method,
173
+ args.vis_class, ablation_fold)), exist_ok=True)
174
+ args.method_dir = os.path.join(PATH, 'visualizations/{}/{}/{}'.format(args.method,
175
+ args.vis_class, ablation_fold))
176
+
177
+ cuda = torch.cuda.is_available()
178
+ device = torch.device("cuda" if cuda else "cpu")
179
+
180
+ # Model
181
+ model = vit_base_patch16_224(pretrained=True).cuda()
182
+ baselines = Baselines(model)
183
+
184
+ # LRP
185
+ model_LRP = vit_LRP(pretrained=True).cuda()
186
+ model_LRP.eval()
187
+ lrp = LRP(model_LRP)
188
+
189
+ # orig LRP
190
+ model_orig_LRP = vit_orig_LRP(pretrained=True).cuda()
191
+ model_orig_LRP.eval()
192
+ orig_lrp = LRP(model_orig_LRP)
193
+
194
+ # Dataset loader for sample images
195
+ transform = transforms.Compose([
196
+ transforms.Resize((224, 224)),
197
+ transforms.ToTensor(),
198
+ ])
199
+
200
+ imagenet_ds = ImageNet(args.imagenet_validation_path, split='val', download=False, transform=transform)
201
+ sample_loader = torch.utils.data.DataLoader(
202
+ imagenet_ds,
203
+ batch_size=args.batch_size,
204
+ shuffle=False,
205
+ num_workers=4
206
+ )
207
+
208
+ compute_saliency_and_save(args)
ViT_DeiT/baselines/ViT/helpers.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Model creation / weight loading / state_dict helpers
2
+
3
+ Hacked together by / Copyright 2020 Ross Wightman
4
+ """
5
+ import logging
6
+ import os
7
+ import math
8
+ from collections import OrderedDict
9
+ from copy import deepcopy
10
+ from typing import Callable
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.utils.model_zoo as model_zoo
15
+
16
+ _logger = logging.getLogger(__name__)
17
+
18
+
19
+ def load_state_dict(checkpoint_path, use_ema=False):
20
+ if checkpoint_path and os.path.isfile(checkpoint_path):
21
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
22
+ state_dict_key = 'state_dict'
23
+ if isinstance(checkpoint, dict):
24
+ if use_ema and 'state_dict_ema' in checkpoint:
25
+ state_dict_key = 'state_dict_ema'
26
+ if state_dict_key and state_dict_key in checkpoint:
27
+ new_state_dict = OrderedDict()
28
+ for k, v in checkpoint[state_dict_key].items():
29
+ # strip `module.` prefix
30
+ name = k[7:] if k.startswith('module') else k
31
+ new_state_dict[name] = v
32
+ state_dict = new_state_dict
33
+ else:
34
+ state_dict = checkpoint
35
+ _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
36
+ return state_dict
37
+ else:
38
+ _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
39
+ raise FileNotFoundError()
40
+
41
+
42
+ def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
43
+ state_dict = load_state_dict(checkpoint_path, use_ema)
44
+ model.load_state_dict(state_dict, strict=strict)
45
+
46
+
47
+ def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
48
+ resume_epoch = None
49
+ if os.path.isfile(checkpoint_path):
50
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
51
+ if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
52
+ if log_info:
53
+ _logger.info('Restoring model state from checkpoint...')
54
+ new_state_dict = OrderedDict()
55
+ for k, v in checkpoint['state_dict'].items():
56
+ name = k[7:] if k.startswith('module') else k
57
+ new_state_dict[name] = v
58
+ model.load_state_dict(new_state_dict)
59
+
60
+ if optimizer is not None and 'optimizer' in checkpoint:
61
+ if log_info:
62
+ _logger.info('Restoring optimizer state from checkpoint...')
63
+ optimizer.load_state_dict(checkpoint['optimizer'])
64
+
65
+ if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
66
+ if log_info:
67
+ _logger.info('Restoring AMP loss scaler state from checkpoint...')
68
+ loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])
69
+
70
+ if 'epoch' in checkpoint:
71
+ resume_epoch = checkpoint['epoch']
72
+ if 'version' in checkpoint and checkpoint['version'] > 1:
73
+ resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
74
+
75
+ if log_info:
76
+ _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
77
+ else:
78
+ model.load_state_dict(checkpoint)
79
+ if log_info:
80
+ _logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
81
+ return resume_epoch
82
+ else:
83
+ _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
84
+ raise FileNotFoundError()
85
+
86
+
87
+ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True):
88
+ if cfg is None:
89
+ cfg = getattr(model, 'default_cfg')
90
+ if cfg is None or 'url' not in cfg or not cfg['url']:
91
+ _logger.warning("Pretrained model URL is invalid, using random initialization.")
92
+ return
93
+
94
+ state_dict = model_zoo.load_url(cfg['url'], progress=False, map_location='cpu')
95
+
96
+ if filter_fn is not None:
97
+ state_dict = filter_fn(state_dict)
98
+
99
+ if in_chans == 1:
100
+ conv1_name = cfg['first_conv']
101
+ _logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name)
102
+ conv1_weight = state_dict[conv1_name + '.weight']
103
+ # Some weights are in torch.half, ensure it's float for sum on CPU
104
+ conv1_type = conv1_weight.dtype
105
+ conv1_weight = conv1_weight.float()
106
+ O, I, J, K = conv1_weight.shape
107
+ if I > 3:
108
+ assert conv1_weight.shape[1] % 3 == 0
109
+ # For models with space2depth stems
110
+ conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K)
111
+ conv1_weight = conv1_weight.sum(dim=2, keepdim=False)
112
+ else:
113
+ conv1_weight = conv1_weight.sum(dim=1, keepdim=True)
114
+ conv1_weight = conv1_weight.to(conv1_type)
115
+ state_dict[conv1_name + '.weight'] = conv1_weight
116
+ elif in_chans != 3:
117
+ conv1_name = cfg['first_conv']
118
+ conv1_weight = state_dict[conv1_name + '.weight']
119
+ conv1_type = conv1_weight.dtype
120
+ conv1_weight = conv1_weight.float()
121
+ O, I, J, K = conv1_weight.shape
122
+ if I != 3:
123
+ _logger.warning('Deleting first conv (%s) from pretrained weights.' % conv1_name)
124
+ del state_dict[conv1_name + '.weight']
125
+ strict = False
126
+ else:
127
+ # NOTE this strategy should be better than random init, but there could be other combinations of
128
+ # the original RGB input layer weights that'd work better for specific cases.
129
+ _logger.info('Repeating first conv (%s) weights in channel dim.' % conv1_name)
130
+ repeat = int(math.ceil(in_chans / 3))
131
+ conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
132
+ conv1_weight *= (3 / float(in_chans))
133
+ conv1_weight = conv1_weight.to(conv1_type)
134
+ state_dict[conv1_name + '.weight'] = conv1_weight
135
+
136
+ classifier_name = cfg['classifier']
137
+ if num_classes == 1000 and cfg['num_classes'] == 1001:
138
+ # special case for imagenet trained models with extra background class in pretrained weights
139
+ classifier_weight = state_dict[classifier_name + '.weight']
140
+ state_dict[classifier_name + '.weight'] = classifier_weight[1:]
141
+ classifier_bias = state_dict[classifier_name + '.bias']
142
+ state_dict[classifier_name + '.bias'] = classifier_bias[1:]
143
+ elif num_classes != cfg['num_classes']:
144
+ # completely discard fully connected for all other differences between pretrained and created model
145
+ del state_dict[classifier_name + '.weight']
146
+ del state_dict[classifier_name + '.bias']
147
+ strict = False
148
+
149
+ model.load_state_dict(state_dict, strict=strict)
150
+
151
+
152
+ def extract_layer(model, layer):
153
+ layer = layer.split('.')
154
+ module = model
155
+ if hasattr(model, 'module') and layer[0] != 'module':
156
+ module = model.module
157
+ if not hasattr(model, 'module') and layer[0] == 'module':
158
+ layer = layer[1:]
159
+ for l in layer:
160
+ if hasattr(module, l):
161
+ if not l.isdigit():
162
+ module = getattr(module, l)
163
+ else:
164
+ module = module[int(l)]
165
+ else:
166
+ return module
167
+ return module
168
+
169
+
170
+ def set_layer(model, layer, val):
171
+ layer = layer.split('.')
172
+ module = model
173
+ if hasattr(model, 'module') and layer[0] != 'module':
174
+ module = model.module
175
+ lst_index = 0
176
+ module2 = module
177
+ for l in layer:
178
+ if hasattr(module2, l):
179
+ if not l.isdigit():
180
+ module2 = getattr(module2, l)
181
+ else:
182
+ module2 = module2[int(l)]
183
+ lst_index += 1
184
+ lst_index -= 1
185
+ for l in layer[:lst_index]:
186
+ if not l.isdigit():
187
+ module = getattr(module, l)
188
+ else:
189
+ module = module[int(l)]
190
+ l = layer[lst_index]
191
+ setattr(module, l, val)
192
+
193
+
194
+ def adapt_model_from_string(parent_module, model_string):
195
+ separator = '***'
196
+ state_dict = {}
197
+ lst_shape = model_string.split(separator)
198
+ for k in lst_shape:
199
+ k = k.split(':')
200
+ key = k[0]
201
+ shape = k[1][1:-1].split(',')
202
+ if shape[0] != '':
203
+ state_dict[key] = [int(i) for i in shape]
204
+
205
+ new_module = deepcopy(parent_module)
206
+ for n, m in parent_module.named_modules():
207
+ old_module = extract_layer(parent_module, n)
208
+ if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame):
209
+ if isinstance(old_module, Conv2dSame):
210
+ conv = Conv2dSame
211
+ else:
212
+ conv = nn.Conv2d
213
+ s = state_dict[n + '.weight']
214
+ in_channels = s[1]
215
+ out_channels = s[0]
216
+ g = 1
217
+ if old_module.groups > 1:
218
+ in_channels = out_channels
219
+ g = in_channels
220
+ new_conv = conv(
221
+ in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size,
222
+ bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation,
223
+ groups=g, stride=old_module.stride)
224
+ set_layer(new_module, n, new_conv)
225
+ if isinstance(old_module, nn.BatchNorm2d):
226
+ new_bn = nn.BatchNorm2d(
227
+ num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
228
+ affine=old_module.affine, track_running_stats=True)
229
+ set_layer(new_module, n, new_bn)
230
+ if isinstance(old_module, nn.Linear):
231
+ # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
232
+ num_features = state_dict[n + '.weight'][1]
233
+ new_fc = nn.Linear(
234
+ in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None)
235
+ set_layer(new_module, n, new_fc)
236
+ if hasattr(new_module, 'num_features'):
237
+ new_module.num_features = num_features
238
+ new_module.eval()
239
+ parent_module.eval()
240
+
241
+ return new_module
242
+
243
+
244
+ def adapt_model_from_file(parent_module, model_variant):
245
+ adapt_file = os.path.join(os.path.dirname(__file__), 'pruned', model_variant + '.txt')
246
+ with open(adapt_file, 'r') as f:
247
+ return adapt_model_from_string(parent_module, f.read().strip())
248
+
249
+
250
+ def build_model_with_cfg(
251
+ model_cls: Callable,
252
+ variant: str,
253
+ pretrained: bool,
254
+ default_cfg: dict,
255
+ model_cfg: dict = None,
256
+ feature_cfg: dict = None,
257
+ pretrained_strict: bool = True,
258
+ pretrained_filter_fn: Callable = None,
259
+ **kwargs):
260
+ pruned = kwargs.pop('pruned', False)
261
+ features = False
262
+ feature_cfg = feature_cfg or {}
263
+
264
+ if kwargs.pop('features_only', False):
265
+ features = True
266
+ feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
267
+ if 'out_indices' in kwargs:
268
+ feature_cfg['out_indices'] = kwargs.pop('out_indices')
269
+
270
+ model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
271
+ model.default_cfg = deepcopy(default_cfg)
272
+
273
+ if pruned:
274
+ model = adapt_model_from_file(model, variant)
275
+
276
+ if pretrained:
277
+ load_pretrained(
278
+ model,
279
+ num_classes=kwargs.get('num_classes', 0),
280
+ in_chans=kwargs.get('in_chans', 3),
281
+ filter_fn=pretrained_filter_fn, strict=pretrained_strict)
282
+
283
+ if features:
284
+ feature_cls = FeatureListNet
285
+ if 'feature_cls' in feature_cfg:
286
+ feature_cls = feature_cfg.pop('feature_cls')
287
+ if isinstance(feature_cls, str):
288
+ feature_cls = feature_cls.lower()
289
+ if 'hook' in feature_cls:
290
+ feature_cls = FeatureHookNet
291
+ else:
292
+ assert False, f'Unknown feature class {feature_cls}'
293
+ model = feature_cls(model, **feature_cfg)
294
+
295
+ return model
ViT_DeiT/baselines/ViT/imagenet_seg_eval.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ from torch.utils.data import DataLoader
5
+ from numpy import *
6
+ import argparse
7
+ from PIL import Image
8
+ import imageio
9
+ import os
10
+ from tqdm import tqdm
11
+ from utils.metrices import *
12
+
13
+ from utils import render
14
+ from utils.saver import Saver
15
+ from utils.iou import IoU
16
+
17
+ from data.Imagenet import Imagenet_Segmentation
18
+
19
+ from ViT_explanation_generator import Baselines, LRP
20
+ from ViT_new import vit_base_patch16_224
21
+ from ViT_LRP import vit_base_patch16_224 as vit_LRP
22
+ from ViT_orig_LRP import vit_base_patch16_224 as vit_orig_LRP
23
+
24
+ from sklearn.metrics import precision_recall_curve
25
+ import matplotlib.pyplot as plt
26
+
27
+ import torch.nn.functional as F
28
+
29
+ plt.switch_backend('agg')
30
+
31
+
32
+ # hyperparameters
33
+ num_workers = 0
34
+ batch_size = 1
35
+
36
+ cls = ['airplane',
37
+ 'bicycle',
38
+ 'bird',
39
+ 'boat',
40
+ 'bottle',
41
+ 'bus',
42
+ 'car',
43
+ 'cat',
44
+ 'chair',
45
+ 'cow',
46
+ 'dining table',
47
+ 'dog',
48
+ 'horse',
49
+ 'motobike',
50
+ 'person',
51
+ 'potted plant',
52
+ 'sheep',
53
+ 'sofa',
54
+ 'train',
55
+ 'tv'
56
+ ]
57
+
58
+ # Args
59
+ parser = argparse.ArgumentParser(description='Training multi-class classifier')
60
+ parser.add_argument('--arc', type=str, default='vgg', metavar='N',
61
+ help='Model architecture')
62
+ parser.add_argument('--train_dataset', type=str, default='imagenet', metavar='N',
63
+ help='Testing Dataset')
64
+ parser.add_argument('--method', type=str,
65
+ default='grad_rollout',
66
+ choices=[ 'rollout', 'lrp','transformer_attribution', 'full_lrp', 'lrp_last_layer',
67
+ 'attn_last_layer', 'attn_gradcam'],
68
+ help='')
69
+ parser.add_argument('--thr', type=float, default=0.,
70
+ help='threshold')
71
+ parser.add_argument('--K', type=int, default=1,
72
+ help='new - top K results')
73
+ parser.add_argument('--save-img', action='store_true',
74
+ default=False,
75
+ help='')
76
+ parser.add_argument('--no-ia', action='store_true',
77
+ default=False,
78
+ help='')
79
+ parser.add_argument('--no-fx', action='store_true',
80
+ default=False,
81
+ help='')
82
+ parser.add_argument('--no-fgx', action='store_true',
83
+ default=False,
84
+ help='')
85
+ parser.add_argument('--no-m', action='store_true',
86
+ default=False,
87
+ help='')
88
+ parser.add_argument('--no-reg', action='store_true',
89
+ default=False,
90
+ help='')
91
+ parser.add_argument('--is-ablation', type=bool,
92
+ default=False,
93
+ help='')
94
+ parser.add_argument('--imagenet-seg-path', type=str, required=True)
95
+ args = parser.parse_args()
96
+
97
+ args.checkname = args.method + '_' + args.arc
98
+
99
+ alpha = 2
100
+
101
+ cuda = torch.cuda.is_available()
102
+ device = torch.device("cuda" if cuda else "cpu")
103
+
104
+ # Define Saver
105
+ saver = Saver(args)
106
+ saver.results_dir = os.path.join(saver.experiment_dir, 'results')
107
+ if not os.path.exists(saver.results_dir):
108
+ os.makedirs(saver.results_dir)
109
+ if not os.path.exists(os.path.join(saver.results_dir, 'input')):
110
+ os.makedirs(os.path.join(saver.results_dir, 'input'))
111
+ if not os.path.exists(os.path.join(saver.results_dir, 'explain')):
112
+ os.makedirs(os.path.join(saver.results_dir, 'explain'))
113
+
114
+ args.exp_img_path = os.path.join(saver.results_dir, 'explain/img')
115
+ if not os.path.exists(args.exp_img_path):
116
+ os.makedirs(args.exp_img_path)
117
+ args.exp_np_path = os.path.join(saver.results_dir, 'explain/np')
118
+ if not os.path.exists(args.exp_np_path):
119
+ os.makedirs(args.exp_np_path)
120
+
121
+ # Data
122
+ normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
123
+ test_img_trans = transforms.Compose([
124
+ transforms.Resize((224, 224)),
125
+ transforms.ToTensor(),
126
+ normalize,
127
+ ])
128
+ test_lbl_trans = transforms.Compose([
129
+ transforms.Resize((224, 224), Image.NEAREST),
130
+ ])
131
+
132
+ ds = Imagenet_Segmentation(args.imagenet_seg_path,
133
+ transform=test_img_trans, target_transform=test_lbl_trans)
134
+ dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=False)
135
+
136
+ # Model
137
+ model = vit_base_patch16_224(pretrained=True).cuda()
138
+ baselines = Baselines(model)
139
+
140
+ # LRP
141
+ model_LRP = vit_LRP(pretrained=True).cuda()
142
+ model_LRP.eval()
143
+ lrp = LRP(model_LRP)
144
+
145
+ # orig LRP
146
+ model_orig_LRP = vit_orig_LRP(pretrained=True).cuda()
147
+ model_orig_LRP.eval()
148
+ orig_lrp = LRP(model_orig_LRP)
149
+
150
+ metric = IoU(2, ignore_index=-1)
151
+
152
+ iterator = tqdm(dl)
153
+
154
+ model.eval()
155
+
156
+
157
+ def compute_pred(output):
158
+ pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
159
+ # pred[0, 0] = 282
160
+ # print('Pred cls : ' + str(pred))
161
+ T = pred.squeeze().cpu().numpy()
162
+ T = np.expand_dims(T, 0)
163
+ T = (T[:, np.newaxis] == np.arange(1000)) * 1.0
164
+ T = torch.from_numpy(T).type(torch.FloatTensor)
165
+ Tt = T.cuda()
166
+
167
+ return Tt
168
+
169
+
170
+ def eval_batch(image, labels, evaluator, index):
171
+ evaluator.zero_grad()
172
+ # Save input image
173
+ if args.save_img:
174
+ img = image[0].permute(1, 2, 0).data.cpu().numpy()
175
+ img = 255 * (img - img.min()) / (img.max() - img.min())
176
+ img = img.astype('uint8')
177
+ Image.fromarray(img, 'RGB').save(os.path.join(saver.results_dir, 'input/{}_input.png'.format(index)))
178
+ Image.fromarray((labels.repeat(3, 1, 1).permute(1, 2, 0).data.cpu().numpy() * 255).astype('uint8'), 'RGB').save(
179
+ os.path.join(saver.results_dir, 'input/{}_mask.png'.format(index)))
180
+
181
+ image.requires_grad = True
182
+
183
+ image = image.requires_grad_()
184
+ predictions = evaluator(image)
185
+
186
+ # segmentation test for the rollout baseline
187
+ if args.method == 'rollout':
188
+ Res = baselines.generate_rollout(image.cuda(), start_layer=1).reshape(batch_size, 1, 14, 14)
189
+
190
+ # segmentation test for the LRP baseline (this is full LRP, not partial)
191
+ elif args.method == 'full_lrp':
192
+ Res = orig_lrp.generate_LRP(image.cuda(), method="full").reshape(batch_size, 1, 224, 224)
193
+
194
+ # segmentation test for our method
195
+ elif args.method == 'transformer_attribution':
196
+ Res = lrp.generate_LRP(image.cuda(), start_layer=1, method="transformer_attribution").reshape(batch_size, 1, 14, 14)
197
+
198
+ # segmentation test for the partial LRP baseline (last attn layer)
199
+ elif args.method == 'lrp_last_layer':
200
+ Res = orig_lrp.generate_LRP(image.cuda(), method="last_layer", is_ablation=args.is_ablation)\
201
+ .reshape(batch_size, 1, 14, 14)
202
+
203
+ # segmentation test for the raw attention baseline (last attn layer)
204
+ elif args.method == 'attn_last_layer':
205
+ Res = orig_lrp.generate_LRP(image.cuda(), method="last_layer_attn", is_ablation=args.is_ablation)\
206
+ .reshape(batch_size, 1, 14, 14)
207
+
208
+ # segmentation test for the GradCam baseline (last attn layer)
209
+ elif args.method == 'attn_gradcam':
210
+ Res = baselines.generate_cam_attn(image.cuda()).reshape(batch_size, 1, 14, 14)
211
+
212
+ if args.method != 'full_lrp':
213
+ # interpolate to full image size (224,224)
214
+ Res = torch.nn.functional.interpolate(Res, scale_factor=16, mode='bilinear').cuda()
215
+
216
+ # threshold between FG and BG is the mean
217
+ Res = (Res - Res.min()) / (Res.max() - Res.min())
218
+
219
+ ret = Res.mean()
220
+
221
+ Res_1 = Res.gt(ret).type(Res.type())
222
+ Res_0 = Res.le(ret).type(Res.type())
223
+
224
+ Res_1_AP = Res
225
+ Res_0_AP = 1-Res
226
+
227
+ Res_1[Res_1 != Res_1] = 0
228
+ Res_0[Res_0 != Res_0] = 0
229
+ Res_1_AP[Res_1_AP != Res_1_AP] = 0
230
+ Res_0_AP[Res_0_AP != Res_0_AP] = 0
231
+
232
+
233
+ # TEST
234
+ pred = Res.clamp(min=args.thr) / Res.max()
235
+ pred = pred.view(-1).data.cpu().numpy()
236
+ target = labels.view(-1).data.cpu().numpy()
237
+ # print("target", target.shape)
238
+
239
+ output = torch.cat((Res_0, Res_1), 1)
240
+ output_AP = torch.cat((Res_0_AP, Res_1_AP), 1)
241
+
242
+ if args.save_img:
243
+ # Save predicted mask
244
+ mask = F.interpolate(Res_1, [64, 64], mode='bilinear')
245
+ mask = mask[0].squeeze().data.cpu().numpy()
246
+ # mask = Res_1[0].squeeze().data.cpu().numpy()
247
+ mask = 255 * mask
248
+ mask = mask.astype('uint8')
249
+ imageio.imsave(os.path.join(args.exp_img_path, 'mask_' + str(index) + '.jpg'), mask)
250
+
251
+ relevance = F.interpolate(Res, [64, 64], mode='bilinear')
252
+ relevance = relevance[0].permute(1, 2, 0).data.cpu().numpy()
253
+ # relevance = Res[0].permute(1, 2, 0).data.cpu().numpy()
254
+ hm = np.sum(relevance, axis=-1)
255
+ maps = (render.hm_to_rgb(hm, scaling=3, sigma=1, cmap='seismic') * 255).astype(np.uint8)
256
+ imageio.imsave(os.path.join(args.exp_img_path, 'heatmap_' + str(index) + '.jpg'), maps)
257
+
258
+ # Evaluate Segmentation
259
+ batch_inter, batch_union, batch_correct, batch_label = 0, 0, 0, 0
260
+ batch_ap, batch_f1 = 0, 0
261
+
262
+ # Segmentation resutls
263
+ correct, labeled = batch_pix_accuracy(output[0].data.cpu(), labels[0])
264
+ inter, union = batch_intersection_union(output[0].data.cpu(), labels[0], 2)
265
+ batch_correct += correct
266
+ batch_label += labeled
267
+ batch_inter += inter
268
+ batch_union += union
269
+ # print("output", output.shape)
270
+ # print("ap labels", labels.shape)
271
+ # ap = np.nan_to_num(get_ap_scores(output, labels))
272
+ ap = np.nan_to_num(get_ap_scores(output_AP, labels))
273
+ f1 = np.nan_to_num(get_f1_scores(output[0, 1].data.cpu(), labels[0]))
274
+ batch_ap += ap
275
+ batch_f1 += f1
276
+
277
+ return batch_correct, batch_label, batch_inter, batch_union, batch_ap, batch_f1, pred, target
278
+
279
+
280
+ total_inter, total_union, total_correct, total_label = np.int64(0), np.int64(0), np.int64(0), np.int64(0)
281
+ total_ap, total_f1 = [], []
282
+
283
+ predictions, targets = [], []
284
+ for batch_idx, (image, labels) in enumerate(iterator):
285
+
286
+ if args.method == "blur":
287
+ images = (image[0].cuda(), image[1].cuda())
288
+ else:
289
+ images = image.cuda()
290
+ labels = labels.cuda()
291
+ # print("image", image.shape)
292
+ # print("lables", labels.shape)
293
+
294
+ correct, labeled, inter, union, ap, f1, pred, target = eval_batch(images, labels, model, batch_idx)
295
+
296
+ predictions.append(pred)
297
+ targets.append(target)
298
+
299
+ total_correct += correct.astype('int64')
300
+ total_label += labeled.astype('int64')
301
+ total_inter += inter.astype('int64')
302
+ total_union += union.astype('int64')
303
+ total_ap += [ap]
304
+ total_f1 += [f1]
305
+ pixAcc = np.float64(1.0) * total_correct / (np.spacing(1, dtype=np.float64) + total_label)
306
+ IoU = np.float64(1.0) * total_inter / (np.spacing(1, dtype=np.float64) + total_union)
307
+ mIoU = IoU.mean()
308
+ mAp = np.mean(total_ap)
309
+ mF1 = np.mean(total_f1)
310
+ iterator.set_description('pixAcc: %.4f, mIoU: %.4f, mAP: %.4f, mF1: %.4f' % (pixAcc, mIoU, mAp, mF1))
311
+
312
+ predictions = np.concatenate(predictions)
313
+ targets = np.concatenate(targets)
314
+ pr, rc, thr = precision_recall_curve(targets, predictions)
315
+ np.save(os.path.join(saver.experiment_dir, 'precision.npy'), pr)
316
+ np.save(os.path.join(saver.experiment_dir, 'recall.npy'), rc)
317
+
318
+ plt.figure()
319
+ plt.plot(rc, pr)
320
+ plt.savefig(os.path.join(saver.experiment_dir, 'PR_curve_{}.png'.format(args.method)))
321
+
322
+ txtfile = os.path.join(saver.experiment_dir, 'result_mIoU_%.4f.txt' % mIoU)
323
+ # txtfile = 'result_mIoU_%.4f.txt' % mIoU
324
+ fh = open(txtfile, 'w')
325
+ print("Mean IoU over %d classes: %.4f\n" % (2, mIoU))
326
+ print("Pixel-wise Accuracy: %2.2f%%\n" % (pixAcc * 100))
327
+ print("Mean AP over %d classes: %.4f\n" % (2, mAp))
328
+ print("Mean F1 over %d classes: %.4f\n" % (2, mF1))
329
+
330
+ fh.write("Mean IoU over %d classes: %.4f\n" % (2, mIoU))
331
+ fh.write("Pixel-wise Accuracy: %2.2f%%\n" % (pixAcc * 100))
332
+ fh.write("Mean AP over %d classes: %.4f\n" % (2, mAp))
333
+ fh.write("Mean F1 over %d classes: %.4f\n" % (2, mF1))
334
+ fh.close()
ViT_DeiT/baselines/ViT/layer_helpers.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Layer/Module Helpers
2
+ Hacked together by / Copyright 2020 Ross Wightman
3
+ """
4
+ from itertools import repeat
5
+ import collections.abc
6
+
7
+
8
+ # From PyTorch internals
9
+ def _ntuple(n):
10
+ def parse(x):
11
+ if isinstance(x, collections.abc.Iterable):
12
+ return x
13
+ return tuple(repeat(x, n))
14
+ return parse
15
+
16
+
17
+ to_1tuple = _ntuple(1)
18
+ to_2tuple = _ntuple(2)
19
+ to_3tuple = _ntuple(3)
20
+ to_4tuple = _ntuple(4)
21
+ to_ntuple = _ntuple
ViT_DeiT/baselines/ViT/misc_functions.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2019 Idiap Research Institute, http://www.idiap.ch/
3
+ # Written by Suraj Srinivas <suraj.srinivas@idiap.ch>
4
+ #
5
+
6
+ """ Misc helper functions """
7
+
8
+ import cv2
9
+ import numpy as np
10
+ import subprocess
11
+
12
+ import torch
13
+ import torchvision.transforms as transforms
14
+
15
+
16
+ class NormalizeInverse(transforms.Normalize):
17
+ # Undo normalization on images
18
+
19
+ def __init__(self, mean, std):
20
+ mean = torch.as_tensor(mean)
21
+ std = torch.as_tensor(std)
22
+ std_inv = 1 / (std + 1e-7)
23
+ mean_inv = -mean * std_inv
24
+ super(NormalizeInverse, self).__init__(mean=mean_inv, std=std_inv)
25
+
26
+ def __call__(self, tensor):
27
+ return super(NormalizeInverse, self).__call__(tensor.clone())
28
+
29
+
30
+ def create_folder(folder_name):
31
+ try:
32
+ subprocess.call(['mkdir', '-p', folder_name])
33
+ except OSError:
34
+ None
35
+
36
+
37
+ def save_saliency_map(image, saliency_map, filename):
38
+ """
39
+ Save saliency map on image.
40
+
41
+ Args:
42
+ image: Tensor of size (3,H,W)
43
+ saliency_map: Tensor of size (1,H,W)
44
+ filename: string with complete path and file extension
45
+
46
+ """
47
+
48
+ image = image.data.cpu().numpy()
49
+ saliency_map = saliency_map.data.cpu().numpy()
50
+
51
+ saliency_map = saliency_map - saliency_map.min()
52
+ saliency_map = saliency_map / saliency_map.max()
53
+ saliency_map = saliency_map.clip(0, 1)
54
+
55
+ saliency_map = np.uint8(saliency_map * 255).transpose(1, 2, 0)
56
+ saliency_map = cv2.resize(saliency_map, (224, 224))
57
+
58
+ image = np.uint8(image * 255).transpose(1, 2, 0)
59
+ image = cv2.resize(image, (224, 224))
60
+
61
+ # Apply JET colormap
62
+ color_heatmap = cv2.applyColorMap(saliency_map, cv2.COLORMAP_JET)
63
+
64
+ # Combine image with heatmap
65
+ img_with_heatmap = np.float32(color_heatmap) + np.float32(image)
66
+ img_with_heatmap = img_with_heatmap / np.max(img_with_heatmap)
67
+
68
+ cv2.imwrite(filename, np.uint8(255 * img_with_heatmap))
ViT_DeiT/baselines/ViT/pertubation_eval_from_hdf5.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import os
4
+ from tqdm import tqdm
5
+ import numpy as np
6
+ import argparse
7
+
8
+ # Import saliency methods and models
9
+ from ViT_explanation_generator import Baselines
10
+ from ViT_new import vit_base_patch16_224
11
+ # from models.vgg import vgg19
12
+ import glob
13
+
14
+ from dataset.expl_hdf5 import ImagenetResults
15
+
16
+
17
+ def normalize(tensor,
18
+ mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
19
+ dtype = tensor.dtype
20
+ mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
21
+ std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
22
+ tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None])
23
+ return tensor
24
+
25
+
26
+ def eval(args):
27
+ num_samples = 0
28
+ num_correct_model = np.zeros((len(imagenet_ds,)))
29
+ dissimilarity_model = np.zeros((len(imagenet_ds,)))
30
+ model_index = 0
31
+
32
+ if args.scale == 'per':
33
+ base_size = 224 * 224
34
+ perturbation_steps = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
35
+ elif args.scale == '100':
36
+ base_size = 100
37
+ perturbation_steps = [5, 10, 15, 20, 25, 30, 35, 40, 45]
38
+ else:
39
+ raise Exception('scale not valid')
40
+
41
+ num_correct_pertub = np.zeros((9, len(imagenet_ds)))
42
+ dissimilarity_pertub = np.zeros((9, len(imagenet_ds)))
43
+ logit_diff_pertub = np.zeros((9, len(imagenet_ds)))
44
+ prob_diff_pertub = np.zeros((9, len(imagenet_ds)))
45
+ perturb_index = 0
46
+
47
+ for batch_idx, (data, vis, target) in enumerate(tqdm(sample_loader)):
48
+ # Update the number of samples
49
+ num_samples += len(data)
50
+
51
+ data = data.to(device)
52
+ vis = vis.to(device)
53
+ target = target.to(device)
54
+ norm_data = normalize(data.clone())
55
+
56
+ # Compute model accuracy
57
+ pred = model(norm_data)
58
+ pred_probabilities = torch.softmax(pred, dim=1)
59
+ pred_org_logit = pred.data.max(1, keepdim=True)[0].squeeze(1)
60
+ pred_org_prob = pred_probabilities.data.max(1, keepdim=True)[0].squeeze(1)
61
+ pred_class = pred.data.max(1, keepdim=True)[1].squeeze(1)
62
+ tgt_pred = (target == pred_class).type(target.type()).data.cpu().numpy()
63
+ num_correct_model[model_index:model_index+len(tgt_pred)] = tgt_pred
64
+
65
+ probs = torch.softmax(pred, dim=1)
66
+ target_probs = torch.gather(probs, 1, target[:, None])[:, 0]
67
+ second_probs = probs.data.topk(2, dim=1)[0][:, 1]
68
+ temp = torch.log(target_probs / second_probs).data.cpu().numpy()
69
+ dissimilarity_model[model_index:model_index+len(temp)] = temp
70
+
71
+ if args.wrong:
72
+ wid = np.argwhere(tgt_pred == 0).flatten()
73
+ if len(wid) == 0:
74
+ continue
75
+ wid = torch.from_numpy(wid).to(vis.device)
76
+ vis = vis.index_select(0, wid)
77
+ data = data.index_select(0, wid)
78
+ target = target.index_select(0, wid)
79
+
80
+ # Save original shape
81
+ org_shape = data.shape
82
+
83
+ if args.neg:
84
+ vis = -vis
85
+
86
+ vis = vis.reshape(org_shape[0], -1)
87
+
88
+ for i in range(len(perturbation_steps)):
89
+ _data = data.clone()
90
+
91
+ _, idx = torch.topk(vis, int(base_size * perturbation_steps[i]), dim=-1)
92
+ idx = idx.unsqueeze(1).repeat(1, org_shape[1], 1)
93
+ _data = _data.reshape(org_shape[0], org_shape[1], -1)
94
+ _data = _data.scatter_(-1, idx, 0)
95
+ _data = _data.reshape(*org_shape)
96
+
97
+ _norm_data = normalize(_data)
98
+
99
+ out = model(_norm_data)
100
+
101
+ pred_probabilities = torch.softmax(out, dim=1)
102
+ pred_prob = pred_probabilities.data.max(1, keepdim=True)[0].squeeze(1)
103
+ diff = (pred_prob - pred_org_prob).data.cpu().numpy()
104
+ prob_diff_pertub[i, perturb_index:perturb_index+len(diff)] = diff
105
+
106
+ pred_logit = out.data.max(1, keepdim=True)[0].squeeze(1)
107
+ diff = (pred_logit - pred_org_logit).data.cpu().numpy()
108
+ logit_diff_pertub[i, perturb_index:perturb_index+len(diff)] = diff
109
+
110
+ target_class = out.data.max(1, keepdim=True)[1].squeeze(1)
111
+ temp = (target == target_class).type(target.type()).data.cpu().numpy()
112
+ num_correct_pertub[i, perturb_index:perturb_index+len(temp)] = temp
113
+
114
+ probs_pertub = torch.softmax(out, dim=1)
115
+ target_probs = torch.gather(probs_pertub, 1, target[:, None])[:, 0]
116
+ second_probs = probs_pertub.data.topk(2, dim=1)[0][:, 1]
117
+ temp = torch.log(target_probs / second_probs).data.cpu().numpy()
118
+ dissimilarity_pertub[i, perturb_index:perturb_index+len(temp)] = temp
119
+
120
+ model_index += len(target)
121
+ perturb_index += len(target)
122
+
123
+ np.save(os.path.join(args.experiment_dir, 'model_hits.npy'), num_correct_model)
124
+ np.save(os.path.join(args.experiment_dir, 'model_dissimilarities.npy'), dissimilarity_model)
125
+ np.save(os.path.join(args.experiment_dir, 'perturbations_hits.npy'), num_correct_pertub[:, :perturb_index])
126
+ np.save(os.path.join(args.experiment_dir, 'perturbations_dissimilarities.npy'), dissimilarity_pertub[:, :perturb_index])
127
+ np.save(os.path.join(args.experiment_dir, 'perturbations_logit_diff.npy'), logit_diff_pertub[:, :perturb_index])
128
+ np.save(os.path.join(args.experiment_dir, 'perturbations_prob_diff.npy'), prob_diff_pertub[:, :perturb_index])
129
+
130
+ print(np.mean(num_correct_model), np.std(num_correct_model))
131
+ print(np.mean(dissimilarity_model), np.std(dissimilarity_model))
132
+ print(perturbation_steps)
133
+ print(np.mean(num_correct_pertub, axis=1), np.std(num_correct_pertub, axis=1))
134
+ print(np.mean(dissimilarity_pertub, axis=1), np.std(dissimilarity_pertub, axis=1))
135
+
136
+
137
+ if __name__ == "__main__":
138
+ parser = argparse.ArgumentParser(description='Train a segmentation')
139
+ parser.add_argument('--batch-size', type=int,
140
+ default=16,
141
+ help='')
142
+ parser.add_argument('--neg', type=bool,
143
+ default=True,
144
+ help='')
145
+ parser.add_argument('--value', action='store_true',
146
+ default=False,
147
+ help='')
148
+ parser.add_argument('--scale', type=str,
149
+ default='per',
150
+ choices=['per', '100'],
151
+ help='')
152
+ parser.add_argument('--method', type=str,
153
+ default='grad_rollout',
154
+ choices=['rollout', 'lrp', 'transformer_attribution', 'full_lrp', 'v_gradcam', 'lrp_last_layer',
155
+ 'lrp_second_layer', 'gradcam',
156
+ 'attn_last_layer', 'attn_gradcam', 'input_grads'],
157
+ help='')
158
+ parser.add_argument('--vis-class', type=str,
159
+ default='top',
160
+ choices=['top', 'target', 'index'],
161
+ help='')
162
+ parser.add_argument('--wrong', action='store_true',
163
+ default=False,
164
+ help='')
165
+ parser.add_argument('--class-id', type=int,
166
+ default=0,
167
+ help='')
168
+ parser.add_argument('--is-ablation', type=bool,
169
+ default=False,
170
+ help='')
171
+ args = parser.parse_args()
172
+
173
+ torch.multiprocessing.set_start_method('spawn')
174
+
175
+ # PATH variables
176
+ PATH = os.path.dirname(os.path.abspath(__file__)) + '/'
177
+ dataset = PATH + 'dataset/'
178
+ os.makedirs(os.path.join(PATH, 'experiments'), exist_ok=True)
179
+ os.makedirs(os.path.join(PATH, 'experiments/perturbations'), exist_ok=True)
180
+
181
+ exp_name = args.method
182
+ exp_name += '_neg' if args.neg else '_pos'
183
+ print(exp_name)
184
+
185
+ if args.vis_class == 'index':
186
+ args.runs_dir = os.path.join(PATH, 'experiments/perturbations/{}/{}_{}'.format(exp_name,
187
+ args.vis_class,
188
+ args.class_id))
189
+ else:
190
+ ablation_fold = 'ablation' if args.is_ablation else 'not_ablation'
191
+ args.runs_dir = os.path.join(PATH, 'experiments/perturbations/{}/{}/{}'.format(exp_name,
192
+ args.vis_class, ablation_fold))
193
+ # args.runs_dir = os.path.join(PATH, 'experiments/perturbations/{}/{}'.format(exp_name,
194
+ # args.vis_class))
195
+
196
+ if args.wrong:
197
+ args.runs_dir += '_wrong'
198
+
199
+ experiments = sorted(glob.glob(os.path.join(args.runs_dir, 'experiment_*')))
200
+ experiment_id = int(experiments[-1].split('_')[-1]) + 1 if experiments else 0
201
+ args.experiment_dir = os.path.join(args.runs_dir, 'experiment_{}'.format(str(experiment_id)))
202
+ os.makedirs(args.experiment_dir, exist_ok=True)
203
+
204
+ cuda = torch.cuda.is_available()
205
+ device = torch.device("cuda" if cuda else "cpu")
206
+
207
+ if args.vis_class == 'index':
208
+ vis_method_dir = os.path.join(PATH,'visualizations/{}/{}_{}'.format(args.method,
209
+ args.vis_class,
210
+ args.class_id))
211
+ else:
212
+ ablation_fold = 'ablation' if args.is_ablation else 'not_ablation'
213
+ vis_method_dir = os.path.join(PATH,'visualizations/{}/{}/{}'.format(args.method,
214
+ args.vis_class, ablation_fold))
215
+ # vis_method_dir = os.path.join(PATH, 'visualizations/{}/{}'.format(args.method,
216
+ # args.vis_class))
217
+
218
+ # imagenet_ds = ImagenetResults('visualizations/{}'.format(args.method))
219
+ imagenet_ds = ImagenetResults(vis_method_dir)
220
+
221
+ # Model
222
+ model = vit_base_patch16_224(pretrained=True).cuda()
223
+ model.eval()
224
+
225
+ save_path = PATH + 'results/'
226
+
227
+ sample_loader = torch.utils.data.DataLoader(
228
+ imagenet_ds,
229
+ batch_size=args.batch_size,
230
+ num_workers=2,
231
+ shuffle=False)
232
+
233
+ eval(args)
ViT_DeiT/baselines/ViT/weight_init.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import warnings
4
+
5
+
6
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
7
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
8
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
9
+ def norm_cdf(x):
10
+ # Computes standard normal cumulative distribution function
11
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
12
+
13
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
14
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
15
+ "The distribution of values may be incorrect.",
16
+ stacklevel=2)
17
+
18
+ with torch.no_grad():
19
+ # Values are generated by using a truncated uniform distribution and
20
+ # then using the inverse CDF for the normal distribution.
21
+ # Get upper and lower cdf values
22
+ l = norm_cdf((a - mean) / std)
23
+ u = norm_cdf((b - mean) / std)
24
+
25
+ # Uniformly fill tensor with values from [l, u], then translate to
26
+ # [2l-1, 2u-1].
27
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
28
+
29
+ # Use inverse cdf transform for normal distribution to get truncated
30
+ # standard normal
31
+ tensor.erfinv_()
32
+
33
+ # Transform to proper mean, std
34
+ tensor.mul_(std * math.sqrt(2.))
35
+ tensor.add_(mean)
36
+
37
+ # Clamp to ensure it's in the proper range
38
+ tensor.clamp_(min=a, max=b)
39
+ return tensor
40
+
41
+
42
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
43
+ # type: (Tensor, float, float, float, float) -> Tensor
44
+ r"""Fills the input Tensor with values drawn from a truncated
45
+ normal distribution. The values are effectively drawn from the
46
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
47
+ with values outside :math:`[a, b]` redrawn until they are within
48
+ the bounds. The method used for generating the random values works
49
+ best when :math:`a \leq \text{mean} \leq b`.
50
+ Args:
51
+ tensor: an n-dimensional `torch.Tensor`
52
+ mean: the mean of the normal distribution
53
+ std: the standard deviation of the normal distribution
54
+ a: the minimum cutoff value
55
+ b: the maximum cutoff value
56
+ Examples:
57
+ >>> w = torch.empty(3, 5)
58
+ >>> nn.init.trunc_normal_(w)
59
+ """
60
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
ViT_DeiT/data/VOC.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tarfile
3
+ import torch
4
+ import torch.utils.data as data
5
+ import numpy as np
6
+ import h5py
7
+
8
+ from PIL import Image
9
+ from scipy import io
10
+ from torchvision.datasets.utils import download_url
11
+
12
+ DATASET_YEAR_DICT = {
13
+ '2012': {
14
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar',
15
+ 'filename': 'VOCtrainval_11-May-2012.tar',
16
+ 'md5': '6cd6e144f989b92b3379bac3b3de84fd',
17
+ 'base_dir': 'VOCdevkit/VOC2012'
18
+ },
19
+ '2011': {
20
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar',
21
+ 'filename': 'VOCtrainval_25-May-2011.tar',
22
+ 'md5': '6c3384ef61512963050cb5d687e5bf1e',
23
+ 'base_dir': 'TrainVal/VOCdevkit/VOC2011'
24
+ },
25
+ '2010': {
26
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar',
27
+ 'filename': 'VOCtrainval_03-May-2010.tar',
28
+ 'md5': 'da459979d0c395079b5c75ee67908abb',
29
+ 'base_dir': 'VOCdevkit/VOC2010'
30
+ },
31
+ '2009': {
32
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar',
33
+ 'filename': 'VOCtrainval_11-May-2009.tar',
34
+ 'md5': '59065e4b188729180974ef6572f6a212',
35
+ 'base_dir': 'VOCdevkit/VOC2009'
36
+ },
37
+ '2008': {
38
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar',
39
+ 'filename': 'VOCtrainval_11-May-2012.tar',
40
+ 'md5': '2629fa636546599198acfcfbfcf1904a',
41
+ 'base_dir': 'VOCdevkit/VOC2008'
42
+ },
43
+ '2007': {
44
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar',
45
+ 'filename': 'VOCtrainval_06-Nov-2007.tar',
46
+ 'md5': 'c52e279531787c972589f7e41ab4ae64',
47
+ 'base_dir': 'VOCdevkit/VOC2007'
48
+ }
49
+ }
50
+
51
+
52
+ class VOCSegmentation(data.Dataset):
53
+ """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
54
+
55
+ Args:
56
+ root (string): Root directory of the VOC Dataset.
57
+ year (string, optional): The dataset year, supports years 2007 to 2012.
58
+ image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
59
+ download (bool, optional): If true, downloads the dataset from the internet and
60
+ puts it in root directory. If dataset is already downloaded, it is not
61
+ downloaded again.
62
+ transform (callable, optional): A function/transform that takes in an PIL image
63
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
64
+ target_transform (callable, optional): A function/transform that takes in the
65
+ target and transforms it.
66
+ """
67
+
68
+ CLASSES = 20
69
+ CLASSES_NAMES = [
70
+ 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
71
+ 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
72
+ 'motorbike', 'person', 'potted-plant', 'sheep', 'sofa', 'train',
73
+ 'tvmonitor', 'ambigious'
74
+ ]
75
+
76
+ def __init__(self,
77
+ root,
78
+ year='2012',
79
+ image_set='train',
80
+ download=False,
81
+ transform=None,
82
+ target_transform=None):
83
+ self.root = os.path.expanduser(root)
84
+ self.year = year
85
+ self.url = DATASET_YEAR_DICT[year]['url']
86
+ self.filename = DATASET_YEAR_DICT[year]['filename']
87
+ self.md5 = DATASET_YEAR_DICT[year]['md5']
88
+ self.transform = transform
89
+ self.target_transform = target_transform
90
+ self.image_set = image_set
91
+ base_dir = DATASET_YEAR_DICT[year]['base_dir']
92
+ voc_root = os.path.join(self.root, base_dir)
93
+ image_dir = os.path.join(voc_root, 'JPEGImages')
94
+ mask_dir = os.path.join(voc_root, 'SegmentationClass')
95
+
96
+ if download:
97
+ download_extract(self.url, self.root, self.filename, self.md5)
98
+
99
+ if not os.path.isdir(voc_root):
100
+ raise RuntimeError('Dataset not found or corrupted.' +
101
+ ' You can use download=True to download it')
102
+
103
+ splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation')
104
+
105
+ split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
106
+
107
+ if not os.path.exists(split_f):
108
+ raise ValueError(
109
+ 'Wrong image_set entered! Please use image_set="train" '
110
+ 'or image_set="trainval" or image_set="val"')
111
+
112
+ with open(os.path.join(split_f), "r") as f:
113
+ file_names = [x.strip() for x in f.readlines()]
114
+
115
+ self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
116
+ self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
117
+ assert (len(self.images) == len(self.masks))
118
+
119
+ def __getitem__(self, index):
120
+ """
121
+ Args:
122
+ index (int): Index
123
+
124
+ Returns:
125
+ tuple: (image, target) where target is the image segmentation.
126
+ """
127
+ img = Image.open(self.images[index]).convert('RGB')
128
+ target = Image.open(self.masks[index])
129
+
130
+ if self.transform is not None:
131
+ img = self.transform(img)
132
+
133
+ if self.target_transform is not None:
134
+ target = np.array(self.target_transform(target)).astype('int32')
135
+ target[target == 255] = -1
136
+ target = torch.from_numpy(target).long()
137
+
138
+ return img, target
139
+
140
+ @staticmethod
141
+ def _mask_transform(mask):
142
+ target = np.array(mask).astype('int32')
143
+ target[target == 255] = -1
144
+ return torch.from_numpy(target).long()
145
+
146
+ def __len__(self):
147
+ return len(self.images)
148
+
149
+ @property
150
+ def pred_offset(self):
151
+ return 0
152
+
153
+
154
+ class VOCClassification(data.Dataset):
155
+ """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
156
+
157
+ Args:
158
+ root (string): Root directory of the VOC Dataset.
159
+ year (string, optional): The dataset year, supports years 2007 to 2012.
160
+ image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
161
+ download (bool, optional): If true, downloads the dataset from the internet and
162
+ puts it in root directory. If dataset is already downloaded, it is not
163
+ downloaded again.
164
+ transform (callable, optional): A function/transform that takes in an PIL image
165
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
166
+ """
167
+ CLASSES = 20
168
+
169
+ def __init__(self,
170
+ root,
171
+ year='2012',
172
+ image_set='train',
173
+ download=False,
174
+ transform=None):
175
+ self.root = os.path.expanduser(root)
176
+ self.year = year
177
+ self.url = DATASET_YEAR_DICT[year]['url']
178
+ self.filename = DATASET_YEAR_DICT[year]['filename']
179
+ self.md5 = DATASET_YEAR_DICT[year]['md5']
180
+ self.transform = transform
181
+ self.image_set = image_set
182
+ base_dir = DATASET_YEAR_DICT[year]['base_dir']
183
+ voc_root = os.path.join(self.root, base_dir)
184
+ image_dir = os.path.join(voc_root, 'JPEGImages')
185
+ mask_dir = os.path.join(voc_root, 'SegmentationClass')
186
+
187
+ if download:
188
+ download_extract(self.url, self.root, self.filename, self.md5)
189
+
190
+ if not os.path.isdir(voc_root):
191
+ raise RuntimeError('Dataset not found or corrupted.' +
192
+ ' You can use download=True to download it')
193
+
194
+ splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation')
195
+
196
+ split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
197
+
198
+ if not os.path.exists(split_f):
199
+ raise ValueError(
200
+ 'Wrong image_set entered! Please use image_set="train" '
201
+ 'or image_set="trainval" or image_set="val"')
202
+
203
+ with open(os.path.join(split_f), "r") as f:
204
+ file_names = [x.strip() for x in f.readlines()]
205
+
206
+ self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
207
+ self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
208
+ assert (len(self.images) == len(self.masks))
209
+
210
+ def __getitem__(self, index):
211
+ """
212
+ Args:
213
+ index (int): Index
214
+
215
+ Returns:
216
+ tuple: (image, target) where target is the image segmentation.
217
+ """
218
+ img = Image.open(self.images[index]).convert('RGB')
219
+ target = Image.open(self.masks[index])
220
+
221
+ # if self.transform is not None:
222
+ # img = self.transform(img)
223
+ if self.transform is not None:
224
+ img, target = self.transform(img, target)
225
+
226
+ visible_classes = np.unique(target)
227
+ labels = torch.zeros(self.CLASSES)
228
+ for id in visible_classes:
229
+ if id not in (0, 255):
230
+ labels[id - 1].fill_(1)
231
+
232
+ return img, labels
233
+
234
+ def __len__(self):
235
+ return len(self.images)
236
+
237
+
238
+ class VOCSBDClassification(data.Dataset):
239
+ """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
240
+
241
+ Args:
242
+ root (string): Root directory of the VOC Dataset.
243
+ year (string, optional): The dataset year, supports years 2007 to 2012.
244
+ image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
245
+ download (bool, optional): If true, downloads the dataset from the internet and
246
+ puts it in root directory. If dataset is already downloaded, it is not
247
+ downloaded again.
248
+ transform (callable, optional): A function/transform that takes in an PIL image
249
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
250
+ """
251
+ CLASSES = 20
252
+
253
+ def __init__(self,
254
+ root,
255
+ sbd_root,
256
+ year='2012',
257
+ image_set='train',
258
+ download=False,
259
+ transform=None):
260
+ self.root = os.path.expanduser(root)
261
+ self.sbd_root = os.path.expanduser(sbd_root)
262
+ self.year = year
263
+ self.url = DATASET_YEAR_DICT[year]['url']
264
+ self.filename = DATASET_YEAR_DICT[year]['filename']
265
+ self.md5 = DATASET_YEAR_DICT[year]['md5']
266
+ self.transform = transform
267
+ self.image_set = image_set
268
+ base_dir = DATASET_YEAR_DICT[year]['base_dir']
269
+ voc_root = os.path.join(self.root, base_dir)
270
+ image_dir = os.path.join(voc_root, 'JPEGImages')
271
+ mask_dir = os.path.join(voc_root, 'SegmentationClass')
272
+ sbd_image_dir = os.path.join(sbd_root, 'img')
273
+ sbd_mask_dir = os.path.join(sbd_root, 'cls')
274
+
275
+ if download:
276
+ download_extract(self.url, self.root, self.filename, self.md5)
277
+
278
+ if not os.path.isdir(voc_root):
279
+ raise RuntimeError('Dataset not found or corrupted.' +
280
+ ' You can use download=True to download it')
281
+
282
+ splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation')
283
+
284
+ split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
285
+ sbd_split = os.path.join(sbd_root, 'train.txt')
286
+
287
+ if not os.path.exists(split_f):
288
+ raise ValueError(
289
+ 'Wrong image_set entered! Please use image_set="train" '
290
+ 'or image_set="trainval" or image_set="val"')
291
+
292
+ with open(os.path.join(split_f), "r") as f:
293
+ voc_file_names = [x.strip() for x in f.readlines()]
294
+
295
+ with open(os.path.join(sbd_split), "r") as f:
296
+ sbd_file_names = [x.strip() for x in f.readlines()]
297
+
298
+ self.images = [os.path.join(image_dir, x + ".jpg") for x in voc_file_names]
299
+ self.images += [os.path.join(sbd_image_dir, x + ".jpg") for x in sbd_file_names]
300
+ self.masks = [os.path.join(mask_dir, x + ".png") for x in voc_file_names]
301
+ self.masks += [os.path.join(sbd_mask_dir, x + ".mat") for x in sbd_file_names]
302
+ assert (len(self.images) == len(self.masks))
303
+
304
+ def __getitem__(self, index):
305
+ """
306
+ Args:
307
+ index (int): Index
308
+
309
+ Returns:
310
+ tuple: (image, target) where target is the image segmentation.
311
+ """
312
+ img = Image.open(self.images[index]).convert('RGB')
313
+ mask_path = self.masks[index]
314
+ if mask_path[-3:] == 'mat':
315
+ target = io.loadmat(mask_path, struct_as_record=False, squeeze_me=True)['GTcls'].Segmentation
316
+ target = Image.fromarray(target, mode='P')
317
+ else:
318
+ target = Image.open(self.masks[index])
319
+
320
+ if self.transform is not None:
321
+ img, target = self.transform(img, target)
322
+
323
+ visible_classes = np.unique(target)
324
+ labels = torch.zeros(self.CLASSES)
325
+ for id in visible_classes:
326
+ if id not in (0, 255):
327
+ labels[id - 1].fill_(1)
328
+
329
+ return img, labels
330
+
331
+ def __len__(self):
332
+ return len(self.images)
333
+
334
+
335
+ def download_extract(url, root, filename, md5):
336
+ download_url(url, root, filename, md5)
337
+ with tarfile.open(os.path.join(root, filename), "r") as tar:
338
+ tar.extractall(path=root)
339
+
340
+
341
+ class VOCResults(data.Dataset):
342
+ CLASSES = 20
343
+ CLASSES_NAMES = [
344
+ 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
345
+ 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
346
+ 'motorbike', 'person', 'potted-plant', 'sheep', 'sofa', 'train',
347
+ 'tvmonitor', 'ambigious'
348
+ ]
349
+
350
+ def __init__(self, path):
351
+ super(VOCResults, self).__init__()
352
+
353
+ self.path = os.path.join(path, 'results.hdf5')
354
+ self.data = None
355
+
356
+ print('Reading dataset length...')
357
+ with h5py.File(self.path , 'r') as f:
358
+ self.data_length = len(f['/image'])
359
+
360
+ def __len__(self):
361
+ return self.data_length
362
+
363
+ def __getitem__(self, item):
364
+ if self.data is None:
365
+ self.data = h5py.File(self.path, 'r')
366
+
367
+ image = torch.tensor(self.data['image'][item])
368
+ vis = torch.tensor(self.data['vis'][item])
369
+ target = torch.tensor(self.data['target'][item])
370
+ class_pred = torch.tensor(self.data['class_pred'][item])
371
+
372
+ return image, vis, target, class_pred
ViT_DeiT/data/__init__.py ADDED
File without changes
ViT_DeiT/data/imagenet.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.utils.data as data
4
+ import numpy as np
5
+
6
+ from PIL import Image
7
+ import h5py
8
+
9
+ __all__ = ['ImagenetResults']
10
+
11
+
12
+ class Imagenet_Segmentation(data.Dataset):
13
+ CLASSES = 2
14
+
15
+ def __init__(self,
16
+ path,
17
+ transform=None,
18
+ target_transform=None):
19
+ self.path = path
20
+ self.transform = transform
21
+ self.target_transform = target_transform
22
+ self.h5py = None
23
+ tmp = h5py.File(path, 'r')
24
+ self.data_length = len(tmp['/value/img'])
25
+ tmp.close()
26
+ del tmp
27
+
28
+ def __getitem__(self, index):
29
+
30
+ if self.h5py is None:
31
+ self.h5py = h5py.File(self.path, 'r')
32
+
33
+ img = np.array(self.h5py[self.h5py['/value/img'][index, 0]]).transpose((2, 1, 0))
34
+ target = np.array(self.h5py[self.h5py[self.h5py['/value/gt'][index, 0]][0, 0]]).transpose((1, 0))
35
+
36
+ img = Image.fromarray(img).convert('RGB')
37
+ target = Image.fromarray(target)
38
+
39
+ if self.transform is not None:
40
+ img = self.transform(img)
41
+
42
+ if self.target_transform is not None:
43
+ target = np.array(self.target_transform(target)).astype('int32')
44
+ target = torch.from_numpy(target).long()
45
+
46
+ return img, target
47
+
48
+ def __len__(self):
49
+ return self.data_length
50
+
51
+
52
+ class ImagenetResults(data.Dataset):
53
+ def __init__(self, path):
54
+ super(ImagenetResults, self).__init__()
55
+
56
+ self.path = os.path.join(path, 'results.hdf5')
57
+ self.data = None
58
+
59
+ print('Reading dataset length...')
60
+ with h5py.File(self.path, 'r') as f:
61
+ self.data_length = len(f['/image'])
62
+
63
+ def __len__(self):
64
+ return self.data_length
65
+
66
+ def __getitem__(self, item):
67
+ if self.data is None:
68
+ self.data = h5py.File(self.path, 'r')
69
+
70
+ image = torch.tensor(self.data['image'][item])
71
+ vis = torch.tensor(self.data['vis'][item])
72
+ target = torch.tensor(self.data['target'][item]).long()
73
+
74
+ return image, vis, target
ViT_DeiT/data/imagenet_utils.py ADDED
@@ -0,0 +1,1002 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CLS2IDX = {
2
+ 0: 'tench, Tinca tinca',
3
+ 1: 'goldfish, Carassius auratus',
4
+ 2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias',
5
+ 3: 'tiger shark, Galeocerdo cuvieri',
6
+ 4: 'hammerhead, hammerhead shark',
7
+ 5: 'electric ray, crampfish, numbfish, torpedo',
8
+ 6: 'stingray',
9
+ 7: 'cock',
10
+ 8: 'hen',
11
+ 9: 'ostrich, Struthio camelus',
12
+ 10: 'brambling, Fringilla montifringilla',
13
+ 11: 'goldfinch, Carduelis carduelis',
14
+ 12: 'house finch, linnet, Carpodacus mexicanus',
15
+ 13: 'junco, snowbird',
16
+ 14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea',
17
+ 15: 'robin, American robin, Turdus migratorius',
18
+ 16: 'bulbul',
19
+ 17: 'jay',
20
+ 18: 'magpie',
21
+ 19: 'chickadee',
22
+ 20: 'water ouzel, dipper',
23
+ 21: 'kite',
24
+ 22: 'bald eagle, American eagle, Haliaeetus leucocephalus',
25
+ 23: 'vulture',
26
+ 24: 'great grey owl, great gray owl, Strix nebulosa',
27
+ 25: 'European fire salamander, Salamandra salamandra',
28
+ 26: 'common newt, Triturus vulgaris',
29
+ 27: 'eft',
30
+ 28: 'spotted salamander, Ambystoma maculatum',
31
+ 29: 'axolotl, mud puppy, Ambystoma mexicanum',
32
+ 30: 'bullfrog, Rana catesbeiana',
33
+ 31: 'tree frog, tree-frog',
34
+ 32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui',
35
+ 33: 'loggerhead, loggerhead turtle, Caretta caretta',
36
+ 34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea',
37
+ 35: 'mud turtle',
38
+ 36: 'terrapin',
39
+ 37: 'box turtle, box tortoise',
40
+ 38: 'banded gecko',
41
+ 39: 'common iguana, iguana, Iguana iguana',
42
+ 40: 'American chameleon, anole, Anolis carolinensis',
43
+ 41: 'whiptail, whiptail lizard',
44
+ 42: 'agama',
45
+ 43: 'frilled lizard, Chlamydosaurus kingi',
46
+ 44: 'alligator lizard',
47
+ 45: 'Gila monster, Heloderma suspectum',
48
+ 46: 'green lizard, Lacerta viridis',
49
+ 47: 'African chameleon, Chamaeleo chamaeleon',
50
+ 48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis',
51
+ 49: 'African crocodile, Nile crocodile, Crocodylus niloticus',
52
+ 50: 'American alligator, Alligator mississipiensis',
53
+ 51: 'triceratops',
54
+ 52: 'thunder snake, worm snake, Carphophis amoenus',
55
+ 53: 'ringneck snake, ring-necked snake, ring snake',
56
+ 54: 'hognose snake, puff adder, sand viper',
57
+ 55: 'green snake, grass snake',
58
+ 56: 'king snake, kingsnake',
59
+ 57: 'garter snake, grass snake',
60
+ 58: 'water snake',
61
+ 59: 'vine snake',
62
+ 60: 'night snake, Hypsiglena torquata',
63
+ 61: 'boa constrictor, Constrictor constrictor',
64
+ 62: 'rock python, rock snake, Python sebae',
65
+ 63: 'Indian cobra, Naja naja',
66
+ 64: 'green mamba',
67
+ 65: 'sea snake',
68
+ 66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus',
69
+ 67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus',
70
+ 68: 'sidewinder, horned rattlesnake, Crotalus cerastes',
71
+ 69: 'trilobite',
72
+ 70: 'harvestman, daddy longlegs, Phalangium opilio',
73
+ 71: 'scorpion',
74
+ 72: 'black and gold garden spider, Argiope aurantia',
75
+ 73: 'barn spider, Araneus cavaticus',
76
+ 74: 'garden spider, Aranea diademata',
77
+ 75: 'black widow, Latrodectus mactans',
78
+ 76: 'tarantula',
79
+ 77: 'wolf spider, hunting spider',
80
+ 78: 'tick',
81
+ 79: 'centipede',
82
+ 80: 'black grouse',
83
+ 81: 'ptarmigan',
84
+ 82: 'ruffed grouse, partridge, Bonasa umbellus',
85
+ 83: 'prairie chicken, prairie grouse, prairie fowl',
86
+ 84: 'peacock',
87
+ 85: 'quail',
88
+ 86: 'partridge',
89
+ 87: 'African grey, African gray, Psittacus erithacus',
90
+ 88: 'macaw',
91
+ 89: 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita',
92
+ 90: 'lorikeet',
93
+ 91: 'coucal',
94
+ 92: 'bee eater',
95
+ 93: 'hornbill',
96
+ 94: 'hummingbird',
97
+ 95: 'jacamar',
98
+ 96: 'toucan',
99
+ 97: 'drake',
100
+ 98: 'red-breasted merganser, Mergus serrator',
101
+ 99: 'goose',
102
+ 100: 'black swan, Cygnus atratus',
103
+ 101: 'tusker',
104
+ 102: 'echidna, spiny anteater, anteater',
105
+ 103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus',
106
+ 104: 'wallaby, brush kangaroo',
107
+ 105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus',
108
+ 106: 'wombat',
109
+ 107: 'jellyfish',
110
+ 108: 'sea anemone, anemone',
111
+ 109: 'brain coral',
112
+ 110: 'flatworm, platyhelminth',
113
+ 111: 'nematode, nematode worm, roundworm',
114
+ 112: 'conch',
115
+ 113: 'snail',
116
+ 114: 'slug',
117
+ 115: 'sea slug, nudibranch',
118
+ 116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore',
119
+ 117: 'chambered nautilus, pearly nautilus, nautilus',
120
+ 118: 'Dungeness crab, Cancer magister',
121
+ 119: 'rock crab, Cancer irroratus',
122
+ 120: 'fiddler crab',
123
+ 121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica',
124
+ 122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus',
125
+ 123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish',
126
+ 124: 'crayfish, crawfish, crawdad, crawdaddy',
127
+ 125: 'hermit crab',
128
+ 126: 'isopod',
129
+ 127: 'white stork, Ciconia ciconia',
130
+ 128: 'black stork, Ciconia nigra',
131
+ 129: 'spoonbill',
132
+ 130: 'flamingo',
133
+ 131: 'little blue heron, Egretta caerulea',
134
+ 132: 'American egret, great white heron, Egretta albus',
135
+ 133: 'bittern',
136
+ 134: 'crane',
137
+ 135: 'limpkin, Aramus pictus',
138
+ 136: 'European gallinule, Porphyrio porphyrio',
139
+ 137: 'American coot, marsh hen, mud hen, water hen, Fulica americana',
140
+ 138: 'bustard',
141
+ 139: 'ruddy turnstone, Arenaria interpres',
142
+ 140: 'red-backed sandpiper, dunlin, Erolia alpina',
143
+ 141: 'redshank, Tringa totanus',
144
+ 142: 'dowitcher',
145
+ 143: 'oystercatcher, oyster catcher',
146
+ 144: 'pelican',
147
+ 145: 'king penguin, Aptenodytes patagonica',
148
+ 146: 'albatross, mollymawk',
149
+ 147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus',
150
+ 148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca',
151
+ 149: 'dugong, Dugong dugon',
152
+ 150: 'sea lion',
153
+ 151: 'Chihuahua',
154
+ 152: 'Japanese spaniel',
155
+ 153: 'Maltese dog, Maltese terrier, Maltese',
156
+ 154: 'Pekinese, Pekingese, Peke',
157
+ 155: 'Shih-Tzu',
158
+ 156: 'Blenheim spaniel',
159
+ 157: 'papillon',
160
+ 158: 'toy terrier',
161
+ 159: 'Rhodesian ridgeback',
162
+ 160: 'Afghan hound, Afghan',
163
+ 161: 'basset, basset hound',
164
+ 162: 'beagle',
165
+ 163: 'bloodhound, sleuthhound',
166
+ 164: 'bluetick',
167
+ 165: 'black-and-tan coonhound',
168
+ 166: 'Walker hound, Walker foxhound',
169
+ 167: 'English foxhound',
170
+ 168: 'redbone',
171
+ 169: 'borzoi, Russian wolfhound',
172
+ 170: 'Irish wolfhound',
173
+ 171: 'Italian greyhound',
174
+ 172: 'whippet',
175
+ 173: 'Ibizan hound, Ibizan Podenco',
176
+ 174: 'Norwegian elkhound, elkhound',
177
+ 175: 'otterhound, otter hound',
178
+ 176: 'Saluki, gazelle hound',
179
+ 177: 'Scottish deerhound, deerhound',
180
+ 178: 'Weimaraner',
181
+ 179: 'Staffordshire bullterrier, Staffordshire bull terrier',
182
+ 180: 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier',
183
+ 181: 'Bedlington terrier',
184
+ 182: 'Border terrier',
185
+ 183: 'Kerry blue terrier',
186
+ 184: 'Irish terrier',
187
+ 185: 'Norfolk terrier',
188
+ 186: 'Norwich terrier',
189
+ 187: 'Yorkshire terrier',
190
+ 188: 'wire-haired fox terrier',
191
+ 189: 'Lakeland terrier',
192
+ 190: 'Sealyham terrier, Sealyham',
193
+ 191: 'Airedale, Airedale terrier',
194
+ 192: 'cairn, cairn terrier',
195
+ 193: 'Australian terrier',
196
+ 194: 'Dandie Dinmont, Dandie Dinmont terrier',
197
+ 195: 'Boston bull, Boston terrier',
198
+ 196: 'miniature schnauzer',
199
+ 197: 'giant schnauzer',
200
+ 198: 'standard schnauzer',
201
+ 199: 'Scotch terrier, Scottish terrier, Scottie',
202
+ 200: 'Tibetan terrier, chrysanthemum dog',
203
+ 201: 'silky terrier, Sydney silky',
204
+ 202: 'soft-coated wheaten terrier',
205
+ 203: 'West Highland white terrier',
206
+ 204: 'Lhasa, Lhasa apso',
207
+ 205: 'flat-coated retriever',
208
+ 206: 'curly-coated retriever',
209
+ 207: 'golden retriever',
210
+ 208: 'Labrador retriever',
211
+ 209: 'Chesapeake Bay retriever',
212
+ 210: 'German short-haired pointer',
213
+ 211: 'vizsla, Hungarian pointer',
214
+ 212: 'English setter',
215
+ 213: 'Irish setter, red setter',
216
+ 214: 'Gordon setter',
217
+ 215: 'Brittany spaniel',
218
+ 216: 'clumber, clumber spaniel',
219
+ 217: 'English springer, English springer spaniel',
220
+ 218: 'Welsh springer spaniel',
221
+ 219: 'cocker spaniel, English cocker spaniel, cocker',
222
+ 220: 'Sussex spaniel',
223
+ 221: 'Irish water spaniel',
224
+ 222: 'kuvasz',
225
+ 223: 'schipperke',
226
+ 224: 'groenendael',
227
+ 225: 'malinois',
228
+ 226: 'briard',
229
+ 227: 'kelpie',
230
+ 228: 'komondor',
231
+ 229: 'Old English sheepdog, bobtail',
232
+ 230: 'Shetland sheepdog, Shetland sheep dog, Shetland',
233
+ 231: 'collie',
234
+ 232: 'Border collie',
235
+ 233: 'Bouvier des Flandres, Bouviers des Flandres',
236
+ 234: 'Rottweiler',
237
+ 235: 'German shepherd, German shepherd dog, German police dog, alsatian',
238
+ 236: 'Doberman, Doberman pinscher',
239
+ 237: 'miniature pinscher',
240
+ 238: 'Greater Swiss Mountain dog',
241
+ 239: 'Bernese mountain dog',
242
+ 240: 'Appenzeller',
243
+ 241: 'EntleBucher',
244
+ 242: 'boxer',
245
+ 243: 'bull mastiff',
246
+ 244: 'Tibetan mastiff',
247
+ 245: 'French bulldog',
248
+ 246: 'Great Dane',
249
+ 247: 'Saint Bernard, St Bernard',
250
+ 248: 'Eskimo dog, husky',
251
+ 249: 'malamute, malemute, Alaskan malamute',
252
+ 250: 'Siberian husky',
253
+ 251: 'dalmatian, coach dog, carriage dog',
254
+ 252: 'affenpinscher, monkey pinscher, monkey dog',
255
+ 253: 'basenji',
256
+ 254: 'pug, pug-dog',
257
+ 255: 'Leonberg',
258
+ 256: 'Newfoundland, Newfoundland dog',
259
+ 257: 'Great Pyrenees',
260
+ 258: 'Samoyed, Samoyede',
261
+ 259: 'Pomeranian',
262
+ 260: 'chow, chow chow',
263
+ 261: 'keeshond',
264
+ 262: 'Brabancon griffon',
265
+ 263: 'Pembroke, Pembroke Welsh corgi',
266
+ 264: 'Cardigan, Cardigan Welsh corgi',
267
+ 265: 'toy poodle',
268
+ 266: 'miniature poodle',
269
+ 267: 'standard poodle',
270
+ 268: 'Mexican hairless',
271
+ 269: 'timber wolf, grey wolf, gray wolf, Canis lupus',
272
+ 270: 'white wolf, Arctic wolf, Canis lupus tundrarum',
273
+ 271: 'red wolf, maned wolf, Canis rufus, Canis niger',
274
+ 272: 'coyote, prairie wolf, brush wolf, Canis latrans',
275
+ 273: 'dingo, warrigal, warragal, Canis dingo',
276
+ 274: 'dhole, Cuon alpinus',
277
+ 275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus',
278
+ 276: 'hyena, hyaena',
279
+ 277: 'red fox, Vulpes vulpes',
280
+ 278: 'kit fox, Vulpes macrotis',
281
+ 279: 'Arctic fox, white fox, Alopex lagopus',
282
+ 280: 'grey fox, gray fox, Urocyon cinereoargenteus',
283
+ 281: 'tabby, tabby cat',
284
+ 282: 'tiger cat',
285
+ 283: 'Persian cat',
286
+ 284: 'Siamese cat, Siamese',
287
+ 285: 'Egyptian cat',
288
+ 286: 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor',
289
+ 287: 'lynx, catamount',
290
+ 288: 'leopard, Panthera pardus',
291
+ 289: 'snow leopard, ounce, Panthera uncia',
292
+ 290: 'jaguar, panther, Panthera onca, Felis onca',
293
+ 291: 'lion, king of beasts, Panthera leo',
294
+ 292: 'tiger, Panthera tigris',
295
+ 293: 'cheetah, chetah, Acinonyx jubatus',
296
+ 294: 'brown bear, bruin, Ursus arctos',
297
+ 295: 'American black bear, black bear, Ursus americanus, Euarctos americanus',
298
+ 296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus',
299
+ 297: 'sloth bear, Melursus ursinus, Ursus ursinus',
300
+ 298: 'mongoose',
301
+ 299: 'meerkat, mierkat',
302
+ 300: 'tiger beetle',
303
+ 301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle',
304
+ 302: 'ground beetle, carabid beetle',
305
+ 303: 'long-horned beetle, longicorn, longicorn beetle',
306
+ 304: 'leaf beetle, chrysomelid',
307
+ 305: 'dung beetle',
308
+ 306: 'rhinoceros beetle',
309
+ 307: 'weevil',
310
+ 308: 'fly',
311
+ 309: 'bee',
312
+ 310: 'ant, emmet, pismire',
313
+ 311: 'grasshopper, hopper',
314
+ 312: 'cricket',
315
+ 313: 'walking stick, walkingstick, stick insect',
316
+ 314: 'cockroach, roach',
317
+ 315: 'mantis, mantid',
318
+ 316: 'cicada, cicala',
319
+ 317: 'leafhopper',
320
+ 318: 'lacewing, lacewing fly',
321
+ 319: "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
322
+ 320: 'damselfly',
323
+ 321: 'admiral',
324
+ 322: 'ringlet, ringlet butterfly',
325
+ 323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus',
326
+ 324: 'cabbage butterfly',
327
+ 325: 'sulphur butterfly, sulfur butterfly',
328
+ 326: 'lycaenid, lycaenid butterfly',
329
+ 327: 'starfish, sea star',
330
+ 328: 'sea urchin',
331
+ 329: 'sea cucumber, holothurian',
332
+ 330: 'wood rabbit, cottontail, cottontail rabbit',
333
+ 331: 'hare',
334
+ 332: 'Angora, Angora rabbit',
335
+ 333: 'hamster',
336
+ 334: 'porcupine, hedgehog',
337
+ 335: 'fox squirrel, eastern fox squirrel, Sciurus niger',
338
+ 336: 'marmot',
339
+ 337: 'beaver',
340
+ 338: 'guinea pig, Cavia cobaya',
341
+ 339: 'sorrel',
342
+ 340: 'zebra',
343
+ 341: 'hog, pig, grunter, squealer, Sus scrofa',
344
+ 342: 'wild boar, boar, Sus scrofa',
345
+ 343: 'warthog',
346
+ 344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius',
347
+ 345: 'ox',
348
+ 346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis',
349
+ 347: 'bison',
350
+ 348: 'ram, tup',
351
+ 349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis',
352
+ 350: 'ibex, Capra ibex',
353
+ 351: 'hartebeest',
354
+ 352: 'impala, Aepyceros melampus',
355
+ 353: 'gazelle',
356
+ 354: 'Arabian camel, dromedary, Camelus dromedarius',
357
+ 355: 'llama',
358
+ 356: 'weasel',
359
+ 357: 'mink',
360
+ 358: 'polecat, fitch, foulmart, foumart, Mustela putorius',
361
+ 359: 'black-footed ferret, ferret, Mustela nigripes',
362
+ 360: 'otter',
363
+ 361: 'skunk, polecat, wood pussy',
364
+ 362: 'badger',
365
+ 363: 'armadillo',
366
+ 364: 'three-toed sloth, ai, Bradypus tridactylus',
367
+ 365: 'orangutan, orang, orangutang, Pongo pygmaeus',
368
+ 366: 'gorilla, Gorilla gorilla',
369
+ 367: 'chimpanzee, chimp, Pan troglodytes',
370
+ 368: 'gibbon, Hylobates lar',
371
+ 369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus',
372
+ 370: 'guenon, guenon monkey',
373
+ 371: 'patas, hussar monkey, Erythrocebus patas',
374
+ 372: 'baboon',
375
+ 373: 'macaque',
376
+ 374: 'langur',
377
+ 375: 'colobus, colobus monkey',
378
+ 376: 'proboscis monkey, Nasalis larvatus',
379
+ 377: 'marmoset',
380
+ 378: 'capuchin, ringtail, Cebus capucinus',
381
+ 379: 'howler monkey, howler',
382
+ 380: 'titi, titi monkey',
383
+ 381: 'spider monkey, Ateles geoffroyi',
384
+ 382: 'squirrel monkey, Saimiri sciureus',
385
+ 383: 'Madagascar cat, ring-tailed lemur, Lemur catta',
386
+ 384: 'indri, indris, Indri indri, Indri brevicaudatus',
387
+ 385: 'Indian elephant, Elephas maximus',
388
+ 386: 'African elephant, Loxodonta africana',
389
+ 387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens',
390
+ 388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca',
391
+ 389: 'barracouta, snoek',
392
+ 390: 'eel',
393
+ 391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch',
394
+ 392: 'rock beauty, Holocanthus tricolor',
395
+ 393: 'anemone fish',
396
+ 394: 'sturgeon',
397
+ 395: 'gar, garfish, garpike, billfish, Lepisosteus osseus',
398
+ 396: 'lionfish',
399
+ 397: 'puffer, pufferfish, blowfish, globefish',
400
+ 398: 'abacus',
401
+ 399: 'abaya',
402
+ 400: "academic gown, academic robe, judge's robe",
403
+ 401: 'accordion, piano accordion, squeeze box',
404
+ 402: 'acoustic guitar',
405
+ 403: 'aircraft carrier, carrier, flattop, attack aircraft carrier',
406
+ 404: 'airliner',
407
+ 405: 'airship, dirigible',
408
+ 406: 'altar',
409
+ 407: 'ambulance',
410
+ 408: 'amphibian, amphibious vehicle',
411
+ 409: 'analog clock',
412
+ 410: 'apiary, bee house',
413
+ 411: 'apron',
414
+ 412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin',
415
+ 413: 'assault rifle, assault gun',
416
+ 414: 'backpack, back pack, knapsack, packsack, rucksack, haversack',
417
+ 415: 'bakery, bakeshop, bakehouse',
418
+ 416: 'balance beam, beam',
419
+ 417: 'balloon',
420
+ 418: 'ballpoint, ballpoint pen, ballpen, Biro',
421
+ 419: 'Band Aid',
422
+ 420: 'banjo',
423
+ 421: 'bannister, banister, balustrade, balusters, handrail',
424
+ 422: 'barbell',
425
+ 423: 'barber chair',
426
+ 424: 'barbershop',
427
+ 425: 'barn',
428
+ 426: 'barometer',
429
+ 427: 'barrel, cask',
430
+ 428: 'barrow, garden cart, lawn cart, wheelbarrow',
431
+ 429: 'baseball',
432
+ 430: 'basketball',
433
+ 431: 'bassinet',
434
+ 432: 'bassoon',
435
+ 433: 'bathing cap, swimming cap',
436
+ 434: 'bath towel',
437
+ 435: 'bathtub, bathing tub, bath, tub',
438
+ 436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon',
439
+ 437: 'beacon, lighthouse, beacon light, pharos',
440
+ 438: 'beaker',
441
+ 439: 'bearskin, busby, shako',
442
+ 440: 'beer bottle',
443
+ 441: 'beer glass',
444
+ 442: 'bell cote, bell cot',
445
+ 443: 'bib',
446
+ 444: 'bicycle-built-for-two, tandem bicycle, tandem',
447
+ 445: 'bikini, two-piece',
448
+ 446: 'binder, ring-binder',
449
+ 447: 'binoculars, field glasses, opera glasses',
450
+ 448: 'birdhouse',
451
+ 449: 'boathouse',
452
+ 450: 'bobsled, bobsleigh, bob',
453
+ 451: 'bolo tie, bolo, bola tie, bola',
454
+ 452: 'bonnet, poke bonnet',
455
+ 453: 'bookcase',
456
+ 454: 'bookshop, bookstore, bookstall',
457
+ 455: 'bottlecap',
458
+ 456: 'bow',
459
+ 457: 'bow tie, bow-tie, bowtie',
460
+ 458: 'brass, memorial tablet, plaque',
461
+ 459: 'brassiere, bra, bandeau',
462
+ 460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty',
463
+ 461: 'breastplate, aegis, egis',
464
+ 462: 'broom',
465
+ 463: 'bucket, pail',
466
+ 464: 'buckle',
467
+ 465: 'bulletproof vest',
468
+ 466: 'bullet train, bullet',
469
+ 467: 'butcher shop, meat market',
470
+ 468: 'cab, hack, taxi, taxicab',
471
+ 469: 'caldron, cauldron',
472
+ 470: 'candle, taper, wax light',
473
+ 471: 'cannon',
474
+ 472: 'canoe',
475
+ 473: 'can opener, tin opener',
476
+ 474: 'cardigan',
477
+ 475: 'car mirror',
478
+ 476: 'carousel, carrousel, merry-go-round, roundabout, whirligig',
479
+ 477: "carpenter's kit, tool kit",
480
+ 478: 'carton',
481
+ 479: 'car wheel',
482
+ 480: 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM',
483
+ 481: 'cassette',
484
+ 482: 'cassette player',
485
+ 483: 'castle',
486
+ 484: 'catamaran',
487
+ 485: 'CD player',
488
+ 486: 'cello, violoncello',
489
+ 487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone',
490
+ 488: 'chain',
491
+ 489: 'chainlink fence',
492
+ 490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour',
493
+ 491: 'chain saw, chainsaw',
494
+ 492: 'chest',
495
+ 493: 'chiffonier, commode',
496
+ 494: 'chime, bell, gong',
497
+ 495: 'china cabinet, china closet',
498
+ 496: 'Christmas stocking',
499
+ 497: 'church, church building',
500
+ 498: 'cinema, movie theater, movie theatre, movie house, picture palace',
501
+ 499: 'cleaver, meat cleaver, chopper',
502
+ 500: 'cliff dwelling',
503
+ 501: 'cloak',
504
+ 502: 'clog, geta, patten, sabot',
505
+ 503: 'cocktail shaker',
506
+ 504: 'coffee mug',
507
+ 505: 'coffeepot',
508
+ 506: 'coil, spiral, volute, whorl, helix',
509
+ 507: 'combination lock',
510
+ 508: 'computer keyboard, keypad',
511
+ 509: 'confectionery, confectionary, candy store',
512
+ 510: 'container ship, containership, container vessel',
513
+ 511: 'convertible',
514
+ 512: 'corkscrew, bottle screw',
515
+ 513: 'cornet, horn, trumpet, trump',
516
+ 514: 'cowboy boot',
517
+ 515: 'cowboy hat, ten-gallon hat',
518
+ 516: 'cradle',
519
+ 517: 'crane',
520
+ 518: 'crash helmet',
521
+ 519: 'crate',
522
+ 520: 'crib, cot',
523
+ 521: 'Crock Pot',
524
+ 522: 'croquet ball',
525
+ 523: 'crutch',
526
+ 524: 'cuirass',
527
+ 525: 'dam, dike, dyke',
528
+ 526: 'desk',
529
+ 527: 'desktop computer',
530
+ 528: 'dial telephone, dial phone',
531
+ 529: 'diaper, nappy, napkin',
532
+ 530: 'digital clock',
533
+ 531: 'digital watch',
534
+ 532: 'dining table, board',
535
+ 533: 'dishrag, dishcloth',
536
+ 534: 'dishwasher, dish washer, dishwashing machine',
537
+ 535: 'disk brake, disc brake',
538
+ 536: 'dock, dockage, docking facility',
539
+ 537: 'dogsled, dog sled, dog sleigh',
540
+ 538: 'dome',
541
+ 539: 'doormat, welcome mat',
542
+ 540: 'drilling platform, offshore rig',
543
+ 541: 'drum, membranophone, tympan',
544
+ 542: 'drumstick',
545
+ 543: 'dumbbell',
546
+ 544: 'Dutch oven',
547
+ 545: 'electric fan, blower',
548
+ 546: 'electric guitar',
549
+ 547: 'electric locomotive',
550
+ 548: 'entertainment center',
551
+ 549: 'envelope',
552
+ 550: 'espresso maker',
553
+ 551: 'face powder',
554
+ 552: 'feather boa, boa',
555
+ 553: 'file, file cabinet, filing cabinet',
556
+ 554: 'fireboat',
557
+ 555: 'fire engine, fire truck',
558
+ 556: 'fire screen, fireguard',
559
+ 557: 'flagpole, flagstaff',
560
+ 558: 'flute, transverse flute',
561
+ 559: 'folding chair',
562
+ 560: 'football helmet',
563
+ 561: 'forklift',
564
+ 562: 'fountain',
565
+ 563: 'fountain pen',
566
+ 564: 'four-poster',
567
+ 565: 'freight car',
568
+ 566: 'French horn, horn',
569
+ 567: 'frying pan, frypan, skillet',
570
+ 568: 'fur coat',
571
+ 569: 'garbage truck, dustcart',
572
+ 570: 'gasmask, respirator, gas helmet',
573
+ 571: 'gas pump, gasoline pump, petrol pump, island dispenser',
574
+ 572: 'goblet',
575
+ 573: 'go-kart',
576
+ 574: 'golf ball',
577
+ 575: 'golfcart, golf cart',
578
+ 576: 'gondola',
579
+ 577: 'gong, tam-tam',
580
+ 578: 'gown',
581
+ 579: 'grand piano, grand',
582
+ 580: 'greenhouse, nursery, glasshouse',
583
+ 581: 'grille, radiator grille',
584
+ 582: 'grocery store, grocery, food market, market',
585
+ 583: 'guillotine',
586
+ 584: 'hair slide',
587
+ 585: 'hair spray',
588
+ 586: 'half track',
589
+ 587: 'hammer',
590
+ 588: 'hamper',
591
+ 589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier',
592
+ 590: 'hand-held computer, hand-held microcomputer',
593
+ 591: 'handkerchief, hankie, hanky, hankey',
594
+ 592: 'hard disc, hard disk, fixed disk',
595
+ 593: 'harmonica, mouth organ, harp, mouth harp',
596
+ 594: 'harp',
597
+ 595: 'harvester, reaper',
598
+ 596: 'hatchet',
599
+ 597: 'holster',
600
+ 598: 'home theater, home theatre',
601
+ 599: 'honeycomb',
602
+ 600: 'hook, claw',
603
+ 601: 'hoopskirt, crinoline',
604
+ 602: 'horizontal bar, high bar',
605
+ 603: 'horse cart, horse-cart',
606
+ 604: 'hourglass',
607
+ 605: 'iPod',
608
+ 606: 'iron, smoothing iron',
609
+ 607: "jack-o'-lantern",
610
+ 608: 'jean, blue jean, denim',
611
+ 609: 'jeep, landrover',
612
+ 610: 'jersey, T-shirt, tee shirt',
613
+ 611: 'jigsaw puzzle',
614
+ 612: 'jinrikisha, ricksha, rickshaw',
615
+ 613: 'joystick',
616
+ 614: 'kimono',
617
+ 615: 'knee pad',
618
+ 616: 'knot',
619
+ 617: 'lab coat, laboratory coat',
620
+ 618: 'ladle',
621
+ 619: 'lampshade, lamp shade',
622
+ 620: 'laptop, laptop computer',
623
+ 621: 'lawn mower, mower',
624
+ 622: 'lens cap, lens cover',
625
+ 623: 'letter opener, paper knife, paperknife',
626
+ 624: 'library',
627
+ 625: 'lifeboat',
628
+ 626: 'lighter, light, igniter, ignitor',
629
+ 627: 'limousine, limo',
630
+ 628: 'liner, ocean liner',
631
+ 629: 'lipstick, lip rouge',
632
+ 630: 'Loafer',
633
+ 631: 'lotion',
634
+ 632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system',
635
+ 633: "loupe, jeweler's loupe",
636
+ 634: 'lumbermill, sawmill',
637
+ 635: 'magnetic compass',
638
+ 636: 'mailbag, postbag',
639
+ 637: 'mailbox, letter box',
640
+ 638: 'maillot',
641
+ 639: 'maillot, tank suit',
642
+ 640: 'manhole cover',
643
+ 641: 'maraca',
644
+ 642: 'marimba, xylophone',
645
+ 643: 'mask',
646
+ 644: 'matchstick',
647
+ 645: 'maypole',
648
+ 646: 'maze, labyrinth',
649
+ 647: 'measuring cup',
650
+ 648: 'medicine chest, medicine cabinet',
651
+ 649: 'megalith, megalithic structure',
652
+ 650: 'microphone, mike',
653
+ 651: 'microwave, microwave oven',
654
+ 652: 'military uniform',
655
+ 653: 'milk can',
656
+ 654: 'minibus',
657
+ 655: 'miniskirt, mini',
658
+ 656: 'minivan',
659
+ 657: 'missile',
660
+ 658: 'mitten',
661
+ 659: 'mixing bowl',
662
+ 660: 'mobile home, manufactured home',
663
+ 661: 'Model T',
664
+ 662: 'modem',
665
+ 663: 'monastery',
666
+ 664: 'monitor',
667
+ 665: 'moped',
668
+ 666: 'mortar',
669
+ 667: 'mortarboard',
670
+ 668: 'mosque',
671
+ 669: 'mosquito net',
672
+ 670: 'motor scooter, scooter',
673
+ 671: 'mountain bike, all-terrain bike, off-roader',
674
+ 672: 'mountain tent',
675
+ 673: 'mouse, computer mouse',
676
+ 674: 'mousetrap',
677
+ 675: 'moving van',
678
+ 676: 'muzzle',
679
+ 677: 'nail',
680
+ 678: 'neck brace',
681
+ 679: 'necklace',
682
+ 680: 'nipple',
683
+ 681: 'notebook, notebook computer',
684
+ 682: 'obelisk',
685
+ 683: 'oboe, hautboy, hautbois',
686
+ 684: 'ocarina, sweet potato',
687
+ 685: 'odometer, hodometer, mileometer, milometer',
688
+ 686: 'oil filter',
689
+ 687: 'organ, pipe organ',
690
+ 688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO',
691
+ 689: 'overskirt',
692
+ 690: 'oxcart',
693
+ 691: 'oxygen mask',
694
+ 692: 'packet',
695
+ 693: 'paddle, boat paddle',
696
+ 694: 'paddlewheel, paddle wheel',
697
+ 695: 'padlock',
698
+ 696: 'paintbrush',
699
+ 697: "pajama, pyjama, pj's, jammies",
700
+ 698: 'palace',
701
+ 699: 'panpipe, pandean pipe, syrinx',
702
+ 700: 'paper towel',
703
+ 701: 'parachute, chute',
704
+ 702: 'parallel bars, bars',
705
+ 703: 'park bench',
706
+ 704: 'parking meter',
707
+ 705: 'passenger car, coach, carriage',
708
+ 706: 'patio, terrace',
709
+ 707: 'pay-phone, pay-station',
710
+ 708: 'pedestal, plinth, footstall',
711
+ 709: 'pencil box, pencil case',
712
+ 710: 'pencil sharpener',
713
+ 711: 'perfume, essence',
714
+ 712: 'Petri dish',
715
+ 713: 'photocopier',
716
+ 714: 'pick, plectrum, plectron',
717
+ 715: 'pickelhaube',
718
+ 716: 'picket fence, paling',
719
+ 717: 'pickup, pickup truck',
720
+ 718: 'pier',
721
+ 719: 'piggy bank, penny bank',
722
+ 720: 'pill bottle',
723
+ 721: 'pillow',
724
+ 722: 'ping-pong ball',
725
+ 723: 'pinwheel',
726
+ 724: 'pirate, pirate ship',
727
+ 725: 'pitcher, ewer',
728
+ 726: "plane, carpenter's plane, woodworking plane",
729
+ 727: 'planetarium',
730
+ 728: 'plastic bag',
731
+ 729: 'plate rack',
732
+ 730: 'plow, plough',
733
+ 731: "plunger, plumber's helper",
734
+ 732: 'Polaroid camera, Polaroid Land camera',
735
+ 733: 'pole',
736
+ 734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria',
737
+ 735: 'poncho',
738
+ 736: 'pool table, billiard table, snooker table',
739
+ 737: 'pop bottle, soda bottle',
740
+ 738: 'pot, flowerpot',
741
+ 739: "potter's wheel",
742
+ 740: 'power drill',
743
+ 741: 'prayer rug, prayer mat',
744
+ 742: 'printer',
745
+ 743: 'prison, prison house',
746
+ 744: 'projectile, missile',
747
+ 745: 'projector',
748
+ 746: 'puck, hockey puck',
749
+ 747: 'punching bag, punch bag, punching ball, punchball',
750
+ 748: 'purse',
751
+ 749: 'quill, quill pen',
752
+ 750: 'quilt, comforter, comfort, puff',
753
+ 751: 'racer, race car, racing car',
754
+ 752: 'racket, racquet',
755
+ 753: 'radiator',
756
+ 754: 'radio, wireless',
757
+ 755: 'radio telescope, radio reflector',
758
+ 756: 'rain barrel',
759
+ 757: 'recreational vehicle, RV, R.V.',
760
+ 758: 'reel',
761
+ 759: 'reflex camera',
762
+ 760: 'refrigerator, icebox',
763
+ 761: 'remote control, remote',
764
+ 762: 'restaurant, eating house, eating place, eatery',
765
+ 763: 'revolver, six-gun, six-shooter',
766
+ 764: 'rifle',
767
+ 765: 'rocking chair, rocker',
768
+ 766: 'rotisserie',
769
+ 767: 'rubber eraser, rubber, pencil eraser',
770
+ 768: 'rugby ball',
771
+ 769: 'rule, ruler',
772
+ 770: 'running shoe',
773
+ 771: 'safe',
774
+ 772: 'safety pin',
775
+ 773: 'saltshaker, salt shaker',
776
+ 774: 'sandal',
777
+ 775: 'sarong',
778
+ 776: 'sax, saxophone',
779
+ 777: 'scabbard',
780
+ 778: 'scale, weighing machine',
781
+ 779: 'school bus',
782
+ 780: 'schooner',
783
+ 781: 'scoreboard',
784
+ 782: 'screen, CRT screen',
785
+ 783: 'screw',
786
+ 784: 'screwdriver',
787
+ 785: 'seat belt, seatbelt',
788
+ 786: 'sewing machine',
789
+ 787: 'shield, buckler',
790
+ 788: 'shoe shop, shoe-shop, shoe store',
791
+ 789: 'shoji',
792
+ 790: 'shopping basket',
793
+ 791: 'shopping cart',
794
+ 792: 'shovel',
795
+ 793: 'shower cap',
796
+ 794: 'shower curtain',
797
+ 795: 'ski',
798
+ 796: 'ski mask',
799
+ 797: 'sleeping bag',
800
+ 798: 'slide rule, slipstick',
801
+ 799: 'sliding door',
802
+ 800: 'slot, one-armed bandit',
803
+ 801: 'snorkel',
804
+ 802: 'snowmobile',
805
+ 803: 'snowplow, snowplough',
806
+ 804: 'soap dispenser',
807
+ 805: 'soccer ball',
808
+ 806: 'sock',
809
+ 807: 'solar dish, solar collector, solar furnace',
810
+ 808: 'sombrero',
811
+ 809: 'soup bowl',
812
+ 810: 'space bar',
813
+ 811: 'space heater',
814
+ 812: 'space shuttle',
815
+ 813: 'spatula',
816
+ 814: 'speedboat',
817
+ 815: "spider web, spider's web",
818
+ 816: 'spindle',
819
+ 817: 'sports car, sport car',
820
+ 818: 'spotlight, spot',
821
+ 819: 'stage',
822
+ 820: 'steam locomotive',
823
+ 821: 'steel arch bridge',
824
+ 822: 'steel drum',
825
+ 823: 'stethoscope',
826
+ 824: 'stole',
827
+ 825: 'stone wall',
828
+ 826: 'stopwatch, stop watch',
829
+ 827: 'stove',
830
+ 828: 'strainer',
831
+ 829: 'streetcar, tram, tramcar, trolley, trolley car',
832
+ 830: 'stretcher',
833
+ 831: 'studio couch, day bed',
834
+ 832: 'stupa, tope',
835
+ 833: 'submarine, pigboat, sub, U-boat',
836
+ 834: 'suit, suit of clothes',
837
+ 835: 'sundial',
838
+ 836: 'sunglass',
839
+ 837: 'sunglasses, dark glasses, shades',
840
+ 838: 'sunscreen, sunblock, sun blocker',
841
+ 839: 'suspension bridge',
842
+ 840: 'swab, swob, mop',
843
+ 841: 'sweatshirt',
844
+ 842: 'swimming trunks, bathing trunks',
845
+ 843: 'swing',
846
+ 844: 'switch, electric switch, electrical switch',
847
+ 845: 'syringe',
848
+ 846: 'table lamp',
849
+ 847: 'tank, army tank, armored combat vehicle, armoured combat vehicle',
850
+ 848: 'tape player',
851
+ 849: 'teapot',
852
+ 850: 'teddy, teddy bear',
853
+ 851: 'television, television system',
854
+ 852: 'tennis ball',
855
+ 853: 'thatch, thatched roof',
856
+ 854: 'theater curtain, theatre curtain',
857
+ 855: 'thimble',
858
+ 856: 'thresher, thrasher, threshing machine',
859
+ 857: 'throne',
860
+ 858: 'tile roof',
861
+ 859: 'toaster',
862
+ 860: 'tobacco shop, tobacconist shop, tobacconist',
863
+ 861: 'toilet seat',
864
+ 862: 'torch',
865
+ 863: 'totem pole',
866
+ 864: 'tow truck, tow car, wrecker',
867
+ 865: 'toyshop',
868
+ 866: 'tractor',
869
+ 867: 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi',
870
+ 868: 'tray',
871
+ 869: 'trench coat',
872
+ 870: 'tricycle, trike, velocipede',
873
+ 871: 'trimaran',
874
+ 872: 'tripod',
875
+ 873: 'triumphal arch',
876
+ 874: 'trolleybus, trolley coach, trackless trolley',
877
+ 875: 'trombone',
878
+ 876: 'tub, vat',
879
+ 877: 'turnstile',
880
+ 878: 'typewriter keyboard',
881
+ 879: 'umbrella',
882
+ 880: 'unicycle, monocycle',
883
+ 881: 'upright, upright piano',
884
+ 882: 'vacuum, vacuum cleaner',
885
+ 883: 'vase',
886
+ 884: 'vault',
887
+ 885: 'velvet',
888
+ 886: 'vending machine',
889
+ 887: 'vestment',
890
+ 888: 'viaduct',
891
+ 889: 'violin, fiddle',
892
+ 890: 'volleyball',
893
+ 891: 'waffle iron',
894
+ 892: 'wall clock',
895
+ 893: 'wallet, billfold, notecase, pocketbook',
896
+ 894: 'wardrobe, closet, press',
897
+ 895: 'warplane, military plane',
898
+ 896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin',
899
+ 897: 'washer, automatic washer, washing machine',
900
+ 898: 'water bottle',
901
+ 899: 'water jug',
902
+ 900: 'water tower',
903
+ 901: 'whiskey jug',
904
+ 902: 'whistle',
905
+ 903: 'wig',
906
+ 904: 'window screen',
907
+ 905: 'window shade',
908
+ 906: 'Windsor tie',
909
+ 907: 'wine bottle',
910
+ 908: 'wing',
911
+ 909: 'wok',
912
+ 910: 'wooden spoon',
913
+ 911: 'wool, woolen, woollen',
914
+ 912: 'worm fence, snake fence, snake-rail fence, Virginia fence',
915
+ 913: 'wreck',
916
+ 914: 'yawl',
917
+ 915: 'yurt',
918
+ 916: 'web site, website, internet site, site',
919
+ 917: 'comic book',
920
+ 918: 'crossword puzzle, crossword',
921
+ 919: 'street sign',
922
+ 920: 'traffic light, traffic signal, stoplight',
923
+ 921: 'book jacket, dust cover, dust jacket, dust wrapper',
924
+ 922: 'menu',
925
+ 923: 'plate',
926
+ 924: 'guacamole',
927
+ 925: 'consomme',
928
+ 926: 'hot pot, hotpot',
929
+ 927: 'trifle',
930
+ 928: 'ice cream, icecream',
931
+ 929: 'ice lolly, lolly, lollipop, popsicle',
932
+ 930: 'French loaf',
933
+ 931: 'bagel, beigel',
934
+ 932: 'pretzel',
935
+ 933: 'cheeseburger',
936
+ 934: 'hotdog, hot dog, red hot',
937
+ 935: 'mashed potato',
938
+ 936: 'head cabbage',
939
+ 937: 'broccoli',
940
+ 938: 'cauliflower',
941
+ 939: 'zucchini, courgette',
942
+ 940: 'spaghetti squash',
943
+ 941: 'acorn squash',
944
+ 942: 'butternut squash',
945
+ 943: 'cucumber, cuke',
946
+ 944: 'artichoke, globe artichoke',
947
+ 945: 'bell pepper',
948
+ 946: 'cardoon',
949
+ 947: 'mushroom',
950
+ 948: 'Granny Smith',
951
+ 949: 'strawberry',
952
+ 950: 'orange',
953
+ 951: 'lemon',
954
+ 952: 'fig',
955
+ 953: 'pineapple, ananas',
956
+ 954: 'banana',
957
+ 955: 'jackfruit, jak, jack',
958
+ 956: 'custard apple',
959
+ 957: 'pomegranate',
960
+ 958: 'hay',
961
+ 959: 'carbonara',
962
+ 960: 'chocolate sauce, chocolate syrup',
963
+ 961: 'dough',
964
+ 962: 'meat loaf, meatloaf',
965
+ 963: 'pizza, pizza pie',
966
+ 964: 'potpie',
967
+ 965: 'burrito',
968
+ 966: 'red wine',
969
+ 967: 'espresso',
970
+ 968: 'cup',
971
+ 969: 'eggnog',
972
+ 970: 'alp',
973
+ 971: 'bubble',
974
+ 972: 'cliff, drop, drop-off',
975
+ 973: 'coral reef',
976
+ 974: 'geyser',
977
+ 975: 'lakeside, lakeshore',
978
+ 976: 'promontory, headland, head, foreland',
979
+ 977: 'sandbar, sand bar',
980
+ 978: 'seashore, coast, seacoast, sea-coast',
981
+ 979: 'valley, vale',
982
+ 980: 'volcano',
983
+ 981: 'ballplayer, baseball player',
984
+ 982: 'groom, bridegroom',
985
+ 983: 'scuba diver',
986
+ 984: 'rapeseed',
987
+ 985: 'daisy',
988
+ 986: "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
989
+ 987: 'corn',
990
+ 988: 'acorn',
991
+ 989: 'hip, rose hip, rosehip',
992
+ 990: 'buckeye, horse chestnut, conker',
993
+ 991: 'coral fungus',
994
+ 992: 'agaric',
995
+ 993: 'gyromitra',
996
+ 994: 'stinkhorn, carrion fungus',
997
+ 995: 'earthstar',
998
+ 996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa',
999
+ 997: 'bolete',
1000
+ 998: 'ear, spike, capitulum',
1001
+ 999: 'toilet tissue, toilet paper, bathroom tissue'
1002
+ }
ViT_DeiT/data/transforms.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+ import sys
3
+ import random
4
+ from PIL import Image
5
+
6
+ try:
7
+ import accimage
8
+ except ImportError:
9
+ accimage = None
10
+ import numbers
11
+ import collections
12
+
13
+ from torchvision.transforms import functional as F
14
+
15
+ if sys.version_info < (3, 3):
16
+ Sequence = collections.Sequence
17
+ Iterable = collections.Iterable
18
+ else:
19
+ Sequence = collections.abc.Sequence
20
+ Iterable = collections.abc.Iterable
21
+
22
+ _pil_interpolation_to_str = {
23
+ Image.NEAREST: 'PIL.Image.NEAREST',
24
+ Image.BILINEAR: 'PIL.Image.BILINEAR',
25
+ Image.BICUBIC: 'PIL.Image.BICUBIC',
26
+ Image.LANCZOS: 'PIL.Image.LANCZOS',
27
+ Image.HAMMING: 'PIL.Image.HAMMING',
28
+ Image.BOX: 'PIL.Image.BOX',
29
+ }
30
+
31
+
32
+ class Compose(object):
33
+ """Composes several transforms together.
34
+
35
+ Args:
36
+ transforms (list of ``Transform`` objects): list of transforms to compose.
37
+
38
+ Example:
39
+ >>> transforms.Compose([
40
+ >>> transforms.CenterCrop(10),
41
+ >>> transforms.ToTensor(),
42
+ >>> ])
43
+ """
44
+
45
+ def __init__(self, transforms):
46
+ self.transforms = transforms
47
+
48
+ def __call__(self, img, tgt):
49
+ for t in self.transforms:
50
+ img, tgt = t(img, tgt)
51
+ return img, tgt
52
+
53
+ def __repr__(self):
54
+ format_string = self.__class__.__name__ + '('
55
+ for t in self.transforms:
56
+ format_string += '\n'
57
+ format_string += ' {0}'.format(t)
58
+ format_string += '\n)'
59
+ return format_string
60
+
61
+
62
+ class Resize(object):
63
+ """Resize the input PIL Image to the given size.
64
+
65
+ Args:
66
+ size (sequence or int): Desired output size. If size is a sequence like
67
+ (h, w), output size will be matched to this. If size is an int,
68
+ smaller edge of the image will be matched to this number.
69
+ i.e, if height > width, then image will be rescaled to
70
+ (size * height / width, size)
71
+ interpolation (int, optional): Desired interpolation. Default is
72
+ ``PIL.Image.BILINEAR``
73
+ """
74
+
75
+ def __init__(self, size, interpolation=Image.BILINEAR):
76
+ assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)
77
+ self.size = size
78
+ self.interpolation = interpolation
79
+
80
+ def __call__(self, img, tgt):
81
+ """
82
+ Args:
83
+ img (PIL Image): Image to be scaled.
84
+
85
+ Returns:
86
+ PIL Image: Rescaled image.
87
+ """
88
+ return F.resize(img, self.size, self.interpolation), F.resize(tgt, self.size, Image.NEAREST)
89
+
90
+ def __repr__(self):
91
+ interpolate_str = _pil_interpolation_to_str[self.interpolation]
92
+ return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
93
+
94
+
95
+ class CenterCrop(object):
96
+ """Crops the given PIL Image at the center.
97
+
98
+ Args:
99
+ size (sequence or int): Desired output size of the crop. If size is an
100
+ int instead of sequence like (h, w), a square crop (size, size) is
101
+ made.
102
+ """
103
+
104
+ def __init__(self, size):
105
+ if isinstance(size, numbers.Number):
106
+ self.size = (int(size), int(size))
107
+ else:
108
+ self.size = size
109
+
110
+ def __call__(self, img, tgt):
111
+ """
112
+ Args:
113
+ img (PIL Image): Image to be cropped.
114
+
115
+ Returns:
116
+ PIL Image: Cropped image.
117
+ """
118
+ return F.center_crop(img, self.size), F.center_crop(tgt, self.size)
119
+
120
+ def __repr__(self):
121
+ return self.__class__.__name__ + '(size={0})'.format(self.size)
122
+
123
+
124
+ class RandomCrop(object):
125
+ """Crop the given PIL Image at a random location.
126
+
127
+ Args:
128
+ size (sequence or int): Desired output size of the crop. If size is an
129
+ int instead of sequence like (h, w), a square crop (size, size) is
130
+ made.
131
+ padding (int or sequence, optional): Optional padding on each border
132
+ of the image. Default is None, i.e no padding. If a sequence of length
133
+ 4 is provided, it is used to pad left, top, right, bottom borders
134
+ respectively. If a sequence of length 2 is provided, it is used to
135
+ pad left/right, top/bottom borders, respectively.
136
+ pad_if_needed (boolean): It will pad the image if smaller than the
137
+ desired size to avoid raising an exception.
138
+ fill: Pixel fill value for constant fill. Default is 0. If a tuple of
139
+ length 3, it is used to fill R, G, B channels respectively.
140
+ This value is only used when the padding_mode is constant
141
+ padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
142
+
143
+ - constant: pads with a constant value, this value is specified with fill
144
+
145
+ - edge: pads with the last value on the edge of the image
146
+
147
+ - reflect: pads with reflection of image (without repeating the last value on the edge)
148
+
149
+ padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
150
+ will result in [3, 2, 1, 2, 3, 4, 3, 2]
151
+
152
+ - symmetric: pads with reflection of image (repeating the last value on the edge)
153
+
154
+ padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
155
+ will result in [2, 1, 1, 2, 3, 4, 4, 3]
156
+
157
+ """
158
+
159
+ def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'):
160
+ if isinstance(size, numbers.Number):
161
+ self.size = (int(size), int(size))
162
+ else:
163
+ self.size = size
164
+ self.padding = padding
165
+ self.pad_if_needed = pad_if_needed
166
+ self.fill = fill
167
+ self.padding_mode = padding_mode
168
+
169
+ @staticmethod
170
+ def get_params(img, output_size):
171
+ """Get parameters for ``crop`` for a random crop.
172
+
173
+ Args:
174
+ img (PIL Image): Image to be cropped.
175
+ output_size (tuple): Expected output size of the crop.
176
+
177
+ Returns:
178
+ tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
179
+ """
180
+ w, h = img.size
181
+ th, tw = output_size
182
+ if w == tw and h == th:
183
+ return 0, 0, h, w
184
+
185
+ i = random.randint(0, h - th)
186
+ j = random.randint(0, w - tw)
187
+ return i, j, th, tw
188
+
189
+ def __call__(self, img, tgt):
190
+ """
191
+ Args:
192
+ img (PIL Image): Image to be cropped.
193
+
194
+ Returns:
195
+ PIL Image: Cropped image.
196
+ """
197
+ if self.padding is not None:
198
+ img = F.pad(img, self.padding, self.fill, self.padding_mode)
199
+ tgt = F.pad(tgt, self.padding, self.fill, self.padding_mode)
200
+
201
+ # pad the width if needed
202
+ if self.pad_if_needed and img.size[0] < self.size[1]:
203
+ img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode)
204
+ tgt = F.pad(tgt, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode)
205
+ # pad the height if needed
206
+ if self.pad_if_needed and img.size[1] < self.size[0]:
207
+ img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode)
208
+ tgt = F.pad(tgt, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode)
209
+
210
+ i, j, h, w = self.get_params(img, self.size)
211
+
212
+ return F.crop(img, i, j, h, w), F.crop(tgt, i, j, h, w)
213
+
214
+ def __repr__(self):
215
+ return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding)
216
+
217
+
218
+ class RandomHorizontalFlip(object):
219
+ """Horizontally flip the given PIL Image randomly with a given probability.
220
+
221
+ Args:
222
+ p (float): probability of the image being flipped. Default value is 0.5
223
+ """
224
+
225
+ def __init__(self, p=0.5):
226
+ self.p = p
227
+
228
+ def __call__(self, img, tgt):
229
+ """
230
+ Args:
231
+ img (PIL Image): Image to be flipped.
232
+
233
+ Returns:
234
+ PIL Image: Randomly flipped image.
235
+ """
236
+ if random.random() < self.p:
237
+ return F.hflip(img), F.hflip(tgt)
238
+
239
+ return img, tgt
240
+
241
+ def __repr__(self):
242
+ return self.__class__.__name__ + '(p={})'.format(self.p)
243
+
244
+
245
+ class RandomVerticalFlip(object):
246
+ """Vertically flip the given PIL Image randomly with a given probability.
247
+
248
+ Args:
249
+ p (float): probability of the image being flipped. Default value is 0.5
250
+ """
251
+
252
+ def __init__(self, p=0.5):
253
+ self.p = p
254
+
255
+ def __call__(self, img, tgt):
256
+ """
257
+ Args:
258
+ img (PIL Image): Image to be flipped.
259
+
260
+ Returns:
261
+ PIL Image: Randomly flipped image.
262
+ """
263
+ if random.random() < self.p:
264
+ return F.vflip(img), F.vflip(tgt)
265
+ return img, tgt
266
+
267
+ def __repr__(self):
268
+ return self.__class__.__name__ + '(p={})'.format(self.p)
269
+
270
+
271
+ class Lambda(object):
272
+ """Apply a user-defined lambda as a transform.
273
+
274
+ Args:
275
+ lambd (function): Lambda/function to be used for transform.
276
+ """
277
+
278
+ def __init__(self, lambd):
279
+ assert callable(lambd), repr(type(lambd).__name__) + " object is not callable"
280
+ self.lambd = lambd
281
+
282
+ def __call__(self, img, tgt):
283
+ return self.lambd(img, tgt)
284
+
285
+ def __repr__(self):
286
+ return self.__class__.__name__ + '()'
287
+
288
+
289
+ class ColorJitter(object):
290
+ """Randomly change the brightness, contrast and saturation of an image.
291
+
292
+ Args:
293
+ brightness (float or tuple of float (min, max)): How much to jitter brightness.
294
+ brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
295
+ or the given [min, max]. Should be non negative numbers.
296
+ contrast (float or tuple of float (min, max)): How much to jitter contrast.
297
+ contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
298
+ or the given [min, max]. Should be non negative numbers.
299
+ saturation (float or tuple of float (min, max)): How much to jitter saturation.
300
+ saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
301
+ or the given [min, max]. Should be non negative numbers.
302
+ hue (float or tuple of float (min, max)): How much to jitter hue.
303
+ hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
304
+ Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
305
+ """
306
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
307
+ self.brightness = self._check_input(brightness, 'brightness')
308
+ self.contrast = self._check_input(contrast, 'contrast')
309
+ self.saturation = self._check_input(saturation, 'saturation')
310
+ self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
311
+ clip_first_on_zero=False)
312
+
313
+ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
314
+ if isinstance(value, numbers.Number):
315
+ if value < 0:
316
+ raise ValueError("If {} is a single number, it must be non negative.".format(name))
317
+ value = [center - value, center + value]
318
+ if clip_first_on_zero:
319
+ value[0] = max(value[0], 0)
320
+ elif isinstance(value, (tuple, list)) and len(value) == 2:
321
+ if not bound[0] <= value[0] <= value[1] <= bound[1]:
322
+ raise ValueError("{} values should be between {}".format(name, bound))
323
+ else:
324
+ raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))
325
+
326
+ # if value is 0 or (1., 1.) for brightness/contrast/saturation
327
+ # or (0., 0.) for hue, do nothing
328
+ if value[0] == value[1] == center:
329
+ value = None
330
+ return value
331
+
332
+ @staticmethod
333
+ def get_params(brightness, contrast, saturation, hue):
334
+ """Get a randomized transform to be applied on image.
335
+
336
+ Arguments are same as that of __init__.
337
+
338
+ Returns:
339
+ Transform which randomly adjusts brightness, contrast and
340
+ saturation in a random order.
341
+ """
342
+ transforms = []
343
+
344
+ if brightness is not None:
345
+ brightness_factor = random.uniform(brightness[0], brightness[1])
346
+ transforms.append(Lambda(lambda img, tgt: (F.adjust_brightness(img, brightness_factor), tgt)))
347
+
348
+ if contrast is not None:
349
+ contrast_factor = random.uniform(contrast[0], contrast[1])
350
+ transforms.append(Lambda(lambda img, tgt: (F.adjust_contrast(img, contrast_factor), tgt)))
351
+
352
+ if saturation is not None:
353
+ saturation_factor = random.uniform(saturation[0], saturation[1])
354
+ transforms.append(Lambda(lambda img, tgt: (F.adjust_saturation(img, saturation_factor), tgt)))
355
+
356
+ if hue is not None:
357
+ hue_factor = random.uniform(hue[0], hue[1])
358
+ transforms.append(Lambda(lambda img, tgt: (F.adjust_hue(img, hue_factor), tgt)))
359
+
360
+ random.shuffle(transforms)
361
+ transform = Compose(transforms)
362
+
363
+ return transform
364
+
365
+ def __call__(self, img, tgt):
366
+ """
367
+ Args:
368
+ img (PIL Image): Input image.
369
+
370
+ Returns:
371
+ PIL Image: Color jittered image.
372
+ """
373
+ transform = self.get_params(self.brightness, self.contrast,
374
+ self.saturation, self.hue)
375
+ return transform(img, tgt)
376
+
377
+ def __repr__(self):
378
+ format_string = self.__class__.__name__ + '('
379
+ format_string += 'brightness={0}'.format(self.brightness)
380
+ format_string += ', contrast={0}'.format(self.contrast)
381
+ format_string += ', saturation={0}'.format(self.saturation)
382
+ format_string += ', hue={0})'.format(self.hue)
383
+ return format_string
384
+
385
+
386
+ class Normalize(object):
387
+ """Normalize a tensor image with mean and standard deviation.
388
+ Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
389
+ will normalize each channel of the input ``torch.*Tensor`` i.e.
390
+ ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
391
+
392
+ .. note::
393
+ This transform acts out of place, i.e., it does not mutates the input tensor.
394
+
395
+ Args:
396
+ mean (sequence): Sequence of means for each channel.
397
+ std (sequence): Sequence of standard deviations for each channel.
398
+ """
399
+
400
+ def __init__(self, mean, std, inplace=False):
401
+ self.mean = mean
402
+ self.std = std
403
+ self.inplace = inplace
404
+
405
+ def __call__(self, img, tgt):
406
+ """
407
+ Args:
408
+ tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
409
+
410
+ Returns:
411
+ Tensor: Normalized Tensor image.
412
+ """
413
+ # return F.normalize(img, self.mean, self.std, self.inplace), tgt
414
+ return F.normalize(img, self.mean, self.std), tgt
415
+
416
+ def __repr__(self):
417
+ return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
418
+
419
+
420
+ class ToTensor(object):
421
+ """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
422
+
423
+ Converts a PIL Image or numpy.ndarray (H x W x C) in the range
424
+ [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
425
+ if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
426
+ or if the numpy.ndarray has dtype = np.uint8
427
+
428
+ In the other cases, tensors are returned without scaling.
429
+ """
430
+
431
+ def __call__(self, img, tgt):
432
+ """
433
+ Args:
434
+ pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
435
+
436
+ Returns:
437
+ Tensor: Converted image.
438
+ """
439
+ return F.to_tensor(img), tgt
440
+
441
+ def __repr__(self):
442
+ return self.__class__.__name__ + '()'
ViT_DeiT/dataset/expl_hdf5.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ import h5py
4
+ import os
5
+
6
+
7
+ class ImagenetResults(Dataset):
8
+ def __init__(self, path):
9
+ super(ImagenetResults, self).__init__()
10
+
11
+ self.path = os.path.join(path, 'results.hdf5')
12
+ self.data = None
13
+
14
+ print('Reading dataset length...')
15
+ with h5py.File(self.path , 'r') as f:
16
+ # tmp = h5py.File(self.path , 'r')
17
+ self.data_length = len(f['/image'])
18
+
19
+ def __len__(self):
20
+ return self.data_length
21
+
22
+ def __getitem__(self, item):
23
+ if self.data is None:
24
+ self.data = h5py.File(self.path, 'r')
25
+
26
+ image = torch.tensor(self.data['image'][item])
27
+ vis = torch.tensor(self.data['vis'][item])
28
+ target = torch.tensor(self.data['target'][item]).long()
29
+
30
+ return image, vis, target
31
+
32
+
33
+ if __name__ == '__main__':
34
+ from utils import render
35
+ import imageio
36
+ import numpy as np
37
+
38
+ ds = ImagenetResults('../visualizations/fullgrad')
39
+ sample_loader = torch.utils.data.DataLoader(
40
+ ds,
41
+ batch_size=5,
42
+ shuffle=False)
43
+
44
+ iterator = iter(sample_loader)
45
+ image, vis, target = next(iterator)
46
+
47
+ maps = (render.hm_to_rgb(vis[0].data.cpu().numpy(), scaling=3, sigma=1, cmap='seismic') * 255).astype(np.uint8)
48
+
49
+ # imageio.imsave('../delete_hm.jpg', maps)
50
+
51
+ print(len(ds))
ViT_DeiT/modules/__init__.py ADDED
File without changes
ViT_DeiT/modules/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (165 Bytes). View file
 
ViT_DeiT/modules/__pycache__/layers_ours.cpython-38.pyc ADDED
Binary file (9.94 kB). View file
 
ViT_DeiT/modules/layers_lrp.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ __all__ = ['forward_hook', 'Clone', 'Add', 'Cat', 'ReLU', 'GELU', 'Dropout', 'BatchNorm2d', 'Linear', 'MaxPool2d',
6
+ 'AdaptiveAvgPool2d', 'AvgPool2d', 'Conv2d', 'Sequential', 'safe_divide', 'einsum', 'Softmax', 'IndexSelect',
7
+ 'LayerNorm', 'AddEye']
8
+
9
+
10
+ def safe_divide(a, b):
11
+ den = b.clamp(min=1e-9) + b.clamp(max=1e-9)
12
+ den = den + den.eq(0).type(den.type()) * 1e-9
13
+ return a / den * b.ne(0).type(b.type())
14
+
15
+
16
+ def forward_hook(self, input, output):
17
+ if type(input[0]) in (list, tuple):
18
+ self.X = []
19
+ for i in input[0]:
20
+ x = i.detach()
21
+ x.requires_grad = True
22
+ self.X.append(x)
23
+ else:
24
+ self.X = input[0].detach()
25
+ self.X.requires_grad = True
26
+
27
+ self.Y = output
28
+
29
+
30
+ def backward_hook(self, grad_input, grad_output):
31
+ self.grad_input = grad_input
32
+ self.grad_output = grad_output
33
+
34
+
35
+ class RelProp(nn.Module):
36
+ def __init__(self):
37
+ super(RelProp, self).__init__()
38
+ # if not self.training:
39
+ self.register_forward_hook(forward_hook)
40
+
41
+ def gradprop(self, Z, X, S):
42
+ C = torch.autograd.grad(Z, X, S, retain_graph=True)
43
+ return C
44
+
45
+ def relprop(self, R, alpha):
46
+ return R
47
+
48
+
49
+ class RelPropSimple(RelProp):
50
+ def relprop(self, R, alpha):
51
+ Z = self.forward(self.X)
52
+ S = safe_divide(R, Z)
53
+ C = self.gradprop(Z, self.X, S)
54
+
55
+ if torch.is_tensor(self.X) == False:
56
+ outputs = []
57
+ outputs.append(self.X[0] * C[0])
58
+ outputs.append(self.X[1] * C[1])
59
+ else:
60
+ outputs = self.X * (C[0])
61
+ return outputs
62
+
63
+ class AddEye(RelPropSimple):
64
+ # input of shape B, C, seq_len, seq_len
65
+ def forward(self, input):
66
+ return input + torch.eye(input.shape[2]).expand_as(input).to(input.device)
67
+
68
+ class ReLU(nn.ReLU, RelProp):
69
+ pass
70
+
71
+ class GELU(nn.GELU, RelProp):
72
+ pass
73
+
74
+ class Softmax(nn.Softmax, RelProp):
75
+ pass
76
+
77
+ class LayerNorm(nn.LayerNorm, RelProp):
78
+ pass
79
+
80
+ class Dropout(nn.Dropout, RelProp):
81
+ pass
82
+
83
+
84
+ class MaxPool2d(nn.MaxPool2d, RelPropSimple):
85
+ pass
86
+
87
+ class LayerNorm(nn.LayerNorm, RelProp):
88
+ pass
89
+
90
+ class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, RelPropSimple):
91
+ pass
92
+
93
+
94
+ class AvgPool2d(nn.AvgPool2d, RelPropSimple):
95
+ pass
96
+
97
+
98
+ class Add(RelPropSimple):
99
+ def forward(self, inputs):
100
+ return torch.add(*inputs)
101
+
102
+ class einsum(RelPropSimple):
103
+ def __init__(self, equation):
104
+ super().__init__()
105
+ self.equation = equation
106
+ def forward(self, *operands):
107
+ return torch.einsum(self.equation, *operands)
108
+
109
+ class IndexSelect(RelProp):
110
+ def forward(self, inputs, dim, indices):
111
+ self.__setattr__('dim', dim)
112
+ self.__setattr__('indices', indices)
113
+
114
+ return torch.index_select(inputs, dim, indices)
115
+
116
+ def relprop(self, R, alpha):
117
+ Z = self.forward(self.X, self.dim, self.indices)
118
+ S = safe_divide(R, Z)
119
+ C = self.gradprop(Z, self.X, S)
120
+
121
+ if torch.is_tensor(self.X) == False:
122
+ outputs = []
123
+ outputs.append(self.X[0] * C[0])
124
+ outputs.append(self.X[1] * C[1])
125
+ else:
126
+ outputs = self.X * (C[0])
127
+ return outputs
128
+
129
+
130
+
131
+ class Clone(RelProp):
132
+ def forward(self, input, num):
133
+ self.__setattr__('num', num)
134
+ outputs = []
135
+ for _ in range(num):
136
+ outputs.append(input)
137
+
138
+ return outputs
139
+
140
+ def relprop(self, R, alpha):
141
+ Z = []
142
+ for _ in range(self.num):
143
+ Z.append(self.X)
144
+ S = [safe_divide(r, z) for r, z in zip(R, Z)]
145
+ C = self.gradprop(Z, self.X, S)[0]
146
+
147
+ R = self.X * C
148
+
149
+ return R
150
+
151
+ class Cat(RelProp):
152
+ def forward(self, inputs, dim):
153
+ self.__setattr__('dim', dim)
154
+ return torch.cat(inputs, dim)
155
+
156
+ def relprop(self, R, alpha):
157
+ Z = self.forward(self.X, self.dim)
158
+ S = safe_divide(R, Z)
159
+ C = self.gradprop(Z, self.X, S)
160
+
161
+ outputs = []
162
+ for x, c in zip(self.X, C):
163
+ outputs.append(x * c)
164
+
165
+ return outputs
166
+
167
+
168
+ class Sequential(nn.Sequential):
169
+ def relprop(self, R, alpha):
170
+ for m in reversed(self._modules.values()):
171
+ R = m.relprop(R, alpha)
172
+ return R
173
+
174
+
175
+ class BatchNorm2d(nn.BatchNorm2d, RelProp):
176
+ def relprop(self, R, alpha):
177
+ X = self.X
178
+ beta = 1 - alpha
179
+ weight = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / (
180
+ (self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.eps).pow(0.5))
181
+ Z = X * weight + 1e-9
182
+ S = R / Z
183
+ Ca = S * weight
184
+ R = self.X * (Ca)
185
+ return R
186
+
187
+
188
+ class Linear(nn.Linear, RelProp):
189
+ def relprop(self, R, alpha):
190
+ beta = alpha - 1
191
+ pw = torch.clamp(self.weight, min=0)
192
+ nw = torch.clamp(self.weight, max=0)
193
+ px = torch.clamp(self.X, min=0)
194
+ nx = torch.clamp(self.X, max=0)
195
+
196
+ def f(w1, w2, x1, x2):
197
+ Z1 = F.linear(x1, w1)
198
+ Z2 = F.linear(x2, w2)
199
+ S1 = safe_divide(R, Z1)
200
+ S2 = safe_divide(R, Z2)
201
+ C1 = x1 * torch.autograd.grad(Z1, x1, S1)[0]
202
+ C2 = x2 * torch.autograd.grad(Z2, x2, S2)[0]
203
+
204
+ return C1 + C2
205
+
206
+ activator_relevances = f(pw, nw, px, nx)
207
+ inhibitor_relevances = f(nw, pw, px, nx)
208
+
209
+ R = alpha * activator_relevances - beta * inhibitor_relevances
210
+
211
+ return R
212
+
213
+
214
+ class Conv2d(nn.Conv2d, RelProp):
215
+ def gradprop2(self, DY, weight):
216
+ Z = self.forward(self.X)
217
+
218
+ output_padding = self.X.size()[2] - (
219
+ (Z.size()[2] - 1) * self.stride[0] - 2 * self.padding[0] + self.kernel_size[0])
220
+
221
+ return F.conv_transpose2d(DY, weight, stride=self.stride, padding=self.padding, output_padding=output_padding)
222
+
223
+ def relprop(self, R, alpha):
224
+ if self.X.shape[1] == 3:
225
+ pw = torch.clamp(self.weight, min=0)
226
+ nw = torch.clamp(self.weight, max=0)
227
+ X = self.X
228
+ L = self.X * 0 + \
229
+ torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
230
+ keepdim=True)[0]
231
+ H = self.X * 0 + \
232
+ torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
233
+ keepdim=True)[0]
234
+ Za = torch.conv2d(X, self.weight, bias=None, stride=self.stride, padding=self.padding) - \
235
+ torch.conv2d(L, pw, bias=None, stride=self.stride, padding=self.padding) - \
236
+ torch.conv2d(H, nw, bias=None, stride=self.stride, padding=self.padding) + 1e-9
237
+
238
+ S = R / Za
239
+ C = X * self.gradprop2(S, self.weight) - L * self.gradprop2(S, pw) - H * self.gradprop2(S, nw)
240
+ R = C
241
+ else:
242
+ beta = alpha - 1
243
+ pw = torch.clamp(self.weight, min=0)
244
+ nw = torch.clamp(self.weight, max=0)
245
+ px = torch.clamp(self.X, min=0)
246
+ nx = torch.clamp(self.X, max=0)
247
+
248
+ def f(w1, w2, x1, x2):
249
+ Z1 = F.conv2d(x1, w1, bias=None, stride=self.stride, padding=self.padding)
250
+ Z2 = F.conv2d(x2, w2, bias=None, stride=self.stride, padding=self.padding)
251
+ S1 = safe_divide(R, Z1)
252
+ S2 = safe_divide(R, Z2)
253
+ C1 = x1 * self.gradprop(Z1, x1, S1)[0]
254
+ C2 = x2 * self.gradprop(Z2, x2, S2)[0]
255
+ return C1 + C2
256
+
257
+ activator_relevances = f(pw, nw, px, nx)
258
+ inhibitor_relevances = f(nw, pw, px, nx)
259
+
260
+ R = alpha * activator_relevances - beta * inhibitor_relevances
261
+ return R
ViT_DeiT/modules/layers_ours.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ __all__ = ['forward_hook', 'Clone', 'Add', 'Cat', 'ReLU', 'GELU', 'Dropout', 'BatchNorm2d', 'Linear', 'MaxPool2d',
6
+ 'AdaptiveAvgPool2d', 'AvgPool2d', 'Conv2d', 'Sequential', 'safe_divide', 'einsum', 'Softmax', 'IndexSelect',
7
+ 'LayerNorm', 'AddEye']
8
+
9
+
10
+ def safe_divide(a, b):
11
+ den = b.clamp(min=1e-9) + b.clamp(max=1e-9)
12
+ den = den + den.eq(0).type(den.type()) * 1e-9
13
+ return a / den * b.ne(0).type(b.type())
14
+
15
+
16
+ def forward_hook(self, input, output):
17
+ if type(input[0]) in (list, tuple):
18
+ self.X = []
19
+ for i in input[0]:
20
+ x = i.detach()
21
+ x.requires_grad = True
22
+ self.X.append(x)
23
+ else:
24
+ self.X = input[0].detach()
25
+ self.X.requires_grad = True
26
+
27
+ self.Y = output
28
+
29
+
30
+ def backward_hook(self, grad_input, grad_output):
31
+ self.grad_input = grad_input
32
+ self.grad_output = grad_output
33
+
34
+
35
+ class RelProp(nn.Module):
36
+ def __init__(self):
37
+ super(RelProp, self).__init__()
38
+ # if not self.training:
39
+ self.register_forward_hook(forward_hook)
40
+
41
+ def gradprop(self, Z, X, S):
42
+ C = torch.autograd.grad(Z, X, S, retain_graph=True)
43
+ return C
44
+
45
+ def relprop(self, R, alpha):
46
+ return R
47
+
48
+ class RelPropSimple(RelProp):
49
+ def relprop(self, R, alpha):
50
+ Z = self.forward(self.X)
51
+ S = safe_divide(R, Z)
52
+ C = self.gradprop(Z, self.X, S)
53
+
54
+ if torch.is_tensor(self.X) == False:
55
+ outputs = []
56
+ outputs.append(self.X[0] * C[0])
57
+ outputs.append(self.X[1] * C[1])
58
+ else:
59
+ outputs = self.X * (C[0])
60
+ return outputs
61
+
62
+ class AddEye(RelPropSimple):
63
+ # input of shape B, C, seq_len, seq_len
64
+ def forward(self, input):
65
+ return input + torch.eye(input.shape[2]).expand_as(input).to(input.device)
66
+
67
+ class ReLU(nn.ReLU, RelProp):
68
+ pass
69
+
70
+ class GELU(nn.GELU, RelProp):
71
+ pass
72
+
73
+ class Softmax(nn.Softmax, RelProp):
74
+ pass
75
+
76
+ class LayerNorm(nn.LayerNorm, RelProp):
77
+ pass
78
+
79
+ class Dropout(nn.Dropout, RelProp):
80
+ pass
81
+
82
+
83
+ class MaxPool2d(nn.MaxPool2d, RelPropSimple):
84
+ pass
85
+
86
+ class LayerNorm(nn.LayerNorm, RelProp):
87
+ pass
88
+
89
+ class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, RelPropSimple):
90
+ pass
91
+
92
+
93
+ class AvgPool2d(nn.AvgPool2d, RelPropSimple):
94
+ pass
95
+
96
+
97
+ class Add(RelPropSimple):
98
+ def forward(self, inputs):
99
+ return torch.add(*inputs)
100
+
101
+ def relprop(self, R, alpha):
102
+ Z = self.forward(self.X)
103
+ S = safe_divide(R, Z)
104
+ C = self.gradprop(Z, self.X, S)
105
+
106
+ a = self.X[0] * C[0]
107
+ b = self.X[1] * C[1]
108
+
109
+ a_sum = a.sum()
110
+ b_sum = b.sum()
111
+
112
+ a_fact = safe_divide(a_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
113
+ b_fact = safe_divide(b_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
114
+
115
+ a = a * safe_divide(a_fact, a.sum())
116
+ b = b * safe_divide(b_fact, b.sum())
117
+
118
+ outputs = [a, b]
119
+
120
+ return outputs
121
+
122
+ class einsum(RelPropSimple):
123
+ def __init__(self, equation):
124
+ super().__init__()
125
+ self.equation = equation
126
+ def forward(self, *operands):
127
+ return torch.einsum(self.equation, *operands)
128
+
129
+ class IndexSelect(RelProp):
130
+ def forward(self, inputs, dim, indices):
131
+ self.__setattr__('dim', dim)
132
+ self.__setattr__('indices', indices)
133
+
134
+ return torch.index_select(inputs, dim, indices)
135
+
136
+ def relprop(self, R, alpha):
137
+ Z = self.forward(self.X, self.dim, self.indices)
138
+ S = safe_divide(R, Z)
139
+ C = self.gradprop(Z, self.X, S)
140
+
141
+ if torch.is_tensor(self.X) == False:
142
+ outputs = []
143
+ outputs.append(self.X[0] * C[0])
144
+ outputs.append(self.X[1] * C[1])
145
+ else:
146
+ outputs = self.X * (C[0])
147
+ return outputs
148
+
149
+
150
+
151
+ class Clone(RelProp):
152
+ def forward(self, input, num):
153
+ self.__setattr__('num', num)
154
+ outputs = []
155
+ for _ in range(num):
156
+ outputs.append(input)
157
+
158
+ return outputs
159
+
160
+ def relprop(self, R, alpha):
161
+ Z = []
162
+ for _ in range(self.num):
163
+ Z.append(self.X)
164
+ S = [safe_divide(r, z) for r, z in zip(R, Z)]
165
+ C = self.gradprop(Z, self.X, S)[0]
166
+
167
+ R = self.X * C
168
+
169
+ return R
170
+
171
+ class Cat(RelProp):
172
+ def forward(self, inputs, dim):
173
+ self.__setattr__('dim', dim)
174
+ return torch.cat(inputs, dim)
175
+
176
+ def relprop(self, R, alpha):
177
+ Z = self.forward(self.X, self.dim)
178
+ S = safe_divide(R, Z)
179
+ C = self.gradprop(Z, self.X, S)
180
+
181
+ outputs = []
182
+ for x, c in zip(self.X, C):
183
+ outputs.append(x * c)
184
+
185
+ return outputs
186
+
187
+
188
+ class Sequential(nn.Sequential):
189
+ def relprop(self, R, alpha):
190
+ for m in reversed(self._modules.values()):
191
+ R = m.relprop(R, alpha)
192
+ return R
193
+
194
+ class BatchNorm2d(nn.BatchNorm2d, RelProp):
195
+ def relprop(self, R, alpha):
196
+ X = self.X
197
+ beta = 1 - alpha
198
+ weight = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / (
199
+ (self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.eps).pow(0.5))
200
+ Z = X * weight + 1e-9
201
+ S = R / Z
202
+ Ca = S * weight
203
+ R = self.X * (Ca)
204
+ return R
205
+
206
+
207
+ class Linear(nn.Linear, RelProp):
208
+ def relprop(self, R, alpha):
209
+ beta = alpha - 1
210
+ pw = torch.clamp(self.weight, min=0)
211
+ nw = torch.clamp(self.weight, max=0)
212
+ px = torch.clamp(self.X, min=0)
213
+ nx = torch.clamp(self.X, max=0)
214
+
215
+ def f(w1, w2, x1, x2):
216
+ Z1 = F.linear(x1, w1)
217
+ Z2 = F.linear(x2, w2)
218
+ S1 = safe_divide(R, Z1 + Z2)
219
+ S2 = safe_divide(R, Z1 + Z2)
220
+ C1 = x1 * torch.autograd.grad(Z1, x1, S1)[0]
221
+ C2 = x2 * torch.autograd.grad(Z2, x2, S2)[0]
222
+
223
+ return C1 + C2
224
+
225
+ activator_relevances = f(pw, nw, px, nx)
226
+ inhibitor_relevances = f(nw, pw, px, nx)
227
+
228
+ R = alpha * activator_relevances - beta * inhibitor_relevances
229
+
230
+ return R
231
+
232
+
233
+ class Conv2d(nn.Conv2d, RelProp):
234
+ def gradprop2(self, DY, weight):
235
+ Z = self.forward(self.X)
236
+
237
+ output_padding = self.X.size()[2] - (
238
+ (Z.size()[2] - 1) * self.stride[0] - 2 * self.padding[0] + self.kernel_size[0])
239
+
240
+ return F.conv_transpose2d(DY, weight, stride=self.stride, padding=self.padding, output_padding=output_padding)
241
+
242
+ def relprop(self, R, alpha):
243
+ if self.X.shape[1] == 3:
244
+ pw = torch.clamp(self.weight, min=0)
245
+ nw = torch.clamp(self.weight, max=0)
246
+ X = self.X
247
+ L = self.X * 0 + \
248
+ torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
249
+ keepdim=True)[0]
250
+ H = self.X * 0 + \
251
+ torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
252
+ keepdim=True)[0]
253
+ Za = torch.conv2d(X, self.weight, bias=None, stride=self.stride, padding=self.padding) - \
254
+ torch.conv2d(L, pw, bias=None, stride=self.stride, padding=self.padding) - \
255
+ torch.conv2d(H, nw, bias=None, stride=self.stride, padding=self.padding) + 1e-9
256
+
257
+ S = R / Za
258
+ C = X * self.gradprop2(S, self.weight) - L * self.gradprop2(S, pw) - H * self.gradprop2(S, nw)
259
+ R = C
260
+ else:
261
+ beta = alpha - 1
262
+ pw = torch.clamp(self.weight, min=0)
263
+ nw = torch.clamp(self.weight, max=0)
264
+ px = torch.clamp(self.X, min=0)
265
+ nx = torch.clamp(self.X, max=0)
266
+
267
+ def f(w1, w2, x1, x2):
268
+ Z1 = F.conv2d(x1, w1, bias=None, stride=self.stride, padding=self.padding)
269
+ Z2 = F.conv2d(x2, w2, bias=None, stride=self.stride, padding=self.padding)
270
+ S1 = safe_divide(R, Z1)
271
+ S2 = safe_divide(R, Z2)
272
+ C1 = x1 * self.gradprop(Z1, x1, S1)[0]
273
+ C2 = x2 * self.gradprop(Z2, x2, S2)[0]
274
+ return C1 + C2
275
+
276
+ activator_relevances = f(pw, nw, px, nx)
277
+ inhibitor_relevances = f(nw, pw, px, nx)
278
+
279
+ R = alpha * activator_relevances - beta * inhibitor_relevances
280
+ return R
ViT_DeiT/requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Pillow>=8.1.1
2
+ einops == 0.3.0
3
+ h5py == 2.8.0
4
+ imageio == 2.9.0
5
+ matplotlib == 3.3.2
6
+ opencv_python
7
+ scikit_image == 0.17.2
8
+ scipy == 1.5.2
9
+ sklearn
10
+ torch == 1.7.0
11
+ torchvision == 0.8.1
12
+ tqdm == 4.51.0
13
+ transformers == 3.5.1
14
+ utils == 1.0.1
15
+ Pygments>=2.7.4
ViT_DeiT/samples/CLS2IDX.py ADDED
@@ -0,0 +1,1000 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CLS2IDX = {0: 'tench, Tinca tinca',
2
+ 1: 'goldfish, Carassius auratus',
3
+ 2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias',
4
+ 3: 'tiger shark, Galeocerdo cuvieri',
5
+ 4: 'hammerhead, hammerhead shark',
6
+ 5: 'electric ray, crampfish, numbfish, torpedo',
7
+ 6: 'stingray',
8
+ 7: 'cock',
9
+ 8: 'hen',
10
+ 9: 'ostrich, Struthio camelus',
11
+ 10: 'brambling, Fringilla montifringilla',
12
+ 11: 'goldfinch, Carduelis carduelis',
13
+ 12: 'house finch, linnet, Carpodacus mexicanus',
14
+ 13: 'junco, snowbird',
15
+ 14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea',
16
+ 15: 'robin, American robin, Turdus migratorius',
17
+ 16: 'bulbul',
18
+ 17: 'jay',
19
+ 18: 'magpie',
20
+ 19: 'chickadee',
21
+ 20: 'water ouzel, dipper',
22
+ 21: 'kite',
23
+ 22: 'bald eagle, American eagle, Haliaeetus leucocephalus',
24
+ 23: 'vulture',
25
+ 24: 'great grey owl, great gray owl, Strix nebulosa',
26
+ 25: 'European fire salamander, Salamandra salamandra',
27
+ 26: 'common newt, Triturus vulgaris',
28
+ 27: 'eft',
29
+ 28: 'spotted salamander, Ambystoma maculatum',
30
+ 29: 'axolotl, mud puppy, Ambystoma mexicanum',
31
+ 30: 'bullfrog, Rana catesbeiana',
32
+ 31: 'tree frog, tree-frog',
33
+ 32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui',
34
+ 33: 'loggerhead, loggerhead turtle, Caretta caretta',
35
+ 34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea',
36
+ 35: 'mud turtle',
37
+ 36: 'terrapin',
38
+ 37: 'box turtle, box tortoise',
39
+ 38: 'banded gecko',
40
+ 39: 'common iguana, iguana, Iguana iguana',
41
+ 40: 'American chameleon, anole, Anolis carolinensis',
42
+ 41: 'whiptail, whiptail lizard',
43
+ 42: 'agama',
44
+ 43: 'frilled lizard, Chlamydosaurus kingi',
45
+ 44: 'alligator lizard',
46
+ 45: 'Gila monster, Heloderma suspectum',
47
+ 46: 'green lizard, Lacerta viridis',
48
+ 47: 'African chameleon, Chamaeleo chamaeleon',
49
+ 48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis',
50
+ 49: 'African crocodile, Nile crocodile, Crocodylus niloticus',
51
+ 50: 'American alligator, Alligator mississipiensis',
52
+ 51: 'triceratops',
53
+ 52: 'thunder snake, worm snake, Carphophis amoenus',
54
+ 53: 'ringneck snake, ring-necked snake, ring snake',
55
+ 54: 'hognose snake, puff adder, sand viper',
56
+ 55: 'green snake, grass snake',
57
+ 56: 'king snake, kingsnake',
58
+ 57: 'garter snake, grass snake',
59
+ 58: 'water snake',
60
+ 59: 'vine snake',
61
+ 60: 'night snake, Hypsiglena torquata',
62
+ 61: 'boa constrictor, Constrictor constrictor',
63
+ 62: 'rock python, rock snake, Python sebae',
64
+ 63: 'Indian cobra, Naja naja',
65
+ 64: 'green mamba',
66
+ 65: 'sea snake',
67
+ 66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus',
68
+ 67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus',
69
+ 68: 'sidewinder, horned rattlesnake, Crotalus cerastes',
70
+ 69: 'trilobite',
71
+ 70: 'harvestman, daddy longlegs, Phalangium opilio',
72
+ 71: 'scorpion',
73
+ 72: 'black and gold garden spider, Argiope aurantia',
74
+ 73: 'barn spider, Araneus cavaticus',
75
+ 74: 'garden spider, Aranea diademata',
76
+ 75: 'black widow, Latrodectus mactans',
77
+ 76: 'tarantula',
78
+ 77: 'wolf spider, hunting spider',
79
+ 78: 'tick',
80
+ 79: 'centipede',
81
+ 80: 'black grouse',
82
+ 81: 'ptarmigan',
83
+ 82: 'ruffed grouse, partridge, Bonasa umbellus',
84
+ 83: 'prairie chicken, prairie grouse, prairie fowl',
85
+ 84: 'peacock',
86
+ 85: 'quail',
87
+ 86: 'partridge',
88
+ 87: 'African grey, African gray, Psittacus erithacus',
89
+ 88: 'macaw',
90
+ 89: 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita',
91
+ 90: 'lorikeet',
92
+ 91: 'coucal',
93
+ 92: 'bee eater',
94
+ 93: 'hornbill',
95
+ 94: 'hummingbird',
96
+ 95: 'jacamar',
97
+ 96: 'toucan',
98
+ 97: 'drake',
99
+ 98: 'red-breasted merganser, Mergus serrator',
100
+ 99: 'goose',
101
+ 100: 'black swan, Cygnus atratus',
102
+ 101: 'tusker',
103
+ 102: 'echidna, spiny anteater, anteater',
104
+ 103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus',
105
+ 104: 'wallaby, brush kangaroo',
106
+ 105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus',
107
+ 106: 'wombat',
108
+ 107: 'jellyfish',
109
+ 108: 'sea anemone, anemone',
110
+ 109: 'brain coral',
111
+ 110: 'flatworm, platyhelminth',
112
+ 111: 'nematode, nematode worm, roundworm',
113
+ 112: 'conch',
114
+ 113: 'snail',
115
+ 114: 'slug',
116
+ 115: 'sea slug, nudibranch',
117
+ 116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore',
118
+ 117: 'chambered nautilus, pearly nautilus, nautilus',
119
+ 118: 'Dungeness crab, Cancer magister',
120
+ 119: 'rock crab, Cancer irroratus',
121
+ 120: 'fiddler crab',
122
+ 121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica',
123
+ 122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus',
124
+ 123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish',
125
+ 124: 'crayfish, crawfish, crawdad, crawdaddy',
126
+ 125: 'hermit crab',
127
+ 126: 'isopod',
128
+ 127: 'white stork, Ciconia ciconia',
129
+ 128: 'black stork, Ciconia nigra',
130
+ 129: 'spoonbill',
131
+ 130: 'flamingo',
132
+ 131: 'little blue heron, Egretta caerulea',
133
+ 132: 'American egret, great white heron, Egretta albus',
134
+ 133: 'bittern',
135
+ 134: 'crane',
136
+ 135: 'limpkin, Aramus pictus',
137
+ 136: 'European gallinule, Porphyrio porphyrio',
138
+ 137: 'American coot, marsh hen, mud hen, water hen, Fulica americana',
139
+ 138: 'bustard',
140
+ 139: 'ruddy turnstone, Arenaria interpres',
141
+ 140: 'red-backed sandpiper, dunlin, Erolia alpina',
142
+ 141: 'redshank, Tringa totanus',
143
+ 142: 'dowitcher',
144
+ 143: 'oystercatcher, oyster catcher',
145
+ 144: 'pelican',
146
+ 145: 'king penguin, Aptenodytes patagonica',
147
+ 146: 'albatross, mollymawk',
148
+ 147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus',
149
+ 148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca',
150
+ 149: 'dugong, Dugong dugon',
151
+ 150: 'sea lion',
152
+ 151: 'Chihuahua',
153
+ 152: 'Japanese spaniel',
154
+ 153: 'Maltese dog, Maltese terrier, Maltese',
155
+ 154: 'Pekinese, Pekingese, Peke',
156
+ 155: 'Shih-Tzu',
157
+ 156: 'Blenheim spaniel',
158
+ 157: 'papillon',
159
+ 158: 'toy terrier',
160
+ 159: 'Rhodesian ridgeback',
161
+ 160: 'Afghan hound, Afghan',
162
+ 161: 'basset, basset hound',
163
+ 162: 'beagle',
164
+ 163: 'bloodhound, sleuthhound',
165
+ 164: 'bluetick',
166
+ 165: 'black-and-tan coonhound',
167
+ 166: 'Walker hound, Walker foxhound',
168
+ 167: 'English foxhound',
169
+ 168: 'redbone',
170
+ 169: 'borzoi, Russian wolfhound',
171
+ 170: 'Irish wolfhound',
172
+ 171: 'Italian greyhound',
173
+ 172: 'whippet',
174
+ 173: 'Ibizan hound, Ibizan Podenco',
175
+ 174: 'Norwegian elkhound, elkhound',
176
+ 175: 'otterhound, otter hound',
177
+ 176: 'Saluki, gazelle hound',
178
+ 177: 'Scottish deerhound, deerhound',
179
+ 178: 'Weimaraner',
180
+ 179: 'Staffordshire bullterrier, Staffordshire bull terrier',
181
+ 180: 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier',
182
+ 181: 'Bedlington terrier',
183
+ 182: 'Border terrier',
184
+ 183: 'Kerry blue terrier',
185
+ 184: 'Irish terrier',
186
+ 185: 'Norfolk terrier',
187
+ 186: 'Norwich terrier',
188
+ 187: 'Yorkshire terrier',
189
+ 188: 'wire-haired fox terrier',
190
+ 189: 'Lakeland terrier',
191
+ 190: 'Sealyham terrier, Sealyham',
192
+ 191: 'Airedale, Airedale terrier',
193
+ 192: 'cairn, cairn terrier',
194
+ 193: 'Australian terrier',
195
+ 194: 'Dandie Dinmont, Dandie Dinmont terrier',
196
+ 195: 'Boston bull, Boston terrier',
197
+ 196: 'miniature schnauzer',
198
+ 197: 'giant schnauzer',
199
+ 198: 'standard schnauzer',
200
+ 199: 'Scotch terrier, Scottish terrier, Scottie',
201
+ 200: 'Tibetan terrier, chrysanthemum dog',
202
+ 201: 'silky terrier, Sydney silky',
203
+ 202: 'soft-coated wheaten terrier',
204
+ 203: 'West Highland white terrier',
205
+ 204: 'Lhasa, Lhasa apso',
206
+ 205: 'flat-coated retriever',
207
+ 206: 'curly-coated retriever',
208
+ 207: 'golden retriever',
209
+ 208: 'Labrador retriever',
210
+ 209: 'Chesapeake Bay retriever',
211
+ 210: 'German short-haired pointer',
212
+ 211: 'vizsla, Hungarian pointer',
213
+ 212: 'English setter',
214
+ 213: 'Irish setter, red setter',
215
+ 214: 'Gordon setter',
216
+ 215: 'Brittany spaniel',
217
+ 216: 'clumber, clumber spaniel',
218
+ 217: 'English springer, English springer spaniel',
219
+ 218: 'Welsh springer spaniel',
220
+ 219: 'cocker spaniel, English cocker spaniel, cocker',
221
+ 220: 'Sussex spaniel',
222
+ 221: 'Irish water spaniel',
223
+ 222: 'kuvasz',
224
+ 223: 'schipperke',
225
+ 224: 'groenendael',
226
+ 225: 'malinois',
227
+ 226: 'briard',
228
+ 227: 'kelpie',
229
+ 228: 'komondor',
230
+ 229: 'Old English sheepdog, bobtail',
231
+ 230: 'Shetland sheepdog, Shetland sheep dog, Shetland',
232
+ 231: 'collie',
233
+ 232: 'Border collie',
234
+ 233: 'Bouvier des Flandres, Bouviers des Flandres',
235
+ 234: 'Rottweiler',
236
+ 235: 'German shepherd, German shepherd dog, German police dog, alsatian',
237
+ 236: 'Doberman, Doberman pinscher',
238
+ 237: 'miniature pinscher',
239
+ 238: 'Greater Swiss Mountain dog',
240
+ 239: 'Bernese mountain dog',
241
+ 240: 'Appenzeller',
242
+ 241: 'EntleBucher',
243
+ 242: 'boxer',
244
+ 243: 'bull mastiff',
245
+ 244: 'Tibetan mastiff',
246
+ 245: 'French bulldog',
247
+ 246: 'Great Dane',
248
+ 247: 'Saint Bernard, St Bernard',
249
+ 248: 'Eskimo dog, husky',
250
+ 249: 'malamute, malemute, Alaskan malamute',
251
+ 250: 'Siberian husky',
252
+ 251: 'dalmatian, coach dog, carriage dog',
253
+ 252: 'affenpinscher, monkey pinscher, monkey dog',
254
+ 253: 'basenji',
255
+ 254: 'pug, pug-dog',
256
+ 255: 'Leonberg',
257
+ 256: 'Newfoundland, Newfoundland dog',
258
+ 257: 'Great Pyrenees',
259
+ 258: 'Samoyed, Samoyede',
260
+ 259: 'Pomeranian',
261
+ 260: 'chow, chow chow',
262
+ 261: 'keeshond',
263
+ 262: 'Brabancon griffon',
264
+ 263: 'Pembroke, Pembroke Welsh corgi',
265
+ 264: 'Cardigan, Cardigan Welsh corgi',
266
+ 265: 'toy poodle',
267
+ 266: 'miniature poodle',
268
+ 267: 'standard poodle',
269
+ 268: 'Mexican hairless',
270
+ 269: 'timber wolf, grey wolf, gray wolf, Canis lupus',
271
+ 270: 'white wolf, Arctic wolf, Canis lupus tundrarum',
272
+ 271: 'red wolf, maned wolf, Canis rufus, Canis niger',
273
+ 272: 'coyote, prairie wolf, brush wolf, Canis latrans',
274
+ 273: 'dingo, warrigal, warragal, Canis dingo',
275
+ 274: 'dhole, Cuon alpinus',
276
+ 275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus',
277
+ 276: 'hyena, hyaena',
278
+ 277: 'red fox, Vulpes vulpes',
279
+ 278: 'kit fox, Vulpes macrotis',
280
+ 279: 'Arctic fox, white fox, Alopex lagopus',
281
+ 280: 'grey fox, gray fox, Urocyon cinereoargenteus',
282
+ 281: 'tabby, tabby cat',
283
+ 282: 'tiger cat',
284
+ 283: 'Persian cat',
285
+ 284: 'Siamese cat, Siamese',
286
+ 285: 'Egyptian cat',
287
+ 286: 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor',
288
+ 287: 'lynx, catamount',
289
+ 288: 'leopard, Panthera pardus',
290
+ 289: 'snow leopard, ounce, Panthera uncia',
291
+ 290: 'jaguar, panther, Panthera onca, Felis onca',
292
+ 291: 'lion, king of beasts, Panthera leo',
293
+ 292: 'tiger, Panthera tigris',
294
+ 293: 'cheetah, chetah, Acinonyx jubatus',
295
+ 294: 'brown bear, bruin, Ursus arctos',
296
+ 295: 'American black bear, black bear, Ursus americanus, Euarctos americanus',
297
+ 296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus',
298
+ 297: 'sloth bear, Melursus ursinus, Ursus ursinus',
299
+ 298: 'mongoose',
300
+ 299: 'meerkat, mierkat',
301
+ 300: 'tiger beetle',
302
+ 301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle',
303
+ 302: 'ground beetle, carabid beetle',
304
+ 303: 'long-horned beetle, longicorn, longicorn beetle',
305
+ 304: 'leaf beetle, chrysomelid',
306
+ 305: 'dung beetle',
307
+ 306: 'rhinoceros beetle',
308
+ 307: 'weevil',
309
+ 308: 'fly',
310
+ 309: 'bee',
311
+ 310: 'ant, emmet, pismire',
312
+ 311: 'grasshopper, hopper',
313
+ 312: 'cricket',
314
+ 313: 'walking stick, walkingstick, stick insect',
315
+ 314: 'cockroach, roach',
316
+ 315: 'mantis, mantid',
317
+ 316: 'cicada, cicala',
318
+ 317: 'leafhopper',
319
+ 318: 'lacewing, lacewing fly',
320
+ 319: "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
321
+ 320: 'damselfly',
322
+ 321: 'admiral',
323
+ 322: 'ringlet, ringlet butterfly',
324
+ 323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus',
325
+ 324: 'cabbage butterfly',
326
+ 325: 'sulphur butterfly, sulfur butterfly',
327
+ 326: 'lycaenid, lycaenid butterfly',
328
+ 327: 'starfish, sea star',
329
+ 328: 'sea urchin',
330
+ 329: 'sea cucumber, holothurian',
331
+ 330: 'wood rabbit, cottontail, cottontail rabbit',
332
+ 331: 'hare',
333
+ 332: 'Angora, Angora rabbit',
334
+ 333: 'hamster',
335
+ 334: 'porcupine, hedgehog',
336
+ 335: 'fox squirrel, eastern fox squirrel, Sciurus niger',
337
+ 336: 'marmot',
338
+ 337: 'beaver',
339
+ 338: 'guinea pig, Cavia cobaya',
340
+ 339: 'sorrel',
341
+ 340: 'zebra',
342
+ 341: 'hog, pig, grunter, squealer, Sus scrofa',
343
+ 342: 'wild boar, boar, Sus scrofa',
344
+ 343: 'warthog',
345
+ 344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius',
346
+ 345: 'ox',
347
+ 346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis',
348
+ 347: 'bison',
349
+ 348: 'ram, tup',
350
+ 349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis',
351
+ 350: 'ibex, Capra ibex',
352
+ 351: 'hartebeest',
353
+ 352: 'impala, Aepyceros melampus',
354
+ 353: 'gazelle',
355
+ 354: 'Arabian camel, dromedary, Camelus dromedarius',
356
+ 355: 'llama',
357
+ 356: 'weasel',
358
+ 357: 'mink',
359
+ 358: 'polecat, fitch, foulmart, foumart, Mustela putorius',
360
+ 359: 'black-footed ferret, ferret, Mustela nigripes',
361
+ 360: 'otter',
362
+ 361: 'skunk, polecat, wood pussy',
363
+ 362: 'badger',
364
+ 363: 'armadillo',
365
+ 364: 'three-toed sloth, ai, Bradypus tridactylus',
366
+ 365: 'orangutan, orang, orangutang, Pongo pygmaeus',
367
+ 366: 'gorilla, Gorilla gorilla',
368
+ 367: 'chimpanzee, chimp, Pan troglodytes',
369
+ 368: 'gibbon, Hylobates lar',
370
+ 369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus',
371
+ 370: 'guenon, guenon monkey',
372
+ 371: 'patas, hussar monkey, Erythrocebus patas',
373
+ 372: 'baboon',
374
+ 373: 'macaque',
375
+ 374: 'langur',
376
+ 375: 'colobus, colobus monkey',
377
+ 376: 'proboscis monkey, Nasalis larvatus',
378
+ 377: 'marmoset',
379
+ 378: 'capuchin, ringtail, Cebus capucinus',
380
+ 379: 'howler monkey, howler',
381
+ 380: 'titi, titi monkey',
382
+ 381: 'spider monkey, Ateles geoffroyi',
383
+ 382: 'squirrel monkey, Saimiri sciureus',
384
+ 383: 'Madagascar cat, ring-tailed lemur, Lemur catta',
385
+ 384: 'indri, indris, Indri indri, Indri brevicaudatus',
386
+ 385: 'Indian elephant, Elephas maximus',
387
+ 386: 'African elephant, Loxodonta africana',
388
+ 387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens',
389
+ 388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca',
390
+ 389: 'barracouta, snoek',
391
+ 390: 'eel',
392
+ 391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch',
393
+ 392: 'rock beauty, Holocanthus tricolor',
394
+ 393: 'anemone fish',
395
+ 394: 'sturgeon',
396
+ 395: 'gar, garfish, garpike, billfish, Lepisosteus osseus',
397
+ 396: 'lionfish',
398
+ 397: 'puffer, pufferfish, blowfish, globefish',
399
+ 398: 'abacus',
400
+ 399: 'abaya',
401
+ 400: "academic gown, academic robe, judge's robe",
402
+ 401: 'accordion, piano accordion, squeeze box',
403
+ 402: 'acoustic guitar',
404
+ 403: 'aircraft carrier, carrier, flattop, attack aircraft carrier',
405
+ 404: 'airliner',
406
+ 405: 'airship, dirigible',
407
+ 406: 'altar',
408
+ 407: 'ambulance',
409
+ 408: 'amphibian, amphibious vehicle',
410
+ 409: 'analog clock',
411
+ 410: 'apiary, bee house',
412
+ 411: 'apron',
413
+ 412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin',
414
+ 413: 'assault rifle, assault gun',
415
+ 414: 'backpack, back pack, knapsack, packsack, rucksack, haversack',
416
+ 415: 'bakery, bakeshop, bakehouse',
417
+ 416: 'balance beam, beam',
418
+ 417: 'balloon',
419
+ 418: 'ballpoint, ballpoint pen, ballpen, Biro',
420
+ 419: 'Band Aid',
421
+ 420: 'banjo',
422
+ 421: 'bannister, banister, balustrade, balusters, handrail',
423
+ 422: 'barbell',
424
+ 423: 'barber chair',
425
+ 424: 'barbershop',
426
+ 425: 'barn',
427
+ 426: 'barometer',
428
+ 427: 'barrel, cask',
429
+ 428: 'barrow, garden cart, lawn cart, wheelbarrow',
430
+ 429: 'baseball',
431
+ 430: 'basketball',
432
+ 431: 'bassinet',
433
+ 432: 'bassoon',
434
+ 433: 'bathing cap, swimming cap',
435
+ 434: 'bath towel',
436
+ 435: 'bathtub, bathing tub, bath, tub',
437
+ 436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon',
438
+ 437: 'beacon, lighthouse, beacon light, pharos',
439
+ 438: 'beaker',
440
+ 439: 'bearskin, busby, shako',
441
+ 440: 'beer bottle',
442
+ 441: 'beer glass',
443
+ 442: 'bell cote, bell cot',
444
+ 443: 'bib',
445
+ 444: 'bicycle-built-for-two, tandem bicycle, tandem',
446
+ 445: 'bikini, two-piece',
447
+ 446: 'binder, ring-binder',
448
+ 447: 'binoculars, field glasses, opera glasses',
449
+ 448: 'birdhouse',
450
+ 449: 'boathouse',
451
+ 450: 'bobsled, bobsleigh, bob',
452
+ 451: 'bolo tie, bolo, bola tie, bola',
453
+ 452: 'bonnet, poke bonnet',
454
+ 453: 'bookcase',
455
+ 454: 'bookshop, bookstore, bookstall',
456
+ 455: 'bottlecap',
457
+ 456: 'bow',
458
+ 457: 'bow tie, bow-tie, bowtie',
459
+ 458: 'brass, memorial tablet, plaque',
460
+ 459: 'brassiere, bra, bandeau',
461
+ 460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty',
462
+ 461: 'breastplate, aegis, egis',
463
+ 462: 'broom',
464
+ 463: 'bucket, pail',
465
+ 464: 'buckle',
466
+ 465: 'bulletproof vest',
467
+ 466: 'bullet train, bullet',
468
+ 467: 'butcher shop, meat market',
469
+ 468: 'cab, hack, taxi, taxicab',
470
+ 469: 'caldron, cauldron',
471
+ 470: 'candle, taper, wax light',
472
+ 471: 'cannon',
473
+ 472: 'canoe',
474
+ 473: 'can opener, tin opener',
475
+ 474: 'cardigan',
476
+ 475: 'car mirror',
477
+ 476: 'carousel, carrousel, merry-go-round, roundabout, whirligig',
478
+ 477: "carpenter's kit, tool kit",
479
+ 478: 'carton',
480
+ 479: 'car wheel',
481
+ 480: 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM',
482
+ 481: 'cassette',
483
+ 482: 'cassette player',
484
+ 483: 'castle',
485
+ 484: 'catamaran',
486
+ 485: 'CD player',
487
+ 486: 'cello, violoncello',
488
+ 487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone',
489
+ 488: 'chain',
490
+ 489: 'chainlink fence',
491
+ 490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour',
492
+ 491: 'chain saw, chainsaw',
493
+ 492: 'chest',
494
+ 493: 'chiffonier, commode',
495
+ 494: 'chime, bell, gong',
496
+ 495: 'china cabinet, china closet',
497
+ 496: 'Christmas stocking',
498
+ 497: 'church, church building',
499
+ 498: 'cinema, movie theater, movie theatre, movie house, picture palace',
500
+ 499: 'cleaver, meat cleaver, chopper',
501
+ 500: 'cliff dwelling',
502
+ 501: 'cloak',
503
+ 502: 'clog, geta, patten, sabot',
504
+ 503: 'cocktail shaker',
505
+ 504: 'coffee mug',
506
+ 505: 'coffeepot',
507
+ 506: 'coil, spiral, volute, whorl, helix',
508
+ 507: 'combination lock',
509
+ 508: 'computer keyboard, keypad',
510
+ 509: 'confectionery, confectionary, candy store',
511
+ 510: 'container ship, containership, container vessel',
512
+ 511: 'convertible',
513
+ 512: 'corkscrew, bottle screw',
514
+ 513: 'cornet, horn, trumpet, trump',
515
+ 514: 'cowboy boot',
516
+ 515: 'cowboy hat, ten-gallon hat',
517
+ 516: 'cradle',
518
+ 517: 'crane',
519
+ 518: 'crash helmet',
520
+ 519: 'crate',
521
+ 520: 'crib, cot',
522
+ 521: 'Crock Pot',
523
+ 522: 'croquet ball',
524
+ 523: 'crutch',
525
+ 524: 'cuirass',
526
+ 525: 'dam, dike, dyke',
527
+ 526: 'desk',
528
+ 527: 'desktop computer',
529
+ 528: 'dial telephone, dial phone',
530
+ 529: 'diaper, nappy, napkin',
531
+ 530: 'digital clock',
532
+ 531: 'digital watch',
533
+ 532: 'dining table, board',
534
+ 533: 'dishrag, dishcloth',
535
+ 534: 'dishwasher, dish washer, dishwashing machine',
536
+ 535: 'disk brake, disc brake',
537
+ 536: 'dock, dockage, docking facility',
538
+ 537: 'dogsled, dog sled, dog sleigh',
539
+ 538: 'dome',
540
+ 539: 'doormat, welcome mat',
541
+ 540: 'drilling platform, offshore rig',
542
+ 541: 'drum, membranophone, tympan',
543
+ 542: 'drumstick',
544
+ 543: 'dumbbell',
545
+ 544: 'Dutch oven',
546
+ 545: 'electric fan, blower',
547
+ 546: 'electric guitar',
548
+ 547: 'electric locomotive',
549
+ 548: 'entertainment center',
550
+ 549: 'envelope',
551
+ 550: 'espresso maker',
552
+ 551: 'face powder',
553
+ 552: 'feather boa, boa',
554
+ 553: 'file, file cabinet, filing cabinet',
555
+ 554: 'fireboat',
556
+ 555: 'fire engine, fire truck',
557
+ 556: 'fire screen, fireguard',
558
+ 557: 'flagpole, flagstaff',
559
+ 558: 'flute, transverse flute',
560
+ 559: 'folding chair',
561
+ 560: 'football helmet',
562
+ 561: 'forklift',
563
+ 562: 'fountain',
564
+ 563: 'fountain pen',
565
+ 564: 'four-poster',
566
+ 565: 'freight car',
567
+ 566: 'French horn, horn',
568
+ 567: 'frying pan, frypan, skillet',
569
+ 568: 'fur coat',
570
+ 569: 'garbage truck, dustcart',
571
+ 570: 'gasmask, respirator, gas helmet',
572
+ 571: 'gas pump, gasoline pump, petrol pump, island dispenser',
573
+ 572: 'goblet',
574
+ 573: 'go-kart',
575
+ 574: 'golf ball',
576
+ 575: 'golfcart, golf cart',
577
+ 576: 'gondola',
578
+ 577: 'gong, tam-tam',
579
+ 578: 'gown',
580
+ 579: 'grand piano, grand',
581
+ 580: 'greenhouse, nursery, glasshouse',
582
+ 581: 'grille, radiator grille',
583
+ 582: 'grocery store, grocery, food market, market',
584
+ 583: 'guillotine',
585
+ 584: 'hair slide',
586
+ 585: 'hair spray',
587
+ 586: 'half track',
588
+ 587: 'hammer',
589
+ 588: 'hamper',
590
+ 589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier',
591
+ 590: 'hand-held computer, hand-held microcomputer',
592
+ 591: 'handkerchief, hankie, hanky, hankey',
593
+ 592: 'hard disc, hard disk, fixed disk',
594
+ 593: 'harmonica, mouth organ, harp, mouth harp',
595
+ 594: 'harp',
596
+ 595: 'harvester, reaper',
597
+ 596: 'hatchet',
598
+ 597: 'holster',
599
+ 598: 'home theater, home theatre',
600
+ 599: 'honeycomb',
601
+ 600: 'hook, claw',
602
+ 601: 'hoopskirt, crinoline',
603
+ 602: 'horizontal bar, high bar',
604
+ 603: 'horse cart, horse-cart',
605
+ 604: 'hourglass',
606
+ 605: 'iPod',
607
+ 606: 'iron, smoothing iron',
608
+ 607: "jack-o'-lantern",
609
+ 608: 'jean, blue jean, denim',
610
+ 609: 'jeep, landrover',
611
+ 610: 'jersey, T-shirt, tee shirt',
612
+ 611: 'jigsaw puzzle',
613
+ 612: 'jinrikisha, ricksha, rickshaw',
614
+ 613: 'joystick',
615
+ 614: 'kimono',
616
+ 615: 'knee pad',
617
+ 616: 'knot',
618
+ 617: 'lab coat, laboratory coat',
619
+ 618: 'ladle',
620
+ 619: 'lampshade, lamp shade',
621
+ 620: 'laptop, laptop computer',
622
+ 621: 'lawn mower, mower',
623
+ 622: 'lens cap, lens cover',
624
+ 623: 'letter opener, paper knife, paperknife',
625
+ 624: 'library',
626
+ 625: 'lifeboat',
627
+ 626: 'lighter, light, igniter, ignitor',
628
+ 627: 'limousine, limo',
629
+ 628: 'liner, ocean liner',
630
+ 629: 'lipstick, lip rouge',
631
+ 630: 'Loafer',
632
+ 631: 'lotion',
633
+ 632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system',
634
+ 633: "loupe, jeweler's loupe",
635
+ 634: 'lumbermill, sawmill',
636
+ 635: 'magnetic compass',
637
+ 636: 'mailbag, postbag',
638
+ 637: 'mailbox, letter box',
639
+ 638: 'maillot',
640
+ 639: 'maillot, tank suit',
641
+ 640: 'manhole cover',
642
+ 641: 'maraca',
643
+ 642: 'marimba, xylophone',
644
+ 643: 'mask',
645
+ 644: 'matchstick',
646
+ 645: 'maypole',
647
+ 646: 'maze, labyrinth',
648
+ 647: 'measuring cup',
649
+ 648: 'medicine chest, medicine cabinet',
650
+ 649: 'megalith, megalithic structure',
651
+ 650: 'microphone, mike',
652
+ 651: 'microwave, microwave oven',
653
+ 652: 'military uniform',
654
+ 653: 'milk can',
655
+ 654: 'minibus',
656
+ 655: 'miniskirt, mini',
657
+ 656: 'minivan',
658
+ 657: 'missile',
659
+ 658: 'mitten',
660
+ 659: 'mixing bowl',
661
+ 660: 'mobile home, manufactured home',
662
+ 661: 'Model T',
663
+ 662: 'modem',
664
+ 663: 'monastery',
665
+ 664: 'monitor',
666
+ 665: 'moped',
667
+ 666: 'mortar',
668
+ 667: 'mortarboard',
669
+ 668: 'mosque',
670
+ 669: 'mosquito net',
671
+ 670: 'motor scooter, scooter',
672
+ 671: 'mountain bike, all-terrain bike, off-roader',
673
+ 672: 'mountain tent',
674
+ 673: 'mouse, computer mouse',
675
+ 674: 'mousetrap',
676
+ 675: 'moving van',
677
+ 676: 'muzzle',
678
+ 677: 'nail',
679
+ 678: 'neck brace',
680
+ 679: 'necklace',
681
+ 680: 'nipple',
682
+ 681: 'notebook, notebook computer',
683
+ 682: 'obelisk',
684
+ 683: 'oboe, hautboy, hautbois',
685
+ 684: 'ocarina, sweet potato',
686
+ 685: 'odometer, hodometer, mileometer, milometer',
687
+ 686: 'oil filter',
688
+ 687: 'organ, pipe organ',
689
+ 688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO',
690
+ 689: 'overskirt',
691
+ 690: 'oxcart',
692
+ 691: 'oxygen mask',
693
+ 692: 'packet',
694
+ 693: 'paddle, boat paddle',
695
+ 694: 'paddlewheel, paddle wheel',
696
+ 695: 'padlock',
697
+ 696: 'paintbrush',
698
+ 697: "pajama, pyjama, pj's, jammies",
699
+ 698: 'palace',
700
+ 699: 'panpipe, pandean pipe, syrinx',
701
+ 700: 'paper towel',
702
+ 701: 'parachute, chute',
703
+ 702: 'parallel bars, bars',
704
+ 703: 'park bench',
705
+ 704: 'parking meter',
706
+ 705: 'passenger car, coach, carriage',
707
+ 706: 'patio, terrace',
708
+ 707: 'pay-phone, pay-station',
709
+ 708: 'pedestal, plinth, footstall',
710
+ 709: 'pencil box, pencil case',
711
+ 710: 'pencil sharpener',
712
+ 711: 'perfume, essence',
713
+ 712: 'Petri dish',
714
+ 713: 'photocopier',
715
+ 714: 'pick, plectrum, plectron',
716
+ 715: 'pickelhaube',
717
+ 716: 'picket fence, paling',
718
+ 717: 'pickup, pickup truck',
719
+ 718: 'pier',
720
+ 719: 'piggy bank, penny bank',
721
+ 720: 'pill bottle',
722
+ 721: 'pillow',
723
+ 722: 'ping-pong ball',
724
+ 723: 'pinwheel',
725
+ 724: 'pirate, pirate ship',
726
+ 725: 'pitcher, ewer',
727
+ 726: "plane, carpenter's plane, woodworking plane",
728
+ 727: 'planetarium',
729
+ 728: 'plastic bag',
730
+ 729: 'plate rack',
731
+ 730: 'plow, plough',
732
+ 731: "plunger, plumber's helper",
733
+ 732: 'Polaroid camera, Polaroid Land camera',
734
+ 733: 'pole',
735
+ 734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria',
736
+ 735: 'poncho',
737
+ 736: 'pool table, billiard table, snooker table',
738
+ 737: 'pop bottle, soda bottle',
739
+ 738: 'pot, flowerpot',
740
+ 739: "potter's wheel",
741
+ 740: 'power drill',
742
+ 741: 'prayer rug, prayer mat',
743
+ 742: 'printer',
744
+ 743: 'prison, prison house',
745
+ 744: 'projectile, missile',
746
+ 745: 'projector',
747
+ 746: 'puck, hockey puck',
748
+ 747: 'punching bag, punch bag, punching ball, punchball',
749
+ 748: 'purse',
750
+ 749: 'quill, quill pen',
751
+ 750: 'quilt, comforter, comfort, puff',
752
+ 751: 'racer, race car, racing car',
753
+ 752: 'racket, racquet',
754
+ 753: 'radiator',
755
+ 754: 'radio, wireless',
756
+ 755: 'radio telescope, radio reflector',
757
+ 756: 'rain barrel',
758
+ 757: 'recreational vehicle, RV, R.V.',
759
+ 758: 'reel',
760
+ 759: 'reflex camera',
761
+ 760: 'refrigerator, icebox',
762
+ 761: 'remote control, remote',
763
+ 762: 'restaurant, eating house, eating place, eatery',
764
+ 763: 'revolver, six-gun, six-shooter',
765
+ 764: 'rifle',
766
+ 765: 'rocking chair, rocker',
767
+ 766: 'rotisserie',
768
+ 767: 'rubber eraser, rubber, pencil eraser',
769
+ 768: 'rugby ball',
770
+ 769: 'rule, ruler',
771
+ 770: 'running shoe',
772
+ 771: 'safe',
773
+ 772: 'safety pin',
774
+ 773: 'saltshaker, salt shaker',
775
+ 774: 'sandal',
776
+ 775: 'sarong',
777
+ 776: 'sax, saxophone',
778
+ 777: 'scabbard',
779
+ 778: 'scale, weighing machine',
780
+ 779: 'school bus',
781
+ 780: 'schooner',
782
+ 781: 'scoreboard',
783
+ 782: 'screen, CRT screen',
784
+ 783: 'screw',
785
+ 784: 'screwdriver',
786
+ 785: 'seat belt, seatbelt',
787
+ 786: 'sewing machine',
788
+ 787: 'shield, buckler',
789
+ 788: 'shoe shop, shoe-shop, shoe store',
790
+ 789: 'shoji',
791
+ 790: 'shopping basket',
792
+ 791: 'shopping cart',
793
+ 792: 'shovel',
794
+ 793: 'shower cap',
795
+ 794: 'shower curtain',
796
+ 795: 'ski',
797
+ 796: 'ski mask',
798
+ 797: 'sleeping bag',
799
+ 798: 'slide rule, slipstick',
800
+ 799: 'sliding door',
801
+ 800: 'slot, one-armed bandit',
802
+ 801: 'snorkel',
803
+ 802: 'snowmobile',
804
+ 803: 'snowplow, snowplough',
805
+ 804: 'soap dispenser',
806
+ 805: 'soccer ball',
807
+ 806: 'sock',
808
+ 807: 'solar dish, solar collector, solar furnace',
809
+ 808: 'sombrero',
810
+ 809: 'soup bowl',
811
+ 810: 'space bar',
812
+ 811: 'space heater',
813
+ 812: 'space shuttle',
814
+ 813: 'spatula',
815
+ 814: 'speedboat',
816
+ 815: "spider web, spider's web",
817
+ 816: 'spindle',
818
+ 817: 'sports car, sport car',
819
+ 818: 'spotlight, spot',
820
+ 819: 'stage',
821
+ 820: 'steam locomotive',
822
+ 821: 'steel arch bridge',
823
+ 822: 'steel drum',
824
+ 823: 'stethoscope',
825
+ 824: 'stole',
826
+ 825: 'stone wall',
827
+ 826: 'stopwatch, stop watch',
828
+ 827: 'stove',
829
+ 828: 'strainer',
830
+ 829: 'streetcar, tram, tramcar, trolley, trolley car',
831
+ 830: 'stretcher',
832
+ 831: 'studio couch, day bed',
833
+ 832: 'stupa, tope',
834
+ 833: 'submarine, pigboat, sub, U-boat',
835
+ 834: 'suit, suit of clothes',
836
+ 835: 'sundial',
837
+ 836: 'sunglass',
838
+ 837: 'sunglasses, dark glasses, shades',
839
+ 838: 'sunscreen, sunblock, sun blocker',
840
+ 839: 'suspension bridge',
841
+ 840: 'swab, swob, mop',
842
+ 841: 'sweatshirt',
843
+ 842: 'swimming trunks, bathing trunks',
844
+ 843: 'swing',
845
+ 844: 'switch, electric switch, electrical switch',
846
+ 845: 'syringe',
847
+ 846: 'table lamp',
848
+ 847: 'tank, army tank, armored combat vehicle, armoured combat vehicle',
849
+ 848: 'tape player',
850
+ 849: 'teapot',
851
+ 850: 'teddy, teddy bear',
852
+ 851: 'television, television system',
853
+ 852: 'tennis ball',
854
+ 853: 'thatch, thatched roof',
855
+ 854: 'theater curtain, theatre curtain',
856
+ 855: 'thimble',
857
+ 856: 'thresher, thrasher, threshing machine',
858
+ 857: 'throne',
859
+ 858: 'tile roof',
860
+ 859: 'toaster',
861
+ 860: 'tobacco shop, tobacconist shop, tobacconist',
862
+ 861: 'toilet seat',
863
+ 862: 'torch',
864
+ 863: 'totem pole',
865
+ 864: 'tow truck, tow car, wrecker',
866
+ 865: 'toyshop',
867
+ 866: 'tractor',
868
+ 867: 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi',
869
+ 868: 'tray',
870
+ 869: 'trench coat',
871
+ 870: 'tricycle, trike, velocipede',
872
+ 871: 'trimaran',
873
+ 872: 'tripod',
874
+ 873: 'triumphal arch',
875
+ 874: 'trolleybus, trolley coach, trackless trolley',
876
+ 875: 'trombone',
877
+ 876: 'tub, vat',
878
+ 877: 'turnstile',
879
+ 878: 'typewriter keyboard',
880
+ 879: 'umbrella',
881
+ 880: 'unicycle, monocycle',
882
+ 881: 'upright, upright piano',
883
+ 882: 'vacuum, vacuum cleaner',
884
+ 883: 'vase',
885
+ 884: 'vault',
886
+ 885: 'velvet',
887
+ 886: 'vending machine',
888
+ 887: 'vestment',
889
+ 888: 'viaduct',
890
+ 889: 'violin, fiddle',
891
+ 890: 'volleyball',
892
+ 891: 'waffle iron',
893
+ 892: 'wall clock',
894
+ 893: 'wallet, billfold, notecase, pocketbook',
895
+ 894: 'wardrobe, closet, press',
896
+ 895: 'warplane, military plane',
897
+ 896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin',
898
+ 897: 'washer, automatic washer, washing machine',
899
+ 898: 'water bottle',
900
+ 899: 'water jug',
901
+ 900: 'water tower',
902
+ 901: 'whiskey jug',
903
+ 902: 'whistle',
904
+ 903: 'wig',
905
+ 904: 'window screen',
906
+ 905: 'window shade',
907
+ 906: 'Windsor tie',
908
+ 907: 'wine bottle',
909
+ 908: 'wing',
910
+ 909: 'wok',
911
+ 910: 'wooden spoon',
912
+ 911: 'wool, woolen, woollen',
913
+ 912: 'worm fence, snake fence, snake-rail fence, Virginia fence',
914
+ 913: 'wreck',
915
+ 914: 'yawl',
916
+ 915: 'yurt',
917
+ 916: 'web site, website, internet site, site',
918
+ 917: 'comic book',
919
+ 918: 'crossword puzzle, crossword',
920
+ 919: 'street sign',
921
+ 920: 'traffic light, traffic signal, stoplight',
922
+ 921: 'book jacket, dust cover, dust jacket, dust wrapper',
923
+ 922: 'menu',
924
+ 923: 'plate',
925
+ 924: 'guacamole',
926
+ 925: 'consomme',
927
+ 926: 'hot pot, hotpot',
928
+ 927: 'trifle',
929
+ 928: 'ice cream, icecream',
930
+ 929: 'ice lolly, lolly, lollipop, popsicle',
931
+ 930: 'French loaf',
932
+ 931: 'bagel, beigel',
933
+ 932: 'pretzel',
934
+ 933: 'cheeseburger',
935
+ 934: 'hotdog, hot dog, red hot',
936
+ 935: 'mashed potato',
937
+ 936: 'head cabbage',
938
+ 937: 'broccoli',
939
+ 938: 'cauliflower',
940
+ 939: 'zucchini, courgette',
941
+ 940: 'spaghetti squash',
942
+ 941: 'acorn squash',
943
+ 942: 'butternut squash',
944
+ 943: 'cucumber, cuke',
945
+ 944: 'artichoke, globe artichoke',
946
+ 945: 'bell pepper',
947
+ 946: 'cardoon',
948
+ 947: 'mushroom',
949
+ 948: 'Granny Smith',
950
+ 949: 'strawberry',
951
+ 950: 'orange',
952
+ 951: 'lemon',
953
+ 952: 'fig',
954
+ 953: 'pineapple, ananas',
955
+ 954: 'banana',
956
+ 955: 'jackfruit, jak, jack',
957
+ 956: 'custard apple',
958
+ 957: 'pomegranate',
959
+ 958: 'hay',
960
+ 959: 'carbonara',
961
+ 960: 'chocolate sauce, chocolate syrup',
962
+ 961: 'dough',
963
+ 962: 'meat loaf, meatloaf',
964
+ 963: 'pizza, pizza pie',
965
+ 964: 'potpie',
966
+ 965: 'burrito',
967
+ 966: 'red wine',
968
+ 967: 'espresso',
969
+ 968: 'cup',
970
+ 969: 'eggnog',
971
+ 970: 'alp',
972
+ 971: 'bubble',
973
+ 972: 'cliff, drop, drop-off',
974
+ 973: 'coral reef',
975
+ 974: 'geyser',
976
+ 975: 'lakeside, lakeshore',
977
+ 976: 'promontory, headland, head, foreland',
978
+ 977: 'sandbar, sand bar',
979
+ 978: 'seashore, coast, seacoast, sea-coast',
980
+ 979: 'valley, vale',
981
+ 980: 'volcano',
982
+ 981: 'ballplayer, baseball player',
983
+ 982: 'groom, bridegroom',
984
+ 983: 'scuba diver',
985
+ 984: 'rapeseed',
986
+ 985: 'daisy',
987
+ 986: "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
988
+ 987: 'corn',
989
+ 988: 'acorn',
990
+ 989: 'hip, rose hip, rosehip',
991
+ 990: 'buckeye, horse chestnut, conker',
992
+ 991: 'coral fungus',
993
+ 992: 'agaric',
994
+ 993: 'gyromitra',
995
+ 994: 'stinkhorn, carrion fungus',
996
+ 995: 'earthstar',
997
+ 996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa',
998
+ 997: 'bolete',
999
+ 998: 'ear, spike, capitulum',
1000
+ 999: 'toilet tissue, toilet paper, bathroom tissue'}
ViT_DeiT/samples/__pycache__/CLS2IDX.cpython-38.pyc ADDED
Binary file (33.4 kB). View file
 
ViT_DeiT/samples/catdog.png ADDED
ViT_DeiT/samples/dogbird.png ADDED
ViT_DeiT/samples/dogcat2.png ADDED
ViT_DeiT/samples/el1.png ADDED
ViT_DeiT/samples/el2.png ADDED
ViT_DeiT/samples/el3.png ADDED
ViT_DeiT/samples/el4.png ADDED
ViT_DeiT/samples/el5.png ADDED
ViT_DeiT/utils/__init__.py ADDED
File without changes
ViT_DeiT/utils/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (163 Bytes). View file
 
ViT_DeiT/utils/confusionmatrix.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from . import metric
4
+
5
+
6
+ class ConfusionMatrix(metric.Metric):
7
+ """Constructs a confusion matrix for a multi-class classification problems.
8
+ Does not support multi-label, multi-class problems.
9
+ Keyword arguments:
10
+ - num_classes (int): number of classes in the classification problem.
11
+ - normalized (boolean, optional): Determines whether or not the confusion
12
+ matrix is normalized or not. Default: False.
13
+ Modified from: https://github.com/pytorch/tnt/blob/master/torchnet/meter/confusionmeter.py
14
+ """
15
+
16
+ def __init__(self, num_classes, normalized=False):
17
+ super().__init__()
18
+
19
+ self.conf = np.ndarray((num_classes, num_classes), dtype=np.int32)
20
+ self.normalized = normalized
21
+ self.num_classes = num_classes
22
+ self.reset()
23
+
24
+ def reset(self):
25
+ self.conf.fill(0)
26
+
27
+ def add(self, predicted, target):
28
+ """Computes the confusion matrix
29
+ The shape of the confusion matrix is K x K, where K is the number
30
+ of classes.
31
+ Keyword arguments:
32
+ - predicted (Tensor or numpy.ndarray): Can be an N x K tensor/array of
33
+ predicted scores obtained from the model for N examples and K classes,
34
+ or an N-tensor/array of integer values between 0 and K-1.
35
+ - target (Tensor or numpy.ndarray): Can be an N x K tensor/array of
36
+ ground-truth classes for N examples and K classes, or an N-tensor/array
37
+ of integer values between 0 and K-1.
38
+ """
39
+ # If target and/or predicted are tensors, convert them to numpy arrays
40
+ if torch.is_tensor(predicted):
41
+ predicted = predicted.cpu().numpy()
42
+ if torch.is_tensor(target):
43
+ target = target.cpu().numpy()
44
+
45
+ assert predicted.shape[0] == target.shape[0], \
46
+ 'number of targets and predicted outputs do not match'
47
+
48
+ if np.ndim(predicted) != 1:
49
+ assert predicted.shape[1] == self.num_classes, \
50
+ 'number of predictions does not match size of confusion matrix'
51
+ predicted = np.argmax(predicted, 1)
52
+ else:
53
+ assert (predicted.max() < self.num_classes) and (predicted.min() >= 0), \
54
+ 'predicted values are not between 0 and k-1'
55
+
56
+ if np.ndim(target) != 1:
57
+ assert target.shape[1] == self.num_classes, \
58
+ 'Onehot target does not match size of confusion matrix'
59
+ assert (target >= 0).all() and (target <= 1).all(), \
60
+ 'in one-hot encoding, target values should be 0 or 1'
61
+ assert (target.sum(1) == 1).all(), \
62
+ 'multi-label setting is not supported'
63
+ target = np.argmax(target, 1)
64
+ else:
65
+ assert (target.max() < self.num_classes) and (target.min() >= 0), \
66
+ 'target values are not between 0 and k-1'
67
+
68
+ # hack for bincounting 2 arrays together
69
+ x = predicted + self.num_classes * target
70
+ bincount_2d = np.bincount(
71
+ x.astype(np.int32), minlength=self.num_classes**2)
72
+ assert bincount_2d.size == self.num_classes**2
73
+ conf = bincount_2d.reshape((self.num_classes, self.num_classes))
74
+
75
+ self.conf += conf
76
+
77
+ def value(self):
78
+ """
79
+ Returns:
80
+ Confustion matrix of K rows and K columns, where rows corresponds
81
+ to ground-truth targets and columns corresponds to predicted
82
+ targets.
83
+ """
84
+ if self.normalized:
85
+ conf = self.conf.astype(np.float32)
86
+ return conf / conf.sum(1).clip(min=1e-12)[:, None]
87
+ else:
88
+ return self.conf
ViT_DeiT/utils/iou.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from . import metric
4
+ from .confusionmatrix import ConfusionMatrix
5
+
6
+
7
+ class IoU(metric.Metric):
8
+ """Computes the intersection over union (IoU) per class and corresponding
9
+ mean (mIoU).
10
+
11
+ Intersection over union (IoU) is a common evaluation metric for semantic
12
+ segmentation. The predictions are first accumulated in a confusion matrix
13
+ and the IoU is computed from it as follows:
14
+
15
+ IoU = true_positive / (true_positive + false_positive + false_negative).
16
+
17
+ Keyword arguments:
18
+ - num_classes (int): number of classes in the classification problem
19
+ - normalized (boolean, optional): Determines whether or not the confusion
20
+ matrix is normalized or not. Default: False.
21
+ - ignore_index (int or iterable, optional): Index of the classes to ignore
22
+ when computing the IoU. Can be an int, or any iterable of ints.
23
+ """
24
+
25
+ def __init__(self, num_classes, normalized=False, ignore_index=None):
26
+ super().__init__()
27
+ self.conf_metric = ConfusionMatrix(num_classes, normalized)
28
+
29
+ if ignore_index is None:
30
+ self.ignore_index = None
31
+ elif isinstance(ignore_index, int):
32
+ self.ignore_index = (ignore_index,)
33
+ else:
34
+ try:
35
+ self.ignore_index = tuple(ignore_index)
36
+ except TypeError:
37
+ raise ValueError("'ignore_index' must be an int or iterable")
38
+
39
+ def reset(self):
40
+ self.conf_metric.reset()
41
+
42
+ def add(self, predicted, target):
43
+ """Adds the predicted and target pair to the IoU metric.
44
+
45
+ Keyword arguments:
46
+ - predicted (Tensor): Can be a (N, K, H, W) tensor of
47
+ predicted scores obtained from the model for N examples and K classes,
48
+ or (N, H, W) tensor of integer values between 0 and K-1.
49
+ - target (Tensor): Can be a (N, K, H, W) tensor of
50
+ target scores for N examples and K classes, or (N, H, W) tensor of
51
+ integer values between 0 and K-1.
52
+
53
+ """
54
+ # Dimensions check
55
+ assert predicted.size(0) == target.size(0), \
56
+ 'number of targets and predicted outputs do not match'
57
+ assert predicted.dim() == 3 or predicted.dim() == 4, \
58
+ "predictions must be of dimension (N, H, W) or (N, K, H, W)"
59
+ assert target.dim() == 3 or target.dim() == 4, \
60
+ "targets must be of dimension (N, H, W) or (N, K, H, W)"
61
+
62
+ # If the tensor is in categorical format convert it to integer format
63
+ if predicted.dim() == 4:
64
+ _, predicted = predicted.max(1)
65
+ if target.dim() == 4:
66
+ _, target = target.max(1)
67
+
68
+ self.conf_metric.add(predicted.view(-1), target.view(-1))
69
+
70
+ def value(self):
71
+ """Computes the IoU and mean IoU.
72
+
73
+ The mean computation ignores NaN elements of the IoU array.
74
+
75
+ Returns:
76
+ Tuple: (IoU, mIoU). The first output is the per class IoU,
77
+ for K classes it's numpy.ndarray with K elements. The second output,
78
+ is the mean IoU.
79
+ """
80
+ conf_matrix = self.conf_metric.value()
81
+ if self.ignore_index is not None:
82
+ for index in self.ignore_index:
83
+ conf_matrix[:, self.ignore_index] = 0
84
+ conf_matrix[self.ignore_index, :] = 0
85
+ true_positive = np.diag(conf_matrix)
86
+ false_positive = np.sum(conf_matrix, 0) - true_positive
87
+ false_negative = np.sum(conf_matrix, 1) - true_positive
88
+
89
+ # Just in case we get a division by 0, ignore/hide the error
90
+ with np.errstate(divide='ignore', invalid='ignore'):
91
+ iou = true_positive / (true_positive + false_positive + false_negative)
92
+
93
+ return iou, np.nanmean(iou)
ViT_DeiT/utils/metric.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Metric(object):
2
+ """Base class for all metrics.
3
+ From: https://github.com/pytorch/tnt/blob/master/torchnet/meter/meter.py
4
+ """
5
+ def reset(self):
6
+ pass
7
+
8
+ def add(self):
9
+ pass
10
+
11
+ def value(self):
12
+ pass
ViT_DeiT/utils/metrices.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from sklearn.metrics import f1_score, average_precision_score
4
+ from sklearn.metrics import precision_recall_curve, roc_curve
5
+
6
+ SMOOTH = 1e-6
7
+ __all__ = ['get_f1_scores', 'get_ap_scores', 'batch_pix_accuracy', 'batch_intersection_union', 'get_iou', 'get_pr',
8
+ 'get_roc', 'get_ap_multiclass']
9
+
10
+
11
+ def get_iou(outputs: torch.Tensor, labels: torch.Tensor):
12
+ # You can comment out this line if you are passing tensors of equal shape
13
+ # But if you are passing output from UNet or something it will most probably
14
+ # be with the BATCH x 1 x H x W shape
15
+ outputs = outputs.squeeze(1) # BATCH x 1 x H x W => BATCH x H x W
16
+ labels = labels.squeeze(1) # BATCH x 1 x H x W => BATCH x H x W
17
+
18
+ intersection = (outputs & labels).float().sum((1, 2)) # Will be zero if Truth=0 or Prediction=0
19
+ union = (outputs | labels).float().sum((1, 2)) # Will be zzero if both are 0
20
+
21
+ iou = (intersection + SMOOTH) / (union + SMOOTH) # We smooth our devision to avoid 0/0
22
+
23
+ return iou.cpu().numpy()
24
+
25
+
26
+ def get_f1_scores(predict, target, ignore_index=-1):
27
+ # Tensor process
28
+ batch_size = predict.shape[0]
29
+ predict = predict.data.cpu().numpy().reshape(-1)
30
+ target = target.data.cpu().numpy().reshape(-1)
31
+ pb = predict[target != ignore_index].reshape(batch_size, -1)
32
+ tb = target[target != ignore_index].reshape(batch_size, -1)
33
+
34
+ total = []
35
+ for p, t in zip(pb, tb):
36
+ total.append(np.nan_to_num(f1_score(t, p)))
37
+
38
+ return total
39
+
40
+
41
+ def get_roc(predict, target, ignore_index=-1):
42
+ target_expand = target.unsqueeze(1).expand_as(predict)
43
+ target_expand_numpy = target_expand.data.cpu().numpy().reshape(-1)
44
+ # Tensor process
45
+ x = torch.zeros_like(target_expand)
46
+ t = target.unsqueeze(1).clamp(min=0)
47
+ target_1hot = x.scatter_(1, t, 1)
48
+ batch_size = predict.shape[0]
49
+ predict = predict.data.cpu().numpy().reshape(-1)
50
+ target = target_1hot.data.cpu().numpy().reshape(-1)
51
+ pb = predict[target_expand_numpy != ignore_index].reshape(batch_size, -1)
52
+ tb = target[target_expand_numpy != ignore_index].reshape(batch_size, -1)
53
+
54
+ total = []
55
+ for p, t in zip(pb, tb):
56
+ total.append(roc_curve(t, p))
57
+
58
+ return total
59
+
60
+
61
+ def get_pr(predict, target, ignore_index=-1):
62
+ target_expand = target.unsqueeze(1).expand_as(predict)
63
+ target_expand_numpy = target_expand.data.cpu().numpy().reshape(-1)
64
+ # Tensor process
65
+ x = torch.zeros_like(target_expand)
66
+ t = target.unsqueeze(1).clamp(min=0)
67
+ target_1hot = x.scatter_(1, t, 1)
68
+ batch_size = predict.shape[0]
69
+ predict = predict.data.cpu().numpy().reshape(-1)
70
+ target = target_1hot.data.cpu().numpy().reshape(-1)
71
+ pb = predict[target_expand_numpy != ignore_index].reshape(batch_size, -1)
72
+ tb = target[target_expand_numpy != ignore_index].reshape(batch_size, -1)
73
+
74
+ total = []
75
+ for p, t in zip(pb, tb):
76
+ total.append(precision_recall_curve(t, p))
77
+
78
+ return total
79
+
80
+
81
+ def get_ap_scores(predict, target, ignore_index=-1):
82
+ total = []
83
+ for pred, tgt in zip(predict, target):
84
+ target_expand = tgt.unsqueeze(0).expand_as(pred)
85
+ target_expand_numpy = target_expand.data.cpu().numpy().reshape(-1)
86
+
87
+ # Tensor process
88
+ x = torch.zeros_like(target_expand)
89
+ t = tgt.unsqueeze(0).clamp(min=0).long()
90
+ target_1hot = x.scatter_(0, t, 1)
91
+ predict_flat = pred.data.cpu().numpy().reshape(-1)
92
+ target_flat = target_1hot.data.cpu().numpy().reshape(-1)
93
+
94
+ p = predict_flat[target_expand_numpy != ignore_index]
95
+ t = target_flat[target_expand_numpy != ignore_index]
96
+
97
+ total.append(np.nan_to_num(average_precision_score(t, p)))
98
+
99
+ return total
100
+
101
+
102
+ def get_ap_multiclass(predict, target):
103
+ total = []
104
+ for pred, tgt in zip(predict, target):
105
+ predict_flat = pred.data.cpu().numpy().reshape(-1)
106
+ target_flat = tgt.data.cpu().numpy().reshape(-1)
107
+
108
+ total.append(np.nan_to_num(average_precision_score(target_flat, predict_flat)))
109
+
110
+ return total
111
+
112
+
113
+ def batch_precision_recall(predict, target, thr=0.5):
114
+ """Batch Precision Recall
115
+ Args:
116
+ predict: input 4D tensor
117
+ target: label 4D tensor
118
+ """
119
+ # _, predict = torch.max(predict, 1)
120
+
121
+ predict = predict > thr
122
+ predict = predict.data.cpu().numpy() + 1
123
+ target = target.data.cpu().numpy() + 1
124
+
125
+ tp = np.sum(((predict == 2) * (target == 2)) * (target > 0))
126
+ fp = np.sum(((predict == 2) * (target == 1)) * (target > 0))
127
+ fn = np.sum(((predict == 1) * (target == 2)) * (target > 0))
128
+
129
+ precision = float(np.nan_to_num(tp / (tp + fp)))
130
+ recall = float(np.nan_to_num(tp / (tp + fn)))
131
+
132
+ return precision, recall
133
+
134
+
135
+ def batch_pix_accuracy(predict, target):
136
+ """Batch Pixel Accuracy
137
+ Args:
138
+ predict: input 3D tensor
139
+ target: label 3D tensor
140
+ """
141
+
142
+ # for thr in np.linspace(0, 1, slices):
143
+
144
+ _, predict = torch.max(predict, 0)
145
+ predict = predict.cpu().numpy() + 1
146
+ target = target.cpu().numpy() + 1
147
+ pixel_labeled = np.sum(target > 0)
148
+ pixel_correct = np.sum((predict == target) * (target > 0))
149
+ assert pixel_correct <= pixel_labeled, \
150
+ "Correct area should be smaller than Labeled"
151
+ return pixel_correct, pixel_labeled
152
+
153
+
154
+ def batch_intersection_union(predict, target, nclass):
155
+ """Batch Intersection of Union
156
+ Args:
157
+ predict: input 3D tensor
158
+ target: label 3D tensor
159
+ nclass: number of categories (int)
160
+ """
161
+ _, predict = torch.max(predict, 0)
162
+ mini = 1
163
+ maxi = nclass
164
+ nbins = nclass
165
+ predict = predict.cpu().numpy() + 1
166
+ target = target.cpu().numpy() + 1
167
+
168
+ predict = predict * (target > 0).astype(predict.dtype)
169
+ intersection = predict * (predict == target)
170
+ # areas of intersection and union
171
+ area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi))
172
+ area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi))
173
+ area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi))
174
+ area_union = area_pred + area_lab - area_inter
175
+ assert (area_inter <= area_union).all(), \
176
+ "Intersection area should be smaller than Union area"
177
+ return area_inter, area_union
178
+
179
+
180
+ # ref https://github.com/CSAILVision/sceneparsing/blob/master/evaluationCode/utils_eval.py
181
+ def pixel_accuracy(im_pred, im_lab):
182
+ im_pred = np.asarray(im_pred)
183
+ im_lab = np.asarray(im_lab)
184
+
185
+ # Remove classes from unlabeled pixels in gt image.
186
+ # We should not penalize detections in unlabeled portions of the image.
187
+ pixel_labeled = np.sum(im_lab > 0)
188
+ pixel_correct = np.sum((im_pred == im_lab) * (im_lab > 0))
189
+ # pixel_accuracy = 1.0 * pixel_correct / pixel_labeled
190
+ return pixel_correct, pixel_labeled
191
+
192
+
193
+ def intersection_and_union(im_pred, im_lab, num_class):
194
+ im_pred = np.asarray(im_pred)
195
+ im_lab = np.asarray(im_lab)
196
+ # Remove classes from unlabeled pixels in gt image.
197
+ im_pred = im_pred * (im_lab > 0)
198
+ # Compute area intersection:
199
+ intersection = im_pred * (im_pred == im_lab)
200
+ area_inter, _ = np.histogram(intersection, bins=num_class - 1,
201
+ range=(1, num_class - 1))
202
+ # Compute area union:
203
+ area_pred, _ = np.histogram(im_pred, bins=num_class - 1,
204
+ range=(1, num_class - 1))
205
+ area_lab, _ = np.histogram(im_lab, bins=num_class - 1,
206
+ range=(1, num_class - 1))
207
+ area_union = area_pred + area_lab - area_inter
208
+ return area_inter, area_union